25 #include "llvm/ADT/TypeSwitch.h"
26 #include "llvm/Support/Debug.h"
27 #include "llvm/Support/FormatVariadic.h"
29 #define DEBUG_TYPE "spirv-to-llvm-pattern"
41 if (
auto vecType = dyn_cast<VectorType>(type))
42 return vecType.getElementType().isSignedInteger();
50 if (
auto vecType = dyn_cast<VectorType>(type))
51 return vecType.getElementType().isUnsignedInteger();
58 if (
auto intType = dyn_cast<IntegerType>(type))
59 return intType.getWidth();
60 if (
auto vecType = dyn_cast<VectorType>(type))
61 if (
auto intType = dyn_cast<IntegerType>(vecType.getElementType()))
62 return intType.getWidth();
69 "bitwidth is not supported for this type");
72 auto vecType = dyn_cast<VectorType>(type);
73 auto elementType = vecType.getElementType();
74 assert(elementType.isIntOrFloat() &&
75 "only integers and floats have a bitwidth");
76 return elementType.getIntOrFloatBitWidth();
89 if (
auto vecType = dyn_cast<VectorType>(type)) {
90 auto integerType = cast<IntegerType>(vecType.getElementType());
93 auto integerType = cast<IntegerType>(type);
100 if (isa<VectorType>(srcType)) {
101 return rewriter.
create<LLVM::ConstantOp>(
106 return rewriter.
create<LLVM::ConstantOp>(
113 if (
auto vecType = dyn_cast<VectorType>(srcType)) {
114 auto floatType = cast<FloatType>(vecType.getElementType());
115 return rewriter.
create<LLVM::ConstantOp>(
120 auto floatType = cast<FloatType>(srcType);
121 return rewriter.
create<LLVM::ConstantOp>(
134 auto srcType = value.
getType();
140 if (valueBitWidth < targetBitWidth)
141 return rewriter.
create<LLVM::ZExtOp>(loc, llvmType, value);
146 if (valueBitWidth > targetBitWidth)
147 return rewriter.
create<LLVM::TruncOp>(loc, llvmType, value);
156 auto llvmVectorType = typeConverter.
convertType(vectorType);
158 Value broadcasted = rewriter.
create<LLVM::UndefOp>(loc, llvmVectorType);
159 for (
unsigned i = 0; i < numElements; ++i) {
160 auto index = rewriter.
create<LLVM::ConstantOp>(
162 broadcasted = rewriter.
create<LLVM::InsertElementOp>(
163 loc, llvmVectorType, broadcasted, toBroadcast, index);
172 if (
auto vectorType = dyn_cast<VectorType>(srcType)) {
173 unsigned numElements = vectorType.getNumElements();
174 return broadcast(loc, value, numElements, typeConverter, rewriter);
224 return rewriter.
create<LLVM::ConstantOp>(
233 unsigned alignment,
bool isVolatile,
234 bool isNonTemporal) {
235 if (
auto loadOp = dyn_cast<spirv::LoadOp>(op)) {
236 auto dstType = typeConverter.
convertType(loadOp.getType());
240 loadOp, dstType, spirv::LoadOpAdaptor(operands).getPtr(), alignment,
241 isVolatile, isNonTemporal);
244 auto storeOp = cast<spirv::StoreOp>(op);
245 spirv::StoreOpAdaptor adaptor(operands);
247 adaptor.getPtr(), alignment,
248 isVolatile, isNonTemporal);
263 auto sizeInBytes = cast<spirv::SPIRVType>(elementType).getSizeInBytes();
264 if (stride != 0 && (!sizeInBytes || *sizeInBytes != stride))
267 auto llvmElementType = converter.
convertType(elementType);
276 spirv::ClientAPI clientAPI) {
277 unsigned addressSpace =
299 if (!memberDecorations.empty())
317 matchAndRewrite(spirv::AccessChainOp op, OpAdaptor adaptor,
320 getTypeConverter()->convertType(op.getComponentPtr().getType());
324 auto indices = llvm::to_vector<4>(adaptor.getIndices());
325 Type indexType = op.getIndices().front().getType();
326 auto llvmIndexType = getTypeConverter()->convertType(indexType);
330 op.getLoc(), llvmIndexType, rewriter.
getIntegerAttr(indexType, 0));
331 indices.insert(indices.begin(), zero);
333 auto elementType = getTypeConverter()->convertType(
334 cast<spirv::PointerType>(op.getBasePtr().getType()).getPointeeType());
338 adaptor.getBasePtr(), indices);
348 matchAndRewrite(spirv::AddressOfOp op, OpAdaptor adaptor,
350 auto dstType = getTypeConverter()->convertType(op.getPointer().getType());
359 class BitFieldInsertPattern
365 matchAndRewrite(spirv::BitFieldInsertOp op, OpAdaptor adaptor,
367 auto srcType = op.getType();
368 auto dstType = getTypeConverter()->convertType(srcType);
375 *getTypeConverter(), rewriter);
377 *getTypeConverter(), rewriter);
381 Value maskShiftedByCount =
382 rewriter.
create<LLVM::ShlOp>(loc, dstType, minusOne, count);
383 Value negated = rewriter.
create<LLVM::XOrOp>(loc, dstType,
384 maskShiftedByCount, minusOne);
385 Value maskShiftedByCountAndOffset =
386 rewriter.
create<LLVM::ShlOp>(loc, dstType, negated, offset);
388 loc, dstType, maskShiftedByCountAndOffset, minusOne);
393 rewriter.
create<LLVM::AndOp>(loc, dstType, op.getBase(), mask);
394 Value insertShiftedByOffset =
395 rewriter.
create<LLVM::ShlOp>(loc, dstType, op.getInsert(), offset);
397 insertShiftedByOffset);
403 class ConstantScalarAndVectorPattern
409 matchAndRewrite(spirv::ConstantOp constOp, OpAdaptor adaptor,
411 auto srcType = constOp.getType();
412 if (!isa<VectorType>(srcType) && !srcType.isIntOrFloat())
415 auto dstType = getTypeConverter()->convertType(srcType);
428 if (isa<VectorType>(srcType)) {
429 auto dstElementsAttr = cast<DenseIntElementsAttr>(constOp.getValue());
432 dstElementsAttr.mapValues(
433 signlessType, [&](
const APInt &value) {
return value; }));
436 auto srcAttr = cast<IntegerAttr>(constOp.getValue());
437 auto dstAttr = rewriter.
getIntegerAttr(signlessType, srcAttr.getValue());
442 constOp, dstType, adaptor.getOperands(), constOp->getAttrs());
447 class BitFieldSExtractPattern
453 matchAndRewrite(spirv::BitFieldSExtractOp op, OpAdaptor adaptor,
455 auto srcType = op.getType();
456 auto dstType = getTypeConverter()->convertType(srcType);
463 *getTypeConverter(), rewriter);
465 *getTypeConverter(), rewriter);
468 IntegerType integerType;
469 if (
auto vecType = dyn_cast<VectorType>(srcType))
470 integerType = cast<IntegerType>(vecType.getElementType());
472 integerType = cast<IntegerType>(srcType);
476 isa<VectorType>(srcType)
477 ? rewriter.
create<LLVM::ConstantOp>(
480 : rewriter.
create<LLVM::ConstantOp>(loc, dstType, baseSize);
484 Value countPlusOffset =
485 rewriter.
create<LLVM::AddOp>(loc, dstType, count, offset);
486 Value amountToShiftLeft =
487 rewriter.
create<LLVM::SubOp>(loc, dstType, size, countPlusOffset);
488 Value baseShiftedLeft = rewriter.
create<LLVM::ShlOp>(
489 loc, dstType, op.getBase(), amountToShiftLeft);
492 Value amountToShiftRight =
493 rewriter.
create<LLVM::AddOp>(loc, dstType, offset, amountToShiftLeft);
500 class BitFieldUExtractPattern
506 matchAndRewrite(spirv::BitFieldUExtractOp op, OpAdaptor adaptor,
508 auto srcType = op.getType();
509 auto dstType = getTypeConverter()->convertType(srcType);
516 *getTypeConverter(), rewriter);
518 *getTypeConverter(), rewriter);
522 Value maskShiftedByCount =
523 rewriter.
create<LLVM::ShlOp>(loc, dstType, minusOne, count);
524 Value mask = rewriter.
create<LLVM::XOrOp>(loc, dstType, maskShiftedByCount,
529 rewriter.
create<LLVM::LShrOp>(loc, dstType, op.getBase(), offset);
540 matchAndRewrite(spirv::BranchOp branchOp, OpAdaptor adaptor,
543 branchOp.getTarget());
548 class BranchConditionalConversionPattern
555 matchAndRewrite(spirv::BranchConditionalOp op, OpAdaptor adaptor,
559 if (
auto weights = op.getBranchWeights()) {
561 for (
auto weight : weights->getAsRange<IntegerAttr>())
562 weightValues.push_back(weight.getInt());
567 op, op.getCondition(), op.getTrueBlockArguments(),
568 op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(),
577 class CompositeExtractPattern
583 matchAndRewrite(spirv::CompositeExtractOp op, OpAdaptor adaptor,
585 auto dstType = this->getTypeConverter()->convertType(op.getType());
589 Type containerType = op.getComposite().getType();
590 if (isa<VectorType>(containerType)) {
592 IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]);
595 op, dstType, adaptor.getComposite(), index);
600 op, adaptor.getComposite(),
609 class CompositeInsertPattern
615 matchAndRewrite(spirv::CompositeInsertOp op, OpAdaptor adaptor,
617 auto dstType = this->getTypeConverter()->convertType(op.getType());
621 Type containerType = op.getComposite().getType();
622 if (isa<VectorType>(containerType)) {
624 IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]);
627 op, dstType, adaptor.getComposite(), adaptor.getObject(), index);
632 op, adaptor.getComposite(), adaptor.getObject(),
640 template <
typename SPIRVOp,
typename LLVMOp>
646 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
648 auto dstType = this->getTypeConverter()->convertType(op.getType());
651 rewriter.template replaceOpWithNewOp<LLVMOp>(
652 op, dstType, adaptor.getOperands(), op->getAttrs());
659 class ExecutionModePattern
665 matchAndRewrite(spirv::ExecutionModeOp op, OpAdaptor adaptor,
670 ModuleOp module = op->getParentOfType<ModuleOp>();
671 spirv::ExecutionModeAttr executionModeAttr = op.getExecutionModeAttr();
672 std::string moduleName;
673 if (module.getName().has_value())
674 moduleName =
"_" + module.getName()->str();
677 std::string executionModeInfoName = llvm::formatv(
678 "__spv_{0}_{1}_execution_mode_info_{2}", moduleName, op.getFn().str(),
679 static_cast<uint32_t
>(executionModeAttr.getValue()));
692 fields.push_back(llvmI32Type);
693 ArrayAttr values = op.getValues();
694 if (!values.empty()) {
696 fields.push_back(arrayType);
701 auto global = rewriter.
create<LLVM::GlobalOp>(
703 LLVM::Linkage::External, executionModeInfoName,
Attribute(),
706 Region ®ion = global.getInitializerRegion();
711 Value structValue = rewriter.
create<LLVM::UndefOp>(loc, structType);
712 Value executionMode = rewriter.
create<LLVM::ConstantOp>(
715 static_cast<uint32_t
>(executionModeAttr.getValue())));
716 structValue = rewriter.
create<LLVM::InsertValueOp>(loc, structValue,
720 for (
unsigned i = 0, e = values.size(); i < e; ++i) {
721 auto attr = values.getValue()[i];
722 Value entry = rewriter.
create<LLVM::ConstantOp>(loc, llvmI32Type, attr);
723 structValue = rewriter.
create<LLVM::InsertValueOp>(
736 class GlobalVariablePattern
739 template <
typename... Args>
740 GlobalVariablePattern(spirv::ClientAPI clientAPI, Args &&...args)
742 std::forward<Args>(args)...),
743 clientAPI(clientAPI) {}
746 matchAndRewrite(spirv::GlobalVariableOp op, OpAdaptor adaptor,
750 if (op.getInitializer())
753 auto srcType = cast<spirv::PointerType>(op.getType());
754 auto dstType = getTypeConverter()->convertType(srcType.getPointeeType());
761 auto storageClass = srcType.getStorageClass();
762 switch (storageClass) {
763 case spirv::StorageClass::Input:
764 case spirv::StorageClass::Private:
765 case spirv::StorageClass::Output:
766 case spirv::StorageClass::StorageBuffer:
767 case spirv::StorageClass::UniformConstant:
776 bool isConstant = (storageClass == spirv::StorageClass::Input) ||
777 (storageClass == spirv::StorageClass::UniformConstant);
783 auto linkage = storageClass == spirv::StorageClass::Private
784 ? LLVM::Linkage::Private
785 : LLVM::Linkage::External;
787 op, dstType, isConstant, linkage, op.getSymName(),
Attribute(),
791 if (op.getLocationAttr())
792 newGlobalOp->setAttr(op.getLocationAttrName(), op.getLocationAttr());
798 spirv::ClientAPI clientAPI;
803 template <
typename SPIRVOp,
typename LLVMExtOp,
typename LLVMTruncOp>
809 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
812 Type fromType = op.getOperand().getType();
813 Type toType = op.getType();
815 auto dstType = this->getTypeConverter()->convertType(toType);
820 rewriter.template replaceOpWithNewOp<LLVMExtOp>(op, dstType,
821 adaptor.getOperands());
825 rewriter.template replaceOpWithNewOp<LLVMTruncOp>(op, dstType,
826 adaptor.getOperands());
833 class FunctionCallPattern
839 matchAndRewrite(spirv::FunctionCallOp callOp, OpAdaptor adaptor,
841 if (callOp.getNumResults() == 0) {
843 callOp, std::nullopt, adaptor.getOperands(), callOp->getAttrs());
844 newOp.getProperties().operandSegmentSizes = {
845 static_cast<int32_t
>(adaptor.getOperands().size()), 0};
851 auto dstType = getTypeConverter()->convertType(callOp.getType(0));
855 callOp, dstType, adaptor.getOperands(), callOp->getAttrs());
856 newOp.getProperties().operandSegmentSizes = {
857 static_cast<int32_t
>(adaptor.getOperands().size()), 0};
864 template <
typename SPIRVOp, LLVM::FCmpPredicate predicate>
870 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
873 auto dstType = this->getTypeConverter()->convertType(op.getType());
877 rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>(
878 op, dstType, predicate, op.getOperand1(), op.getOperand2());
884 template <
typename SPIRVOp, LLVM::ICmpPredicate predicate>
890 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
893 auto dstType = this->getTypeConverter()->convertType(op.getType());
897 rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>(
898 op, dstType, predicate, op.getOperand1(), op.getOperand2());
903 class InverseSqrtPattern
909 matchAndRewrite(spirv::GLInverseSqrtOp op, OpAdaptor adaptor,
911 auto srcType = op.getType();
912 auto dstType = getTypeConverter()->convertType(srcType);
918 Value sqrt = rewriter.
create<LLVM::SqrtOp>(loc, dstType, op.getOperand());
925 template <
typename SPIRVOp>
931 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
933 if (!op.getMemoryAccess()) {
935 *this->getTypeConverter(), 0,
939 auto memoryAccess = *op.getMemoryAccess();
940 switch (memoryAccess) {
941 case spirv::MemoryAccess::Aligned:
943 case spirv::MemoryAccess::Nontemporal:
944 case spirv::MemoryAccess::Volatile: {
946 memoryAccess == spirv::MemoryAccess::Aligned ? *op.getAlignment() : 0;
947 bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal;
948 bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile;
950 *this->getTypeConverter(), alignment,
951 isVolatile, isNonTemporal);
961 template <
typename SPIRVOp>
967 matchAndRewrite(SPIRVOp notOp,
typename SPIRVOp::Adaptor adaptor,
969 auto srcType = notOp.getType();
970 auto dstType = this->getTypeConverter()->convertType(srcType);
977 isa<VectorType>(srcType)
978 ? rewriter.
create<LLVM::ConstantOp>(
981 : rewriter.
create<LLVM::ConstantOp>(loc, dstType, minusOne);
982 rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType,
983 notOp.getOperand(), mask);
989 template <
typename SPIRVOp>
995 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
1007 matchAndRewrite(spirv::ReturnOp returnOp, OpAdaptor adaptor,
1020 matchAndRewrite(spirv::ReturnValueOp returnValueOp, OpAdaptor adaptor,
1023 adaptor.getOperands());
1032 bool convergent =
true) {
1033 auto func = dyn_cast_or_null<LLVM::LLVMFuncOp>(
1039 func = b.create<LLVM::LLVMFuncOp>(
1040 symbolTable->
getLoc(), name,
1042 func.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
1043 func.setConvergent(convergent);
1044 func.setNoUnwind(
true);
1045 func.setWillReturn(
true);
1050 LLVM::LLVMFuncOp func,
1052 auto call = builder.
create<LLVM::CallOp>(loc, func, args);
1053 call.setCConv(func.getCConv());
1054 call.setConvergentAttr(func.getConvergentAttr());
1055 call.setNoUnwindAttr(func.getNoUnwindAttr());
1056 call.setWillReturnAttr(func.getWillReturnAttr());
1060 template <
typename BarrierOpTy>
1067 static constexpr StringRef getFuncName();
1070 matchAndRewrite(BarrierOpTy controlBarrierOp, OpAdaptor adaptor,
1072 constexpr StringRef funcName = getFuncName();
1074 controlBarrierOp->template getParentWithTrait<OpTrait::SymbolTable>();
1078 Type voidTy = rewriter.
getType<LLVM::LLVMVoidType>();
1079 LLVM::LLVMFuncOp func =
1082 Location loc = controlBarrierOp->getLoc();
1083 Value execution = rewriter.
create<LLVM::ConstantOp>(
1084 loc, i32,
static_cast<int32_t
>(adaptor.getExecutionScope()));
1086 loc, i32,
static_cast<int32_t
>(adaptor.getMemoryScope()));
1087 Value semantics = rewriter.
create<LLVM::ConstantOp>(
1088 loc, i32,
static_cast<int32_t
>(adaptor.getMemorySemantics()));
1091 {execution, memory, semantics});
1093 rewriter.
replaceOp(controlBarrierOp, call);
1100 StringRef getTypeMangling(
Type type,
bool isSigned) {
1102 .Case<Float16Type>([](
auto) {
return "Dh"; })
1103 .Case<Float32Type>([](
auto) {
return "f"; })
1104 .Case<Float64Type>([](
auto) {
return "d"; })
1105 .Case<IntegerType>([isSigned](IntegerType intTy) {
1106 switch (intTy.getWidth()) {
1110 return (isSigned) ?
"a" :
"c";
1112 return (isSigned) ?
"s" :
"t";
1114 return (isSigned) ?
"i" :
"j";
1116 return (isSigned) ?
"l" :
"m";
1118 llvm_unreachable(
"Unsupported integer width");
1122 llvm_unreachable(
"No mangling defined");
1127 template <
typename ReduceOp>
1128 constexpr StringLiteral getGroupFuncName();
1131 constexpr StringLiteral getGroupFuncName<spirv::GroupIAddOp>() {
1132 return "_Z17__spirv_GroupIAddii";
1135 constexpr StringLiteral getGroupFuncName<spirv::GroupFAddOp>() {
1136 return "_Z17__spirv_GroupFAddii";
1139 constexpr StringLiteral getGroupFuncName<spirv::GroupSMinOp>() {
1140 return "_Z17__spirv_GroupSMinii";
1143 constexpr StringLiteral getGroupFuncName<spirv::GroupUMinOp>() {
1144 return "_Z17__spirv_GroupUMinii";
1147 constexpr StringLiteral getGroupFuncName<spirv::GroupFMinOp>() {
1148 return "_Z17__spirv_GroupFMinii";
1151 constexpr StringLiteral getGroupFuncName<spirv::GroupSMaxOp>() {
1152 return "_Z17__spirv_GroupSMaxii";
1155 constexpr StringLiteral getGroupFuncName<spirv::GroupUMaxOp>() {
1156 return "_Z17__spirv_GroupUMaxii";
1159 constexpr StringLiteral getGroupFuncName<spirv::GroupFMaxOp>() {
1160 return "_Z17__spirv_GroupFMaxii";
1163 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIAddOp>() {
1164 return "_Z27__spirv_GroupNonUniformIAddii";
1167 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFAddOp>() {
1168 return "_Z27__spirv_GroupNonUniformFAddii";
1171 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIMulOp>() {
1172 return "_Z27__spirv_GroupNonUniformIMulii";
1175 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMulOp>() {
1176 return "_Z27__spirv_GroupNonUniformFMulii";
1179 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMinOp>() {
1180 return "_Z27__spirv_GroupNonUniformSMinii";
1183 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMinOp>() {
1184 return "_Z27__spirv_GroupNonUniformUMinii";
1187 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMinOp>() {
1188 return "_Z27__spirv_GroupNonUniformFMinii";
1191 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMaxOp>() {
1192 return "_Z27__spirv_GroupNonUniformSMaxii";
1195 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMaxOp>() {
1196 return "_Z27__spirv_GroupNonUniformUMaxii";
1199 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMaxOp>() {
1200 return "_Z27__spirv_GroupNonUniformFMaxii";
1203 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseAndOp>() {
1204 return "_Z33__spirv_GroupNonUniformBitwiseAndii";
1207 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseOrOp>() {
1208 return "_Z32__spirv_GroupNonUniformBitwiseOrii";
1211 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseXorOp>() {
1212 return "_Z33__spirv_GroupNonUniformBitwiseXorii";
1215 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalAndOp>() {
1216 return "_Z33__spirv_GroupNonUniformLogicalAndii";
1219 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalOrOp>() {
1220 return "_Z32__spirv_GroupNonUniformLogicalOrii";
1223 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalXorOp>() {
1224 return "_Z33__spirv_GroupNonUniformLogicalXorii";
1228 template <
typename ReduceOp,
bool Signed = false,
bool NonUniform = false>
1234 matchAndRewrite(ReduceOp op,
typename ReduceOp::Adaptor adaptor,
1237 Type retTy = op.getResult().getType();
1242 funcName += getTypeMangling(retTy,
false);
1246 if constexpr (NonUniform) {
1247 if (adaptor.getClusterSize()) {
1249 paramTypes.push_back(i32Ty);
1254 op->template getParentWithTrait<OpTrait::SymbolTable>();
1257 symbolTable, funcName, paramTypes, retTy, !NonUniform);
1261 loc, i32Ty,
static_cast<int32_t
>(adaptor.getExecutionScope()));
1262 Value groupOp = rewriter.
create<LLVM::ConstantOp>(
1263 loc, i32Ty,
static_cast<int32_t
>(adaptor.getGroupOperation()));
1265 operands.append(adaptor.getOperands().begin(), adaptor.getOperands().end());
1275 ControlBarrierPattern<spirv::ControlBarrierOp>::getFuncName() {
1276 return "_Z22__spirv_ControlBarrieriii";
1281 ControlBarrierPattern<spirv::INTELControlBarrierArriveOp>::getFuncName() {
1282 return "_Z33__spirv_ControlBarrierArriveINTELiii";
1287 ControlBarrierPattern<spirv::INTELControlBarrierWaitOp>::getFuncName() {
1288 return "_Z31__spirv_ControlBarrierWaitINTELiii";
1344 matchAndRewrite(spirv::LoopOp loopOp, OpAdaptor adaptor,
1351 if (loopOp.getBody().empty()) {
1366 Block *entryBlock = loopOp.getEntryBlock();
1368 auto brOp = dyn_cast<spirv::BranchOp>(entryBlock->
getOperations().front());
1371 Block *headerBlock = loopOp.getHeaderBlock();
1373 rewriter.
create<LLVM::BrOp>(loc, brOp.getBlockArguments(), headerBlock);
1377 Block *mergeBlock = loopOp.getMergeBlock();
1381 rewriter.
create<LLVM::BrOp>(loc, terminatorOperands, endBlock);
1397 matchAndRewrite(spirv::SelectionOp op, OpAdaptor adaptor,
1409 if (op.getBody().getBlocks().size() <= 2) {
1421 auto *continueBlock = rewriter.
splitBlock(currentBlock, position);
1427 auto *headerBlock = op.getHeaderBlock();
1429 auto condBrOp = dyn_cast<spirv::BranchConditionalOp>(
1436 auto *mergeBlock = op.getMergeBlock();
1440 rewriter.
create<LLVM::BrOp>(loc, terminatorOperands, continueBlock);
1443 Block *trueBlock = condBrOp.getTrueBlock();
1444 Block *falseBlock = condBrOp.getFalseBlock();
1446 rewriter.
create<LLVM::CondBrOp>(loc, condBrOp.getCondition(), trueBlock,
1447 condBrOp.getTrueTargetOperands(),
1449 condBrOp.getFalseTargetOperands());
1452 rewriter.
replaceOp(op, continueBlock->getArguments());
1461 template <
typename SPIRVOp,
typename LLVMOp>
1467 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
1470 auto dstType = this->getTypeConverter()->convertType(op.getType());
1474 Type op1Type = op.getOperand1().getType();
1475 Type op2Type = op.getOperand2().getType();
1477 if (op1Type == op2Type) {
1478 rewriter.template replaceOpWithNewOp<LLVMOp>(op, dstType,
1479 adaptor.getOperands());
1483 std::optional<uint64_t> dstTypeWidth =
1485 std::optional<uint64_t> op2TypeWidth =
1488 if (!dstTypeWidth || !op2TypeWidth)
1493 if (op2TypeWidth < dstTypeWidth) {
1495 extended = rewriter.template create<LLVM::ZExtOp>(
1496 loc, dstType, adaptor.getOperand2());
1498 extended = rewriter.template create<LLVM::SExtOp>(
1499 loc, dstType, adaptor.getOperand2());
1501 }
else if (op2TypeWidth == dstTypeWidth) {
1502 extended = adaptor.getOperand2();
1507 Value result = rewriter.template create<LLVMOp>(
1508 loc, dstType, adaptor.getOperand1(), extended);
1519 matchAndRewrite(spirv::GLTanOp tanOp, OpAdaptor adaptor,
1521 auto dstType = getTypeConverter()->convertType(tanOp.getType());
1526 Value sin = rewriter.
create<LLVM::SinOp>(loc, dstType, tanOp.getOperand());
1527 Value cos = rewriter.
create<LLVM::CosOp>(loc, dstType, tanOp.getOperand());
1544 matchAndRewrite(spirv::GLTanhOp tanhOp, OpAdaptor adaptor,
1546 auto srcType = tanhOp.getType();
1547 auto dstType = getTypeConverter()->convertType(srcType);
1554 rewriter.
create<LLVM::FMulOp>(loc, dstType, two, tanhOp.getOperand());
1555 Value exponential = rewriter.
create<LLVM::ExpOp>(loc, dstType, multiplied);
1558 rewriter.
create<LLVM::FSubOp>(loc, dstType, exponential, one);
1560 rewriter.
create<LLVM::FAddOp>(loc, dstType, exponential, one);
1572 matchAndRewrite(spirv::VariableOp varOp, OpAdaptor adaptor,
1574 auto srcType = varOp.getType();
1576 auto pointerTo = cast<spirv::PointerType>(srcType).getPointeeType();
1577 auto init = varOp.getInitializer();
1578 if (init && !pointerTo.isIntOrFloat() && !isa<VectorType>(pointerTo))
1581 auto dstType = getTypeConverter()->convertType(srcType);
1588 auto elementType = getTypeConverter()->convertType(pointerTo);
1595 auto elementType = getTypeConverter()->convertType(pointerTo);
1599 rewriter.
create<LLVM::AllocaOp>(loc, dstType, elementType, size);
1600 rewriter.
create<LLVM::StoreOp>(loc, adaptor.getInitializer(), allocated);
1610 class BitcastConversionPattern
1616 matchAndRewrite(spirv::BitcastOp bitcastOp, OpAdaptor adaptor,
1618 auto dstType = getTypeConverter()->convertType(bitcastOp.getType());
1623 if (isa<LLVM::LLVMPointerType>(dstType)) {
1624 rewriter.
replaceOp(bitcastOp, adaptor.getOperand());
1629 bitcastOp, dstType, adaptor.getOperands(), bitcastOp->getAttrs());
1643 matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor,
1648 auto funcType = funcOp.getFunctionType();
1650 funcType.getNumInputs());
1652 ->convertFunctionSignature(
1654 false, signatureConverter);
1660 StringRef name = funcOp.getName();
1661 auto newFuncOp = rewriter.
create<LLVM::LLVMFuncOp>(loc, name, llvmType);
1665 switch (funcOp.getFunctionControl()) {
1666 case spirv::FunctionControl::Inline:
1667 newFuncOp.setAlwaysInline(
true);
1669 case spirv::FunctionControl::DontInline:
1670 newFuncOp.setNoInline(
true);
1673 #define DISPATCH(functionControl, llvmAttr) \
1674 case functionControl: \
1675 newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr})); \
1678 DISPATCH(spirv::FunctionControl::Pure,
1680 DISPATCH(spirv::FunctionControl::Const,
1694 &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter))) {
1711 matchAndRewrite(spirv::ModuleOp spvModuleOp, OpAdaptor adaptor,
1715 rewriter.
create<ModuleOp>(spvModuleOp.getLoc(), spvModuleOp.getName());
1719 rewriter.
eraseBlock(&newModuleOp.getBodyRegion().back());
1720 rewriter.
eraseOp(spvModuleOp);
1729 class VectorShufflePattern
1734 matchAndRewrite(spirv::VectorShuffleOp op, OpAdaptor adaptor,
1737 auto components = adaptor.getComponents();
1738 auto vector1 = adaptor.getVector1();
1739 auto vector2 = adaptor.getVector2();
1740 int vector1Size = cast<VectorType>(vector1.getType()).getNumElements();
1741 int vector2Size = cast<VectorType>(vector2.getType()).getNumElements();
1742 if (vector1Size == vector2Size) {
1744 op, vector1, vector2,
1745 LLVM::convertArrayToIndices<int32_t>(components));
1749 auto dstType = getTypeConverter()->convertType(op.getType());
1752 auto scalarType = cast<VectorType>(dstType).getElementType();
1753 auto componentsArray = components.getValue();
1756 Value targetOp = rewriter.
create<LLVM::UndefOp>(loc, dstType);
1757 for (
unsigned i = 0; i < componentsArray.size(); i++) {
1758 if (!isa<IntegerAttr>(componentsArray[i]))
1759 return op.
emitError(
"unable to support non-constant component");
1761 int indexVal = cast<IntegerAttr>(componentsArray[i]).getInt();
1766 Value baseVector = vector1;
1767 if (indexVal >= vector1Size) {
1768 offsetVal = vector1Size;
1769 baseVector = vector2;
1772 Value dstIndex = rewriter.
create<LLVM::ConstantOp>(
1778 auto extractOp = rewriter.
create<LLVM::ExtractElementOp>(
1779 loc, scalarType, baseVector, index);
1780 targetOp = rewriter.
create<LLVM::InsertElementOp>(loc, dstType, targetOp,
1781 extractOp, dstIndex);
1794 spirv::ClientAPI clientAPI) {
1811 spirv::ClientAPI clientAPI) {
1814 DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>,
1815 DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>,
1816 DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>,
1817 DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>,
1818 DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>,
1819 DirectConversionPattern<spirv::FMulOp, LLVM::FMulOp>,
1820 DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>,
1821 DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>,
1822 DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>,
1823 DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>,
1824 DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>,
1825 DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>,
1826 DirectConversionPattern<spirv::UModOp, LLVM::URemOp>,
1829 BitFieldInsertPattern, BitFieldUExtractPattern, BitFieldSExtractPattern,
1830 DirectConversionPattern<spirv::BitCountOp, LLVM::CtPopOp>,
1831 DirectConversionPattern<spirv::BitReverseOp, LLVM::BitReverseOp>,
1832 DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>,
1833 DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>,
1834 DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>,
1835 NotPattern<spirv::NotOp>,
1838 BitcastConversionPattern,
1839 DirectConversionPattern<spirv::ConvertFToSOp, LLVM::FPToSIOp>,
1840 DirectConversionPattern<spirv::ConvertFToUOp, LLVM::FPToUIOp>,
1841 DirectConversionPattern<spirv::ConvertSToFOp, LLVM::SIToFPOp>,
1842 DirectConversionPattern<spirv::ConvertUToFOp, LLVM::UIToFPOp>,
1843 IndirectCastPattern<spirv::FConvertOp, LLVM::FPExtOp, LLVM::FPTruncOp>,
1844 IndirectCastPattern<spirv::SConvertOp, LLVM::SExtOp, LLVM::TruncOp>,
1845 IndirectCastPattern<spirv::UConvertOp, LLVM::ZExtOp, LLVM::TruncOp>,
1848 IComparePattern<spirv::IEqualOp, LLVM::ICmpPredicate::eq>,
1849 IComparePattern<spirv::INotEqualOp, LLVM::ICmpPredicate::ne>,
1850 FComparePattern<spirv::FOrdEqualOp, LLVM::FCmpPredicate::oeq>,
1851 FComparePattern<spirv::FOrdGreaterThanOp, LLVM::FCmpPredicate::ogt>,
1852 FComparePattern<spirv::FOrdGreaterThanEqualOp, LLVM::FCmpPredicate::oge>,
1853 FComparePattern<spirv::FOrdLessThanEqualOp, LLVM::FCmpPredicate::ole>,
1854 FComparePattern<spirv::FOrdLessThanOp, LLVM::FCmpPredicate::olt>,
1855 FComparePattern<spirv::FOrdNotEqualOp, LLVM::FCmpPredicate::one>,
1856 FComparePattern<spirv::FUnordEqualOp, LLVM::FCmpPredicate::ueq>,
1857 FComparePattern<spirv::FUnordGreaterThanOp, LLVM::FCmpPredicate::ugt>,
1858 FComparePattern<spirv::FUnordGreaterThanEqualOp,
1859 LLVM::FCmpPredicate::uge>,
1860 FComparePattern<spirv::FUnordLessThanEqualOp, LLVM::FCmpPredicate::ule>,
1861 FComparePattern<spirv::FUnordLessThanOp, LLVM::FCmpPredicate::ult>,
1862 FComparePattern<spirv::FUnordNotEqualOp, LLVM::FCmpPredicate::une>,
1863 IComparePattern<spirv::SGreaterThanOp, LLVM::ICmpPredicate::sgt>,
1864 IComparePattern<spirv::SGreaterThanEqualOp, LLVM::ICmpPredicate::sge>,
1865 IComparePattern<spirv::SLessThanEqualOp, LLVM::ICmpPredicate::sle>,
1866 IComparePattern<spirv::SLessThanOp, LLVM::ICmpPredicate::slt>,
1867 IComparePattern<spirv::UGreaterThanOp, LLVM::ICmpPredicate::ugt>,
1868 IComparePattern<spirv::UGreaterThanEqualOp, LLVM::ICmpPredicate::uge>,
1869 IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>,
1870 IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>,
1873 ConstantScalarAndVectorPattern,
1876 BranchConversionPattern, BranchConditionalConversionPattern,
1877 FunctionCallPattern, LoopPattern, SelectionPattern,
1878 ErasePattern<spirv::MergeOp>,
1881 ErasePattern<spirv::EntryPointOp>, ExecutionModePattern,
1884 DirectConversionPattern<spirv::GLCeilOp, LLVM::FCeilOp>,
1885 DirectConversionPattern<spirv::GLCosOp, LLVM::CosOp>,
1886 DirectConversionPattern<spirv::GLExpOp, LLVM::ExpOp>,
1887 DirectConversionPattern<spirv::GLFAbsOp, LLVM::FAbsOp>,
1888 DirectConversionPattern<spirv::GLFloorOp, LLVM::FFloorOp>,
1889 DirectConversionPattern<spirv::GLFMaxOp, LLVM::MaxNumOp>,
1890 DirectConversionPattern<spirv::GLFMinOp, LLVM::MinNumOp>,
1891 DirectConversionPattern<spirv::GLLogOp, LLVM::LogOp>,
1892 DirectConversionPattern<spirv::GLSinOp, LLVM::SinOp>,
1893 DirectConversionPattern<spirv::GLSMaxOp, LLVM::SMaxOp>,
1894 DirectConversionPattern<spirv::GLSMinOp, LLVM::SMinOp>,
1895 DirectConversionPattern<spirv::GLSqrtOp, LLVM::SqrtOp>,
1896 InverseSqrtPattern, TanPattern, TanhPattern,
1899 DirectConversionPattern<spirv::LogicalAndOp, LLVM::AndOp>,
1900 DirectConversionPattern<spirv::LogicalOrOp, LLVM::OrOp>,
1901 IComparePattern<spirv::LogicalEqualOp, LLVM::ICmpPredicate::eq>,
1902 IComparePattern<spirv::LogicalNotEqualOp, LLVM::ICmpPredicate::ne>,
1903 NotPattern<spirv::LogicalNotOp>,
1906 AccessChainPattern, AddressOfPattern, LoadStorePattern<spirv::LoadOp>,
1907 LoadStorePattern<spirv::StoreOp>, VariablePattern,
1910 CompositeExtractPattern, CompositeInsertPattern,
1911 DirectConversionPattern<spirv::SelectOp, LLVM::SelectOp>,
1912 DirectConversionPattern<spirv::UndefOp, LLVM::UndefOp>,
1913 VectorShufflePattern,
1916 ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>,
1917 ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>,
1918 ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>,
1921 ReturnPattern, ReturnValuePattern,
1924 ControlBarrierPattern<spirv::ControlBarrierOp>,
1925 ControlBarrierPattern<spirv::INTELControlBarrierArriveOp>,
1926 ControlBarrierPattern<spirv::INTELControlBarrierWaitOp>,
1929 GroupReducePattern<spirv::GroupIAddOp>,
1930 GroupReducePattern<spirv::GroupFAddOp>,
1931 GroupReducePattern<spirv::GroupFMinOp>,
1932 GroupReducePattern<spirv::GroupUMinOp>,
1933 GroupReducePattern<spirv::GroupSMinOp,
true>,
1934 GroupReducePattern<spirv::GroupFMaxOp>,
1935 GroupReducePattern<spirv::GroupUMaxOp>,
1936 GroupReducePattern<spirv::GroupSMaxOp,
true>,
1937 GroupReducePattern<spirv::GroupNonUniformIAddOp,
false,
1939 GroupReducePattern<spirv::GroupNonUniformFAddOp,
false,
1941 GroupReducePattern<spirv::GroupNonUniformIMulOp,
false,
1943 GroupReducePattern<spirv::GroupNonUniformFMulOp,
false,
1945 GroupReducePattern<spirv::GroupNonUniformSMinOp,
true,
1947 GroupReducePattern<spirv::GroupNonUniformUMinOp,
false,
1949 GroupReducePattern<spirv::GroupNonUniformFMinOp,
false,
1951 GroupReducePattern<spirv::GroupNonUniformSMaxOp,
true,
1953 GroupReducePattern<spirv::GroupNonUniformUMaxOp,
false,
1955 GroupReducePattern<spirv::GroupNonUniformFMaxOp,
false,
1957 GroupReducePattern<spirv::GroupNonUniformBitwiseAndOp,
false,
1959 GroupReducePattern<spirv::GroupNonUniformBitwiseOrOp,
false,
1961 GroupReducePattern<spirv::GroupNonUniformBitwiseXorOp,
false,
1963 GroupReducePattern<spirv::GroupNonUniformLogicalAndOp,
false,
1965 GroupReducePattern<spirv::GroupNonUniformLogicalOrOp,
false,
1967 GroupReducePattern<spirv::GroupNonUniformLogicalXorOp,
false,
1971 patterns.
add<GlobalVariablePattern>(clientAPI, patterns.
getContext(),
1977 patterns.
add<FuncConversionPattern>(patterns.
getContext(), typeConverter);
1982 patterns.
add<ModuleConversionPattern>(patterns.
getContext(), typeConverter);
1993 auto spvModules = module.getOps<spirv::ModuleOp>();
1994 for (
auto spvModule : spvModules) {
1995 spvModule.walk([&](spirv::GlobalVariableOp op) {
1996 IntegerAttr descriptorSet =
1998 IntegerAttr binding = op->getAttrOfType<IntegerAttr>(
kBinding);
2001 if (descriptorSet && binding) {
2004 auto moduleAndName =
2005 spvModule.getName().has_value()
2006 ? spvModule.getName()->str() +
"_" + op.getSymName().str()
2007 : op.getSymName().str();
2009 llvm::formatv(
"{0}_descriptor_set{1}_binding{2}", moduleAndName,
2010 std::to_string(descriptorSet.getInt()),
2011 std::to_string(binding.getInt()));
2017 op.emitError(
"unable to replace all symbol uses for ") << name;
static LLVM::CallOp createSPIRVBuiltinCall(Location loc, ConversionPatternRewriter &rewriter, LLVM::LLVMFuncOp func, ValueRange args)
static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable, StringRef name, ArrayRef< Type > paramTypes, Type resultType, bool isMemNone, bool isConvergent)
static MLIRContext * getContext(OpFoldResult val)
static Value optionallyTruncateOrExtend(Location loc, Value value, Type llvmType, PatternRewriter &rewriter)
Utility function for bitfield ops:
static Value createFPConstant(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter, double value)
Creates llvm.mlir.constant with a floating-point scalar or vector value.
static constexpr StringRef kDescriptorSet
static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter, unsigned value)
Creates LLVM dialect constant with the given value.
static Type convertPointerType(spirv::PointerType type, const TypeConverter &converter, spirv::ClientAPI clientAPI)
Converts SPIR-V pointer type to LLVM pointer.
static Value processCountOrOffset(Location loc, Value value, Type srcType, Type dstType, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Utility function for bitfield ops: BitFieldInsert, BitFieldSExtract and BitFieldUExtract.
static unsigned getBitWidth(Type type)
Returns the bit width of integer, float or vector of float or integer values.
static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter, const TypeConverter &typeConverter, unsigned alignment, bool isVolatile, bool isNonTemporal)
Utility for spirv.Load and spirv.Store conversion.
static Type convertStructTypePacked(spirv::StructType type, const TypeConverter &converter)
Converts SPIR-V struct with no offset to packed LLVM struct.
static std::optional< Type > convertRuntimeArrayType(spirv::RuntimeArrayType type, TypeConverter &converter)
Converts SPIR-V runtime array to LLVM array.
static bool isSignedIntegerOrVector(Type type)
Returns true if the given type is a signed integer or vector type.
static bool isUnsignedIntegerOrVector(Type type)
Returns true if the given type is an unsigned integer or vector type.
static std::optional< Type > convertArrayType(spirv::ArrayType type, TypeConverter &converter)
Converts SPIR-V array type to LLVM array.
static constexpr StringRef kBinding
Hook for descriptor set and binding number encoding.
static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder)
Creates IntegerAttribute with all bits set for given type.
static Value optionallyBroadcast(Location loc, Value value, Type srcType, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value. If srcType is a scalar, the value remains unchanged.
static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter)
Creates llvm.mlir.constant with all bits set for the given type.
static unsigned getLLVMTypeBitWidth(Type type)
Returns the bit width of LLVMType integer or vector.
#define DISPATCH(functionControl, llvmAttr)
static std::optional< uint64_t > getIntegerOrVectorElementWidth(Type type)
Returns the width of an integer or of the element type of an integer vector, if applicable.
static Type convertStructTypeWithOffset(spirv::StructType type, const TypeConverter &converter)
Converts SPIR-V struct with a regular (according to VulkanLayoutUtils) offset to LLVM struct.
static Type convertStructType(spirv::StructType type, const TypeConverter &converter)
Converts SPIR-V struct to LLVM struct.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
OpListType::iterator iterator
Operation * getTerminator()
Get the terminator operation of this block.
OpListType & getOperations()
BlockArgListType getArguments()
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI32IntegerAttr(int32_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Conversion from types to the LLVM IR dialect.
static LLVMStructType getLiteral(MLIRContext *context, ArrayRef< Type > types, bool isPacked=false)
Gets or creates a literal struct with the given body in the provided context.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Block * getBlock() const
Returns the current block of the builder.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
operand_range getOperands()
Returns an iterator on the underlying Value's.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, StringAttr newSymbol, Operation *from)
Attempt to replace all uses of the given symbol 'oldSymbol' with the provided symbol 'newSymbol' that...
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
static void setSymbolName(Operation *symbol, StringAttr name)
Sets the name of the given symbol operation.
This class provides all of the information necessary to convert a type signature.
void addConversion(FnT &&callback)
Register a conversion function.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static spirv::StructType decorateType(spirv::StructType structType)
Returns a new StructType with layout decoration.
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Builder from ArrayRef<T>.
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
unsigned getNumElements() const
StorageClass getStorageClass() const
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
TypeRange getElementTypes() const
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect.
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
SmallVector< IntT > convertArrayToIndices(ArrayRef< Attribute > attrs)
Convert an array of integer attributes to a vector of integers that can be used as indices in LLVM op...
Type getVectorElementType(Type type)
Returns the element type of any vector type compatible with the LLVM dialect.
Include the generated interface declarations.
unsigned storageClassToAddressSpace(spirv::ClientAPI clientAPI, spirv::StorageClass storageClass)
void populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter, spirv::ClientAPI clientAPIForAddressSpaceMapping=spirv::ClientAPI::Unknown)
Populates type conversions with additional SPIR-V types.
void populateSPIRVToLLVMFunctionConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given list with patterns for function conversion from SPIR-V to LLVM.
void populateSPIRVToLLVMConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, spirv::ClientAPI clientAPIForAddressSpaceMapping=spirv::ClientAPI::Unknown)
Populates the given list with patterns that convert from SPIR-V to LLVM.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void encodeBindAttribute(ModuleOp module)
Encodes global variable's descriptor set and binding into its name if they both exist.
void populateSPIRVToLLVMModuleConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given patterns for module conversion from SPIR-V to LLVM.