23 #include "llvm/ADT/TypeSwitch.h"
24 #include "llvm/Support/FormatVariadic.h"
26 #define DEBUG_TYPE "spirv-to-llvm-pattern"
38 if (
auto vecType = dyn_cast<VectorType>(type))
39 return vecType.getElementType().isSignedInteger();
47 if (
auto vecType = dyn_cast<VectorType>(type))
48 return vecType.getElementType().isUnsignedInteger();
55 if (
auto intType = dyn_cast<IntegerType>(type))
56 return intType.getWidth();
57 if (
auto vecType = dyn_cast<VectorType>(type))
58 if (
auto intType = dyn_cast<IntegerType>(vecType.getElementType()))
59 return intType.getWidth();
66 "bitwidth is not supported for this type");
69 auto vecType = dyn_cast<VectorType>(type);
70 auto elementType = vecType.getElementType();
71 assert(elementType.isIntOrFloat() &&
72 "only integers and floats have a bitwidth");
73 return elementType.getIntOrFloatBitWidth();
78 if (
auto vecTy = dyn_cast<VectorType>(type))
79 type = vecTy.getElementType();
80 return cast<IntegerType>(type).getWidth();
85 if (
auto vecType = dyn_cast<VectorType>(type)) {
86 auto integerType = cast<IntegerType>(vecType.getElementType());
89 auto integerType = cast<IntegerType>(type);
96 if (isa<VectorType>(srcType)) {
97 return LLVM::ConstantOp::create(
98 rewriter, loc, dstType,
102 return LLVM::ConstantOp::create(rewriter, loc, dstType,
109 if (
auto vecType = dyn_cast<VectorType>(srcType)) {
110 auto floatType = cast<FloatType>(vecType.getElementType());
111 return LLVM::ConstantOp::create(
112 rewriter, loc, dstType,
116 auto floatType = cast<FloatType>(srcType);
117 return LLVM::ConstantOp::create(rewriter, loc, dstType,
130 auto srcType = value.
getType();
136 if (valueBitWidth < targetBitWidth)
137 return LLVM::ZExtOp::create(rewriter, loc, llvmType, value);
142 if (valueBitWidth > targetBitWidth)
143 return LLVM::TruncOp::create(rewriter, loc, llvmType, value);
152 auto llvmVectorType = typeConverter.
convertType(vectorType);
154 Value broadcasted = LLVM::PoisonOp::create(rewriter, loc, llvmVectorType);
155 for (
unsigned i = 0; i < numElements; ++i) {
156 auto index = LLVM::ConstantOp::create(rewriter, loc, llvmI32Type,
158 broadcasted = LLVM::InsertElementOp::create(
159 rewriter, loc, llvmVectorType, broadcasted, toBroadcast, index);
168 if (
auto vectorType = dyn_cast<VectorType>(srcType)) {
169 unsigned numElements = vectorType.getNumElements();
170 return broadcast(loc, value, numElements, typeConverter, rewriter);
203 return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
213 return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
220 return LLVM::ConstantOp::create(
229 unsigned alignment,
bool isVolatile,
230 bool isNonTemporal) {
231 if (
auto loadOp = dyn_cast<spirv::LoadOp>(op)) {
232 auto dstType = typeConverter.
convertType(loadOp.getType());
236 loadOp, dstType, spirv::LoadOpAdaptor(operands).getPtr(), alignment,
237 isVolatile, isNonTemporal);
240 auto storeOp = cast<spirv::StoreOp>(op);
241 spirv::StoreOpAdaptor adaptor(operands);
243 adaptor.getPtr(), alignment,
244 isVolatile, isNonTemporal);
259 auto sizeInBytes = cast<spirv::SPIRVType>(elementType).getSizeInBytes();
260 if (stride != 0 && (!sizeInBytes || *sizeInBytes != stride))
263 auto llvmElementType = converter.
convertType(elementType);
272 spirv::ClientAPI clientAPI) {
273 unsigned addressSpace =
295 if (!memberDecorations.empty())
313 matchAndRewrite(spirv::AccessChainOp op, OpAdaptor adaptor,
316 getTypeConverter()->convertType(op.getComponentPtr().getType());
320 auto indices = llvm::to_vector<4>(adaptor.getIndices());
321 Type indexType = op.getIndices().front().getType();
322 auto llvmIndexType = getTypeConverter()->convertType(indexType);
326 LLVM::ConstantOp::create(rewriter, op.getLoc(), llvmIndexType,
328 indices.insert(indices.begin(), zero);
330 auto elementType = getTypeConverter()->convertType(
331 cast<spirv::PointerType>(op.getBasePtr().getType()).getPointeeType());
335 adaptor.getBasePtr(), indices);
345 matchAndRewrite(spirv::AddressOfOp op, OpAdaptor adaptor,
347 auto dstType = getTypeConverter()->convertType(op.getPointer().getType());
356 class BitFieldInsertPattern
362 matchAndRewrite(spirv::BitFieldInsertOp op, OpAdaptor adaptor,
364 auto srcType = op.getType();
365 auto dstType = getTypeConverter()->convertType(srcType);
372 *getTypeConverter(), rewriter);
374 *getTypeConverter(), rewriter);
378 Value maskShiftedByCount =
379 LLVM::ShlOp::create(rewriter, loc, dstType, minusOne, count);
380 Value negated = LLVM::XOrOp::create(rewriter, loc, dstType,
381 maskShiftedByCount, minusOne);
382 Value maskShiftedByCountAndOffset =
383 LLVM::ShlOp::create(rewriter, loc, dstType, negated, offset);
384 Value mask = LLVM::XOrOp::create(rewriter, loc, dstType,
385 maskShiftedByCountAndOffset, minusOne);
390 LLVM::AndOp::create(rewriter, loc, dstType, op.getBase(), mask);
391 Value insertShiftedByOffset =
392 LLVM::ShlOp::create(rewriter, loc, dstType, op.getInsert(), offset);
394 insertShiftedByOffset);
400 class ConstantScalarAndVectorPattern
406 matchAndRewrite(spirv::ConstantOp constOp, OpAdaptor adaptor,
408 auto srcType = constOp.getType();
409 if (!isa<VectorType>(srcType) && !srcType.isIntOrFloat())
412 auto dstType = getTypeConverter()->convertType(srcType);
425 if (isa<VectorType>(srcType)) {
426 auto dstElementsAttr = cast<DenseIntElementsAttr>(constOp.getValue());
429 dstElementsAttr.mapValues(
430 signlessType, [&](
const APInt &value) {
return value; }));
433 auto srcAttr = cast<IntegerAttr>(constOp.getValue());
434 auto dstAttr = rewriter.
getIntegerAttr(signlessType, srcAttr.getValue());
439 constOp, dstType, adaptor.getOperands(), constOp->getAttrs());
444 class BitFieldSExtractPattern
450 matchAndRewrite(spirv::BitFieldSExtractOp op, OpAdaptor adaptor,
452 auto srcType = op.getType();
453 auto dstType = getTypeConverter()->convertType(srcType);
460 *getTypeConverter(), rewriter);
462 *getTypeConverter(), rewriter);
465 IntegerType integerType;
466 if (
auto vecType = dyn_cast<VectorType>(srcType))
467 integerType = cast<IntegerType>(vecType.getElementType());
469 integerType = cast<IntegerType>(srcType);
473 isa<VectorType>(srcType)
474 ? LLVM::ConstantOp::create(
475 rewriter, loc, dstType,
477 : LLVM::ConstantOp::create(rewriter, loc, dstType, baseSize);
481 Value countPlusOffset =
482 LLVM::AddOp::create(rewriter, loc, dstType, count, offset);
483 Value amountToShiftLeft =
484 LLVM::SubOp::create(rewriter, loc, dstType, size, countPlusOffset);
485 Value baseShiftedLeft = LLVM::ShlOp::create(
486 rewriter, loc, dstType, op.getBase(), amountToShiftLeft);
489 Value amountToShiftRight =
490 LLVM::AddOp::create(rewriter, loc, dstType, offset, amountToShiftLeft);
497 class BitFieldUExtractPattern
503 matchAndRewrite(spirv::BitFieldUExtractOp op, OpAdaptor adaptor,
505 auto srcType = op.getType();
506 auto dstType = getTypeConverter()->convertType(srcType);
513 *getTypeConverter(), rewriter);
515 *getTypeConverter(), rewriter);
519 Value maskShiftedByCount =
520 LLVM::ShlOp::create(rewriter, loc, dstType, minusOne, count);
521 Value mask = LLVM::XOrOp::create(rewriter, loc, dstType, maskShiftedByCount,
526 LLVM::LShrOp::create(rewriter, loc, dstType, op.getBase(), offset);
537 matchAndRewrite(spirv::BranchOp branchOp, OpAdaptor adaptor,
540 branchOp.getTarget());
545 class BranchConditionalConversionPattern
552 matchAndRewrite(spirv::BranchConditionalOp op, OpAdaptor adaptor,
556 if (
auto weights = op.getBranchWeights()) {
558 for (
auto weight : weights->getAsRange<IntegerAttr>())
559 weightValues.push_back(weight.getInt());
564 op, op.getCondition(), op.getTrueBlockArguments(),
565 op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(),
574 class CompositeExtractPattern
580 matchAndRewrite(spirv::CompositeExtractOp op, OpAdaptor adaptor,
582 auto dstType = this->getTypeConverter()->convertType(op.getType());
586 Type containerType = op.getComposite().getType();
587 if (isa<VectorType>(containerType)) {
589 IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]);
592 op, dstType, adaptor.getComposite(), index);
597 op, adaptor.getComposite(),
606 class CompositeInsertPattern
612 matchAndRewrite(spirv::CompositeInsertOp op, OpAdaptor adaptor,
614 auto dstType = this->getTypeConverter()->convertType(op.getType());
618 Type containerType = op.getComposite().getType();
619 if (isa<VectorType>(containerType)) {
621 IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]);
624 op, dstType, adaptor.getComposite(), adaptor.getObject(), index);
629 op, adaptor.getComposite(), adaptor.getObject(),
637 template <
typename SPIRVOp,
typename LLVMOp>
643 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
645 auto dstType = this->getTypeConverter()->convertType(op.getType());
648 rewriter.template replaceOpWithNewOp<LLVMOp>(
649 op, dstType, adaptor.getOperands(), op->getAttrs());
656 class ExecutionModePattern
662 matchAndRewrite(spirv::ExecutionModeOp op, OpAdaptor adaptor,
667 ModuleOp module = op->getParentOfType<ModuleOp>();
668 spirv::ExecutionModeAttr executionModeAttr = op.getExecutionModeAttr();
669 std::string moduleName;
670 if (module.getName().has_value())
671 moduleName =
"_" + module.getName()->str();
674 std::string executionModeInfoName = llvm::formatv(
675 "__spv_{0}_{1}_execution_mode_info_{2}", moduleName, op.getFn().str(),
676 static_cast<uint32_t
>(executionModeAttr.getValue()));
689 fields.push_back(llvmI32Type);
690 ArrayAttr values = op.getValues();
691 if (!values.empty()) {
693 fields.push_back(arrayType);
695 auto structType = LLVM::LLVMStructType::getLiteral(context, fields);
698 auto global = LLVM::GlobalOp::create(
700 LLVM::Linkage::External, executionModeInfoName,
Attribute(),
703 Region ®ion = global.getInitializerRegion();
708 Value structValue = LLVM::PoisonOp::create(rewriter, loc, structType);
709 Value executionMode = LLVM::ConstantOp::create(
710 rewriter, loc, llvmI32Type,
712 static_cast<uint32_t
>(executionModeAttr.getValue())));
714 structValue = LLVM::InsertValueOp::create(rewriter, loc, structValue,
715 executionMode, position);
718 for (
unsigned i = 0, e = values.size(); i < e; ++i) {
719 auto attr = values.getValue()[i];
720 Value entry = LLVM::ConstantOp::create(rewriter, loc, llvmI32Type, attr);
721 structValue = LLVM::InsertValueOp::create(
734 class GlobalVariablePattern
737 template <
typename... Args>
738 GlobalVariablePattern(spirv::ClientAPI clientAPI, Args &&...args)
740 std::forward<Args>(args)...),
741 clientAPI(clientAPI) {}
744 matchAndRewrite(spirv::GlobalVariableOp op, OpAdaptor adaptor,
748 if (op.getInitializer())
751 auto srcType = cast<spirv::PointerType>(op.getType());
752 auto dstType = getTypeConverter()->convertType(srcType.getPointeeType());
759 auto storageClass = srcType.getStorageClass();
760 switch (storageClass) {
761 case spirv::StorageClass::Input:
762 case spirv::StorageClass::Private:
763 case spirv::StorageClass::Output:
764 case spirv::StorageClass::StorageBuffer:
765 case spirv::StorageClass::UniformConstant:
774 bool isConstant = (storageClass == spirv::StorageClass::Input) ||
775 (storageClass == spirv::StorageClass::UniformConstant);
781 auto linkage = storageClass == spirv::StorageClass::Private
782 ? LLVM::Linkage::Private
783 : LLVM::Linkage::External;
784 StringAttr locationAttrName = op.getLocationAttrName();
785 IntegerAttr locationAttr = op.getLocationAttr();
787 op, dstType, isConstant, linkage, op.getSymName(),
Attribute(),
792 newGlobalOp->setAttr(locationAttrName, locationAttr);
798 spirv::ClientAPI clientAPI;
803 template <
typename SPIRVOp,
typename LLVMExtOp,
typename LLVMTruncOp>
809 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
812 Type fromType = op.getOperand().getType();
813 Type toType = op.getType();
815 auto dstType = this->getTypeConverter()->convertType(toType);
820 rewriter.template replaceOpWithNewOp<LLVMExtOp>(op, dstType,
821 adaptor.getOperands());
825 rewriter.template replaceOpWithNewOp<LLVMTruncOp>(op, dstType,
826 adaptor.getOperands());
833 class FunctionCallPattern
839 matchAndRewrite(spirv::FunctionCallOp callOp, OpAdaptor adaptor,
841 if (callOp.getNumResults() == 0) {
843 callOp,
TypeRange(), adaptor.getOperands(), callOp->getAttrs());
844 newOp.getProperties().operandSegmentSizes = {
845 static_cast<int32_t
>(adaptor.getOperands().size()), 0};
851 auto dstType = getTypeConverter()->convertType(callOp.getType(0));
855 callOp, dstType, adaptor.getOperands(), callOp->getAttrs());
856 newOp.getProperties().operandSegmentSizes = {
857 static_cast<int32_t
>(adaptor.getOperands().size()), 0};
864 template <
typename SPIRVOp, LLVM::FCmpPredicate predicate>
870 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
873 auto dstType = this->getTypeConverter()->convertType(op.getType());
877 rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>(
878 op, dstType, predicate, op.getOperand1(), op.getOperand2());
884 template <
typename SPIRVOp, LLVM::ICmpPredicate predicate>
890 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
893 auto dstType = this->getTypeConverter()->convertType(op.getType());
897 rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>(
898 op, dstType, predicate, op.getOperand1(), op.getOperand2());
903 class InverseSqrtPattern
909 matchAndRewrite(spirv::GLInverseSqrtOp op, OpAdaptor adaptor,
911 auto srcType = op.getType();
912 auto dstType = getTypeConverter()->convertType(srcType);
918 Value sqrt = LLVM::SqrtOp::create(rewriter, loc, dstType, op.getOperand());
925 template <
typename SPIRVOp>
931 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
933 if (!op.getMemoryAccess()) {
935 *this->getTypeConverter(), 0,
939 auto memoryAccess = *op.getMemoryAccess();
940 switch (memoryAccess) {
941 case spirv::MemoryAccess::Aligned:
943 case spirv::MemoryAccess::Nontemporal:
944 case spirv::MemoryAccess::Volatile: {
946 memoryAccess == spirv::MemoryAccess::Aligned ? *op.getAlignment() : 0;
947 bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal;
948 bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile;
950 *this->getTypeConverter(), alignment,
951 isVolatile, isNonTemporal);
961 template <
typename SPIRVOp>
967 matchAndRewrite(SPIRVOp notOp,
typename SPIRVOp::Adaptor adaptor,
969 auto srcType = notOp.getType();
970 auto dstType = this->getTypeConverter()->convertType(srcType);
977 isa<VectorType>(srcType)
978 ? LLVM::ConstantOp::create(
979 rewriter, loc, dstType,
981 : LLVM::ConstantOp::create(rewriter, loc, dstType, minusOne);
982 rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType,
983 notOp.getOperand(), mask);
989 template <
typename SPIRVOp>
995 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
1007 matchAndRewrite(spirv::ReturnOp returnOp, OpAdaptor adaptor,
1020 matchAndRewrite(spirv::ReturnValueOp returnValueOp, OpAdaptor adaptor,
1023 adaptor.getOperands());
1032 bool convergent =
true) {
1033 auto func = dyn_cast_or_null<LLVM::LLVMFuncOp>(
1039 func = LLVM::LLVMFuncOp::create(
1040 b, symbolTable->
getLoc(), name,
1042 func.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
1043 func.setConvergent(convergent);
1044 func.setNoUnwind(
true);
1045 func.setWillReturn(
true);
1050 LLVM::LLVMFuncOp func,
1052 auto call = LLVM::CallOp::create(builder, loc, func, args);
1053 call.setCConv(func.getCConv());
1054 call.setConvergentAttr(func.getConvergentAttr());
1055 call.setNoUnwindAttr(func.getNoUnwindAttr());
1056 call.setWillReturnAttr(func.getWillReturnAttr());
1060 template <
typename BarrierOpTy>
1067 static constexpr StringRef getFuncName();
1070 matchAndRewrite(BarrierOpTy controlBarrierOp, OpAdaptor adaptor,
1072 constexpr StringRef funcName = getFuncName();
1074 controlBarrierOp->template getParentWithTrait<OpTrait::SymbolTable>();
1078 Type voidTy = rewriter.
getType<LLVM::LLVMVoidType>();
1079 LLVM::LLVMFuncOp func =
1082 Location loc = controlBarrierOp->getLoc();
1083 Value execution = LLVM::ConstantOp::create(
1084 rewriter, loc, i32,
static_cast<int32_t
>(adaptor.getExecutionScope()));
1085 Value memory = LLVM::ConstantOp::create(
1086 rewriter, loc, i32,
static_cast<int32_t
>(adaptor.getMemoryScope()));
1087 Value semantics = LLVM::ConstantOp::create(
1088 rewriter, loc, i32,
static_cast<int32_t
>(adaptor.getMemorySemantics()));
1091 {execution, memory, semantics});
1093 rewriter.
replaceOp(controlBarrierOp, call);
1100 StringRef getTypeMangling(
Type type,
bool isSigned) {
1102 .Case<Float16Type>([](
auto) {
return "Dh"; })
1103 .Case<Float32Type>([](
auto) {
return "f"; })
1104 .Case<Float64Type>([](
auto) {
return "d"; })
1105 .Case<IntegerType>([isSigned](IntegerType intTy) {
1106 switch (intTy.getWidth()) {
1110 return (isSigned) ?
"a" :
"c";
1112 return (isSigned) ?
"s" :
"t";
1114 return (isSigned) ?
"i" :
"j";
1116 return (isSigned) ?
"l" :
"m";
1118 llvm_unreachable(
"Unsupported integer width");
1122 llvm_unreachable(
"No mangling defined");
1127 template <
typename ReduceOp>
1128 constexpr StringLiteral getGroupFuncName();
1131 constexpr StringLiteral getGroupFuncName<spirv::GroupIAddOp>() {
1132 return "_Z17__spirv_GroupIAddii";
1135 constexpr StringLiteral getGroupFuncName<spirv::GroupFAddOp>() {
1136 return "_Z17__spirv_GroupFAddii";
1139 constexpr StringLiteral getGroupFuncName<spirv::GroupSMinOp>() {
1140 return "_Z17__spirv_GroupSMinii";
1143 constexpr StringLiteral getGroupFuncName<spirv::GroupUMinOp>() {
1144 return "_Z17__spirv_GroupUMinii";
1147 constexpr StringLiteral getGroupFuncName<spirv::GroupFMinOp>() {
1148 return "_Z17__spirv_GroupFMinii";
1151 constexpr StringLiteral getGroupFuncName<spirv::GroupSMaxOp>() {
1152 return "_Z17__spirv_GroupSMaxii";
1155 constexpr StringLiteral getGroupFuncName<spirv::GroupUMaxOp>() {
1156 return "_Z17__spirv_GroupUMaxii";
1159 constexpr StringLiteral getGroupFuncName<spirv::GroupFMaxOp>() {
1160 return "_Z17__spirv_GroupFMaxii";
1163 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIAddOp>() {
1164 return "_Z27__spirv_GroupNonUniformIAddii";
1167 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFAddOp>() {
1168 return "_Z27__spirv_GroupNonUniformFAddii";
1171 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIMulOp>() {
1172 return "_Z27__spirv_GroupNonUniformIMulii";
1175 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMulOp>() {
1176 return "_Z27__spirv_GroupNonUniformFMulii";
1179 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMinOp>() {
1180 return "_Z27__spirv_GroupNonUniformSMinii";
1183 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMinOp>() {
1184 return "_Z27__spirv_GroupNonUniformUMinii";
1187 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMinOp>() {
1188 return "_Z27__spirv_GroupNonUniformFMinii";
1191 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMaxOp>() {
1192 return "_Z27__spirv_GroupNonUniformSMaxii";
1195 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMaxOp>() {
1196 return "_Z27__spirv_GroupNonUniformUMaxii";
1199 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMaxOp>() {
1200 return "_Z27__spirv_GroupNonUniformFMaxii";
1203 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseAndOp>() {
1204 return "_Z33__spirv_GroupNonUniformBitwiseAndii";
1207 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseOrOp>() {
1208 return "_Z32__spirv_GroupNonUniformBitwiseOrii";
1211 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseXorOp>() {
1212 return "_Z33__spirv_GroupNonUniformBitwiseXorii";
1215 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalAndOp>() {
1216 return "_Z33__spirv_GroupNonUniformLogicalAndii";
1219 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalOrOp>() {
1220 return "_Z32__spirv_GroupNonUniformLogicalOrii";
1223 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalXorOp>() {
1224 return "_Z33__spirv_GroupNonUniformLogicalXorii";
1228 template <
typename ReduceOp,
bool Signed = false,
bool NonUniform = false>
1234 matchAndRewrite(ReduceOp op,
typename ReduceOp::Adaptor adaptor,
1237 Type retTy = op.getResult().getType();
1242 funcName += getTypeMangling(retTy,
false);
1246 if constexpr (NonUniform) {
1247 if (adaptor.getClusterSize()) {
1249 paramTypes.push_back(i32Ty);
1254 op->template getParentWithTrait<OpTrait::SymbolTable>();
1256 LLVM::LLVMFuncOp func =
1260 Value scope = LLVM::ConstantOp::create(
1261 rewriter, loc, i32Ty,
1262 static_cast<int32_t
>(adaptor.getExecutionScope()));
1263 Value groupOp = LLVM::ConstantOp::create(
1264 rewriter, loc, i32Ty,
1265 static_cast<int32_t
>(adaptor.getGroupOperation()));
1267 operands.append(adaptor.getOperands().begin(), adaptor.getOperands().end());
1277 ControlBarrierPattern<spirv::ControlBarrierOp>::getFuncName() {
1278 return "_Z22__spirv_ControlBarrieriii";
1283 ControlBarrierPattern<spirv::INTELControlBarrierArriveOp>::getFuncName() {
1284 return "_Z33__spirv_ControlBarrierArriveINTELiii";
1289 ControlBarrierPattern<spirv::INTELControlBarrierWaitOp>::getFuncName() {
1290 return "_Z31__spirv_ControlBarrierWaitINTELiii";
1346 matchAndRewrite(spirv::LoopOp loopOp, OpAdaptor adaptor,
1353 if (loopOp.getBody().empty()) {
1368 Block *entryBlock = loopOp.getEntryBlock();
1370 auto brOp = dyn_cast<spirv::BranchOp>(entryBlock->
getOperations().front());
1373 Block *headerBlock = loopOp.getHeaderBlock();
1375 LLVM::BrOp::create(rewriter, loc, brOp.getBlockArguments(), headerBlock);
1379 Block *mergeBlock = loopOp.getMergeBlock();
1383 LLVM::BrOp::create(rewriter, loc, terminatorOperands, endBlock);
1399 matchAndRewrite(spirv::SelectionOp op, OpAdaptor adaptor,
1411 if (op.getBody().getBlocks().size() <= 2) {
1423 auto *continueBlock = rewriter.
splitBlock(currentBlock, position);
1429 auto *headerBlock = op.getHeaderBlock();
1431 auto condBrOp = dyn_cast<spirv::BranchConditionalOp>(
1437 auto *mergeBlock = op.getMergeBlock();
1441 LLVM::BrOp::create(rewriter, loc, terminatorOperands, continueBlock);
1444 Block *trueBlock = condBrOp.getTrueBlock();
1445 Block *falseBlock = condBrOp.getFalseBlock();
1447 LLVM::CondBrOp::create(rewriter, loc, condBrOp.getCondition(), trueBlock,
1448 condBrOp.getTrueTargetOperands(), falseBlock,
1449 condBrOp.getFalseTargetOperands());
1453 rewriter.
replaceOp(op, continueBlock->getArguments());
1462 template <
typename SPIRVOp,
typename LLVMOp>
1468 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
1471 auto dstType = this->getTypeConverter()->convertType(op.getType());
1475 Type op1Type = op.getOperand1().getType();
1476 Type op2Type = op.getOperand2().getType();
1478 if (op1Type == op2Type) {
1479 rewriter.template replaceOpWithNewOp<LLVMOp>(op, dstType,
1480 adaptor.getOperands());
1484 std::optional<uint64_t> dstTypeWidth =
1486 std::optional<uint64_t> op2TypeWidth =
1489 if (!dstTypeWidth || !op2TypeWidth)
1494 if (op2TypeWidth < dstTypeWidth) {
1497 LLVM::ZExtOp::create(rewriter, loc, dstType, adaptor.getOperand2());
1500 LLVM::SExtOp::create(rewriter, loc, dstType, adaptor.getOperand2());
1502 }
else if (op2TypeWidth == dstTypeWidth) {
1503 extended = adaptor.getOperand2();
1509 LLVMOp::create(rewriter, loc, dstType, adaptor.getOperand1(), extended);
1520 matchAndRewrite(spirv::GLTanOp tanOp, OpAdaptor adaptor,
1522 auto dstType = getTypeConverter()->convertType(tanOp.getType());
1527 Value sin = LLVM::SinOp::create(rewriter, loc, dstType, tanOp.getOperand());
1528 Value cos = LLVM::CosOp::create(rewriter, loc, dstType, tanOp.getOperand());
1545 matchAndRewrite(spirv::GLTanhOp tanhOp, OpAdaptor adaptor,
1547 auto srcType = tanhOp.getType();
1548 auto dstType = getTypeConverter()->convertType(srcType);
1555 LLVM::FMulOp::create(rewriter, loc, dstType, two, tanhOp.getOperand());
1556 Value exponential = LLVM::ExpOp::create(rewriter, loc, dstType, multiplied);
1559 LLVM::FSubOp::create(rewriter, loc, dstType, exponential, one);
1561 LLVM::FAddOp::create(rewriter, loc, dstType, exponential, one);
1573 matchAndRewrite(spirv::VariableOp varOp, OpAdaptor adaptor,
1575 auto srcType = varOp.getType();
1577 auto pointerTo = cast<spirv::PointerType>(srcType).getPointeeType();
1578 auto init = varOp.getInitializer();
1579 if (init && !pointerTo.isIntOrFloat() && !isa<VectorType>(pointerTo))
1582 auto dstType = getTypeConverter()->convertType(srcType);
1589 auto elementType = getTypeConverter()->convertType(pointerTo);
1596 auto elementType = getTypeConverter()->convertType(pointerTo);
1600 LLVM::AllocaOp::create(rewriter, loc, dstType, elementType, size);
1601 LLVM::StoreOp::create(rewriter, loc, adaptor.getInitializer(), allocated);
1611 class BitcastConversionPattern
1617 matchAndRewrite(spirv::BitcastOp bitcastOp, OpAdaptor adaptor,
1619 auto dstType = getTypeConverter()->convertType(bitcastOp.getType());
1624 if (isa<LLVM::LLVMPointerType>(dstType)) {
1625 rewriter.
replaceOp(bitcastOp, adaptor.getOperand());
1630 bitcastOp, dstType, adaptor.getOperands(), bitcastOp->getAttrs());
1644 matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor,
1649 auto funcType = funcOp.getFunctionType();
1651 funcType.getNumInputs());
1653 ->convertFunctionSignature(
1655 false, signatureConverter);
1661 StringRef name = funcOp.getName();
1662 auto newFuncOp = LLVM::LLVMFuncOp::create(rewriter, loc, name, llvmType);
1666 switch (funcOp.getFunctionControl()) {
1667 case spirv::FunctionControl::Inline:
1668 newFuncOp.setAlwaysInline(
true);
1670 case spirv::FunctionControl::DontInline:
1671 newFuncOp.setNoInline(
true);
1674 #define DISPATCH(functionControl, llvmAttr) \
1675 case functionControl: \
1676 newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr})); \
1679 DISPATCH(spirv::FunctionControl::Pure,
1681 DISPATCH(spirv::FunctionControl::Const,
1695 &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter))) {
1712 matchAndRewrite(spirv::ModuleOp spvModuleOp, OpAdaptor adaptor,
1716 ModuleOp::create(rewriter, spvModuleOp.getLoc(), spvModuleOp.getName());
1720 rewriter.
eraseBlock(&newModuleOp.getBodyRegion().back());
1721 rewriter.
eraseOp(spvModuleOp);
1730 class VectorShufflePattern
1735 matchAndRewrite(spirv::VectorShuffleOp op, OpAdaptor adaptor,
1738 auto components = adaptor.getComponents();
1739 auto vector1 = adaptor.getVector1();
1740 auto vector2 = adaptor.getVector2();
1741 int vector1Size = cast<VectorType>(vector1.getType()).getNumElements();
1742 int vector2Size = cast<VectorType>(vector2.getType()).getNumElements();
1743 if (vector1Size == vector2Size) {
1745 op, vector1, vector2,
1746 LLVM::convertArrayToIndices<int32_t>(components));
1750 auto dstType = getTypeConverter()->convertType(op.getType());
1753 auto scalarType = cast<VectorType>(dstType).getElementType();
1754 auto componentsArray = components.getValue();
1757 Value targetOp = LLVM::PoisonOp::create(rewriter, loc, dstType);
1758 for (
unsigned i = 0; i < componentsArray.size(); i++) {
1759 if (!isa<IntegerAttr>(componentsArray[i]))
1760 return op.emitError(
"unable to support non-constant component");
1762 int indexVal = cast<IntegerAttr>(componentsArray[i]).getInt();
1767 Value baseVector = vector1;
1768 if (indexVal >= vector1Size) {
1769 offsetVal = vector1Size;
1770 baseVector = vector2;
1773 Value dstIndex = LLVM::ConstantOp::create(
1774 rewriter, loc, llvmI32Type,
1776 Value index = LLVM::ConstantOp::create(
1777 rewriter, loc, llvmI32Type,
1780 auto extractOp = LLVM::ExtractElementOp::create(rewriter, loc, scalarType,
1782 targetOp = LLVM::InsertElementOp::create(rewriter, loc, dstType, targetOp,
1783 extractOp, dstIndex);
1796 spirv::ClientAPI clientAPI) {
1813 spirv::ClientAPI clientAPI) {
1816 DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>,
1817 DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>,
1818 DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>,
1819 DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>,
1820 DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>,
1821 DirectConversionPattern<spirv::FMulOp, LLVM::FMulOp>,
1822 DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>,
1823 DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>,
1824 DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>,
1825 DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>,
1826 DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>,
1827 DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>,
1828 DirectConversionPattern<spirv::UModOp, LLVM::URemOp>,
1831 BitFieldInsertPattern, BitFieldUExtractPattern, BitFieldSExtractPattern,
1832 DirectConversionPattern<spirv::BitCountOp, LLVM::CtPopOp>,
1833 DirectConversionPattern<spirv::BitReverseOp, LLVM::BitReverseOp>,
1834 DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>,
1835 DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>,
1836 DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>,
1837 NotPattern<spirv::NotOp>,
1840 BitcastConversionPattern,
1841 DirectConversionPattern<spirv::ConvertFToSOp, LLVM::FPToSIOp>,
1842 DirectConversionPattern<spirv::ConvertFToUOp, LLVM::FPToUIOp>,
1843 DirectConversionPattern<spirv::ConvertSToFOp, LLVM::SIToFPOp>,
1844 DirectConversionPattern<spirv::ConvertUToFOp, LLVM::UIToFPOp>,
1845 IndirectCastPattern<spirv::FConvertOp, LLVM::FPExtOp, LLVM::FPTruncOp>,
1846 IndirectCastPattern<spirv::SConvertOp, LLVM::SExtOp, LLVM::TruncOp>,
1847 IndirectCastPattern<spirv::UConvertOp, LLVM::ZExtOp, LLVM::TruncOp>,
1850 IComparePattern<spirv::IEqualOp, LLVM::ICmpPredicate::eq>,
1851 IComparePattern<spirv::INotEqualOp, LLVM::ICmpPredicate::ne>,
1852 FComparePattern<spirv::FOrdEqualOp, LLVM::FCmpPredicate::oeq>,
1853 FComparePattern<spirv::FOrdGreaterThanOp, LLVM::FCmpPredicate::ogt>,
1854 FComparePattern<spirv::FOrdGreaterThanEqualOp, LLVM::FCmpPredicate::oge>,
1855 FComparePattern<spirv::FOrdLessThanEqualOp, LLVM::FCmpPredicate::ole>,
1856 FComparePattern<spirv::FOrdLessThanOp, LLVM::FCmpPredicate::olt>,
1857 FComparePattern<spirv::FOrdNotEqualOp, LLVM::FCmpPredicate::one>,
1858 FComparePattern<spirv::FUnordEqualOp, LLVM::FCmpPredicate::ueq>,
1859 FComparePattern<spirv::FUnordGreaterThanOp, LLVM::FCmpPredicate::ugt>,
1860 FComparePattern<spirv::FUnordGreaterThanEqualOp,
1861 LLVM::FCmpPredicate::uge>,
1862 FComparePattern<spirv::FUnordLessThanEqualOp, LLVM::FCmpPredicate::ule>,
1863 FComparePattern<spirv::FUnordLessThanOp, LLVM::FCmpPredicate::ult>,
1864 FComparePattern<spirv::FUnordNotEqualOp, LLVM::FCmpPredicate::une>,
1865 IComparePattern<spirv::SGreaterThanOp, LLVM::ICmpPredicate::sgt>,
1866 IComparePattern<spirv::SGreaterThanEqualOp, LLVM::ICmpPredicate::sge>,
1867 IComparePattern<spirv::SLessThanEqualOp, LLVM::ICmpPredicate::sle>,
1868 IComparePattern<spirv::SLessThanOp, LLVM::ICmpPredicate::slt>,
1869 IComparePattern<spirv::UGreaterThanOp, LLVM::ICmpPredicate::ugt>,
1870 IComparePattern<spirv::UGreaterThanEqualOp, LLVM::ICmpPredicate::uge>,
1871 IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>,
1872 IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>,
1875 ConstantScalarAndVectorPattern,
1878 BranchConversionPattern, BranchConditionalConversionPattern,
1879 FunctionCallPattern, LoopPattern, SelectionPattern,
1880 ErasePattern<spirv::MergeOp>,
1883 ErasePattern<spirv::EntryPointOp>, ExecutionModePattern,
1886 DirectConversionPattern<spirv::GLCeilOp, LLVM::FCeilOp>,
1887 DirectConversionPattern<spirv::GLCosOp, LLVM::CosOp>,
1888 DirectConversionPattern<spirv::GLExpOp, LLVM::ExpOp>,
1889 DirectConversionPattern<spirv::GLFAbsOp, LLVM::FAbsOp>,
1890 DirectConversionPattern<spirv::GLFloorOp, LLVM::FFloorOp>,
1891 DirectConversionPattern<spirv::GLFMaxOp, LLVM::MaxNumOp>,
1892 DirectConversionPattern<spirv::GLFMinOp, LLVM::MinNumOp>,
1893 DirectConversionPattern<spirv::GLLogOp, LLVM::LogOp>,
1894 DirectConversionPattern<spirv::GLSinOp, LLVM::SinOp>,
1895 DirectConversionPattern<spirv::GLSMaxOp, LLVM::SMaxOp>,
1896 DirectConversionPattern<spirv::GLSMinOp, LLVM::SMinOp>,
1897 DirectConversionPattern<spirv::GLSqrtOp, LLVM::SqrtOp>,
1898 InverseSqrtPattern, TanPattern, TanhPattern,
1901 DirectConversionPattern<spirv::LogicalAndOp, LLVM::AndOp>,
1902 DirectConversionPattern<spirv::LogicalOrOp, LLVM::OrOp>,
1903 IComparePattern<spirv::LogicalEqualOp, LLVM::ICmpPredicate::eq>,
1904 IComparePattern<spirv::LogicalNotEqualOp, LLVM::ICmpPredicate::ne>,
1905 NotPattern<spirv::LogicalNotOp>,
1908 AccessChainPattern, AddressOfPattern, LoadStorePattern<spirv::LoadOp>,
1909 LoadStorePattern<spirv::StoreOp>, VariablePattern,
1912 CompositeExtractPattern, CompositeInsertPattern,
1913 DirectConversionPattern<spirv::SelectOp, LLVM::SelectOp>,
1914 DirectConversionPattern<spirv::UndefOp, LLVM::UndefOp>,
1915 VectorShufflePattern,
1918 ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>,
1919 ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>,
1920 ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>,
1923 ReturnPattern, ReturnValuePattern,
1926 ControlBarrierPattern<spirv::ControlBarrierOp>,
1927 ControlBarrierPattern<spirv::INTELControlBarrierArriveOp>,
1928 ControlBarrierPattern<spirv::INTELControlBarrierWaitOp>,
1931 GroupReducePattern<spirv::GroupIAddOp>,
1932 GroupReducePattern<spirv::GroupFAddOp>,
1933 GroupReducePattern<spirv::GroupFMinOp>,
1934 GroupReducePattern<spirv::GroupUMinOp>,
1935 GroupReducePattern<spirv::GroupSMinOp,
true>,
1936 GroupReducePattern<spirv::GroupFMaxOp>,
1937 GroupReducePattern<spirv::GroupUMaxOp>,
1938 GroupReducePattern<spirv::GroupSMaxOp,
true>,
1939 GroupReducePattern<spirv::GroupNonUniformIAddOp,
false,
1941 GroupReducePattern<spirv::GroupNonUniformFAddOp,
false,
1943 GroupReducePattern<spirv::GroupNonUniformIMulOp,
false,
1945 GroupReducePattern<spirv::GroupNonUniformFMulOp,
false,
1947 GroupReducePattern<spirv::GroupNonUniformSMinOp,
true,
1949 GroupReducePattern<spirv::GroupNonUniformUMinOp,
false,
1951 GroupReducePattern<spirv::GroupNonUniformFMinOp,
false,
1953 GroupReducePattern<spirv::GroupNonUniformSMaxOp,
true,
1955 GroupReducePattern<spirv::GroupNonUniformUMaxOp,
false,
1957 GroupReducePattern<spirv::GroupNonUniformFMaxOp,
false,
1959 GroupReducePattern<spirv::GroupNonUniformBitwiseAndOp,
false,
1961 GroupReducePattern<spirv::GroupNonUniformBitwiseOrOp,
false,
1963 GroupReducePattern<spirv::GroupNonUniformBitwiseXorOp,
false,
1965 GroupReducePattern<spirv::GroupNonUniformLogicalAndOp,
false,
1967 GroupReducePattern<spirv::GroupNonUniformLogicalOrOp,
false,
1969 GroupReducePattern<spirv::GroupNonUniformLogicalXorOp,
false,
1984 patterns.add<ModuleConversionPattern>(
patterns.getContext(), typeConverter);
1995 auto spvModules = module.getOps<spirv::ModuleOp>();
1996 for (
auto spvModule : spvModules) {
1997 spvModule.walk([&](spirv::GlobalVariableOp op) {
1998 IntegerAttr descriptorSet =
2000 IntegerAttr binding = op->getAttrOfType<IntegerAttr>(
kBinding);
2003 if (descriptorSet && binding) {
2006 auto moduleAndName =
2007 spvModule.getName().has_value()
2008 ? spvModule.getName()->str() +
"_" + op.getSymName().str()
2009 : op.getSymName().str();
2011 llvm::formatv(
"{0}_descriptor_set{1}_binding{2}", moduleAndName,
2012 std::to_string(descriptorSet.getInt()),
2013 std::to_string(binding.getInt()));
2019 op.emitError(
"unable to replace all symbol uses for ") << name;
static LLVM::CallOp createSPIRVBuiltinCall(Location loc, ConversionPatternRewriter &rewriter, LLVM::LLVMFuncOp func, ValueRange args)
static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable, StringRef name, ArrayRef< Type > paramTypes, Type resultType, bool isMemNone, bool isConvergent)
static MLIRContext * getContext(OpFoldResult val)
static Value optionallyTruncateOrExtend(Location loc, Value value, Type llvmType, PatternRewriter &rewriter)
Utility function for bitfield ops:
static Value createFPConstant(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter, double value)
Creates llvm.mlir.constant with a floating-point scalar or vector value.
static constexpr StringRef kDescriptorSet
static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter, unsigned value)
Creates LLVM dialect constant with the given value.
static Type convertPointerType(spirv::PointerType type, const TypeConverter &converter, spirv::ClientAPI clientAPI)
Converts SPIR-V pointer type to LLVM pointer.
static Value processCountOrOffset(Location loc, Value value, Type srcType, Type dstType, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Utility function for bitfield ops: BitFieldInsert, BitFieldSExtract and BitFieldUExtract.
static unsigned getBitWidth(Type type)
Returns the bit width of integer, float or vector of float or integer values.
static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter, const TypeConverter &typeConverter, unsigned alignment, bool isVolatile, bool isNonTemporal)
Utility for spirv.Load and spirv.Store conversion.
static Type convertStructTypePacked(spirv::StructType type, const TypeConverter &converter)
Converts SPIR-V struct with no offset to packed LLVM struct.
static std::optional< Type > convertRuntimeArrayType(spirv::RuntimeArrayType type, TypeConverter &converter)
Converts SPIR-V runtime array to LLVM array.
static bool isSignedIntegerOrVector(Type type)
Returns true if the given type is a signed integer or vector type.
static bool isUnsignedIntegerOrVector(Type type)
Returns true if the given type is an unsigned integer or vector type.
static std::optional< Type > convertArrayType(spirv::ArrayType type, TypeConverter &converter)
Converts SPIR-V array type to LLVM array.
static constexpr StringRef kBinding
Hook for descriptor set and binding number encoding.
static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder)
Creates IntegerAttribute with all bits set for given type.
static Value optionallyBroadcast(Location loc, Value value, Type srcType, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value. If srcType is a scalar, the value remains unchanged.
static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter)
Creates llvm.mlir.constant with all bits set for the given type.
static unsigned getLLVMTypeBitWidth(Type type)
Returns the bit width of LLVMType integer or vector.
#define DISPATCH(functionControl, llvmAttr)
static std::optional< uint64_t > getIntegerOrVectorElementWidth(Type type)
Returns the width of an integer or of the element type of an integer vector, if applicable.
static Type convertStructTypeWithOffset(spirv::StructType type, const TypeConverter &converter)
Converts SPIR-V struct with a regular (according to VulkanLayoutUtils) offset to LLVM struct.
static Type convertStructType(spirv::StructType type, const TypeConverter &converter)
Converts SPIR-V struct to LLVM struct.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
OpListType::iterator iterator
Operation * getTerminator()
Get the terminator operation of this block.
OpListType & getOperations()
BlockArgListType getArguments()
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI32IntegerAttr(int32_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * getBlock() const
Returns the current block of the builder.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
operand_range getOperands()
Returns an iterator on the underlying Value's.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, StringAttr newSymbol, Operation *from)
Attempt to replace all uses of the given symbol 'oldSymbol' with the provided symbol 'newSymbol' that...
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
static void setSymbolName(Operation *symbol, StringAttr name)
Sets the name of the given symbol operation.
This class provides all of the information necessary to convert a type signature.
void addConversion(FnT &&callback)
Register a conversion function.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static spirv::StructType decorateType(spirv::StructType structType)
Returns a new StructType with layout decoration.
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Builder from ArrayRef<T>.
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
unsigned getNumElements() const
StorageClass getStorageClass() const
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
TypeRange getElementTypes() const
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
SmallVector< IntT > convertArrayToIndices(ArrayRef< Attribute > attrs)
Convert an array of integer attributes to a vector of integers that can be used as indices in LLVM op...
Include the generated interface declarations.
unsigned storageClassToAddressSpace(spirv::ClientAPI clientAPI, spirv::StorageClass storageClass)
void populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter, spirv::ClientAPI clientAPIForAddressSpaceMapping=spirv::ClientAPI::Unknown)
Populates type conversions with additional SPIR-V types.
void populateSPIRVToLLVMFunctionConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given list with patterns for function conversion from SPIR-V to LLVM.
const FrozenRewritePatternSet & patterns
void populateSPIRVToLLVMConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, spirv::ClientAPI clientAPIForAddressSpaceMapping=spirv::ClientAPI::Unknown)
Populates the given list with patterns that convert from SPIR-V to LLVM.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void encodeBindAttribute(ModuleOp module)
Encodes global variable's descriptor set and binding into its name if they both exist.
void populateSPIRVToLLVMModuleConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given patterns for module conversion from SPIR-V to LLVM.