23 #include "llvm/ADT/TypeSwitch.h"
24 #include "llvm/Support/FormatVariadic.h"
26 #define DEBUG_TYPE "spirv-to-llvm-pattern"
38 if (
auto vecType = dyn_cast<VectorType>(type))
39 return vecType.getElementType().isSignedInteger();
47 if (
auto vecType = dyn_cast<VectorType>(type))
48 return vecType.getElementType().isUnsignedInteger();
55 if (
auto intType = dyn_cast<IntegerType>(type))
56 return intType.getWidth();
57 if (
auto vecType = dyn_cast<VectorType>(type))
58 if (
auto intType = dyn_cast<IntegerType>(vecType.getElementType()))
59 return intType.getWidth();
66 "bitwidth is not supported for this type");
69 auto vecType = dyn_cast<VectorType>(type);
70 auto elementType = vecType.getElementType();
71 assert(elementType.isIntOrFloat() &&
72 "only integers and floats have a bitwidth");
73 return elementType.getIntOrFloatBitWidth();
78 if (
auto vecTy = dyn_cast<VectorType>(type))
79 type = vecTy.getElementType();
80 return cast<IntegerType>(type).getWidth();
85 if (
auto vecType = dyn_cast<VectorType>(type)) {
86 auto integerType = cast<IntegerType>(vecType.getElementType());
89 auto integerType = cast<IntegerType>(type);
96 if (isa<VectorType>(srcType)) {
97 return rewriter.
create<LLVM::ConstantOp>(
102 return rewriter.
create<LLVM::ConstantOp>(
109 if (
auto vecType = dyn_cast<VectorType>(srcType)) {
110 auto floatType = cast<FloatType>(vecType.getElementType());
111 return rewriter.
create<LLVM::ConstantOp>(
116 auto floatType = cast<FloatType>(srcType);
117 return rewriter.
create<LLVM::ConstantOp>(
130 auto srcType = value.
getType();
136 if (valueBitWidth < targetBitWidth)
137 return rewriter.
create<LLVM::ZExtOp>(loc, llvmType, value);
142 if (valueBitWidth > targetBitWidth)
143 return rewriter.
create<LLVM::TruncOp>(loc, llvmType, value);
152 auto llvmVectorType = typeConverter.
convertType(vectorType);
154 Value broadcasted = rewriter.
create<LLVM::PoisonOp>(loc, llvmVectorType);
155 for (
unsigned i = 0; i < numElements; ++i) {
156 auto index = rewriter.
create<LLVM::ConstantOp>(
158 broadcasted = rewriter.
create<LLVM::InsertElementOp>(
159 loc, llvmVectorType, broadcasted, toBroadcast, index);
168 if (
auto vectorType = dyn_cast<VectorType>(srcType)) {
169 unsigned numElements = vectorType.getNumElements();
170 return broadcast(loc, value, numElements, typeConverter, rewriter);
203 return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
213 return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
220 return rewriter.
create<LLVM::ConstantOp>(
229 unsigned alignment,
bool isVolatile,
230 bool isNonTemporal) {
231 if (
auto loadOp = dyn_cast<spirv::LoadOp>(op)) {
232 auto dstType = typeConverter.
convertType(loadOp.getType());
236 loadOp, dstType, spirv::LoadOpAdaptor(operands).getPtr(), alignment,
237 isVolatile, isNonTemporal);
240 auto storeOp = cast<spirv::StoreOp>(op);
241 spirv::StoreOpAdaptor adaptor(operands);
243 adaptor.getPtr(), alignment,
244 isVolatile, isNonTemporal);
259 auto sizeInBytes = cast<spirv::SPIRVType>(elementType).getSizeInBytes();
260 if (stride != 0 && (!sizeInBytes || *sizeInBytes != stride))
263 auto llvmElementType = converter.
convertType(elementType);
272 spirv::ClientAPI clientAPI) {
273 unsigned addressSpace =
295 if (!memberDecorations.empty())
313 matchAndRewrite(spirv::AccessChainOp op, OpAdaptor adaptor,
316 getTypeConverter()->convertType(op.getComponentPtr().getType());
320 auto indices = llvm::to_vector<4>(adaptor.getIndices());
321 Type indexType = op.getIndices().front().getType();
322 auto llvmIndexType = getTypeConverter()->convertType(indexType);
326 op.getLoc(), llvmIndexType, rewriter.
getIntegerAttr(indexType, 0));
327 indices.insert(indices.begin(), zero);
329 auto elementType = getTypeConverter()->convertType(
330 cast<spirv::PointerType>(op.getBasePtr().getType()).getPointeeType());
334 adaptor.getBasePtr(), indices);
344 matchAndRewrite(spirv::AddressOfOp op, OpAdaptor adaptor,
346 auto dstType = getTypeConverter()->convertType(op.getPointer().getType());
355 class BitFieldInsertPattern
361 matchAndRewrite(spirv::BitFieldInsertOp op, OpAdaptor adaptor,
363 auto srcType = op.getType();
364 auto dstType = getTypeConverter()->convertType(srcType);
371 *getTypeConverter(), rewriter);
373 *getTypeConverter(), rewriter);
377 Value maskShiftedByCount =
378 rewriter.
create<LLVM::ShlOp>(loc, dstType, minusOne, count);
379 Value negated = rewriter.
create<LLVM::XOrOp>(loc, dstType,
380 maskShiftedByCount, minusOne);
381 Value maskShiftedByCountAndOffset =
382 rewriter.
create<LLVM::ShlOp>(loc, dstType, negated, offset);
384 loc, dstType, maskShiftedByCountAndOffset, minusOne);
389 rewriter.
create<LLVM::AndOp>(loc, dstType, op.getBase(), mask);
390 Value insertShiftedByOffset =
391 rewriter.
create<LLVM::ShlOp>(loc, dstType, op.getInsert(), offset);
393 insertShiftedByOffset);
399 class ConstantScalarAndVectorPattern
405 matchAndRewrite(spirv::ConstantOp constOp, OpAdaptor adaptor,
407 auto srcType = constOp.getType();
408 if (!isa<VectorType>(srcType) && !srcType.isIntOrFloat())
411 auto dstType = getTypeConverter()->convertType(srcType);
424 if (isa<VectorType>(srcType)) {
425 auto dstElementsAttr = cast<DenseIntElementsAttr>(constOp.getValue());
428 dstElementsAttr.mapValues(
429 signlessType, [&](
const APInt &value) {
return value; }));
432 auto srcAttr = cast<IntegerAttr>(constOp.getValue());
433 auto dstAttr = rewriter.
getIntegerAttr(signlessType, srcAttr.getValue());
438 constOp, dstType, adaptor.getOperands(), constOp->getAttrs());
443 class BitFieldSExtractPattern
449 matchAndRewrite(spirv::BitFieldSExtractOp op, OpAdaptor adaptor,
451 auto srcType = op.getType();
452 auto dstType = getTypeConverter()->convertType(srcType);
459 *getTypeConverter(), rewriter);
461 *getTypeConverter(), rewriter);
464 IntegerType integerType;
465 if (
auto vecType = dyn_cast<VectorType>(srcType))
466 integerType = cast<IntegerType>(vecType.getElementType());
468 integerType = cast<IntegerType>(srcType);
472 isa<VectorType>(srcType)
473 ? rewriter.
create<LLVM::ConstantOp>(
476 : rewriter.
create<LLVM::ConstantOp>(loc, dstType, baseSize);
480 Value countPlusOffset =
481 rewriter.
create<LLVM::AddOp>(loc, dstType, count, offset);
482 Value amountToShiftLeft =
483 rewriter.
create<LLVM::SubOp>(loc, dstType, size, countPlusOffset);
484 Value baseShiftedLeft = rewriter.
create<LLVM::ShlOp>(
485 loc, dstType, op.getBase(), amountToShiftLeft);
488 Value amountToShiftRight =
489 rewriter.
create<LLVM::AddOp>(loc, dstType, offset, amountToShiftLeft);
496 class BitFieldUExtractPattern
502 matchAndRewrite(spirv::BitFieldUExtractOp op, OpAdaptor adaptor,
504 auto srcType = op.getType();
505 auto dstType = getTypeConverter()->convertType(srcType);
512 *getTypeConverter(), rewriter);
514 *getTypeConverter(), rewriter);
518 Value maskShiftedByCount =
519 rewriter.
create<LLVM::ShlOp>(loc, dstType, minusOne, count);
520 Value mask = rewriter.
create<LLVM::XOrOp>(loc, dstType, maskShiftedByCount,
525 rewriter.
create<LLVM::LShrOp>(loc, dstType, op.getBase(), offset);
536 matchAndRewrite(spirv::BranchOp branchOp, OpAdaptor adaptor,
539 branchOp.getTarget());
544 class BranchConditionalConversionPattern
551 matchAndRewrite(spirv::BranchConditionalOp op, OpAdaptor adaptor,
555 if (
auto weights = op.getBranchWeights()) {
557 for (
auto weight : weights->getAsRange<IntegerAttr>())
558 weightValues.push_back(weight.getInt());
563 op, op.getCondition(), op.getTrueBlockArguments(),
564 op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(),
573 class CompositeExtractPattern
579 matchAndRewrite(spirv::CompositeExtractOp op, OpAdaptor adaptor,
581 auto dstType = this->getTypeConverter()->convertType(op.getType());
585 Type containerType = op.getComposite().getType();
586 if (isa<VectorType>(containerType)) {
588 IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]);
591 op, dstType, adaptor.getComposite(), index);
596 op, adaptor.getComposite(),
605 class CompositeInsertPattern
611 matchAndRewrite(spirv::CompositeInsertOp op, OpAdaptor adaptor,
613 auto dstType = this->getTypeConverter()->convertType(op.getType());
617 Type containerType = op.getComposite().getType();
618 if (isa<VectorType>(containerType)) {
620 IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]);
623 op, dstType, adaptor.getComposite(), adaptor.getObject(), index);
628 op, adaptor.getComposite(), adaptor.getObject(),
636 template <
typename SPIRVOp,
typename LLVMOp>
642 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
644 auto dstType = this->getTypeConverter()->convertType(op.getType());
647 rewriter.template replaceOpWithNewOp<LLVMOp>(
648 op, dstType, adaptor.getOperands(), op->getAttrs());
655 class ExecutionModePattern
661 matchAndRewrite(spirv::ExecutionModeOp op, OpAdaptor adaptor,
666 ModuleOp module = op->getParentOfType<ModuleOp>();
667 spirv::ExecutionModeAttr executionModeAttr = op.getExecutionModeAttr();
668 std::string moduleName;
669 if (module.getName().has_value())
670 moduleName =
"_" + module.getName()->str();
673 std::string executionModeInfoName = llvm::formatv(
674 "__spv_{0}_{1}_execution_mode_info_{2}", moduleName, op.getFn().str(),
675 static_cast<uint32_t
>(executionModeAttr.getValue()));
688 fields.push_back(llvmI32Type);
689 ArrayAttr values = op.getValues();
690 if (!values.empty()) {
692 fields.push_back(arrayType);
694 auto structType = LLVM::LLVMStructType::getLiteral(context, fields);
697 auto global = rewriter.
create<LLVM::GlobalOp>(
699 LLVM::Linkage::External, executionModeInfoName,
Attribute(),
702 Region ®ion = global.getInitializerRegion();
707 Value structValue = rewriter.
create<LLVM::PoisonOp>(loc, structType);
708 Value executionMode = rewriter.
create<LLVM::ConstantOp>(
711 static_cast<uint32_t
>(executionModeAttr.getValue())));
712 structValue = rewriter.
create<LLVM::InsertValueOp>(loc, structValue,
716 for (
unsigned i = 0, e = values.size(); i < e; ++i) {
717 auto attr = values.getValue()[i];
718 Value entry = rewriter.
create<LLVM::ConstantOp>(loc, llvmI32Type, attr);
719 structValue = rewriter.
create<LLVM::InsertValueOp>(
732 class GlobalVariablePattern
735 template <
typename... Args>
736 GlobalVariablePattern(spirv::ClientAPI clientAPI, Args &&...args)
738 std::forward<Args>(args)...),
739 clientAPI(clientAPI) {}
742 matchAndRewrite(spirv::GlobalVariableOp op, OpAdaptor adaptor,
746 if (op.getInitializer())
749 auto srcType = cast<spirv::PointerType>(op.getType());
750 auto dstType = getTypeConverter()->convertType(srcType.getPointeeType());
757 auto storageClass = srcType.getStorageClass();
758 switch (storageClass) {
759 case spirv::StorageClass::Input:
760 case spirv::StorageClass::Private:
761 case spirv::StorageClass::Output:
762 case spirv::StorageClass::StorageBuffer:
763 case spirv::StorageClass::UniformConstant:
772 bool isConstant = (storageClass == spirv::StorageClass::Input) ||
773 (storageClass == spirv::StorageClass::UniformConstant);
779 auto linkage = storageClass == spirv::StorageClass::Private
780 ? LLVM::Linkage::Private
781 : LLVM::Linkage::External;
783 op, dstType, isConstant, linkage, op.getSymName(),
Attribute(),
787 if (op.getLocationAttr())
788 newGlobalOp->setAttr(op.getLocationAttrName(), op.getLocationAttr());
794 spirv::ClientAPI clientAPI;
799 template <
typename SPIRVOp,
typename LLVMExtOp,
typename LLVMTruncOp>
805 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
808 Type fromType = op.getOperand().getType();
809 Type toType = op.getType();
811 auto dstType = this->getTypeConverter()->convertType(toType);
816 rewriter.template replaceOpWithNewOp<LLVMExtOp>(op, dstType,
817 adaptor.getOperands());
821 rewriter.template replaceOpWithNewOp<LLVMTruncOp>(op, dstType,
822 adaptor.getOperands());
829 class FunctionCallPattern
835 matchAndRewrite(spirv::FunctionCallOp callOp, OpAdaptor adaptor,
837 if (callOp.getNumResults() == 0) {
839 callOp,
TypeRange(), adaptor.getOperands(), callOp->getAttrs());
840 newOp.getProperties().operandSegmentSizes = {
841 static_cast<int32_t
>(adaptor.getOperands().size()), 0};
847 auto dstType = getTypeConverter()->convertType(callOp.getType(0));
851 callOp, dstType, adaptor.getOperands(), callOp->getAttrs());
852 newOp.getProperties().operandSegmentSizes = {
853 static_cast<int32_t
>(adaptor.getOperands().size()), 0};
860 template <
typename SPIRVOp, LLVM::FCmpPredicate predicate>
866 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
869 auto dstType = this->getTypeConverter()->convertType(op.getType());
873 rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>(
874 op, dstType, predicate, op.getOperand1(), op.getOperand2());
880 template <
typename SPIRVOp, LLVM::ICmpPredicate predicate>
886 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
889 auto dstType = this->getTypeConverter()->convertType(op.getType());
893 rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>(
894 op, dstType, predicate, op.getOperand1(), op.getOperand2());
899 class InverseSqrtPattern
905 matchAndRewrite(spirv::GLInverseSqrtOp op, OpAdaptor adaptor,
907 auto srcType = op.getType();
908 auto dstType = getTypeConverter()->convertType(srcType);
914 Value sqrt = rewriter.
create<LLVM::SqrtOp>(loc, dstType, op.getOperand());
921 template <
typename SPIRVOp>
927 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
929 if (!op.getMemoryAccess()) {
931 *this->getTypeConverter(), 0,
935 auto memoryAccess = *op.getMemoryAccess();
936 switch (memoryAccess) {
937 case spirv::MemoryAccess::Aligned:
939 case spirv::MemoryAccess::Nontemporal:
940 case spirv::MemoryAccess::Volatile: {
942 memoryAccess == spirv::MemoryAccess::Aligned ? *op.getAlignment() : 0;
943 bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal;
944 bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile;
946 *this->getTypeConverter(), alignment,
947 isVolatile, isNonTemporal);
957 template <
typename SPIRVOp>
963 matchAndRewrite(SPIRVOp notOp,
typename SPIRVOp::Adaptor adaptor,
965 auto srcType = notOp.getType();
966 auto dstType = this->getTypeConverter()->convertType(srcType);
973 isa<VectorType>(srcType)
974 ? rewriter.
create<LLVM::ConstantOp>(
977 : rewriter.
create<LLVM::ConstantOp>(loc, dstType, minusOne);
978 rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType,
979 notOp.getOperand(), mask);
985 template <
typename SPIRVOp>
991 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
1003 matchAndRewrite(spirv::ReturnOp returnOp, OpAdaptor adaptor,
1016 matchAndRewrite(spirv::ReturnValueOp returnValueOp, OpAdaptor adaptor,
1019 adaptor.getOperands());
1028 bool convergent =
true) {
1029 auto func = dyn_cast_or_null<LLVM::LLVMFuncOp>(
1035 func = b.create<LLVM::LLVMFuncOp>(
1036 symbolTable->
getLoc(), name,
1038 func.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
1039 func.setConvergent(convergent);
1040 func.setNoUnwind(
true);
1041 func.setWillReturn(
true);
1046 LLVM::LLVMFuncOp func,
1048 auto call = builder.
create<LLVM::CallOp>(loc, func, args);
1049 call.setCConv(func.getCConv());
1050 call.setConvergentAttr(func.getConvergentAttr());
1051 call.setNoUnwindAttr(func.getNoUnwindAttr());
1052 call.setWillReturnAttr(func.getWillReturnAttr());
1056 template <
typename BarrierOpTy>
1063 static constexpr StringRef getFuncName();
1066 matchAndRewrite(BarrierOpTy controlBarrierOp, OpAdaptor adaptor,
1068 constexpr StringRef funcName = getFuncName();
1070 controlBarrierOp->template getParentWithTrait<OpTrait::SymbolTable>();
1074 Type voidTy = rewriter.
getType<LLVM::LLVMVoidType>();
1075 LLVM::LLVMFuncOp func =
1078 Location loc = controlBarrierOp->getLoc();
1079 Value execution = rewriter.
create<LLVM::ConstantOp>(
1080 loc, i32,
static_cast<int32_t
>(adaptor.getExecutionScope()));
1082 loc, i32,
static_cast<int32_t
>(adaptor.getMemoryScope()));
1083 Value semantics = rewriter.
create<LLVM::ConstantOp>(
1084 loc, i32,
static_cast<int32_t
>(adaptor.getMemorySemantics()));
1087 {execution, memory, semantics});
1089 rewriter.
replaceOp(controlBarrierOp, call);
1096 StringRef getTypeMangling(
Type type,
bool isSigned) {
1098 .Case<Float16Type>([](
auto) {
return "Dh"; })
1099 .Case<Float32Type>([](
auto) {
return "f"; })
1100 .Case<Float64Type>([](
auto) {
return "d"; })
1101 .Case<IntegerType>([isSigned](IntegerType intTy) {
1102 switch (intTy.getWidth()) {
1106 return (isSigned) ?
"a" :
"c";
1108 return (isSigned) ?
"s" :
"t";
1110 return (isSigned) ?
"i" :
"j";
1112 return (isSigned) ?
"l" :
"m";
1114 llvm_unreachable(
"Unsupported integer width");
1118 llvm_unreachable(
"No mangling defined");
1123 template <
typename ReduceOp>
1124 constexpr StringLiteral getGroupFuncName();
1127 constexpr StringLiteral getGroupFuncName<spirv::GroupIAddOp>() {
1128 return "_Z17__spirv_GroupIAddii";
1131 constexpr StringLiteral getGroupFuncName<spirv::GroupFAddOp>() {
1132 return "_Z17__spirv_GroupFAddii";
1135 constexpr StringLiteral getGroupFuncName<spirv::GroupSMinOp>() {
1136 return "_Z17__spirv_GroupSMinii";
1139 constexpr StringLiteral getGroupFuncName<spirv::GroupUMinOp>() {
1140 return "_Z17__spirv_GroupUMinii";
1143 constexpr StringLiteral getGroupFuncName<spirv::GroupFMinOp>() {
1144 return "_Z17__spirv_GroupFMinii";
1147 constexpr StringLiteral getGroupFuncName<spirv::GroupSMaxOp>() {
1148 return "_Z17__spirv_GroupSMaxii";
1151 constexpr StringLiteral getGroupFuncName<spirv::GroupUMaxOp>() {
1152 return "_Z17__spirv_GroupUMaxii";
1155 constexpr StringLiteral getGroupFuncName<spirv::GroupFMaxOp>() {
1156 return "_Z17__spirv_GroupFMaxii";
1159 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIAddOp>() {
1160 return "_Z27__spirv_GroupNonUniformIAddii";
1163 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFAddOp>() {
1164 return "_Z27__spirv_GroupNonUniformFAddii";
1167 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIMulOp>() {
1168 return "_Z27__spirv_GroupNonUniformIMulii";
1171 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMulOp>() {
1172 return "_Z27__spirv_GroupNonUniformFMulii";
1175 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMinOp>() {
1176 return "_Z27__spirv_GroupNonUniformSMinii";
1179 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMinOp>() {
1180 return "_Z27__spirv_GroupNonUniformUMinii";
1183 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMinOp>() {
1184 return "_Z27__spirv_GroupNonUniformFMinii";
1187 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMaxOp>() {
1188 return "_Z27__spirv_GroupNonUniformSMaxii";
1191 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMaxOp>() {
1192 return "_Z27__spirv_GroupNonUniformUMaxii";
1195 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMaxOp>() {
1196 return "_Z27__spirv_GroupNonUniformFMaxii";
1199 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseAndOp>() {
1200 return "_Z33__spirv_GroupNonUniformBitwiseAndii";
1203 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseOrOp>() {
1204 return "_Z32__spirv_GroupNonUniformBitwiseOrii";
1207 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseXorOp>() {
1208 return "_Z33__spirv_GroupNonUniformBitwiseXorii";
1211 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalAndOp>() {
1212 return "_Z33__spirv_GroupNonUniformLogicalAndii";
1215 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalOrOp>() {
1216 return "_Z32__spirv_GroupNonUniformLogicalOrii";
1219 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalXorOp>() {
1220 return "_Z33__spirv_GroupNonUniformLogicalXorii";
1224 template <
typename ReduceOp,
bool Signed = false,
bool NonUniform = false>
1230 matchAndRewrite(ReduceOp op,
typename ReduceOp::Adaptor adaptor,
1233 Type retTy = op.getResult().getType();
1238 funcName += getTypeMangling(retTy,
false);
1242 if constexpr (NonUniform) {
1243 if (adaptor.getClusterSize()) {
1245 paramTypes.push_back(i32Ty);
1250 op->template getParentWithTrait<OpTrait::SymbolTable>();
1252 LLVM::LLVMFuncOp func =
1257 loc, i32Ty,
static_cast<int32_t
>(adaptor.getExecutionScope()));
1258 Value groupOp = rewriter.
create<LLVM::ConstantOp>(
1259 loc, i32Ty,
static_cast<int32_t
>(adaptor.getGroupOperation()));
1261 operands.append(adaptor.getOperands().begin(), adaptor.getOperands().end());
1271 ControlBarrierPattern<spirv::ControlBarrierOp>::getFuncName() {
1272 return "_Z22__spirv_ControlBarrieriii";
1277 ControlBarrierPattern<spirv::INTELControlBarrierArriveOp>::getFuncName() {
1278 return "_Z33__spirv_ControlBarrierArriveINTELiii";
1283 ControlBarrierPattern<spirv::INTELControlBarrierWaitOp>::getFuncName() {
1284 return "_Z31__spirv_ControlBarrierWaitINTELiii";
1340 matchAndRewrite(spirv::LoopOp loopOp, OpAdaptor adaptor,
1347 if (loopOp.getBody().empty()) {
1362 Block *entryBlock = loopOp.getEntryBlock();
1364 auto brOp = dyn_cast<spirv::BranchOp>(entryBlock->
getOperations().front());
1367 Block *headerBlock = loopOp.getHeaderBlock();
1369 rewriter.
create<LLVM::BrOp>(loc, brOp.getBlockArguments(), headerBlock);
1373 Block *mergeBlock = loopOp.getMergeBlock();
1377 rewriter.
create<LLVM::BrOp>(loc, terminatorOperands, endBlock);
1393 matchAndRewrite(spirv::SelectionOp op, OpAdaptor adaptor,
1405 if (op.getBody().getBlocks().size() <= 2) {
1417 auto *continueBlock = rewriter.
splitBlock(currentBlock, position);
1423 auto *headerBlock = op.getHeaderBlock();
1425 auto condBrOp = dyn_cast<spirv::BranchConditionalOp>(
1432 auto *mergeBlock = op.getMergeBlock();
1436 rewriter.
create<LLVM::BrOp>(loc, terminatorOperands, continueBlock);
1439 Block *trueBlock = condBrOp.getTrueBlock();
1440 Block *falseBlock = condBrOp.getFalseBlock();
1442 rewriter.
create<LLVM::CondBrOp>(loc, condBrOp.getCondition(), trueBlock,
1443 condBrOp.getTrueTargetOperands(),
1445 condBrOp.getFalseTargetOperands());
1448 rewriter.
replaceOp(op, continueBlock->getArguments());
1457 template <
typename SPIRVOp,
typename LLVMOp>
1463 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
1466 auto dstType = this->getTypeConverter()->convertType(op.getType());
1470 Type op1Type = op.getOperand1().getType();
1471 Type op2Type = op.getOperand2().getType();
1473 if (op1Type == op2Type) {
1474 rewriter.template replaceOpWithNewOp<LLVMOp>(op, dstType,
1475 adaptor.getOperands());
1479 std::optional<uint64_t> dstTypeWidth =
1481 std::optional<uint64_t> op2TypeWidth =
1484 if (!dstTypeWidth || !op2TypeWidth)
1489 if (op2TypeWidth < dstTypeWidth) {
1491 extended = rewriter.template create<LLVM::ZExtOp>(
1492 loc, dstType, adaptor.getOperand2());
1494 extended = rewriter.template create<LLVM::SExtOp>(
1495 loc, dstType, adaptor.getOperand2());
1497 }
else if (op2TypeWidth == dstTypeWidth) {
1498 extended = adaptor.getOperand2();
1503 Value result = rewriter.template create<LLVMOp>(
1504 loc, dstType, adaptor.getOperand1(), extended);
1515 matchAndRewrite(spirv::GLTanOp tanOp, OpAdaptor adaptor,
1517 auto dstType = getTypeConverter()->convertType(tanOp.getType());
1522 Value sin = rewriter.
create<LLVM::SinOp>(loc, dstType, tanOp.getOperand());
1523 Value cos = rewriter.
create<LLVM::CosOp>(loc, dstType, tanOp.getOperand());
1540 matchAndRewrite(spirv::GLTanhOp tanhOp, OpAdaptor adaptor,
1542 auto srcType = tanhOp.getType();
1543 auto dstType = getTypeConverter()->convertType(srcType);
1550 rewriter.
create<LLVM::FMulOp>(loc, dstType, two, tanhOp.getOperand());
1551 Value exponential = rewriter.
create<LLVM::ExpOp>(loc, dstType, multiplied);
1554 rewriter.
create<LLVM::FSubOp>(loc, dstType, exponential, one);
1556 rewriter.
create<LLVM::FAddOp>(loc, dstType, exponential, one);
1568 matchAndRewrite(spirv::VariableOp varOp, OpAdaptor adaptor,
1570 auto srcType = varOp.getType();
1572 auto pointerTo = cast<spirv::PointerType>(srcType).getPointeeType();
1573 auto init = varOp.getInitializer();
1574 if (init && !pointerTo.isIntOrFloat() && !isa<VectorType>(pointerTo))
1577 auto dstType = getTypeConverter()->convertType(srcType);
1584 auto elementType = getTypeConverter()->convertType(pointerTo);
1591 auto elementType = getTypeConverter()->convertType(pointerTo);
1595 rewriter.
create<LLVM::AllocaOp>(loc, dstType, elementType, size);
1596 rewriter.
create<LLVM::StoreOp>(loc, adaptor.getInitializer(), allocated);
1606 class BitcastConversionPattern
1612 matchAndRewrite(spirv::BitcastOp bitcastOp, OpAdaptor adaptor,
1614 auto dstType = getTypeConverter()->convertType(bitcastOp.getType());
1619 if (isa<LLVM::LLVMPointerType>(dstType)) {
1620 rewriter.
replaceOp(bitcastOp, adaptor.getOperand());
1625 bitcastOp, dstType, adaptor.getOperands(), bitcastOp->getAttrs());
1639 matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor,
1644 auto funcType = funcOp.getFunctionType();
1646 funcType.getNumInputs());
1648 ->convertFunctionSignature(
1650 false, signatureConverter);
1656 StringRef name = funcOp.getName();
1657 auto newFuncOp = rewriter.
create<LLVM::LLVMFuncOp>(loc, name, llvmType);
1661 switch (funcOp.getFunctionControl()) {
1662 case spirv::FunctionControl::Inline:
1663 newFuncOp.setAlwaysInline(
true);
1665 case spirv::FunctionControl::DontInline:
1666 newFuncOp.setNoInline(
true);
1669 #define DISPATCH(functionControl, llvmAttr) \
1670 case functionControl: \
1671 newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr})); \
1674 DISPATCH(spirv::FunctionControl::Pure,
1676 DISPATCH(spirv::FunctionControl::Const,
1690 &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter))) {
1707 matchAndRewrite(spirv::ModuleOp spvModuleOp, OpAdaptor adaptor,
1711 rewriter.
create<ModuleOp>(spvModuleOp.getLoc(), spvModuleOp.getName());
1715 rewriter.
eraseBlock(&newModuleOp.getBodyRegion().back());
1716 rewriter.
eraseOp(spvModuleOp);
1725 class VectorShufflePattern
1730 matchAndRewrite(spirv::VectorShuffleOp op, OpAdaptor adaptor,
1733 auto components = adaptor.getComponents();
1734 auto vector1 = adaptor.getVector1();
1735 auto vector2 = adaptor.getVector2();
1736 int vector1Size = cast<VectorType>(vector1.getType()).getNumElements();
1737 int vector2Size = cast<VectorType>(vector2.getType()).getNumElements();
1738 if (vector1Size == vector2Size) {
1740 op, vector1, vector2,
1741 LLVM::convertArrayToIndices<int32_t>(components));
1745 auto dstType = getTypeConverter()->convertType(op.getType());
1748 auto scalarType = cast<VectorType>(dstType).getElementType();
1749 auto componentsArray = components.getValue();
1752 Value targetOp = rewriter.
create<LLVM::PoisonOp>(loc, dstType);
1753 for (
unsigned i = 0; i < componentsArray.size(); i++) {
1754 if (!isa<IntegerAttr>(componentsArray[i]))
1755 return op.
emitError(
"unable to support non-constant component");
1757 int indexVal = cast<IntegerAttr>(componentsArray[i]).getInt();
1762 Value baseVector = vector1;
1763 if (indexVal >= vector1Size) {
1764 offsetVal = vector1Size;
1765 baseVector = vector2;
1768 Value dstIndex = rewriter.
create<LLVM::ConstantOp>(
1774 auto extractOp = rewriter.
create<LLVM::ExtractElementOp>(
1775 loc, scalarType, baseVector, index);
1776 targetOp = rewriter.
create<LLVM::InsertElementOp>(loc, dstType, targetOp,
1777 extractOp, dstIndex);
1790 spirv::ClientAPI clientAPI) {
1807 spirv::ClientAPI clientAPI) {
1810 DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>,
1811 DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>,
1812 DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>,
1813 DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>,
1814 DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>,
1815 DirectConversionPattern<spirv::FMulOp, LLVM::FMulOp>,
1816 DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>,
1817 DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>,
1818 DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>,
1819 DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>,
1820 DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>,
1821 DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>,
1822 DirectConversionPattern<spirv::UModOp, LLVM::URemOp>,
1825 BitFieldInsertPattern, BitFieldUExtractPattern, BitFieldSExtractPattern,
1826 DirectConversionPattern<spirv::BitCountOp, LLVM::CtPopOp>,
1827 DirectConversionPattern<spirv::BitReverseOp, LLVM::BitReverseOp>,
1828 DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>,
1829 DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>,
1830 DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>,
1831 NotPattern<spirv::NotOp>,
1834 BitcastConversionPattern,
1835 DirectConversionPattern<spirv::ConvertFToSOp, LLVM::FPToSIOp>,
1836 DirectConversionPattern<spirv::ConvertFToUOp, LLVM::FPToUIOp>,
1837 DirectConversionPattern<spirv::ConvertSToFOp, LLVM::SIToFPOp>,
1838 DirectConversionPattern<spirv::ConvertUToFOp, LLVM::UIToFPOp>,
1839 IndirectCastPattern<spirv::FConvertOp, LLVM::FPExtOp, LLVM::FPTruncOp>,
1840 IndirectCastPattern<spirv::SConvertOp, LLVM::SExtOp, LLVM::TruncOp>,
1841 IndirectCastPattern<spirv::UConvertOp, LLVM::ZExtOp, LLVM::TruncOp>,
1844 IComparePattern<spirv::IEqualOp, LLVM::ICmpPredicate::eq>,
1845 IComparePattern<spirv::INotEqualOp, LLVM::ICmpPredicate::ne>,
1846 FComparePattern<spirv::FOrdEqualOp, LLVM::FCmpPredicate::oeq>,
1847 FComparePattern<spirv::FOrdGreaterThanOp, LLVM::FCmpPredicate::ogt>,
1848 FComparePattern<spirv::FOrdGreaterThanEqualOp, LLVM::FCmpPredicate::oge>,
1849 FComparePattern<spirv::FOrdLessThanEqualOp, LLVM::FCmpPredicate::ole>,
1850 FComparePattern<spirv::FOrdLessThanOp, LLVM::FCmpPredicate::olt>,
1851 FComparePattern<spirv::FOrdNotEqualOp, LLVM::FCmpPredicate::one>,
1852 FComparePattern<spirv::FUnordEqualOp, LLVM::FCmpPredicate::ueq>,
1853 FComparePattern<spirv::FUnordGreaterThanOp, LLVM::FCmpPredicate::ugt>,
1854 FComparePattern<spirv::FUnordGreaterThanEqualOp,
1855 LLVM::FCmpPredicate::uge>,
1856 FComparePattern<spirv::FUnordLessThanEqualOp, LLVM::FCmpPredicate::ule>,
1857 FComparePattern<spirv::FUnordLessThanOp, LLVM::FCmpPredicate::ult>,
1858 FComparePattern<spirv::FUnordNotEqualOp, LLVM::FCmpPredicate::une>,
1859 IComparePattern<spirv::SGreaterThanOp, LLVM::ICmpPredicate::sgt>,
1860 IComparePattern<spirv::SGreaterThanEqualOp, LLVM::ICmpPredicate::sge>,
1861 IComparePattern<spirv::SLessThanEqualOp, LLVM::ICmpPredicate::sle>,
1862 IComparePattern<spirv::SLessThanOp, LLVM::ICmpPredicate::slt>,
1863 IComparePattern<spirv::UGreaterThanOp, LLVM::ICmpPredicate::ugt>,
1864 IComparePattern<spirv::UGreaterThanEqualOp, LLVM::ICmpPredicate::uge>,
1865 IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>,
1866 IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>,
1869 ConstantScalarAndVectorPattern,
1872 BranchConversionPattern, BranchConditionalConversionPattern,
1873 FunctionCallPattern, LoopPattern, SelectionPattern,
1874 ErasePattern<spirv::MergeOp>,
1877 ErasePattern<spirv::EntryPointOp>, ExecutionModePattern,
1880 DirectConversionPattern<spirv::GLCeilOp, LLVM::FCeilOp>,
1881 DirectConversionPattern<spirv::GLCosOp, LLVM::CosOp>,
1882 DirectConversionPattern<spirv::GLExpOp, LLVM::ExpOp>,
1883 DirectConversionPattern<spirv::GLFAbsOp, LLVM::FAbsOp>,
1884 DirectConversionPattern<spirv::GLFloorOp, LLVM::FFloorOp>,
1885 DirectConversionPattern<spirv::GLFMaxOp, LLVM::MaxNumOp>,
1886 DirectConversionPattern<spirv::GLFMinOp, LLVM::MinNumOp>,
1887 DirectConversionPattern<spirv::GLLogOp, LLVM::LogOp>,
1888 DirectConversionPattern<spirv::GLSinOp, LLVM::SinOp>,
1889 DirectConversionPattern<spirv::GLSMaxOp, LLVM::SMaxOp>,
1890 DirectConversionPattern<spirv::GLSMinOp, LLVM::SMinOp>,
1891 DirectConversionPattern<spirv::GLSqrtOp, LLVM::SqrtOp>,
1892 InverseSqrtPattern, TanPattern, TanhPattern,
1895 DirectConversionPattern<spirv::LogicalAndOp, LLVM::AndOp>,
1896 DirectConversionPattern<spirv::LogicalOrOp, LLVM::OrOp>,
1897 IComparePattern<spirv::LogicalEqualOp, LLVM::ICmpPredicate::eq>,
1898 IComparePattern<spirv::LogicalNotEqualOp, LLVM::ICmpPredicate::ne>,
1899 NotPattern<spirv::LogicalNotOp>,
1902 AccessChainPattern, AddressOfPattern, LoadStorePattern<spirv::LoadOp>,
1903 LoadStorePattern<spirv::StoreOp>, VariablePattern,
1906 CompositeExtractPattern, CompositeInsertPattern,
1907 DirectConversionPattern<spirv::SelectOp, LLVM::SelectOp>,
1908 DirectConversionPattern<spirv::UndefOp, LLVM::UndefOp>,
1909 VectorShufflePattern,
1912 ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>,
1913 ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>,
1914 ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>,
1917 ReturnPattern, ReturnValuePattern,
1920 ControlBarrierPattern<spirv::ControlBarrierOp>,
1921 ControlBarrierPattern<spirv::INTELControlBarrierArriveOp>,
1922 ControlBarrierPattern<spirv::INTELControlBarrierWaitOp>,
1925 GroupReducePattern<spirv::GroupIAddOp>,
1926 GroupReducePattern<spirv::GroupFAddOp>,
1927 GroupReducePattern<spirv::GroupFMinOp>,
1928 GroupReducePattern<spirv::GroupUMinOp>,
1929 GroupReducePattern<spirv::GroupSMinOp,
true>,
1930 GroupReducePattern<spirv::GroupFMaxOp>,
1931 GroupReducePattern<spirv::GroupUMaxOp>,
1932 GroupReducePattern<spirv::GroupSMaxOp,
true>,
1933 GroupReducePattern<spirv::GroupNonUniformIAddOp,
false,
1935 GroupReducePattern<spirv::GroupNonUniformFAddOp,
false,
1937 GroupReducePattern<spirv::GroupNonUniformIMulOp,
false,
1939 GroupReducePattern<spirv::GroupNonUniformFMulOp,
false,
1941 GroupReducePattern<spirv::GroupNonUniformSMinOp,
true,
1943 GroupReducePattern<spirv::GroupNonUniformUMinOp,
false,
1945 GroupReducePattern<spirv::GroupNonUniformFMinOp,
false,
1947 GroupReducePattern<spirv::GroupNonUniformSMaxOp,
true,
1949 GroupReducePattern<spirv::GroupNonUniformUMaxOp,
false,
1951 GroupReducePattern<spirv::GroupNonUniformFMaxOp,
false,
1953 GroupReducePattern<spirv::GroupNonUniformBitwiseAndOp,
false,
1955 GroupReducePattern<spirv::GroupNonUniformBitwiseOrOp,
false,
1957 GroupReducePattern<spirv::GroupNonUniformBitwiseXorOp,
false,
1959 GroupReducePattern<spirv::GroupNonUniformLogicalAndOp,
false,
1961 GroupReducePattern<spirv::GroupNonUniformLogicalOrOp,
false,
1963 GroupReducePattern<spirv::GroupNonUniformLogicalXorOp,
false,
1978 patterns.add<ModuleConversionPattern>(
patterns.getContext(), typeConverter);
1989 auto spvModules = module.getOps<spirv::ModuleOp>();
1990 for (
auto spvModule : spvModules) {
1991 spvModule.walk([&](spirv::GlobalVariableOp op) {
1992 IntegerAttr descriptorSet =
1994 IntegerAttr binding = op->getAttrOfType<IntegerAttr>(
kBinding);
1997 if (descriptorSet && binding) {
2000 auto moduleAndName =
2001 spvModule.getName().has_value()
2002 ? spvModule.getName()->str() +
"_" + op.getSymName().str()
2003 : op.getSymName().str();
2005 llvm::formatv(
"{0}_descriptor_set{1}_binding{2}", moduleAndName,
2006 std::to_string(descriptorSet.getInt()),
2007 std::to_string(binding.getInt()));
2013 op.emitError(
"unable to replace all symbol uses for ") << name;
static LLVM::CallOp createSPIRVBuiltinCall(Location loc, ConversionPatternRewriter &rewriter, LLVM::LLVMFuncOp func, ValueRange args)
static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable, StringRef name, ArrayRef< Type > paramTypes, Type resultType, bool isMemNone, bool isConvergent)
static MLIRContext * getContext(OpFoldResult val)
static Value optionallyTruncateOrExtend(Location loc, Value value, Type llvmType, PatternRewriter &rewriter)
Utility function for bitfield ops:
static Value createFPConstant(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter, double value)
Creates llvm.mlir.constant with a floating-point scalar or vector value.
static constexpr StringRef kDescriptorSet
static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter, unsigned value)
Creates LLVM dialect constant with the given value.
static Type convertPointerType(spirv::PointerType type, const TypeConverter &converter, spirv::ClientAPI clientAPI)
Converts SPIR-V pointer type to LLVM pointer.
static Value processCountOrOffset(Location loc, Value value, Type srcType, Type dstType, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Utility function for bitfield ops: BitFieldInsert, BitFieldSExtract and BitFieldUExtract.
static unsigned getBitWidth(Type type)
Returns the bit width of integer, float or vector of float or integer values.
static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter, const TypeConverter &typeConverter, unsigned alignment, bool isVolatile, bool isNonTemporal)
Utility for spirv.Load and spirv.Store conversion.
static Type convertStructTypePacked(spirv::StructType type, const TypeConverter &converter)
Converts SPIR-V struct with no offset to packed LLVM struct.
static std::optional< Type > convertRuntimeArrayType(spirv::RuntimeArrayType type, TypeConverter &converter)
Converts SPIR-V runtime array to LLVM array.
static bool isSignedIntegerOrVector(Type type)
Returns true if the given type is a signed integer or vector type.
static bool isUnsignedIntegerOrVector(Type type)
Returns true if the given type is an unsigned integer or vector type.
static std::optional< Type > convertArrayType(spirv::ArrayType type, TypeConverter &converter)
Converts SPIR-V array type to LLVM array.
static constexpr StringRef kBinding
Hook for descriptor set and binding number encoding.
static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder)
Creates IntegerAttribute with all bits set for given type.
static Value optionallyBroadcast(Location loc, Value value, Type srcType, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value. If srcType is a scalar, the value remains unchanged.
static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter)
Creates llvm.mlir.constant with all bits set for the given type.
static unsigned getLLVMTypeBitWidth(Type type)
Returns the bit width of LLVMType integer or vector.
#define DISPATCH(functionControl, llvmAttr)
static std::optional< uint64_t > getIntegerOrVectorElementWidth(Type type)
Returns the width of an integer or of the element type of an integer vector, if applicable.
static Type convertStructTypeWithOffset(spirv::StructType type, const TypeConverter &converter)
Converts SPIR-V struct with a regular (according to VulkanLayoutUtils) offset to LLVM struct.
static Type convertStructType(spirv::StructType type, const TypeConverter &converter)
Converts SPIR-V struct to LLVM struct.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
OpListType::iterator iterator
Operation * getTerminator()
Get the terminator operation of this block.
OpListType & getOperations()
BlockArgListType getArguments()
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI32IntegerAttr(int32_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * getBlock() const
Returns the current block of the builder.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
operand_range getOperands()
Returns an iterator on the underlying Value's.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, StringAttr newSymbol, Operation *from)
Attempt to replace all uses of the given symbol 'oldSymbol' with the provided symbol 'newSymbol' that...
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
static void setSymbolName(Operation *symbol, StringAttr name)
Sets the name of the given symbol operation.
This class provides all of the information necessary to convert a type signature.
void addConversion(FnT &&callback)
Register a conversion function.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static spirv::StructType decorateType(spirv::StructType structType)
Returns a new StructType with layout decoration.
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Builder from ArrayRef<T>.
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
unsigned getNumElements() const
StorageClass getStorageClass() const
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
TypeRange getElementTypes() const
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
SmallVector< IntT > convertArrayToIndices(ArrayRef< Attribute > attrs)
Convert an array of integer attributes to a vector of integers that can be used as indices in LLVM op...
Include the generated interface declarations.
unsigned storageClassToAddressSpace(spirv::ClientAPI clientAPI, spirv::StorageClass storageClass)
void populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter, spirv::ClientAPI clientAPIForAddressSpaceMapping=spirv::ClientAPI::Unknown)
Populates type conversions with additional SPIR-V types.
void populateSPIRVToLLVMFunctionConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given list with patterns for function conversion from SPIR-V to LLVM.
const FrozenRewritePatternSet & patterns
void populateSPIRVToLLVMConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, spirv::ClientAPI clientAPIForAddressSpaceMapping=spirv::ClientAPI::Unknown)
Populates the given list with patterns that convert from SPIR-V to LLVM.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void encodeBindAttribute(ModuleOp module)
Encodes global variable's descriptor set and binding into its name if they both exist.
void populateSPIRVToLLVMModuleConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given patterns for module conversion from SPIR-V to LLVM.