23 #include "llvm/ADT/TypeSwitch.h"
24 #include "llvm/Support/FormatVariadic.h"
26 #define DEBUG_TYPE "spirv-to-llvm-pattern"
38 if (
auto vecType = dyn_cast<VectorType>(type))
39 return vecType.getElementType().isSignedInteger();
47 if (
auto vecType = dyn_cast<VectorType>(type))
48 return vecType.getElementType().isUnsignedInteger();
55 if (
auto intType = dyn_cast<IntegerType>(type))
56 return intType.getWidth();
57 if (
auto vecType = dyn_cast<VectorType>(type))
58 if (
auto intType = dyn_cast<IntegerType>(vecType.getElementType()))
59 return intType.getWidth();
66 "bitwidth is not supported for this type");
69 auto vecType = dyn_cast<VectorType>(type);
70 auto elementType = vecType.getElementType();
71 assert(elementType.isIntOrFloat() &&
72 "only integers and floats have a bitwidth");
73 return elementType.getIntOrFloatBitWidth();
78 if (
auto vecTy = dyn_cast<VectorType>(type))
79 type = vecTy.getElementType();
80 return cast<IntegerType>(type).getWidth();
85 if (
auto vecType = dyn_cast<VectorType>(type)) {
86 auto integerType = cast<IntegerType>(vecType.getElementType());
89 auto integerType = cast<IntegerType>(type);
96 if (isa<VectorType>(srcType)) {
97 return LLVM::ConstantOp::create(
98 rewriter, loc, dstType,
102 return LLVM::ConstantOp::create(rewriter, loc, dstType,
109 if (
auto vecType = dyn_cast<VectorType>(srcType)) {
110 auto floatType = cast<FloatType>(vecType.getElementType());
111 return LLVM::ConstantOp::create(
112 rewriter, loc, dstType,
116 auto floatType = cast<FloatType>(srcType);
117 return LLVM::ConstantOp::create(rewriter, loc, dstType,
130 auto srcType = value.
getType();
136 if (valueBitWidth < targetBitWidth)
137 return LLVM::ZExtOp::create(rewriter, loc, llvmType, value);
142 if (valueBitWidth > targetBitWidth)
143 return LLVM::TruncOp::create(rewriter, loc, llvmType, value);
152 auto llvmVectorType = typeConverter.
convertType(vectorType);
154 Value broadcasted = LLVM::PoisonOp::create(rewriter, loc, llvmVectorType);
155 for (
unsigned i = 0; i < numElements; ++i) {
156 auto index = LLVM::ConstantOp::create(rewriter, loc, llvmI32Type,
158 broadcasted = LLVM::InsertElementOp::create(
159 rewriter, loc, llvmVectorType, broadcasted, toBroadcast, index);
168 if (
auto vectorType = dyn_cast<VectorType>(srcType)) {
169 unsigned numElements = vectorType.getNumElements();
170 return broadcast(loc, value, numElements, typeConverter, rewriter);
203 return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
213 return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
220 return LLVM::ConstantOp::create(
229 unsigned alignment,
bool isVolatile,
230 bool isNonTemporal) {
231 if (
auto loadOp = dyn_cast<spirv::LoadOp>(op)) {
232 auto dstType = typeConverter.
convertType(loadOp.getType());
236 loadOp, dstType, spirv::LoadOpAdaptor(operands).getPtr(), alignment,
237 isVolatile, isNonTemporal);
240 auto storeOp = cast<spirv::StoreOp>(op);
241 spirv::StoreOpAdaptor adaptor(operands);
243 adaptor.getPtr(), alignment,
244 isVolatile, isNonTemporal);
259 auto sizeInBytes = cast<spirv::SPIRVType>(elementType).getSizeInBytes();
260 if (stride != 0 && (!sizeInBytes || *sizeInBytes != stride))
263 auto llvmElementType = converter.
convertType(elementType);
272 spirv::ClientAPI clientAPI) {
273 unsigned addressSpace =
295 if (!memberDecorations.empty())
313 matchAndRewrite(spirv::AccessChainOp op, OpAdaptor adaptor,
316 getTypeConverter()->convertType(op.getComponentPtr().getType());
320 auto indices = llvm::to_vector<4>(adaptor.getIndices());
321 Type indexType = op.getIndices().front().getType();
322 auto llvmIndexType = getTypeConverter()->convertType(indexType);
326 LLVM::ConstantOp::create(rewriter, op.getLoc(), llvmIndexType,
328 indices.insert(indices.begin(), zero);
330 auto elementType = getTypeConverter()->convertType(
331 cast<spirv::PointerType>(op.getBasePtr().getType()).getPointeeType());
335 adaptor.getBasePtr(), indices);
345 matchAndRewrite(spirv::AddressOfOp op, OpAdaptor adaptor,
347 auto dstType = getTypeConverter()->convertType(op.getPointer().getType());
356 class BitFieldInsertPattern
362 matchAndRewrite(spirv::BitFieldInsertOp op, OpAdaptor adaptor,
364 auto srcType = op.getType();
365 auto dstType = getTypeConverter()->convertType(srcType);
372 *getTypeConverter(), rewriter);
374 *getTypeConverter(), rewriter);
378 Value maskShiftedByCount =
379 LLVM::ShlOp::create(rewriter, loc, dstType, minusOne, count);
380 Value negated = LLVM::XOrOp::create(rewriter, loc, dstType,
381 maskShiftedByCount, minusOne);
382 Value maskShiftedByCountAndOffset =
383 LLVM::ShlOp::create(rewriter, loc, dstType, negated, offset);
384 Value mask = LLVM::XOrOp::create(rewriter, loc, dstType,
385 maskShiftedByCountAndOffset, minusOne);
390 LLVM::AndOp::create(rewriter, loc, dstType, op.getBase(), mask);
391 Value insertShiftedByOffset =
392 LLVM::ShlOp::create(rewriter, loc, dstType, op.getInsert(), offset);
394 insertShiftedByOffset);
400 class ConstantScalarAndVectorPattern
406 matchAndRewrite(spirv::ConstantOp constOp, OpAdaptor adaptor,
408 auto srcType = constOp.getType();
409 if (!isa<VectorType>(srcType) && !srcType.isIntOrFloat())
412 auto dstType = getTypeConverter()->convertType(srcType);
425 if (isa<VectorType>(srcType)) {
426 auto dstElementsAttr = cast<DenseIntElementsAttr>(constOp.getValue());
429 dstElementsAttr.mapValues(
430 signlessType, [&](
const APInt &value) {
return value; }));
433 auto srcAttr = cast<IntegerAttr>(constOp.getValue());
434 auto dstAttr = rewriter.
getIntegerAttr(signlessType, srcAttr.getValue());
439 constOp, dstType, adaptor.getOperands(), constOp->getAttrs());
444 class BitFieldSExtractPattern
450 matchAndRewrite(spirv::BitFieldSExtractOp op, OpAdaptor adaptor,
452 auto srcType = op.getType();
453 auto dstType = getTypeConverter()->convertType(srcType);
460 *getTypeConverter(), rewriter);
462 *getTypeConverter(), rewriter);
465 IntegerType integerType;
466 if (
auto vecType = dyn_cast<VectorType>(srcType))
467 integerType = cast<IntegerType>(vecType.getElementType());
469 integerType = cast<IntegerType>(srcType);
473 isa<VectorType>(srcType)
474 ? LLVM::ConstantOp::create(
475 rewriter, loc, dstType,
477 : LLVM::ConstantOp::create(rewriter, loc, dstType, baseSize);
481 Value countPlusOffset =
482 LLVM::AddOp::create(rewriter, loc, dstType, count, offset);
483 Value amountToShiftLeft =
484 LLVM::SubOp::create(rewriter, loc, dstType, size, countPlusOffset);
485 Value baseShiftedLeft = LLVM::ShlOp::create(
486 rewriter, loc, dstType, op.getBase(), amountToShiftLeft);
489 Value amountToShiftRight =
490 LLVM::AddOp::create(rewriter, loc, dstType, offset, amountToShiftLeft);
497 class BitFieldUExtractPattern
503 matchAndRewrite(spirv::BitFieldUExtractOp op, OpAdaptor adaptor,
505 auto srcType = op.getType();
506 auto dstType = getTypeConverter()->convertType(srcType);
513 *getTypeConverter(), rewriter);
515 *getTypeConverter(), rewriter);
519 Value maskShiftedByCount =
520 LLVM::ShlOp::create(rewriter, loc, dstType, minusOne, count);
521 Value mask = LLVM::XOrOp::create(rewriter, loc, dstType, maskShiftedByCount,
526 LLVM::LShrOp::create(rewriter, loc, dstType, op.getBase(), offset);
537 matchAndRewrite(spirv::BranchOp branchOp, OpAdaptor adaptor,
540 branchOp.getTarget());
545 class BranchConditionalConversionPattern
552 matchAndRewrite(spirv::BranchConditionalOp op, OpAdaptor adaptor,
556 if (
auto weights = op.getBranchWeights()) {
558 for (
auto weight : weights->getAsRange<IntegerAttr>())
559 weightValues.push_back(weight.getInt());
564 op, op.getCondition(), op.getTrueBlockArguments(),
565 op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(),
574 class CompositeExtractPattern
580 matchAndRewrite(spirv::CompositeExtractOp op, OpAdaptor adaptor,
582 auto dstType = this->getTypeConverter()->convertType(op.getType());
586 Type containerType = op.getComposite().getType();
587 if (isa<VectorType>(containerType)) {
589 IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]);
592 op, dstType, adaptor.getComposite(), index);
597 op, adaptor.getComposite(),
606 class CompositeInsertPattern
612 matchAndRewrite(spirv::CompositeInsertOp op, OpAdaptor adaptor,
614 auto dstType = this->getTypeConverter()->convertType(op.getType());
618 Type containerType = op.getComposite().getType();
619 if (isa<VectorType>(containerType)) {
621 IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]);
624 op, dstType, adaptor.getComposite(), adaptor.getObject(), index);
629 op, adaptor.getComposite(), adaptor.getObject(),
637 template <
typename SPIRVOp,
typename LLVMOp>
643 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
645 auto dstType = this->getTypeConverter()->convertType(op.getType());
648 rewriter.template replaceOpWithNewOp<LLVMOp>(
649 op, dstType, adaptor.getOperands(), op->getAttrs());
656 class ExecutionModePattern
662 matchAndRewrite(spirv::ExecutionModeOp op, OpAdaptor adaptor,
667 ModuleOp module = op->getParentOfType<ModuleOp>();
668 spirv::ExecutionModeAttr executionModeAttr = op.getExecutionModeAttr();
669 std::string moduleName;
670 if (module.getName().has_value())
671 moduleName =
"_" + module.getName()->str();
674 std::string executionModeInfoName = llvm::formatv(
675 "__spv_{0}_{1}_execution_mode_info_{2}", moduleName, op.getFn().str(),
676 static_cast<uint32_t
>(executionModeAttr.getValue()));
689 fields.push_back(llvmI32Type);
690 ArrayAttr values = op.getValues();
691 if (!values.empty()) {
693 fields.push_back(arrayType);
695 auto structType = LLVM::LLVMStructType::getLiteral(context, fields);
698 auto global = LLVM::GlobalOp::create(
700 LLVM::Linkage::External, executionModeInfoName,
Attribute(),
703 Region ®ion = global.getInitializerRegion();
708 Value structValue = LLVM::PoisonOp::create(rewriter, loc, structType);
709 Value executionMode = LLVM::ConstantOp::create(
710 rewriter, loc, llvmI32Type,
712 static_cast<uint32_t
>(executionModeAttr.getValue())));
714 structValue = LLVM::InsertValueOp::create(rewriter, loc, structValue,
715 executionMode, position);
718 for (
unsigned i = 0, e = values.size(); i < e; ++i) {
719 auto attr = values.getValue()[i];
720 Value entry = LLVM::ConstantOp::create(rewriter, loc, llvmI32Type, attr);
721 structValue = LLVM::InsertValueOp::create(
734 class GlobalVariablePattern
737 template <
typename... Args>
738 GlobalVariablePattern(spirv::ClientAPI clientAPI, Args &&...args)
740 std::forward<Args>(args)...),
741 clientAPI(clientAPI) {}
744 matchAndRewrite(spirv::GlobalVariableOp op, OpAdaptor adaptor,
748 if (op.getInitializer())
751 auto srcType = cast<spirv::PointerType>(op.getType());
752 auto dstType = getTypeConverter()->convertType(srcType.getPointeeType());
759 auto storageClass = srcType.getStorageClass();
760 switch (storageClass) {
761 case spirv::StorageClass::Input:
762 case spirv::StorageClass::Private:
763 case spirv::StorageClass::Output:
764 case spirv::StorageClass::StorageBuffer:
765 case spirv::StorageClass::UniformConstant:
774 bool isConstant = (storageClass == spirv::StorageClass::Input) ||
775 (storageClass == spirv::StorageClass::UniformConstant);
781 auto linkage = storageClass == spirv::StorageClass::Private
782 ? LLVM::Linkage::Private
783 : LLVM::Linkage::External;
784 StringAttr locationAttrName = op.getLocationAttrName();
785 IntegerAttr locationAttr = op.getLocationAttr();
787 op, dstType, isConstant, linkage, op.getSymName(),
Attribute(),
792 newGlobalOp->setAttr(locationAttrName, locationAttr);
798 spirv::ClientAPI clientAPI;
803 template <
typename SPIRVOp,
typename LLVMExtOp,
typename LLVMTruncOp>
809 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
812 Type fromType = op.getOperand().getType();
813 Type toType = op.getType();
815 auto dstType = this->getTypeConverter()->convertType(toType);
820 rewriter.template replaceOpWithNewOp<LLVMExtOp>(op, dstType,
821 adaptor.getOperands());
825 rewriter.template replaceOpWithNewOp<LLVMTruncOp>(op, dstType,
826 adaptor.getOperands());
833 class FunctionCallPattern
839 matchAndRewrite(spirv::FunctionCallOp callOp, OpAdaptor adaptor,
841 if (callOp.getNumResults() == 0) {
843 callOp,
TypeRange(), adaptor.getOperands(), callOp->getAttrs());
844 newOp.getProperties().operandSegmentSizes = {
845 static_cast<int32_t
>(adaptor.getOperands().size()), 0};
851 auto dstType = getTypeConverter()->convertType(callOp.getType(0));
855 callOp, dstType, adaptor.getOperands(), callOp->getAttrs());
856 newOp.getProperties().operandSegmentSizes = {
857 static_cast<int32_t
>(adaptor.getOperands().size()), 0};
864 template <
typename SPIRVOp, LLVM::FCmpPredicate predicate>
870 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
873 auto dstType = this->getTypeConverter()->convertType(op.getType());
877 rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>(
878 op, dstType, predicate, op.getOperand1(), op.getOperand2());
884 template <
typename SPIRVOp, LLVM::ICmpPredicate predicate>
890 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
893 auto dstType = this->getTypeConverter()->convertType(op.getType());
897 rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>(
898 op, dstType, predicate, op.getOperand1(), op.getOperand2());
903 class InverseSqrtPattern
909 matchAndRewrite(spirv::GLInverseSqrtOp op, OpAdaptor adaptor,
911 auto srcType = op.getType();
912 auto dstType = getTypeConverter()->convertType(srcType);
918 Value sqrt = LLVM::SqrtOp::create(rewriter, loc, dstType, op.getOperand());
925 template <
typename SPIRVOp>
931 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
933 if (!op.getMemoryAccess()) {
935 *this->getTypeConverter(), 0,
939 auto memoryAccess = *op.getMemoryAccess();
940 switch (memoryAccess) {
941 case spirv::MemoryAccess::Aligned:
943 case spirv::MemoryAccess::Nontemporal:
944 case spirv::MemoryAccess::Volatile: {
946 memoryAccess == spirv::MemoryAccess::Aligned ? *op.getAlignment() : 0;
947 bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal;
948 bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile;
950 *this->getTypeConverter(), alignment,
951 isVolatile, isNonTemporal);
961 template <
typename SPIRVOp>
967 matchAndRewrite(SPIRVOp notOp,
typename SPIRVOp::Adaptor adaptor,
969 auto srcType = notOp.getType();
970 auto dstType = this->getTypeConverter()->convertType(srcType);
977 isa<VectorType>(srcType)
978 ? LLVM::ConstantOp::create(
979 rewriter, loc, dstType,
981 : LLVM::ConstantOp::create(rewriter, loc, dstType, minusOne);
982 rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType,
983 notOp.getOperand(), mask);
989 template <
typename SPIRVOp>
995 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
1007 matchAndRewrite(spirv::ReturnOp returnOp, OpAdaptor adaptor,
1020 matchAndRewrite(spirv::ReturnValueOp returnValueOp, OpAdaptor adaptor,
1023 adaptor.getOperands());
1032 bool convergent =
true) {
1033 auto func = dyn_cast_or_null<LLVM::LLVMFuncOp>(
1039 func = LLVM::LLVMFuncOp::create(
1040 b, symbolTable->
getLoc(), name,
1042 func.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
1043 func.setConvergent(convergent);
1044 func.setNoUnwind(
true);
1045 func.setWillReturn(
true);
1050 LLVM::LLVMFuncOp func,
1052 auto call = LLVM::CallOp::create(builder, loc, func, args);
1053 call.setCConv(func.getCConv());
1054 call.setConvergentAttr(func.getConvergentAttr());
1055 call.setNoUnwindAttr(func.getNoUnwindAttr());
1056 call.setWillReturnAttr(func.getWillReturnAttr());
1060 template <
typename BarrierOpTy>
1067 static constexpr StringRef getFuncName();
1070 matchAndRewrite(BarrierOpTy controlBarrierOp, OpAdaptor adaptor,
1072 constexpr StringRef funcName = getFuncName();
1074 controlBarrierOp->template getParentWithTrait<OpTrait::SymbolTable>();
1078 Type voidTy = rewriter.
getType<LLVM::LLVMVoidType>();
1079 LLVM::LLVMFuncOp func =
1082 Location loc = controlBarrierOp->getLoc();
1083 Value execution = LLVM::ConstantOp::create(
1084 rewriter, loc, i32,
static_cast<int32_t
>(adaptor.getExecutionScope()));
1085 Value memory = LLVM::ConstantOp::create(
1086 rewriter, loc, i32,
static_cast<int32_t
>(adaptor.getMemoryScope()));
1087 Value semantics = LLVM::ConstantOp::create(
1088 rewriter, loc, i32,
static_cast<int32_t
>(adaptor.getMemorySemantics()));
1091 {execution, memory, semantics});
1093 rewriter.
replaceOp(controlBarrierOp, call);
1100 StringRef getTypeMangling(
Type type,
bool isSigned) {
1102 .Case<Float16Type>([](
auto) {
return "Dh"; })
1103 .Case<Float32Type>([](
auto) {
return "f"; })
1104 .Case<Float64Type>([](
auto) {
return "d"; })
1105 .Case<IntegerType>([isSigned](IntegerType intTy) {
1106 switch (intTy.getWidth()) {
1110 return (isSigned) ?
"a" :
"c";
1112 return (isSigned) ?
"s" :
"t";
1114 return (isSigned) ?
"i" :
"j";
1116 return (isSigned) ?
"l" :
"m";
1118 llvm_unreachable(
"Unsupported integer width");
1121 .DefaultUnreachable(
"No mangling defined");
1124 template <
typename ReduceOp>
1125 constexpr StringLiteral getGroupFuncName();
1128 constexpr StringLiteral getGroupFuncName<spirv::GroupIAddOp>() {
1129 return "_Z17__spirv_GroupIAddii";
1132 constexpr StringLiteral getGroupFuncName<spirv::GroupFAddOp>() {
1133 return "_Z17__spirv_GroupFAddii";
1136 constexpr StringLiteral getGroupFuncName<spirv::GroupSMinOp>() {
1137 return "_Z17__spirv_GroupSMinii";
1140 constexpr StringLiteral getGroupFuncName<spirv::GroupUMinOp>() {
1141 return "_Z17__spirv_GroupUMinii";
1144 constexpr StringLiteral getGroupFuncName<spirv::GroupFMinOp>() {
1145 return "_Z17__spirv_GroupFMinii";
1148 constexpr StringLiteral getGroupFuncName<spirv::GroupSMaxOp>() {
1149 return "_Z17__spirv_GroupSMaxii";
1152 constexpr StringLiteral getGroupFuncName<spirv::GroupUMaxOp>() {
1153 return "_Z17__spirv_GroupUMaxii";
1156 constexpr StringLiteral getGroupFuncName<spirv::GroupFMaxOp>() {
1157 return "_Z17__spirv_GroupFMaxii";
1160 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIAddOp>() {
1161 return "_Z27__spirv_GroupNonUniformIAddii";
1164 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFAddOp>() {
1165 return "_Z27__spirv_GroupNonUniformFAddii";
1168 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIMulOp>() {
1169 return "_Z27__spirv_GroupNonUniformIMulii";
1172 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMulOp>() {
1173 return "_Z27__spirv_GroupNonUniformFMulii";
1176 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMinOp>() {
1177 return "_Z27__spirv_GroupNonUniformSMinii";
1180 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMinOp>() {
1181 return "_Z27__spirv_GroupNonUniformUMinii";
1184 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMinOp>() {
1185 return "_Z27__spirv_GroupNonUniformFMinii";
1188 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMaxOp>() {
1189 return "_Z27__spirv_GroupNonUniformSMaxii";
1192 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMaxOp>() {
1193 return "_Z27__spirv_GroupNonUniformUMaxii";
1196 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMaxOp>() {
1197 return "_Z27__spirv_GroupNonUniformFMaxii";
1200 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseAndOp>() {
1201 return "_Z33__spirv_GroupNonUniformBitwiseAndii";
1204 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseOrOp>() {
1205 return "_Z32__spirv_GroupNonUniformBitwiseOrii";
1208 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseXorOp>() {
1209 return "_Z33__spirv_GroupNonUniformBitwiseXorii";
1212 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalAndOp>() {
1213 return "_Z33__spirv_GroupNonUniformLogicalAndii";
1216 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalOrOp>() {
1217 return "_Z32__spirv_GroupNonUniformLogicalOrii";
1220 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalXorOp>() {
1221 return "_Z33__spirv_GroupNonUniformLogicalXorii";
1225 template <
typename ReduceOp,
bool Signed = false,
bool NonUniform = false>
1231 matchAndRewrite(ReduceOp op,
typename ReduceOp::Adaptor adaptor,
1234 Type retTy = op.getResult().getType();
1239 funcName += getTypeMangling(retTy,
false);
1243 if constexpr (NonUniform) {
1244 if (adaptor.getClusterSize()) {
1246 paramTypes.push_back(i32Ty);
1251 op->template getParentWithTrait<OpTrait::SymbolTable>();
1253 LLVM::LLVMFuncOp func =
1257 Value scope = LLVM::ConstantOp::create(
1258 rewriter, loc, i32Ty,
1259 static_cast<int32_t
>(adaptor.getExecutionScope()));
1260 Value groupOp = LLVM::ConstantOp::create(
1261 rewriter, loc, i32Ty,
1262 static_cast<int32_t
>(adaptor.getGroupOperation()));
1264 operands.append(adaptor.getOperands().begin(), adaptor.getOperands().end());
1274 ControlBarrierPattern<spirv::ControlBarrierOp>::getFuncName() {
1275 return "_Z22__spirv_ControlBarrieriii";
1280 ControlBarrierPattern<spirv::INTELControlBarrierArriveOp>::getFuncName() {
1281 return "_Z33__spirv_ControlBarrierArriveINTELiii";
1286 ControlBarrierPattern<spirv::INTELControlBarrierWaitOp>::getFuncName() {
1287 return "_Z31__spirv_ControlBarrierWaitINTELiii";
1343 matchAndRewrite(spirv::LoopOp loopOp, OpAdaptor adaptor,
1350 if (loopOp.getBody().empty()) {
1365 Block *entryBlock = loopOp.getEntryBlock();
1367 auto brOp = dyn_cast<spirv::BranchOp>(entryBlock->
getOperations().front());
1370 Block *headerBlock = loopOp.getHeaderBlock();
1372 LLVM::BrOp::create(rewriter, loc, brOp.getBlockArguments(), headerBlock);
1376 Block *mergeBlock = loopOp.getMergeBlock();
1380 LLVM::BrOp::create(rewriter, loc, terminatorOperands, endBlock);
1396 matchAndRewrite(spirv::SelectionOp op, OpAdaptor adaptor,
1408 if (op.getBody().getBlocks().size() <= 2) {
1420 auto *continueBlock = rewriter.
splitBlock(currentBlock, position);
1426 auto *headerBlock = op.getHeaderBlock();
1428 auto condBrOp = dyn_cast<spirv::BranchConditionalOp>(
1434 auto *mergeBlock = op.getMergeBlock();
1438 LLVM::BrOp::create(rewriter, loc, terminatorOperands, continueBlock);
1441 Block *trueBlock = condBrOp.getTrueBlock();
1442 Block *falseBlock = condBrOp.getFalseBlock();
1444 LLVM::CondBrOp::create(rewriter, loc, condBrOp.getCondition(), trueBlock,
1445 condBrOp.getTrueTargetOperands(), falseBlock,
1446 condBrOp.getFalseTargetOperands());
1450 rewriter.
replaceOp(op, continueBlock->getArguments());
1459 template <
typename SPIRVOp,
typename LLVMOp>
1465 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
1468 auto dstType = this->getTypeConverter()->convertType(op.getType());
1472 Type op1Type = op.getOperand1().getType();
1473 Type op2Type = op.getOperand2().getType();
1475 if (op1Type == op2Type) {
1476 rewriter.template replaceOpWithNewOp<LLVMOp>(op, dstType,
1477 adaptor.getOperands());
1481 std::optional<uint64_t> dstTypeWidth =
1483 std::optional<uint64_t> op2TypeWidth =
1486 if (!dstTypeWidth || !op2TypeWidth)
1491 if (op2TypeWidth < dstTypeWidth) {
1494 LLVM::ZExtOp::create(rewriter, loc, dstType, adaptor.getOperand2());
1497 LLVM::SExtOp::create(rewriter, loc, dstType, adaptor.getOperand2());
1499 }
else if (op2TypeWidth == dstTypeWidth) {
1500 extended = adaptor.getOperand2();
1506 LLVMOp::create(rewriter, loc, dstType, adaptor.getOperand1(), extended);
1517 matchAndRewrite(spirv::GLTanOp tanOp, OpAdaptor adaptor,
1519 auto dstType = getTypeConverter()->convertType(tanOp.getType());
1524 Value sin = LLVM::SinOp::create(rewriter, loc, dstType, tanOp.getOperand());
1525 Value cos = LLVM::CosOp::create(rewriter, loc, dstType, tanOp.getOperand());
1542 matchAndRewrite(spirv::GLTanhOp tanhOp, OpAdaptor adaptor,
1544 auto srcType = tanhOp.getType();
1545 auto dstType = getTypeConverter()->convertType(srcType);
1552 LLVM::FMulOp::create(rewriter, loc, dstType, two, tanhOp.getOperand());
1553 Value exponential = LLVM::ExpOp::create(rewriter, loc, dstType, multiplied);
1556 LLVM::FSubOp::create(rewriter, loc, dstType, exponential, one);
1558 LLVM::FAddOp::create(rewriter, loc, dstType, exponential, one);
1570 matchAndRewrite(spirv::VariableOp varOp, OpAdaptor adaptor,
1572 auto srcType = varOp.getType();
1574 auto pointerTo = cast<spirv::PointerType>(srcType).getPointeeType();
1575 auto init = varOp.getInitializer();
1576 if (init && !pointerTo.isIntOrFloat() && !isa<VectorType>(pointerTo))
1579 auto dstType = getTypeConverter()->convertType(srcType);
1586 auto elementType = getTypeConverter()->convertType(pointerTo);
1593 auto elementType = getTypeConverter()->convertType(pointerTo);
1597 LLVM::AllocaOp::create(rewriter, loc, dstType, elementType, size);
1598 LLVM::StoreOp::create(rewriter, loc, adaptor.getInitializer(), allocated);
1608 class BitcastConversionPattern
1614 matchAndRewrite(spirv::BitcastOp bitcastOp, OpAdaptor adaptor,
1616 auto dstType = getTypeConverter()->convertType(bitcastOp.getType());
1621 if (isa<LLVM::LLVMPointerType>(dstType)) {
1622 rewriter.
replaceOp(bitcastOp, adaptor.getOperand());
1627 bitcastOp, dstType, adaptor.getOperands(), bitcastOp->getAttrs());
1641 matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor,
1646 auto funcType = funcOp.getFunctionType();
1648 funcType.getNumInputs());
1650 ->convertFunctionSignature(
1652 false, signatureConverter);
1658 StringRef name = funcOp.getName();
1659 auto newFuncOp = LLVM::LLVMFuncOp::create(rewriter, loc, name, llvmType);
1663 switch (funcOp.getFunctionControl()) {
1664 case spirv::FunctionControl::Inline:
1665 newFuncOp.setAlwaysInline(
true);
1667 case spirv::FunctionControl::DontInline:
1668 newFuncOp.setNoInline(
true);
1671 #define DISPATCH(functionControl, llvmAttr) \
1672 case functionControl: \
1673 newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr})); \
1676 DISPATCH(spirv::FunctionControl::Pure,
1678 DISPATCH(spirv::FunctionControl::Const,
1692 &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter))) {
1709 matchAndRewrite(spirv::ModuleOp spvModuleOp, OpAdaptor adaptor,
1713 ModuleOp::create(rewriter, spvModuleOp.getLoc(), spvModuleOp.getName());
1717 rewriter.
eraseBlock(&newModuleOp.getBodyRegion().back());
1718 rewriter.
eraseOp(spvModuleOp);
1727 class VectorShufflePattern
1732 matchAndRewrite(spirv::VectorShuffleOp op, OpAdaptor adaptor,
1735 auto components = adaptor.getComponents();
1736 auto vector1 = adaptor.getVector1();
1737 auto vector2 = adaptor.getVector2();
1738 int vector1Size = cast<VectorType>(vector1.getType()).getNumElements();
1739 int vector2Size = cast<VectorType>(vector2.getType()).getNumElements();
1740 if (vector1Size == vector2Size) {
1742 op, vector1, vector2,
1743 LLVM::convertArrayToIndices<int32_t>(components));
1747 auto dstType = getTypeConverter()->convertType(op.getType());
1750 auto scalarType = cast<VectorType>(dstType).getElementType();
1751 auto componentsArray = components.getValue();
1754 Value targetOp = LLVM::PoisonOp::create(rewriter, loc, dstType);
1755 for (
unsigned i = 0; i < componentsArray.size(); i++) {
1756 if (!isa<IntegerAttr>(componentsArray[i]))
1757 return op.emitError(
"unable to support non-constant component");
1759 int indexVal = cast<IntegerAttr>(componentsArray[i]).getInt();
1764 Value baseVector = vector1;
1765 if (indexVal >= vector1Size) {
1766 offsetVal = vector1Size;
1767 baseVector = vector2;
1770 Value dstIndex = LLVM::ConstantOp::create(
1771 rewriter, loc, llvmI32Type,
1773 Value index = LLVM::ConstantOp::create(
1774 rewriter, loc, llvmI32Type,
1777 auto extractOp = LLVM::ExtractElementOp::create(rewriter, loc, scalarType,
1779 targetOp = LLVM::InsertElementOp::create(rewriter, loc, dstType, targetOp,
1780 extractOp, dstIndex);
1793 spirv::ClientAPI clientAPI) {
1810 spirv::ClientAPI clientAPI) {
1813 DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>,
1814 DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>,
1815 DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>,
1816 DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>,
1817 DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>,
1818 DirectConversionPattern<spirv::FMulOp, LLVM::FMulOp>,
1819 DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>,
1820 DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>,
1821 DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>,
1822 DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>,
1823 DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>,
1824 DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>,
1825 DirectConversionPattern<spirv::UModOp, LLVM::URemOp>,
1828 BitFieldInsertPattern, BitFieldUExtractPattern, BitFieldSExtractPattern,
1829 DirectConversionPattern<spirv::BitCountOp, LLVM::CtPopOp>,
1830 DirectConversionPattern<spirv::BitReverseOp, LLVM::BitReverseOp>,
1831 DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>,
1832 DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>,
1833 DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>,
1834 NotPattern<spirv::NotOp>,
1837 BitcastConversionPattern,
1838 DirectConversionPattern<spirv::ConvertFToSOp, LLVM::FPToSIOp>,
1839 DirectConversionPattern<spirv::ConvertFToUOp, LLVM::FPToUIOp>,
1840 DirectConversionPattern<spirv::ConvertSToFOp, LLVM::SIToFPOp>,
1841 DirectConversionPattern<spirv::ConvertUToFOp, LLVM::UIToFPOp>,
1842 IndirectCastPattern<spirv::FConvertOp, LLVM::FPExtOp, LLVM::FPTruncOp>,
1843 IndirectCastPattern<spirv::SConvertOp, LLVM::SExtOp, LLVM::TruncOp>,
1844 IndirectCastPattern<spirv::UConvertOp, LLVM::ZExtOp, LLVM::TruncOp>,
1847 IComparePattern<spirv::IEqualOp, LLVM::ICmpPredicate::eq>,
1848 IComparePattern<spirv::INotEqualOp, LLVM::ICmpPredicate::ne>,
1849 FComparePattern<spirv::FOrdEqualOp, LLVM::FCmpPredicate::oeq>,
1850 FComparePattern<spirv::FOrdGreaterThanOp, LLVM::FCmpPredicate::ogt>,
1851 FComparePattern<spirv::FOrdGreaterThanEqualOp, LLVM::FCmpPredicate::oge>,
1852 FComparePattern<spirv::FOrdLessThanEqualOp, LLVM::FCmpPredicate::ole>,
1853 FComparePattern<spirv::FOrdLessThanOp, LLVM::FCmpPredicate::olt>,
1854 FComparePattern<spirv::FOrdNotEqualOp, LLVM::FCmpPredicate::one>,
1855 FComparePattern<spirv::FUnordEqualOp, LLVM::FCmpPredicate::ueq>,
1856 FComparePattern<spirv::FUnordGreaterThanOp, LLVM::FCmpPredicate::ugt>,
1857 FComparePattern<spirv::FUnordGreaterThanEqualOp,
1858 LLVM::FCmpPredicate::uge>,
1859 FComparePattern<spirv::FUnordLessThanEqualOp, LLVM::FCmpPredicate::ule>,
1860 FComparePattern<spirv::FUnordLessThanOp, LLVM::FCmpPredicate::ult>,
1861 FComparePattern<spirv::FUnordNotEqualOp, LLVM::FCmpPredicate::une>,
1862 IComparePattern<spirv::SGreaterThanOp, LLVM::ICmpPredicate::sgt>,
1863 IComparePattern<spirv::SGreaterThanEqualOp, LLVM::ICmpPredicate::sge>,
1864 IComparePattern<spirv::SLessThanEqualOp, LLVM::ICmpPredicate::sle>,
1865 IComparePattern<spirv::SLessThanOp, LLVM::ICmpPredicate::slt>,
1866 IComparePattern<spirv::UGreaterThanOp, LLVM::ICmpPredicate::ugt>,
1867 IComparePattern<spirv::UGreaterThanEqualOp, LLVM::ICmpPredicate::uge>,
1868 IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>,
1869 IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>,
1872 ConstantScalarAndVectorPattern,
1875 BranchConversionPattern, BranchConditionalConversionPattern,
1876 FunctionCallPattern, LoopPattern, SelectionPattern,
1877 ErasePattern<spirv::MergeOp>,
1880 ErasePattern<spirv::EntryPointOp>, ExecutionModePattern,
1883 DirectConversionPattern<spirv::GLCeilOp, LLVM::FCeilOp>,
1884 DirectConversionPattern<spirv::GLCosOp, LLVM::CosOp>,
1885 DirectConversionPattern<spirv::GLExpOp, LLVM::ExpOp>,
1886 DirectConversionPattern<spirv::GLFAbsOp, LLVM::FAbsOp>,
1887 DirectConversionPattern<spirv::GLFloorOp, LLVM::FFloorOp>,
1888 DirectConversionPattern<spirv::GLFMaxOp, LLVM::MaxNumOp>,
1889 DirectConversionPattern<spirv::GLFMinOp, LLVM::MinNumOp>,
1890 DirectConversionPattern<spirv::GLLogOp, LLVM::LogOp>,
1891 DirectConversionPattern<spirv::GLSinOp, LLVM::SinOp>,
1892 DirectConversionPattern<spirv::GLSMaxOp, LLVM::SMaxOp>,
1893 DirectConversionPattern<spirv::GLSMinOp, LLVM::SMinOp>,
1894 DirectConversionPattern<spirv::GLSqrtOp, LLVM::SqrtOp>,
1895 InverseSqrtPattern, TanPattern, TanhPattern,
1898 DirectConversionPattern<spirv::LogicalAndOp, LLVM::AndOp>,
1899 DirectConversionPattern<spirv::LogicalOrOp, LLVM::OrOp>,
1900 IComparePattern<spirv::LogicalEqualOp, LLVM::ICmpPredicate::eq>,
1901 IComparePattern<spirv::LogicalNotEqualOp, LLVM::ICmpPredicate::ne>,
1902 NotPattern<spirv::LogicalNotOp>,
1905 AccessChainPattern, AddressOfPattern, LoadStorePattern<spirv::LoadOp>,
1906 LoadStorePattern<spirv::StoreOp>, VariablePattern,
1909 CompositeExtractPattern, CompositeInsertPattern,
1910 DirectConversionPattern<spirv::SelectOp, LLVM::SelectOp>,
1911 DirectConversionPattern<spirv::UndefOp, LLVM::UndefOp>,
1912 VectorShufflePattern,
1915 ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>,
1916 ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>,
1917 ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>,
1920 ReturnPattern, ReturnValuePattern,
1923 ControlBarrierPattern<spirv::ControlBarrierOp>,
1924 ControlBarrierPattern<spirv::INTELControlBarrierArriveOp>,
1925 ControlBarrierPattern<spirv::INTELControlBarrierWaitOp>,
1928 GroupReducePattern<spirv::GroupIAddOp>,
1929 GroupReducePattern<spirv::GroupFAddOp>,
1930 GroupReducePattern<spirv::GroupFMinOp>,
1931 GroupReducePattern<spirv::GroupUMinOp>,
1932 GroupReducePattern<spirv::GroupSMinOp,
true>,
1933 GroupReducePattern<spirv::GroupFMaxOp>,
1934 GroupReducePattern<spirv::GroupUMaxOp>,
1935 GroupReducePattern<spirv::GroupSMaxOp,
true>,
1936 GroupReducePattern<spirv::GroupNonUniformIAddOp,
false,
1938 GroupReducePattern<spirv::GroupNonUniformFAddOp,
false,
1940 GroupReducePattern<spirv::GroupNonUniformIMulOp,
false,
1942 GroupReducePattern<spirv::GroupNonUniformFMulOp,
false,
1944 GroupReducePattern<spirv::GroupNonUniformSMinOp,
true,
1946 GroupReducePattern<spirv::GroupNonUniformUMinOp,
false,
1948 GroupReducePattern<spirv::GroupNonUniformFMinOp,
false,
1950 GroupReducePattern<spirv::GroupNonUniformSMaxOp,
true,
1952 GroupReducePattern<spirv::GroupNonUniformUMaxOp,
false,
1954 GroupReducePattern<spirv::GroupNonUniformFMaxOp,
false,
1956 GroupReducePattern<spirv::GroupNonUniformBitwiseAndOp,
false,
1958 GroupReducePattern<spirv::GroupNonUniformBitwiseOrOp,
false,
1960 GroupReducePattern<spirv::GroupNonUniformBitwiseXorOp,
false,
1962 GroupReducePattern<spirv::GroupNonUniformLogicalAndOp,
false,
1964 GroupReducePattern<spirv::GroupNonUniformLogicalOrOp,
false,
1966 GroupReducePattern<spirv::GroupNonUniformLogicalXorOp,
false,
1981 patterns.add<ModuleConversionPattern>(
patterns.getContext(), typeConverter);
1992 auto spvModules = module.getOps<spirv::ModuleOp>();
1993 for (
auto spvModule : spvModules) {
1994 spvModule.walk([&](spirv::GlobalVariableOp op) {
1995 IntegerAttr descriptorSet =
1997 IntegerAttr binding = op->getAttrOfType<IntegerAttr>(
kBinding);
2000 if (descriptorSet && binding) {
2003 auto moduleAndName =
2004 spvModule.getName().has_value()
2005 ? spvModule.getName()->str() +
"_" + op.getSymName().str()
2006 : op.getSymName().str();
2008 llvm::formatv(
"{0}_descriptor_set{1}_binding{2}", moduleAndName,
2009 std::to_string(descriptorSet.getInt()),
2010 std::to_string(binding.getInt()));
2016 op.emitError(
"unable to replace all symbol uses for ") << name;
static LLVM::CallOp createSPIRVBuiltinCall(Location loc, ConversionPatternRewriter &rewriter, LLVM::LLVMFuncOp func, ValueRange args)
static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable, StringRef name, ArrayRef< Type > paramTypes, Type resultType, bool isMemNone, bool isConvergent)
static MLIRContext * getContext(OpFoldResult val)
static Value optionallyTruncateOrExtend(Location loc, Value value, Type llvmType, PatternRewriter &rewriter)
Utility function for bitfield ops:
static Value createFPConstant(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter, double value)
Creates llvm.mlir.constant with a floating-point scalar or vector value.
static constexpr StringRef kDescriptorSet
static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter, unsigned value)
Creates LLVM dialect constant with the given value.
static Type convertPointerType(spirv::PointerType type, const TypeConverter &converter, spirv::ClientAPI clientAPI)
Converts SPIR-V pointer type to LLVM pointer.
static Value processCountOrOffset(Location loc, Value value, Type srcType, Type dstType, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Utility function for bitfield ops: BitFieldInsert, BitFieldSExtract and BitFieldUExtract.
static unsigned getBitWidth(Type type)
Returns the bit width of integer, float or vector of float or integer values.
static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter, const TypeConverter &typeConverter, unsigned alignment, bool isVolatile, bool isNonTemporal)
Utility for spirv.Load and spirv.Store conversion.
static Type convertStructTypePacked(spirv::StructType type, const TypeConverter &converter)
Converts SPIR-V struct with no offset to packed LLVM struct.
static std::optional< Type > convertRuntimeArrayType(spirv::RuntimeArrayType type, TypeConverter &converter)
Converts SPIR-V runtime array to LLVM array.
static bool isSignedIntegerOrVector(Type type)
Returns true if the given type is a signed integer or vector type.
static bool isUnsignedIntegerOrVector(Type type)
Returns true if the given type is an unsigned integer or vector type.
static std::optional< Type > convertArrayType(spirv::ArrayType type, TypeConverter &converter)
Converts SPIR-V array type to LLVM array.
static constexpr StringRef kBinding
Hook for descriptor set and binding number encoding.
static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder)
Creates IntegerAttribute with all bits set for given type.
static Value optionallyBroadcast(Location loc, Value value, Type srcType, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value. If srcType is a scalar, the value remains unchanged.
static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter)
Creates llvm.mlir.constant with all bits set for the given type.
static unsigned getLLVMTypeBitWidth(Type type)
Returns the bit width of LLVMType integer or vector.
#define DISPATCH(functionControl, llvmAttr)
static std::optional< uint64_t > getIntegerOrVectorElementWidth(Type type)
Returns the width of an integer or of the element type of an integer vector, if applicable.
static Type convertStructTypeWithOffset(spirv::StructType type, const TypeConverter &converter)
Converts SPIR-V struct with a regular (according to VulkanLayoutUtils) offset to LLVM struct.
static Type convertStructType(spirv::StructType type, const TypeConverter &converter)
Converts SPIR-V struct to LLVM struct.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
OpListType::iterator iterator
Operation * getTerminator()
Get the terminator operation of this block.
OpListType & getOperations()
BlockArgListType getArguments()
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI32IntegerAttr(int32_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * getBlock() const
Returns the current block of the builder.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
operand_range getOperands()
Returns an iterator on the underlying Value's.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, StringAttr newSymbol, Operation *from)
Attempt to replace all uses of the given symbol 'oldSymbol' with the provided symbol 'newSymbol' that...
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
static void setSymbolName(Operation *symbol, StringAttr name)
Sets the name of the given symbol operation.
This class provides all of the information necessary to convert a type signature.
void addConversion(FnT &&callback)
Register a conversion function.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given types, filling 'results' as necessary.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static spirv::StructType decorateType(spirv::StructType structType)
Returns a new StructType with layout decoration.
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Builder from ArrayRef<T>.
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
unsigned getNumElements() const
StorageClass getStorageClass() const
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
TypeRange getElementTypes() const
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
SmallVector< IntT > convertArrayToIndices(ArrayRef< Attribute > attrs)
Convert an array of integer attributes to a vector of integers that can be used as indices in LLVM op...
Include the generated interface declarations.
unsigned storageClassToAddressSpace(spirv::ClientAPI clientAPI, spirv::StorageClass storageClass)
void populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter, spirv::ClientAPI clientAPIForAddressSpaceMapping=spirv::ClientAPI::Unknown)
Populates type conversions with additional SPIR-V types.
void populateSPIRVToLLVMFunctionConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given list with patterns for function conversion from SPIR-V to LLVM.
const FrozenRewritePatternSet & patterns
void populateSPIRVToLLVMConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, spirv::ClientAPI clientAPIForAddressSpaceMapping=spirv::ClientAPI::Unknown)
Populates the given list with patterns that convert from SPIR-V to LLVM.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void encodeBindAttribute(ModuleOp module)
Encodes global variable's descriptor set and binding into its name if they both exist.
void populateSPIRVToLLVMModuleConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given patterns for module conversion from SPIR-V to LLVM.