25 #include "llvm/ADT/TypeSwitch.h"
26 #include "llvm/Support/Debug.h"
27 #include "llvm/Support/FormatVariadic.h"
29 #define DEBUG_TYPE "spirv-to-llvm-pattern"
41 if (
auto vecType = dyn_cast<VectorType>(type))
42 return vecType.getElementType().isSignedInteger();
50 if (
auto vecType = dyn_cast<VectorType>(type))
51 return vecType.getElementType().isUnsignedInteger();
58 if (
auto intType = dyn_cast<IntegerType>(type))
59 return intType.getWidth();
60 if (
auto vecType = dyn_cast<VectorType>(type))
61 if (
auto intType = dyn_cast<IntegerType>(vecType.getElementType()))
62 return intType.getWidth();
69 "bitwidth is not supported for this type");
72 auto vecType = dyn_cast<VectorType>(type);
73 auto elementType = vecType.getElementType();
74 assert(elementType.isIntOrFloat() &&
75 "only integers and floats have a bitwidth");
76 return elementType.getIntOrFloatBitWidth();
81 if (
auto vecTy = dyn_cast<VectorType>(type))
82 type = vecTy.getElementType();
83 return cast<IntegerType>(type).getWidth();
88 if (
auto vecType = dyn_cast<VectorType>(type)) {
89 auto integerType = cast<IntegerType>(vecType.getElementType());
92 auto integerType = cast<IntegerType>(type);
99 if (isa<VectorType>(srcType)) {
100 return rewriter.
create<LLVM::ConstantOp>(
105 return rewriter.
create<LLVM::ConstantOp>(
112 if (
auto vecType = dyn_cast<VectorType>(srcType)) {
113 auto floatType = cast<FloatType>(vecType.getElementType());
114 return rewriter.
create<LLVM::ConstantOp>(
119 auto floatType = cast<FloatType>(srcType);
120 return rewriter.
create<LLVM::ConstantOp>(
133 auto srcType = value.
getType();
139 if (valueBitWidth < targetBitWidth)
140 return rewriter.
create<LLVM::ZExtOp>(loc, llvmType, value);
145 if (valueBitWidth > targetBitWidth)
146 return rewriter.
create<LLVM::TruncOp>(loc, llvmType, value);
155 auto llvmVectorType = typeConverter.
convertType(vectorType);
157 Value broadcasted = rewriter.
create<LLVM::PoisonOp>(loc, llvmVectorType);
158 for (
unsigned i = 0; i < numElements; ++i) {
159 auto index = rewriter.
create<LLVM::ConstantOp>(
161 broadcasted = rewriter.
create<LLVM::InsertElementOp>(
162 loc, llvmVectorType, broadcasted, toBroadcast, index);
171 if (
auto vectorType = dyn_cast<VectorType>(srcType)) {
172 unsigned numElements = vectorType.getNumElements();
173 return broadcast(loc, value, numElements, typeConverter, rewriter);
206 return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
216 return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
223 return rewriter.
create<LLVM::ConstantOp>(
232 unsigned alignment,
bool isVolatile,
233 bool isNonTemporal) {
234 if (
auto loadOp = dyn_cast<spirv::LoadOp>(op)) {
235 auto dstType = typeConverter.
convertType(loadOp.getType());
239 loadOp, dstType, spirv::LoadOpAdaptor(operands).getPtr(), alignment,
240 isVolatile, isNonTemporal);
243 auto storeOp = cast<spirv::StoreOp>(op);
244 spirv::StoreOpAdaptor adaptor(operands);
246 adaptor.getPtr(), alignment,
247 isVolatile, isNonTemporal);
262 auto sizeInBytes = cast<spirv::SPIRVType>(elementType).getSizeInBytes();
263 if (stride != 0 && (!sizeInBytes || *sizeInBytes != stride))
266 auto llvmElementType = converter.
convertType(elementType);
275 spirv::ClientAPI clientAPI) {
276 unsigned addressSpace =
298 if (!memberDecorations.empty())
316 matchAndRewrite(spirv::AccessChainOp op, OpAdaptor adaptor,
319 getTypeConverter()->convertType(op.getComponentPtr().getType());
323 auto indices = llvm::to_vector<4>(adaptor.getIndices());
324 Type indexType = op.getIndices().front().getType();
325 auto llvmIndexType = getTypeConverter()->convertType(indexType);
329 op.getLoc(), llvmIndexType, rewriter.
getIntegerAttr(indexType, 0));
330 indices.insert(indices.begin(), zero);
332 auto elementType = getTypeConverter()->convertType(
333 cast<spirv::PointerType>(op.getBasePtr().getType()).getPointeeType());
337 adaptor.getBasePtr(), indices);
347 matchAndRewrite(spirv::AddressOfOp op, OpAdaptor adaptor,
349 auto dstType = getTypeConverter()->convertType(op.getPointer().getType());
358 class BitFieldInsertPattern
364 matchAndRewrite(spirv::BitFieldInsertOp op, OpAdaptor adaptor,
366 auto srcType = op.getType();
367 auto dstType = getTypeConverter()->convertType(srcType);
374 *getTypeConverter(), rewriter);
376 *getTypeConverter(), rewriter);
380 Value maskShiftedByCount =
381 rewriter.
create<LLVM::ShlOp>(loc, dstType, minusOne, count);
382 Value negated = rewriter.
create<LLVM::XOrOp>(loc, dstType,
383 maskShiftedByCount, minusOne);
384 Value maskShiftedByCountAndOffset =
385 rewriter.
create<LLVM::ShlOp>(loc, dstType, negated, offset);
387 loc, dstType, maskShiftedByCountAndOffset, minusOne);
392 rewriter.
create<LLVM::AndOp>(loc, dstType, op.getBase(), mask);
393 Value insertShiftedByOffset =
394 rewriter.
create<LLVM::ShlOp>(loc, dstType, op.getInsert(), offset);
396 insertShiftedByOffset);
402 class ConstantScalarAndVectorPattern
408 matchAndRewrite(spirv::ConstantOp constOp, OpAdaptor adaptor,
410 auto srcType = constOp.getType();
411 if (!isa<VectorType>(srcType) && !srcType.isIntOrFloat())
414 auto dstType = getTypeConverter()->convertType(srcType);
427 if (isa<VectorType>(srcType)) {
428 auto dstElementsAttr = cast<DenseIntElementsAttr>(constOp.getValue());
431 dstElementsAttr.mapValues(
432 signlessType, [&](
const APInt &value) {
return value; }));
435 auto srcAttr = cast<IntegerAttr>(constOp.getValue());
436 auto dstAttr = rewriter.
getIntegerAttr(signlessType, srcAttr.getValue());
441 constOp, dstType, adaptor.getOperands(), constOp->getAttrs());
446 class BitFieldSExtractPattern
452 matchAndRewrite(spirv::BitFieldSExtractOp op, OpAdaptor adaptor,
454 auto srcType = op.getType();
455 auto dstType = getTypeConverter()->convertType(srcType);
462 *getTypeConverter(), rewriter);
464 *getTypeConverter(), rewriter);
467 IntegerType integerType;
468 if (
auto vecType = dyn_cast<VectorType>(srcType))
469 integerType = cast<IntegerType>(vecType.getElementType());
471 integerType = cast<IntegerType>(srcType);
475 isa<VectorType>(srcType)
476 ? rewriter.
create<LLVM::ConstantOp>(
479 : rewriter.
create<LLVM::ConstantOp>(loc, dstType, baseSize);
483 Value countPlusOffset =
484 rewriter.
create<LLVM::AddOp>(loc, dstType, count, offset);
485 Value amountToShiftLeft =
486 rewriter.
create<LLVM::SubOp>(loc, dstType, size, countPlusOffset);
487 Value baseShiftedLeft = rewriter.
create<LLVM::ShlOp>(
488 loc, dstType, op.getBase(), amountToShiftLeft);
491 Value amountToShiftRight =
492 rewriter.
create<LLVM::AddOp>(loc, dstType, offset, amountToShiftLeft);
499 class BitFieldUExtractPattern
505 matchAndRewrite(spirv::BitFieldUExtractOp op, OpAdaptor adaptor,
507 auto srcType = op.getType();
508 auto dstType = getTypeConverter()->convertType(srcType);
515 *getTypeConverter(), rewriter);
517 *getTypeConverter(), rewriter);
521 Value maskShiftedByCount =
522 rewriter.
create<LLVM::ShlOp>(loc, dstType, minusOne, count);
523 Value mask = rewriter.
create<LLVM::XOrOp>(loc, dstType, maskShiftedByCount,
528 rewriter.
create<LLVM::LShrOp>(loc, dstType, op.getBase(), offset);
539 matchAndRewrite(spirv::BranchOp branchOp, OpAdaptor adaptor,
542 branchOp.getTarget());
547 class BranchConditionalConversionPattern
554 matchAndRewrite(spirv::BranchConditionalOp op, OpAdaptor adaptor,
558 if (
auto weights = op.getBranchWeights()) {
560 for (
auto weight : weights->getAsRange<IntegerAttr>())
561 weightValues.push_back(weight.getInt());
566 op, op.getCondition(), op.getTrueBlockArguments(),
567 op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(),
576 class CompositeExtractPattern
582 matchAndRewrite(spirv::CompositeExtractOp op, OpAdaptor adaptor,
584 auto dstType = this->getTypeConverter()->convertType(op.getType());
588 Type containerType = op.getComposite().getType();
589 if (isa<VectorType>(containerType)) {
591 IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]);
594 op, dstType, adaptor.getComposite(), index);
599 op, adaptor.getComposite(),
608 class CompositeInsertPattern
614 matchAndRewrite(spirv::CompositeInsertOp op, OpAdaptor adaptor,
616 auto dstType = this->getTypeConverter()->convertType(op.getType());
620 Type containerType = op.getComposite().getType();
621 if (isa<VectorType>(containerType)) {
623 IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]);
626 op, dstType, adaptor.getComposite(), adaptor.getObject(), index);
631 op, adaptor.getComposite(), adaptor.getObject(),
639 template <
typename SPIRVOp,
typename LLVMOp>
645 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
647 auto dstType = this->getTypeConverter()->convertType(op.getType());
650 rewriter.template replaceOpWithNewOp<LLVMOp>(
651 op, dstType, adaptor.getOperands(), op->getAttrs());
658 class ExecutionModePattern
664 matchAndRewrite(spirv::ExecutionModeOp op, OpAdaptor adaptor,
669 ModuleOp module = op->getParentOfType<ModuleOp>();
670 spirv::ExecutionModeAttr executionModeAttr = op.getExecutionModeAttr();
671 std::string moduleName;
672 if (module.getName().has_value())
673 moduleName =
"_" + module.getName()->str();
676 std::string executionModeInfoName = llvm::formatv(
677 "__spv_{0}_{1}_execution_mode_info_{2}", moduleName, op.getFn().str(),
678 static_cast<uint32_t
>(executionModeAttr.getValue()));
691 fields.push_back(llvmI32Type);
692 ArrayAttr values = op.getValues();
693 if (!values.empty()) {
695 fields.push_back(arrayType);
697 auto structType = LLVM::LLVMStructType::getLiteral(context, fields);
700 auto global = rewriter.
create<LLVM::GlobalOp>(
702 LLVM::Linkage::External, executionModeInfoName,
Attribute(),
705 Region ®ion = global.getInitializerRegion();
710 Value structValue = rewriter.
create<LLVM::PoisonOp>(loc, structType);
711 Value executionMode = rewriter.
create<LLVM::ConstantOp>(
714 static_cast<uint32_t
>(executionModeAttr.getValue())));
715 structValue = rewriter.
create<LLVM::InsertValueOp>(loc, structValue,
719 for (
unsigned i = 0, e = values.size(); i < e; ++i) {
720 auto attr = values.getValue()[i];
721 Value entry = rewriter.
create<LLVM::ConstantOp>(loc, llvmI32Type, attr);
722 structValue = rewriter.
create<LLVM::InsertValueOp>(
735 class GlobalVariablePattern
738 template <
typename... Args>
739 GlobalVariablePattern(spirv::ClientAPI clientAPI, Args &&...args)
741 std::forward<Args>(args)...),
742 clientAPI(clientAPI) {}
745 matchAndRewrite(spirv::GlobalVariableOp op, OpAdaptor adaptor,
749 if (op.getInitializer())
752 auto srcType = cast<spirv::PointerType>(op.getType());
753 auto dstType = getTypeConverter()->convertType(srcType.getPointeeType());
760 auto storageClass = srcType.getStorageClass();
761 switch (storageClass) {
762 case spirv::StorageClass::Input:
763 case spirv::StorageClass::Private:
764 case spirv::StorageClass::Output:
765 case spirv::StorageClass::StorageBuffer:
766 case spirv::StorageClass::UniformConstant:
775 bool isConstant = (storageClass == spirv::StorageClass::Input) ||
776 (storageClass == spirv::StorageClass::UniformConstant);
782 auto linkage = storageClass == spirv::StorageClass::Private
783 ? LLVM::Linkage::Private
784 : LLVM::Linkage::External;
786 op, dstType, isConstant, linkage, op.getSymName(),
Attribute(),
790 if (op.getLocationAttr())
791 newGlobalOp->setAttr(op.getLocationAttrName(), op.getLocationAttr());
797 spirv::ClientAPI clientAPI;
802 template <
typename SPIRVOp,
typename LLVMExtOp,
typename LLVMTruncOp>
808 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
811 Type fromType = op.getOperand().getType();
812 Type toType = op.getType();
814 auto dstType = this->getTypeConverter()->convertType(toType);
819 rewriter.template replaceOpWithNewOp<LLVMExtOp>(op, dstType,
820 adaptor.getOperands());
824 rewriter.template replaceOpWithNewOp<LLVMTruncOp>(op, dstType,
825 adaptor.getOperands());
832 class FunctionCallPattern
838 matchAndRewrite(spirv::FunctionCallOp callOp, OpAdaptor adaptor,
840 if (callOp.getNumResults() == 0) {
842 callOp, std::nullopt, adaptor.getOperands(), callOp->getAttrs());
843 newOp.getProperties().operandSegmentSizes = {
844 static_cast<int32_t
>(adaptor.getOperands().size()), 0};
850 auto dstType = getTypeConverter()->convertType(callOp.getType(0));
854 callOp, dstType, adaptor.getOperands(), callOp->getAttrs());
855 newOp.getProperties().operandSegmentSizes = {
856 static_cast<int32_t
>(adaptor.getOperands().size()), 0};
863 template <
typename SPIRVOp, LLVM::FCmpPredicate predicate>
869 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
872 auto dstType = this->getTypeConverter()->convertType(op.getType());
876 rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>(
877 op, dstType, predicate, op.getOperand1(), op.getOperand2());
883 template <
typename SPIRVOp, LLVM::ICmpPredicate predicate>
889 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
892 auto dstType = this->getTypeConverter()->convertType(op.getType());
896 rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>(
897 op, dstType, predicate, op.getOperand1(), op.getOperand2());
902 class InverseSqrtPattern
908 matchAndRewrite(spirv::GLInverseSqrtOp op, OpAdaptor adaptor,
910 auto srcType = op.getType();
911 auto dstType = getTypeConverter()->convertType(srcType);
917 Value sqrt = rewriter.
create<LLVM::SqrtOp>(loc, dstType, op.getOperand());
924 template <
typename SPIRVOp>
930 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
932 if (!op.getMemoryAccess()) {
934 *this->getTypeConverter(), 0,
938 auto memoryAccess = *op.getMemoryAccess();
939 switch (memoryAccess) {
940 case spirv::MemoryAccess::Aligned:
942 case spirv::MemoryAccess::Nontemporal:
943 case spirv::MemoryAccess::Volatile: {
945 memoryAccess == spirv::MemoryAccess::Aligned ? *op.getAlignment() : 0;
946 bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal;
947 bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile;
949 *this->getTypeConverter(), alignment,
950 isVolatile, isNonTemporal);
960 template <
typename SPIRVOp>
966 matchAndRewrite(SPIRVOp notOp,
typename SPIRVOp::Adaptor adaptor,
968 auto srcType = notOp.getType();
969 auto dstType = this->getTypeConverter()->convertType(srcType);
976 isa<VectorType>(srcType)
977 ? rewriter.
create<LLVM::ConstantOp>(
980 : rewriter.
create<LLVM::ConstantOp>(loc, dstType, minusOne);
981 rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType,
982 notOp.getOperand(), mask);
988 template <
typename SPIRVOp>
994 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
1006 matchAndRewrite(spirv::ReturnOp returnOp, OpAdaptor adaptor,
1019 matchAndRewrite(spirv::ReturnValueOp returnValueOp, OpAdaptor adaptor,
1022 adaptor.getOperands());
1031 bool convergent =
true) {
1032 auto func = dyn_cast_or_null<LLVM::LLVMFuncOp>(
1038 func = b.create<LLVM::LLVMFuncOp>(
1039 symbolTable->
getLoc(), name,
1041 func.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
1042 func.setConvergent(convergent);
1043 func.setNoUnwind(
true);
1044 func.setWillReturn(
true);
1049 LLVM::LLVMFuncOp func,
1051 auto call = builder.
create<LLVM::CallOp>(loc, func, args);
1052 call.setCConv(func.getCConv());
1053 call.setConvergentAttr(func.getConvergentAttr());
1054 call.setNoUnwindAttr(func.getNoUnwindAttr());
1055 call.setWillReturnAttr(func.getWillReturnAttr());
1059 template <
typename BarrierOpTy>
1066 static constexpr StringRef getFuncName();
1069 matchAndRewrite(BarrierOpTy controlBarrierOp, OpAdaptor adaptor,
1071 constexpr StringRef funcName = getFuncName();
1073 controlBarrierOp->template getParentWithTrait<OpTrait::SymbolTable>();
1077 Type voidTy = rewriter.
getType<LLVM::LLVMVoidType>();
1078 LLVM::LLVMFuncOp func =
1081 Location loc = controlBarrierOp->getLoc();
1082 Value execution = rewriter.
create<LLVM::ConstantOp>(
1083 loc, i32,
static_cast<int32_t
>(adaptor.getExecutionScope()));
1085 loc, i32,
static_cast<int32_t
>(adaptor.getMemoryScope()));
1086 Value semantics = rewriter.
create<LLVM::ConstantOp>(
1087 loc, i32,
static_cast<int32_t
>(adaptor.getMemorySemantics()));
1090 {execution, memory, semantics});
1092 rewriter.
replaceOp(controlBarrierOp, call);
1099 StringRef getTypeMangling(
Type type,
bool isSigned) {
1101 .Case<Float16Type>([](
auto) {
return "Dh"; })
1102 .Case<Float32Type>([](
auto) {
return "f"; })
1103 .Case<Float64Type>([](
auto) {
return "d"; })
1104 .Case<IntegerType>([isSigned](IntegerType intTy) {
1105 switch (intTy.getWidth()) {
1109 return (isSigned) ?
"a" :
"c";
1111 return (isSigned) ?
"s" :
"t";
1113 return (isSigned) ?
"i" :
"j";
1115 return (isSigned) ?
"l" :
"m";
1117 llvm_unreachable(
"Unsupported integer width");
1121 llvm_unreachable(
"No mangling defined");
1126 template <
typename ReduceOp>
1127 constexpr StringLiteral getGroupFuncName();
1130 constexpr StringLiteral getGroupFuncName<spirv::GroupIAddOp>() {
1131 return "_Z17__spirv_GroupIAddii";
1134 constexpr StringLiteral getGroupFuncName<spirv::GroupFAddOp>() {
1135 return "_Z17__spirv_GroupFAddii";
1138 constexpr StringLiteral getGroupFuncName<spirv::GroupSMinOp>() {
1139 return "_Z17__spirv_GroupSMinii";
1142 constexpr StringLiteral getGroupFuncName<spirv::GroupUMinOp>() {
1143 return "_Z17__spirv_GroupUMinii";
1146 constexpr StringLiteral getGroupFuncName<spirv::GroupFMinOp>() {
1147 return "_Z17__spirv_GroupFMinii";
1150 constexpr StringLiteral getGroupFuncName<spirv::GroupSMaxOp>() {
1151 return "_Z17__spirv_GroupSMaxii";
1154 constexpr StringLiteral getGroupFuncName<spirv::GroupUMaxOp>() {
1155 return "_Z17__spirv_GroupUMaxii";
1158 constexpr StringLiteral getGroupFuncName<spirv::GroupFMaxOp>() {
1159 return "_Z17__spirv_GroupFMaxii";
1162 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIAddOp>() {
1163 return "_Z27__spirv_GroupNonUniformIAddii";
1166 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFAddOp>() {
1167 return "_Z27__spirv_GroupNonUniformFAddii";
1170 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIMulOp>() {
1171 return "_Z27__spirv_GroupNonUniformIMulii";
1174 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMulOp>() {
1175 return "_Z27__spirv_GroupNonUniformFMulii";
1178 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMinOp>() {
1179 return "_Z27__spirv_GroupNonUniformSMinii";
1182 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMinOp>() {
1183 return "_Z27__spirv_GroupNonUniformUMinii";
1186 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMinOp>() {
1187 return "_Z27__spirv_GroupNonUniformFMinii";
1190 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMaxOp>() {
1191 return "_Z27__spirv_GroupNonUniformSMaxii";
1194 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMaxOp>() {
1195 return "_Z27__spirv_GroupNonUniformUMaxii";
1198 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMaxOp>() {
1199 return "_Z27__spirv_GroupNonUniformFMaxii";
1202 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseAndOp>() {
1203 return "_Z33__spirv_GroupNonUniformBitwiseAndii";
1206 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseOrOp>() {
1207 return "_Z32__spirv_GroupNonUniformBitwiseOrii";
1210 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseXorOp>() {
1211 return "_Z33__spirv_GroupNonUniformBitwiseXorii";
1214 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalAndOp>() {
1215 return "_Z33__spirv_GroupNonUniformLogicalAndii";
1218 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalOrOp>() {
1219 return "_Z32__spirv_GroupNonUniformLogicalOrii";
1222 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalXorOp>() {
1223 return "_Z33__spirv_GroupNonUniformLogicalXorii";
1227 template <
typename ReduceOp,
bool Signed = false,
bool NonUniform = false>
1233 matchAndRewrite(ReduceOp op,
typename ReduceOp::Adaptor adaptor,
1236 Type retTy = op.getResult().getType();
1241 funcName += getTypeMangling(retTy,
false);
1245 if constexpr (NonUniform) {
1246 if (adaptor.getClusterSize()) {
1248 paramTypes.push_back(i32Ty);
1253 op->template getParentWithTrait<OpTrait::SymbolTable>();
1255 LLVM::LLVMFuncOp func =
1260 loc, i32Ty,
static_cast<int32_t
>(adaptor.getExecutionScope()));
1261 Value groupOp = rewriter.
create<LLVM::ConstantOp>(
1262 loc, i32Ty,
static_cast<int32_t
>(adaptor.getGroupOperation()));
1264 operands.append(adaptor.getOperands().begin(), adaptor.getOperands().end());
1274 ControlBarrierPattern<spirv::ControlBarrierOp>::getFuncName() {
1275 return "_Z22__spirv_ControlBarrieriii";
1280 ControlBarrierPattern<spirv::INTELControlBarrierArriveOp>::getFuncName() {
1281 return "_Z33__spirv_ControlBarrierArriveINTELiii";
1286 ControlBarrierPattern<spirv::INTELControlBarrierWaitOp>::getFuncName() {
1287 return "_Z31__spirv_ControlBarrierWaitINTELiii";
1343 matchAndRewrite(spirv::LoopOp loopOp, OpAdaptor adaptor,
1350 if (loopOp.getBody().empty()) {
1365 Block *entryBlock = loopOp.getEntryBlock();
1367 auto brOp = dyn_cast<spirv::BranchOp>(entryBlock->
getOperations().front());
1370 Block *headerBlock = loopOp.getHeaderBlock();
1372 rewriter.
create<LLVM::BrOp>(loc, brOp.getBlockArguments(), headerBlock);
1376 Block *mergeBlock = loopOp.getMergeBlock();
1380 rewriter.
create<LLVM::BrOp>(loc, terminatorOperands, endBlock);
1396 matchAndRewrite(spirv::SelectionOp op, OpAdaptor adaptor,
1408 if (op.getBody().getBlocks().size() <= 2) {
1420 auto *continueBlock = rewriter.
splitBlock(currentBlock, position);
1426 auto *headerBlock = op.getHeaderBlock();
1428 auto condBrOp = dyn_cast<spirv::BranchConditionalOp>(
1435 auto *mergeBlock = op.getMergeBlock();
1439 rewriter.
create<LLVM::BrOp>(loc, terminatorOperands, continueBlock);
1442 Block *trueBlock = condBrOp.getTrueBlock();
1443 Block *falseBlock = condBrOp.getFalseBlock();
1445 rewriter.
create<LLVM::CondBrOp>(loc, condBrOp.getCondition(), trueBlock,
1446 condBrOp.getTrueTargetOperands(),
1448 condBrOp.getFalseTargetOperands());
1451 rewriter.
replaceOp(op, continueBlock->getArguments());
1460 template <
typename SPIRVOp,
typename LLVMOp>
1466 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
1469 auto dstType = this->getTypeConverter()->convertType(op.getType());
1473 Type op1Type = op.getOperand1().getType();
1474 Type op2Type = op.getOperand2().getType();
1476 if (op1Type == op2Type) {
1477 rewriter.template replaceOpWithNewOp<LLVMOp>(op, dstType,
1478 adaptor.getOperands());
1482 std::optional<uint64_t> dstTypeWidth =
1484 std::optional<uint64_t> op2TypeWidth =
1487 if (!dstTypeWidth || !op2TypeWidth)
1492 if (op2TypeWidth < dstTypeWidth) {
1494 extended = rewriter.template create<LLVM::ZExtOp>(
1495 loc, dstType, adaptor.getOperand2());
1497 extended = rewriter.template create<LLVM::SExtOp>(
1498 loc, dstType, adaptor.getOperand2());
1500 }
else if (op2TypeWidth == dstTypeWidth) {
1501 extended = adaptor.getOperand2();
1506 Value result = rewriter.template create<LLVMOp>(
1507 loc, dstType, adaptor.getOperand1(), extended);
1518 matchAndRewrite(spirv::GLTanOp tanOp, OpAdaptor adaptor,
1520 auto dstType = getTypeConverter()->convertType(tanOp.getType());
1525 Value sin = rewriter.
create<LLVM::SinOp>(loc, dstType, tanOp.getOperand());
1526 Value cos = rewriter.
create<LLVM::CosOp>(loc, dstType, tanOp.getOperand());
1543 matchAndRewrite(spirv::GLTanhOp tanhOp, OpAdaptor adaptor,
1545 auto srcType = tanhOp.getType();
1546 auto dstType = getTypeConverter()->convertType(srcType);
1553 rewriter.
create<LLVM::FMulOp>(loc, dstType, two, tanhOp.getOperand());
1554 Value exponential = rewriter.
create<LLVM::ExpOp>(loc, dstType, multiplied);
1557 rewriter.
create<LLVM::FSubOp>(loc, dstType, exponential, one);
1559 rewriter.
create<LLVM::FAddOp>(loc, dstType, exponential, one);
1571 matchAndRewrite(spirv::VariableOp varOp, OpAdaptor adaptor,
1573 auto srcType = varOp.getType();
1575 auto pointerTo = cast<spirv::PointerType>(srcType).getPointeeType();
1576 auto init = varOp.getInitializer();
1577 if (init && !pointerTo.isIntOrFloat() && !isa<VectorType>(pointerTo))
1580 auto dstType = getTypeConverter()->convertType(srcType);
1587 auto elementType = getTypeConverter()->convertType(pointerTo);
1594 auto elementType = getTypeConverter()->convertType(pointerTo);
1598 rewriter.
create<LLVM::AllocaOp>(loc, dstType, elementType, size);
1599 rewriter.
create<LLVM::StoreOp>(loc, adaptor.getInitializer(), allocated);
1609 class BitcastConversionPattern
1615 matchAndRewrite(spirv::BitcastOp bitcastOp, OpAdaptor adaptor,
1617 auto dstType = getTypeConverter()->convertType(bitcastOp.getType());
1622 if (isa<LLVM::LLVMPointerType>(dstType)) {
1623 rewriter.
replaceOp(bitcastOp, adaptor.getOperand());
1628 bitcastOp, dstType, adaptor.getOperands(), bitcastOp->getAttrs());
1642 matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor,
1647 auto funcType = funcOp.getFunctionType();
1649 funcType.getNumInputs());
1651 ->convertFunctionSignature(
1653 false, signatureConverter);
1659 StringRef name = funcOp.getName();
1660 auto newFuncOp = rewriter.
create<LLVM::LLVMFuncOp>(loc, name, llvmType);
1664 switch (funcOp.getFunctionControl()) {
1665 case spirv::FunctionControl::Inline:
1666 newFuncOp.setAlwaysInline(
true);
1668 case spirv::FunctionControl::DontInline:
1669 newFuncOp.setNoInline(
true);
1672 #define DISPATCH(functionControl, llvmAttr) \
1673 case functionControl: \
1674 newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr})); \
1677 DISPATCH(spirv::FunctionControl::Pure,
1679 DISPATCH(spirv::FunctionControl::Const,
1693 &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter))) {
1710 matchAndRewrite(spirv::ModuleOp spvModuleOp, OpAdaptor adaptor,
1714 rewriter.
create<ModuleOp>(spvModuleOp.getLoc(), spvModuleOp.getName());
1718 rewriter.
eraseBlock(&newModuleOp.getBodyRegion().back());
1719 rewriter.
eraseOp(spvModuleOp);
1728 class VectorShufflePattern
1733 matchAndRewrite(spirv::VectorShuffleOp op, OpAdaptor adaptor,
1736 auto components = adaptor.getComponents();
1737 auto vector1 = adaptor.getVector1();
1738 auto vector2 = adaptor.getVector2();
1739 int vector1Size = cast<VectorType>(vector1.getType()).getNumElements();
1740 int vector2Size = cast<VectorType>(vector2.getType()).getNumElements();
1741 if (vector1Size == vector2Size) {
1743 op, vector1, vector2,
1744 LLVM::convertArrayToIndices<int32_t>(components));
1748 auto dstType = getTypeConverter()->convertType(op.getType());
1751 auto scalarType = cast<VectorType>(dstType).getElementType();
1752 auto componentsArray = components.getValue();
1755 Value targetOp = rewriter.
create<LLVM::PoisonOp>(loc, dstType);
1756 for (
unsigned i = 0; i < componentsArray.size(); i++) {
1757 if (!isa<IntegerAttr>(componentsArray[i]))
1758 return op.
emitError(
"unable to support non-constant component");
1760 int indexVal = cast<IntegerAttr>(componentsArray[i]).getInt();
1765 Value baseVector = vector1;
1766 if (indexVal >= vector1Size) {
1767 offsetVal = vector1Size;
1768 baseVector = vector2;
1771 Value dstIndex = rewriter.
create<LLVM::ConstantOp>(
1777 auto extractOp = rewriter.
create<LLVM::ExtractElementOp>(
1778 loc, scalarType, baseVector, index);
1779 targetOp = rewriter.
create<LLVM::InsertElementOp>(loc, dstType, targetOp,
1780 extractOp, dstIndex);
1793 spirv::ClientAPI clientAPI) {
1810 spirv::ClientAPI clientAPI) {
1813 DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>,
1814 DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>,
1815 DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>,
1816 DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>,
1817 DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>,
1818 DirectConversionPattern<spirv::FMulOp, LLVM::FMulOp>,
1819 DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>,
1820 DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>,
1821 DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>,
1822 DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>,
1823 DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>,
1824 DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>,
1825 DirectConversionPattern<spirv::UModOp, LLVM::URemOp>,
1828 BitFieldInsertPattern, BitFieldUExtractPattern, BitFieldSExtractPattern,
1829 DirectConversionPattern<spirv::BitCountOp, LLVM::CtPopOp>,
1830 DirectConversionPattern<spirv::BitReverseOp, LLVM::BitReverseOp>,
1831 DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>,
1832 DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>,
1833 DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>,
1834 NotPattern<spirv::NotOp>,
1837 BitcastConversionPattern,
1838 DirectConversionPattern<spirv::ConvertFToSOp, LLVM::FPToSIOp>,
1839 DirectConversionPattern<spirv::ConvertFToUOp, LLVM::FPToUIOp>,
1840 DirectConversionPattern<spirv::ConvertSToFOp, LLVM::SIToFPOp>,
1841 DirectConversionPattern<spirv::ConvertUToFOp, LLVM::UIToFPOp>,
1842 IndirectCastPattern<spirv::FConvertOp, LLVM::FPExtOp, LLVM::FPTruncOp>,
1843 IndirectCastPattern<spirv::SConvertOp, LLVM::SExtOp, LLVM::TruncOp>,
1844 IndirectCastPattern<spirv::UConvertOp, LLVM::ZExtOp, LLVM::TruncOp>,
1847 IComparePattern<spirv::IEqualOp, LLVM::ICmpPredicate::eq>,
1848 IComparePattern<spirv::INotEqualOp, LLVM::ICmpPredicate::ne>,
1849 FComparePattern<spirv::FOrdEqualOp, LLVM::FCmpPredicate::oeq>,
1850 FComparePattern<spirv::FOrdGreaterThanOp, LLVM::FCmpPredicate::ogt>,
1851 FComparePattern<spirv::FOrdGreaterThanEqualOp, LLVM::FCmpPredicate::oge>,
1852 FComparePattern<spirv::FOrdLessThanEqualOp, LLVM::FCmpPredicate::ole>,
1853 FComparePattern<spirv::FOrdLessThanOp, LLVM::FCmpPredicate::olt>,
1854 FComparePattern<spirv::FOrdNotEqualOp, LLVM::FCmpPredicate::one>,
1855 FComparePattern<spirv::FUnordEqualOp, LLVM::FCmpPredicate::ueq>,
1856 FComparePattern<spirv::FUnordGreaterThanOp, LLVM::FCmpPredicate::ugt>,
1857 FComparePattern<spirv::FUnordGreaterThanEqualOp,
1858 LLVM::FCmpPredicate::uge>,
1859 FComparePattern<spirv::FUnordLessThanEqualOp, LLVM::FCmpPredicate::ule>,
1860 FComparePattern<spirv::FUnordLessThanOp, LLVM::FCmpPredicate::ult>,
1861 FComparePattern<spirv::FUnordNotEqualOp, LLVM::FCmpPredicate::une>,
1862 IComparePattern<spirv::SGreaterThanOp, LLVM::ICmpPredicate::sgt>,
1863 IComparePattern<spirv::SGreaterThanEqualOp, LLVM::ICmpPredicate::sge>,
1864 IComparePattern<spirv::SLessThanEqualOp, LLVM::ICmpPredicate::sle>,
1865 IComparePattern<spirv::SLessThanOp, LLVM::ICmpPredicate::slt>,
1866 IComparePattern<spirv::UGreaterThanOp, LLVM::ICmpPredicate::ugt>,
1867 IComparePattern<spirv::UGreaterThanEqualOp, LLVM::ICmpPredicate::uge>,
1868 IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>,
1869 IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>,
1872 ConstantScalarAndVectorPattern,
1875 BranchConversionPattern, BranchConditionalConversionPattern,
1876 FunctionCallPattern, LoopPattern, SelectionPattern,
1877 ErasePattern<spirv::MergeOp>,
1880 ErasePattern<spirv::EntryPointOp>, ExecutionModePattern,
1883 DirectConversionPattern<spirv::GLCeilOp, LLVM::FCeilOp>,
1884 DirectConversionPattern<spirv::GLCosOp, LLVM::CosOp>,
1885 DirectConversionPattern<spirv::GLExpOp, LLVM::ExpOp>,
1886 DirectConversionPattern<spirv::GLFAbsOp, LLVM::FAbsOp>,
1887 DirectConversionPattern<spirv::GLFloorOp, LLVM::FFloorOp>,
1888 DirectConversionPattern<spirv::GLFMaxOp, LLVM::MaxNumOp>,
1889 DirectConversionPattern<spirv::GLFMinOp, LLVM::MinNumOp>,
1890 DirectConversionPattern<spirv::GLLogOp, LLVM::LogOp>,
1891 DirectConversionPattern<spirv::GLSinOp, LLVM::SinOp>,
1892 DirectConversionPattern<spirv::GLSMaxOp, LLVM::SMaxOp>,
1893 DirectConversionPattern<spirv::GLSMinOp, LLVM::SMinOp>,
1894 DirectConversionPattern<spirv::GLSqrtOp, LLVM::SqrtOp>,
1895 InverseSqrtPattern, TanPattern, TanhPattern,
1898 DirectConversionPattern<spirv::LogicalAndOp, LLVM::AndOp>,
1899 DirectConversionPattern<spirv::LogicalOrOp, LLVM::OrOp>,
1900 IComparePattern<spirv::LogicalEqualOp, LLVM::ICmpPredicate::eq>,
1901 IComparePattern<spirv::LogicalNotEqualOp, LLVM::ICmpPredicate::ne>,
1902 NotPattern<spirv::LogicalNotOp>,
1905 AccessChainPattern, AddressOfPattern, LoadStorePattern<spirv::LoadOp>,
1906 LoadStorePattern<spirv::StoreOp>, VariablePattern,
1909 CompositeExtractPattern, CompositeInsertPattern,
1910 DirectConversionPattern<spirv::SelectOp, LLVM::SelectOp>,
1911 DirectConversionPattern<spirv::UndefOp, LLVM::UndefOp>,
1912 VectorShufflePattern,
1915 ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>,
1916 ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>,
1917 ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>,
1920 ReturnPattern, ReturnValuePattern,
1923 ControlBarrierPattern<spirv::ControlBarrierOp>,
1924 ControlBarrierPattern<spirv::INTELControlBarrierArriveOp>,
1925 ControlBarrierPattern<spirv::INTELControlBarrierWaitOp>,
1928 GroupReducePattern<spirv::GroupIAddOp>,
1929 GroupReducePattern<spirv::GroupFAddOp>,
1930 GroupReducePattern<spirv::GroupFMinOp>,
1931 GroupReducePattern<spirv::GroupUMinOp>,
1932 GroupReducePattern<spirv::GroupSMinOp,
true>,
1933 GroupReducePattern<spirv::GroupFMaxOp>,
1934 GroupReducePattern<spirv::GroupUMaxOp>,
1935 GroupReducePattern<spirv::GroupSMaxOp,
true>,
1936 GroupReducePattern<spirv::GroupNonUniformIAddOp,
false,
1938 GroupReducePattern<spirv::GroupNonUniformFAddOp,
false,
1940 GroupReducePattern<spirv::GroupNonUniformIMulOp,
false,
1942 GroupReducePattern<spirv::GroupNonUniformFMulOp,
false,
1944 GroupReducePattern<spirv::GroupNonUniformSMinOp,
true,
1946 GroupReducePattern<spirv::GroupNonUniformUMinOp,
false,
1948 GroupReducePattern<spirv::GroupNonUniformFMinOp,
false,
1950 GroupReducePattern<spirv::GroupNonUniformSMaxOp,
true,
1952 GroupReducePattern<spirv::GroupNonUniformUMaxOp,
false,
1954 GroupReducePattern<spirv::GroupNonUniformFMaxOp,
false,
1956 GroupReducePattern<spirv::GroupNonUniformBitwiseAndOp,
false,
1958 GroupReducePattern<spirv::GroupNonUniformBitwiseOrOp,
false,
1960 GroupReducePattern<spirv::GroupNonUniformBitwiseXorOp,
false,
1962 GroupReducePattern<spirv::GroupNonUniformLogicalAndOp,
false,
1964 GroupReducePattern<spirv::GroupNonUniformLogicalOrOp,
false,
1966 GroupReducePattern<spirv::GroupNonUniformLogicalXorOp,
false,
1981 patterns.add<ModuleConversionPattern>(
patterns.getContext(), typeConverter);
1992 auto spvModules = module.getOps<spirv::ModuleOp>();
1993 for (
auto spvModule : spvModules) {
1994 spvModule.walk([&](spirv::GlobalVariableOp op) {
1995 IntegerAttr descriptorSet =
1997 IntegerAttr binding = op->getAttrOfType<IntegerAttr>(
kBinding);
2000 if (descriptorSet && binding) {
2003 auto moduleAndName =
2004 spvModule.getName().has_value()
2005 ? spvModule.getName()->str() +
"_" + op.getSymName().str()
2006 : op.getSymName().str();
2008 llvm::formatv(
"{0}_descriptor_set{1}_binding{2}", moduleAndName,
2009 std::to_string(descriptorSet.getInt()),
2010 std::to_string(binding.getInt()));
2016 op.emitError(
"unable to replace all symbol uses for ") << name;
static LLVM::CallOp createSPIRVBuiltinCall(Location loc, ConversionPatternRewriter &rewriter, LLVM::LLVMFuncOp func, ValueRange args)
static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable, StringRef name, ArrayRef< Type > paramTypes, Type resultType, bool isMemNone, bool isConvergent)
static MLIRContext * getContext(OpFoldResult val)
static Value optionallyTruncateOrExtend(Location loc, Value value, Type llvmType, PatternRewriter &rewriter)
Utility function for bitfield ops:
static Value createFPConstant(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter, double value)
Creates llvm.mlir.constant with a floating-point scalar or vector value.
static constexpr StringRef kDescriptorSet
static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter, unsigned value)
Creates LLVM dialect constant with the given value.
static Type convertPointerType(spirv::PointerType type, const TypeConverter &converter, spirv::ClientAPI clientAPI)
Converts SPIR-V pointer type to LLVM pointer.
static Value processCountOrOffset(Location loc, Value value, Type srcType, Type dstType, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Utility function for bitfield ops: BitFieldInsert, BitFieldSExtract and BitFieldUExtract.
static unsigned getBitWidth(Type type)
Returns the bit width of integer, float or vector of float or integer values.
static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter, const TypeConverter &typeConverter, unsigned alignment, bool isVolatile, bool isNonTemporal)
Utility for spirv.Load and spirv.Store conversion.
static Type convertStructTypePacked(spirv::StructType type, const TypeConverter &converter)
Converts SPIR-V struct with no offset to packed LLVM struct.
static std::optional< Type > convertRuntimeArrayType(spirv::RuntimeArrayType type, TypeConverter &converter)
Converts SPIR-V runtime array to LLVM array.
static bool isSignedIntegerOrVector(Type type)
Returns true if the given type is a signed integer or vector type.
static bool isUnsignedIntegerOrVector(Type type)
Returns true if the given type is an unsigned integer or vector type.
static std::optional< Type > convertArrayType(spirv::ArrayType type, TypeConverter &converter)
Converts SPIR-V array type to LLVM array.
static constexpr StringRef kBinding
Hook for descriptor set and binding number encoding.
static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder)
Creates IntegerAttribute with all bits set for given type.
static Value optionallyBroadcast(Location loc, Value value, Type srcType, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value. If srcType is a scalar, the value remains unchanged.
static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter)
Creates llvm.mlir.constant with all bits set for the given type.
static unsigned getLLVMTypeBitWidth(Type type)
Returns the bit width of LLVMType integer or vector.
#define DISPATCH(functionControl, llvmAttr)
static std::optional< uint64_t > getIntegerOrVectorElementWidth(Type type)
Returns the width of an integer or of the element type of an integer vector, if applicable.
static Type convertStructTypeWithOffset(spirv::StructType type, const TypeConverter &converter)
Converts SPIR-V struct with a regular (according to VulkanLayoutUtils) offset to LLVM struct.
static Type convertStructType(spirv::StructType type, const TypeConverter &converter)
Converts SPIR-V struct to LLVM struct.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
OpListType::iterator iterator
Operation * getTerminator()
Get the terminator operation of this block.
OpListType & getOperations()
BlockArgListType getArguments()
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI32IntegerAttr(int32_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Block * getBlock() const
Returns the current block of the builder.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
operand_range getOperands()
Returns an iterator on the underlying Value's.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, StringAttr newSymbol, Operation *from)
Attempt to replace all uses of the given symbol 'oldSymbol' with the provided symbol 'newSymbol' that...
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
static void setSymbolName(Operation *symbol, StringAttr name)
Sets the name of the given symbol operation.
This class provides all of the information necessary to convert a type signature.
void addConversion(FnT &&callback)
Register a conversion function.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static spirv::StructType decorateType(spirv::StructType structType)
Returns a new StructType with layout decoration.
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Builder from ArrayRef<T>.
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
unsigned getNumElements() const
StorageClass getStorageClass() const
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
TypeRange getElementTypes() const
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
SmallVector< IntT > convertArrayToIndices(ArrayRef< Attribute > attrs)
Convert an array of integer attributes to a vector of integers that can be used as indices in LLVM op...
Include the generated interface declarations.
unsigned storageClassToAddressSpace(spirv::ClientAPI clientAPI, spirv::StorageClass storageClass)
void populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter, spirv::ClientAPI clientAPIForAddressSpaceMapping=spirv::ClientAPI::Unknown)
Populates type conversions with additional SPIR-V types.
void populateSPIRVToLLVMFunctionConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given list with patterns for function conversion from SPIR-V to LLVM.
const FrozenRewritePatternSet & patterns
void populateSPIRVToLLVMConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, spirv::ClientAPI clientAPIForAddressSpaceMapping=spirv::ClientAPI::Unknown)
Populates the given list with patterns that convert from SPIR-V to LLVM.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void encodeBindAttribute(ModuleOp module)
Encodes global variable's descriptor set and binding into its name if they both exist.
void populateSPIRVToLLVMModuleConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given patterns for module conversion from SPIR-V to LLVM.