23 #include "llvm/ADT/TypeSwitch.h"
24 #include "llvm/Support/FormatVariadic.h"
26 #define DEBUG_TYPE "spirv-to-llvm-pattern"
38 if (
auto vecType = dyn_cast<VectorType>(type))
39 return vecType.getElementType().isSignedInteger();
47 if (
auto vecType = dyn_cast<VectorType>(type))
48 return vecType.getElementType().isUnsignedInteger();
55 if (
auto intType = dyn_cast<IntegerType>(type))
56 return intType.getWidth();
57 if (
auto vecType = dyn_cast<VectorType>(type))
58 if (
auto intType = dyn_cast<IntegerType>(vecType.getElementType()))
59 return intType.getWidth();
66 "bitwidth is not supported for this type");
69 auto vecType = dyn_cast<VectorType>(type);
70 auto elementType = vecType.getElementType();
71 assert(elementType.isIntOrFloat() &&
72 "only integers and floats have a bitwidth");
73 return elementType.getIntOrFloatBitWidth();
78 if (
auto vecTy = dyn_cast<VectorType>(type))
79 type = vecTy.getElementType();
80 return cast<IntegerType>(type).getWidth();
85 if (
auto vecType = dyn_cast<VectorType>(type)) {
86 auto integerType = cast<IntegerType>(vecType.getElementType());
89 auto integerType = cast<IntegerType>(type);
96 if (isa<VectorType>(srcType)) {
97 return rewriter.
create<LLVM::ConstantOp>(
102 return rewriter.
create<LLVM::ConstantOp>(
109 if (
auto vecType = dyn_cast<VectorType>(srcType)) {
110 auto floatType = cast<FloatType>(vecType.getElementType());
111 return rewriter.
create<LLVM::ConstantOp>(
116 auto floatType = cast<FloatType>(srcType);
117 return rewriter.
create<LLVM::ConstantOp>(
130 auto srcType = value.
getType();
136 if (valueBitWidth < targetBitWidth)
137 return rewriter.
create<LLVM::ZExtOp>(loc, llvmType, value);
142 if (valueBitWidth > targetBitWidth)
143 return rewriter.
create<LLVM::TruncOp>(loc, llvmType, value);
152 auto llvmVectorType = typeConverter.
convertType(vectorType);
154 Value broadcasted = rewriter.
create<LLVM::PoisonOp>(loc, llvmVectorType);
155 for (
unsigned i = 0; i < numElements; ++i) {
156 auto index = rewriter.
create<LLVM::ConstantOp>(
158 broadcasted = rewriter.
create<LLVM::InsertElementOp>(
159 loc, llvmVectorType, broadcasted, toBroadcast, index);
168 if (
auto vectorType = dyn_cast<VectorType>(srcType)) {
169 unsigned numElements = vectorType.getNumElements();
170 return broadcast(loc, value, numElements, typeConverter, rewriter);
203 return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
213 return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
220 return rewriter.
create<LLVM::ConstantOp>(
229 unsigned alignment,
bool isVolatile,
230 bool isNonTemporal) {
231 if (
auto loadOp = dyn_cast<spirv::LoadOp>(op)) {
232 auto dstType = typeConverter.
convertType(loadOp.getType());
236 loadOp, dstType, spirv::LoadOpAdaptor(operands).getPtr(), alignment,
237 isVolatile, isNonTemporal);
240 auto storeOp = cast<spirv::StoreOp>(op);
241 spirv::StoreOpAdaptor adaptor(operands);
243 adaptor.getPtr(), alignment,
244 isVolatile, isNonTemporal);
259 auto sizeInBytes = cast<spirv::SPIRVType>(elementType).getSizeInBytes();
260 if (stride != 0 && (!sizeInBytes || *sizeInBytes != stride))
263 auto llvmElementType = converter.
convertType(elementType);
272 spirv::ClientAPI clientAPI) {
273 unsigned addressSpace =
295 if (!memberDecorations.empty())
313 matchAndRewrite(spirv::AccessChainOp op, OpAdaptor adaptor,
316 getTypeConverter()->convertType(op.getComponentPtr().getType());
320 auto indices = llvm::to_vector<4>(adaptor.getIndices());
321 Type indexType = op.getIndices().front().getType();
322 auto llvmIndexType = getTypeConverter()->convertType(indexType);
326 op.getLoc(), llvmIndexType, rewriter.
getIntegerAttr(indexType, 0));
327 indices.insert(indices.begin(), zero);
329 auto elementType = getTypeConverter()->convertType(
330 cast<spirv::PointerType>(op.getBasePtr().getType()).getPointeeType());
334 adaptor.getBasePtr(), indices);
344 matchAndRewrite(spirv::AddressOfOp op, OpAdaptor adaptor,
346 auto dstType = getTypeConverter()->convertType(op.getPointer().getType());
355 class BitFieldInsertPattern
361 matchAndRewrite(spirv::BitFieldInsertOp op, OpAdaptor adaptor,
363 auto srcType = op.getType();
364 auto dstType = getTypeConverter()->convertType(srcType);
371 *getTypeConverter(), rewriter);
373 *getTypeConverter(), rewriter);
377 Value maskShiftedByCount =
378 rewriter.
create<LLVM::ShlOp>(loc, dstType, minusOne, count);
379 Value negated = rewriter.
create<LLVM::XOrOp>(loc, dstType,
380 maskShiftedByCount, minusOne);
381 Value maskShiftedByCountAndOffset =
382 rewriter.
create<LLVM::ShlOp>(loc, dstType, negated, offset);
384 loc, dstType, maskShiftedByCountAndOffset, minusOne);
389 rewriter.
create<LLVM::AndOp>(loc, dstType, op.getBase(), mask);
390 Value insertShiftedByOffset =
391 rewriter.
create<LLVM::ShlOp>(loc, dstType, op.getInsert(), offset);
393 insertShiftedByOffset);
399 class ConstantScalarAndVectorPattern
405 matchAndRewrite(spirv::ConstantOp constOp, OpAdaptor adaptor,
407 auto srcType = constOp.getType();
408 if (!isa<VectorType>(srcType) && !srcType.isIntOrFloat())
411 auto dstType = getTypeConverter()->convertType(srcType);
424 if (isa<VectorType>(srcType)) {
425 auto dstElementsAttr = cast<DenseIntElementsAttr>(constOp.getValue());
428 dstElementsAttr.mapValues(
429 signlessType, [&](
const APInt &value) {
return value; }));
432 auto srcAttr = cast<IntegerAttr>(constOp.getValue());
433 auto dstAttr = rewriter.
getIntegerAttr(signlessType, srcAttr.getValue());
438 constOp, dstType, adaptor.getOperands(), constOp->getAttrs());
443 class BitFieldSExtractPattern
449 matchAndRewrite(spirv::BitFieldSExtractOp op, OpAdaptor adaptor,
451 auto srcType = op.getType();
452 auto dstType = getTypeConverter()->convertType(srcType);
459 *getTypeConverter(), rewriter);
461 *getTypeConverter(), rewriter);
464 IntegerType integerType;
465 if (
auto vecType = dyn_cast<VectorType>(srcType))
466 integerType = cast<IntegerType>(vecType.getElementType());
468 integerType = cast<IntegerType>(srcType);
472 isa<VectorType>(srcType)
473 ? rewriter.
create<LLVM::ConstantOp>(
476 : rewriter.
create<LLVM::ConstantOp>(loc, dstType, baseSize);
480 Value countPlusOffset =
481 rewriter.
create<LLVM::AddOp>(loc, dstType, count, offset);
482 Value amountToShiftLeft =
483 rewriter.
create<LLVM::SubOp>(loc, dstType, size, countPlusOffset);
484 Value baseShiftedLeft = rewriter.
create<LLVM::ShlOp>(
485 loc, dstType, op.getBase(), amountToShiftLeft);
488 Value amountToShiftRight =
489 rewriter.
create<LLVM::AddOp>(loc, dstType, offset, amountToShiftLeft);
496 class BitFieldUExtractPattern
502 matchAndRewrite(spirv::BitFieldUExtractOp op, OpAdaptor adaptor,
504 auto srcType = op.getType();
505 auto dstType = getTypeConverter()->convertType(srcType);
512 *getTypeConverter(), rewriter);
514 *getTypeConverter(), rewriter);
518 Value maskShiftedByCount =
519 rewriter.
create<LLVM::ShlOp>(loc, dstType, minusOne, count);
520 Value mask = rewriter.
create<LLVM::XOrOp>(loc, dstType, maskShiftedByCount,
525 rewriter.
create<LLVM::LShrOp>(loc, dstType, op.getBase(), offset);
536 matchAndRewrite(spirv::BranchOp branchOp, OpAdaptor adaptor,
539 branchOp.getTarget());
544 class BranchConditionalConversionPattern
551 matchAndRewrite(spirv::BranchConditionalOp op, OpAdaptor adaptor,
555 if (
auto weights = op.getBranchWeights()) {
557 for (
auto weight : weights->getAsRange<IntegerAttr>())
558 weightValues.push_back(weight.getInt());
563 op, op.getCondition(), op.getTrueBlockArguments(),
564 op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(),
573 class CompositeExtractPattern
579 matchAndRewrite(spirv::CompositeExtractOp op, OpAdaptor adaptor,
581 auto dstType = this->getTypeConverter()->convertType(op.getType());
585 Type containerType = op.getComposite().getType();
586 if (isa<VectorType>(containerType)) {
588 IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]);
591 op, dstType, adaptor.getComposite(), index);
596 op, adaptor.getComposite(),
605 class CompositeInsertPattern
611 matchAndRewrite(spirv::CompositeInsertOp op, OpAdaptor adaptor,
613 auto dstType = this->getTypeConverter()->convertType(op.getType());
617 Type containerType = op.getComposite().getType();
618 if (isa<VectorType>(containerType)) {
620 IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]);
623 op, dstType, adaptor.getComposite(), adaptor.getObject(), index);
628 op, adaptor.getComposite(), adaptor.getObject(),
636 template <
typename SPIRVOp,
typename LLVMOp>
642 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
644 auto dstType = this->getTypeConverter()->convertType(op.getType());
647 rewriter.template replaceOpWithNewOp<LLVMOp>(
648 op, dstType, adaptor.getOperands(), op->getAttrs());
655 class ExecutionModePattern
661 matchAndRewrite(spirv::ExecutionModeOp op, OpAdaptor adaptor,
666 ModuleOp module = op->getParentOfType<ModuleOp>();
667 spirv::ExecutionModeAttr executionModeAttr = op.getExecutionModeAttr();
668 std::string moduleName;
669 if (module.getName().has_value())
670 moduleName =
"_" + module.getName()->str();
673 std::string executionModeInfoName = llvm::formatv(
674 "__spv_{0}_{1}_execution_mode_info_{2}", moduleName, op.getFn().str(),
675 static_cast<uint32_t
>(executionModeAttr.getValue()));
688 fields.push_back(llvmI32Type);
689 ArrayAttr values = op.getValues();
690 if (!values.empty()) {
692 fields.push_back(arrayType);
694 auto structType = LLVM::LLVMStructType::getLiteral(context, fields);
697 auto global = rewriter.
create<LLVM::GlobalOp>(
699 LLVM::Linkage::External, executionModeInfoName,
Attribute(),
702 Region ®ion = global.getInitializerRegion();
707 Value structValue = rewriter.
create<LLVM::PoisonOp>(loc, structType);
708 Value executionMode = rewriter.
create<LLVM::ConstantOp>(
711 static_cast<uint32_t
>(executionModeAttr.getValue())));
712 structValue = rewriter.
create<LLVM::InsertValueOp>(loc, structValue,
716 for (
unsigned i = 0, e = values.size(); i < e; ++i) {
717 auto attr = values.getValue()[i];
718 Value entry = rewriter.
create<LLVM::ConstantOp>(loc, llvmI32Type, attr);
719 structValue = rewriter.
create<LLVM::InsertValueOp>(
732 class GlobalVariablePattern
735 template <
typename... Args>
736 GlobalVariablePattern(spirv::ClientAPI clientAPI, Args &&...args)
738 std::forward<Args>(args)...),
739 clientAPI(clientAPI) {}
742 matchAndRewrite(spirv::GlobalVariableOp op, OpAdaptor adaptor,
746 if (op.getInitializer())
749 auto srcType = cast<spirv::PointerType>(op.getType());
750 auto dstType = getTypeConverter()->convertType(srcType.getPointeeType());
757 auto storageClass = srcType.getStorageClass();
758 switch (storageClass) {
759 case spirv::StorageClass::Input:
760 case spirv::StorageClass::Private:
761 case spirv::StorageClass::Output:
762 case spirv::StorageClass::StorageBuffer:
763 case spirv::StorageClass::UniformConstant:
772 bool isConstant = (storageClass == spirv::StorageClass::Input) ||
773 (storageClass == spirv::StorageClass::UniformConstant);
779 auto linkage = storageClass == spirv::StorageClass::Private
780 ? LLVM::Linkage::Private
781 : LLVM::Linkage::External;
782 StringAttr locationAttrName = op.getLocationAttrName();
783 IntegerAttr locationAttr = op.getLocationAttr();
785 op, dstType, isConstant, linkage, op.getSymName(),
Attribute(),
790 newGlobalOp->setAttr(locationAttrName, locationAttr);
796 spirv::ClientAPI clientAPI;
801 template <
typename SPIRVOp,
typename LLVMExtOp,
typename LLVMTruncOp>
807 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
810 Type fromType = op.getOperand().getType();
811 Type toType = op.getType();
813 auto dstType = this->getTypeConverter()->convertType(toType);
818 rewriter.template replaceOpWithNewOp<LLVMExtOp>(op, dstType,
819 adaptor.getOperands());
823 rewriter.template replaceOpWithNewOp<LLVMTruncOp>(op, dstType,
824 adaptor.getOperands());
831 class FunctionCallPattern
837 matchAndRewrite(spirv::FunctionCallOp callOp, OpAdaptor adaptor,
839 if (callOp.getNumResults() == 0) {
841 callOp,
TypeRange(), adaptor.getOperands(), callOp->getAttrs());
842 newOp.getProperties().operandSegmentSizes = {
843 static_cast<int32_t
>(adaptor.getOperands().size()), 0};
849 auto dstType = getTypeConverter()->convertType(callOp.getType(0));
853 callOp, dstType, adaptor.getOperands(), callOp->getAttrs());
854 newOp.getProperties().operandSegmentSizes = {
855 static_cast<int32_t
>(adaptor.getOperands().size()), 0};
862 template <
typename SPIRVOp, LLVM::FCmpPredicate predicate>
868 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
871 auto dstType = this->getTypeConverter()->convertType(op.getType());
875 rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>(
876 op, dstType, predicate, op.getOperand1(), op.getOperand2());
882 template <
typename SPIRVOp, LLVM::ICmpPredicate predicate>
888 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
891 auto dstType = this->getTypeConverter()->convertType(op.getType());
895 rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>(
896 op, dstType, predicate, op.getOperand1(), op.getOperand2());
901 class InverseSqrtPattern
907 matchAndRewrite(spirv::GLInverseSqrtOp op, OpAdaptor adaptor,
909 auto srcType = op.getType();
910 auto dstType = getTypeConverter()->convertType(srcType);
916 Value sqrt = rewriter.
create<LLVM::SqrtOp>(loc, dstType, op.getOperand());
923 template <
typename SPIRVOp>
929 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
931 if (!op.getMemoryAccess()) {
933 *this->getTypeConverter(), 0,
937 auto memoryAccess = *op.getMemoryAccess();
938 switch (memoryAccess) {
939 case spirv::MemoryAccess::Aligned:
941 case spirv::MemoryAccess::Nontemporal:
942 case spirv::MemoryAccess::Volatile: {
944 memoryAccess == spirv::MemoryAccess::Aligned ? *op.getAlignment() : 0;
945 bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal;
946 bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile;
948 *this->getTypeConverter(), alignment,
949 isVolatile, isNonTemporal);
959 template <
typename SPIRVOp>
965 matchAndRewrite(SPIRVOp notOp,
typename SPIRVOp::Adaptor adaptor,
967 auto srcType = notOp.getType();
968 auto dstType = this->getTypeConverter()->convertType(srcType);
975 isa<VectorType>(srcType)
976 ? rewriter.
create<LLVM::ConstantOp>(
979 : rewriter.
create<LLVM::ConstantOp>(loc, dstType, minusOne);
980 rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType,
981 notOp.getOperand(), mask);
987 template <
typename SPIRVOp>
993 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
1005 matchAndRewrite(spirv::ReturnOp returnOp, OpAdaptor adaptor,
1018 matchAndRewrite(spirv::ReturnValueOp returnValueOp, OpAdaptor adaptor,
1021 adaptor.getOperands());
1030 bool convergent =
true) {
1031 auto func = dyn_cast_or_null<LLVM::LLVMFuncOp>(
1037 func = b.create<LLVM::LLVMFuncOp>(
1038 symbolTable->
getLoc(), name,
1040 func.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
1041 func.setConvergent(convergent);
1042 func.setNoUnwind(
true);
1043 func.setWillReturn(
true);
1048 LLVM::LLVMFuncOp func,
1050 auto call = builder.
create<LLVM::CallOp>(loc, func, args);
1051 call.setCConv(func.getCConv());
1052 call.setConvergentAttr(func.getConvergentAttr());
1053 call.setNoUnwindAttr(func.getNoUnwindAttr());
1054 call.setWillReturnAttr(func.getWillReturnAttr());
1058 template <
typename BarrierOpTy>
1065 static constexpr StringRef getFuncName();
1068 matchAndRewrite(BarrierOpTy controlBarrierOp, OpAdaptor adaptor,
1070 constexpr StringRef funcName = getFuncName();
1072 controlBarrierOp->template getParentWithTrait<OpTrait::SymbolTable>();
1076 Type voidTy = rewriter.
getType<LLVM::LLVMVoidType>();
1077 LLVM::LLVMFuncOp func =
1080 Location loc = controlBarrierOp->getLoc();
1081 Value execution = rewriter.
create<LLVM::ConstantOp>(
1082 loc, i32,
static_cast<int32_t
>(adaptor.getExecutionScope()));
1084 loc, i32,
static_cast<int32_t
>(adaptor.getMemoryScope()));
1085 Value semantics = rewriter.
create<LLVM::ConstantOp>(
1086 loc, i32,
static_cast<int32_t
>(adaptor.getMemorySemantics()));
1089 {execution, memory, semantics});
1091 rewriter.
replaceOp(controlBarrierOp, call);
1098 StringRef getTypeMangling(
Type type,
bool isSigned) {
1100 .Case<Float16Type>([](
auto) {
return "Dh"; })
1101 .Case<Float32Type>([](
auto) {
return "f"; })
1102 .Case<Float64Type>([](
auto) {
return "d"; })
1103 .Case<IntegerType>([isSigned](IntegerType intTy) {
1104 switch (intTy.getWidth()) {
1108 return (isSigned) ?
"a" :
"c";
1110 return (isSigned) ?
"s" :
"t";
1112 return (isSigned) ?
"i" :
"j";
1114 return (isSigned) ?
"l" :
"m";
1116 llvm_unreachable(
"Unsupported integer width");
1120 llvm_unreachable(
"No mangling defined");
1125 template <
typename ReduceOp>
1126 constexpr StringLiteral getGroupFuncName();
1129 constexpr StringLiteral getGroupFuncName<spirv::GroupIAddOp>() {
1130 return "_Z17__spirv_GroupIAddii";
1133 constexpr StringLiteral getGroupFuncName<spirv::GroupFAddOp>() {
1134 return "_Z17__spirv_GroupFAddii";
1137 constexpr StringLiteral getGroupFuncName<spirv::GroupSMinOp>() {
1138 return "_Z17__spirv_GroupSMinii";
1141 constexpr StringLiteral getGroupFuncName<spirv::GroupUMinOp>() {
1142 return "_Z17__spirv_GroupUMinii";
1145 constexpr StringLiteral getGroupFuncName<spirv::GroupFMinOp>() {
1146 return "_Z17__spirv_GroupFMinii";
1149 constexpr StringLiteral getGroupFuncName<spirv::GroupSMaxOp>() {
1150 return "_Z17__spirv_GroupSMaxii";
1153 constexpr StringLiteral getGroupFuncName<spirv::GroupUMaxOp>() {
1154 return "_Z17__spirv_GroupUMaxii";
1157 constexpr StringLiteral getGroupFuncName<spirv::GroupFMaxOp>() {
1158 return "_Z17__spirv_GroupFMaxii";
1161 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIAddOp>() {
1162 return "_Z27__spirv_GroupNonUniformIAddii";
1165 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFAddOp>() {
1166 return "_Z27__spirv_GroupNonUniformFAddii";
1169 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIMulOp>() {
1170 return "_Z27__spirv_GroupNonUniformIMulii";
1173 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMulOp>() {
1174 return "_Z27__spirv_GroupNonUniformFMulii";
1177 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMinOp>() {
1178 return "_Z27__spirv_GroupNonUniformSMinii";
1181 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMinOp>() {
1182 return "_Z27__spirv_GroupNonUniformUMinii";
1185 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMinOp>() {
1186 return "_Z27__spirv_GroupNonUniformFMinii";
1189 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMaxOp>() {
1190 return "_Z27__spirv_GroupNonUniformSMaxii";
1193 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMaxOp>() {
1194 return "_Z27__spirv_GroupNonUniformUMaxii";
1197 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMaxOp>() {
1198 return "_Z27__spirv_GroupNonUniformFMaxii";
1201 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseAndOp>() {
1202 return "_Z33__spirv_GroupNonUniformBitwiseAndii";
1205 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseOrOp>() {
1206 return "_Z32__spirv_GroupNonUniformBitwiseOrii";
1209 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseXorOp>() {
1210 return "_Z33__spirv_GroupNonUniformBitwiseXorii";
1213 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalAndOp>() {
1214 return "_Z33__spirv_GroupNonUniformLogicalAndii";
1217 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalOrOp>() {
1218 return "_Z32__spirv_GroupNonUniformLogicalOrii";
1221 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalXorOp>() {
1222 return "_Z33__spirv_GroupNonUniformLogicalXorii";
1226 template <
typename ReduceOp,
bool Signed = false,
bool NonUniform = false>
1232 matchAndRewrite(ReduceOp op,
typename ReduceOp::Adaptor adaptor,
1235 Type retTy = op.getResult().getType();
1240 funcName += getTypeMangling(retTy,
false);
1244 if constexpr (NonUniform) {
1245 if (adaptor.getClusterSize()) {
1247 paramTypes.push_back(i32Ty);
1252 op->template getParentWithTrait<OpTrait::SymbolTable>();
1254 LLVM::LLVMFuncOp func =
1259 loc, i32Ty,
static_cast<int32_t
>(adaptor.getExecutionScope()));
1260 Value groupOp = rewriter.
create<LLVM::ConstantOp>(
1261 loc, i32Ty,
static_cast<int32_t
>(adaptor.getGroupOperation()));
1263 operands.append(adaptor.getOperands().begin(), adaptor.getOperands().end());
1273 ControlBarrierPattern<spirv::ControlBarrierOp>::getFuncName() {
1274 return "_Z22__spirv_ControlBarrieriii";
1279 ControlBarrierPattern<spirv::INTELControlBarrierArriveOp>::getFuncName() {
1280 return "_Z33__spirv_ControlBarrierArriveINTELiii";
1285 ControlBarrierPattern<spirv::INTELControlBarrierWaitOp>::getFuncName() {
1286 return "_Z31__spirv_ControlBarrierWaitINTELiii";
1342 matchAndRewrite(spirv::LoopOp loopOp, OpAdaptor adaptor,
1349 if (loopOp.getBody().empty()) {
1364 Block *entryBlock = loopOp.getEntryBlock();
1366 auto brOp = dyn_cast<spirv::BranchOp>(entryBlock->
getOperations().front());
1369 Block *headerBlock = loopOp.getHeaderBlock();
1371 rewriter.
create<LLVM::BrOp>(loc, brOp.getBlockArguments(), headerBlock);
1375 Block *mergeBlock = loopOp.getMergeBlock();
1379 rewriter.
create<LLVM::BrOp>(loc, terminatorOperands, endBlock);
1395 matchAndRewrite(spirv::SelectionOp op, OpAdaptor adaptor,
1407 if (op.getBody().getBlocks().size() <= 2) {
1419 auto *continueBlock = rewriter.
splitBlock(currentBlock, position);
1425 auto *headerBlock = op.getHeaderBlock();
1427 auto condBrOp = dyn_cast<spirv::BranchConditionalOp>(
1433 auto *mergeBlock = op.getMergeBlock();
1437 rewriter.
create<LLVM::BrOp>(loc, terminatorOperands, continueBlock);
1440 Block *trueBlock = condBrOp.getTrueBlock();
1441 Block *falseBlock = condBrOp.getFalseBlock();
1443 rewriter.
create<LLVM::CondBrOp>(loc, condBrOp.getCondition(), trueBlock,
1444 condBrOp.getTrueTargetOperands(),
1446 condBrOp.getFalseTargetOperands());
1450 rewriter.
replaceOp(op, continueBlock->getArguments());
1459 template <
typename SPIRVOp,
typename LLVMOp>
1465 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
1468 auto dstType = this->getTypeConverter()->convertType(op.getType());
1472 Type op1Type = op.getOperand1().getType();
1473 Type op2Type = op.getOperand2().getType();
1475 if (op1Type == op2Type) {
1476 rewriter.template replaceOpWithNewOp<LLVMOp>(op, dstType,
1477 adaptor.getOperands());
1481 std::optional<uint64_t> dstTypeWidth =
1483 std::optional<uint64_t> op2TypeWidth =
1486 if (!dstTypeWidth || !op2TypeWidth)
1491 if (op2TypeWidth < dstTypeWidth) {
1493 extended = rewriter.template create<LLVM::ZExtOp>(
1494 loc, dstType, adaptor.getOperand2());
1496 extended = rewriter.template create<LLVM::SExtOp>(
1497 loc, dstType, adaptor.getOperand2());
1499 }
else if (op2TypeWidth == dstTypeWidth) {
1500 extended = adaptor.getOperand2();
1505 Value result = rewriter.template create<LLVMOp>(
1506 loc, dstType, adaptor.getOperand1(), extended);
1517 matchAndRewrite(spirv::GLTanOp tanOp, OpAdaptor adaptor,
1519 auto dstType = getTypeConverter()->convertType(tanOp.getType());
1524 Value sin = rewriter.
create<LLVM::SinOp>(loc, dstType, tanOp.getOperand());
1525 Value cos = rewriter.
create<LLVM::CosOp>(loc, dstType, tanOp.getOperand());
1542 matchAndRewrite(spirv::GLTanhOp tanhOp, OpAdaptor adaptor,
1544 auto srcType = tanhOp.getType();
1545 auto dstType = getTypeConverter()->convertType(srcType);
1552 rewriter.
create<LLVM::FMulOp>(loc, dstType, two, tanhOp.getOperand());
1553 Value exponential = rewriter.
create<LLVM::ExpOp>(loc, dstType, multiplied);
1556 rewriter.
create<LLVM::FSubOp>(loc, dstType, exponential, one);
1558 rewriter.
create<LLVM::FAddOp>(loc, dstType, exponential, one);
1570 matchAndRewrite(spirv::VariableOp varOp, OpAdaptor adaptor,
1572 auto srcType = varOp.getType();
1574 auto pointerTo = cast<spirv::PointerType>(srcType).getPointeeType();
1575 auto init = varOp.getInitializer();
1576 if (init && !pointerTo.isIntOrFloat() && !isa<VectorType>(pointerTo))
1579 auto dstType = getTypeConverter()->convertType(srcType);
1586 auto elementType = getTypeConverter()->convertType(pointerTo);
1593 auto elementType = getTypeConverter()->convertType(pointerTo);
1597 rewriter.
create<LLVM::AllocaOp>(loc, dstType, elementType, size);
1598 rewriter.
create<LLVM::StoreOp>(loc, adaptor.getInitializer(), allocated);
1608 class BitcastConversionPattern
1614 matchAndRewrite(spirv::BitcastOp bitcastOp, OpAdaptor adaptor,
1616 auto dstType = getTypeConverter()->convertType(bitcastOp.getType());
1621 if (isa<LLVM::LLVMPointerType>(dstType)) {
1622 rewriter.
replaceOp(bitcastOp, adaptor.getOperand());
1627 bitcastOp, dstType, adaptor.getOperands(), bitcastOp->getAttrs());
1641 matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor,
1646 auto funcType = funcOp.getFunctionType();
1648 funcType.getNumInputs());
1650 ->convertFunctionSignature(
1652 false, signatureConverter);
1658 StringRef name = funcOp.getName();
1659 auto newFuncOp = rewriter.
create<LLVM::LLVMFuncOp>(loc, name, llvmType);
1663 switch (funcOp.getFunctionControl()) {
1664 case spirv::FunctionControl::Inline:
1665 newFuncOp.setAlwaysInline(
true);
1667 case spirv::FunctionControl::DontInline:
1668 newFuncOp.setNoInline(
true);
1671 #define DISPATCH(functionControl, llvmAttr) \
1672 case functionControl: \
1673 newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr})); \
1676 DISPATCH(spirv::FunctionControl::Pure,
1678 DISPATCH(spirv::FunctionControl::Const,
1692 &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter))) {
1709 matchAndRewrite(spirv::ModuleOp spvModuleOp, OpAdaptor adaptor,
1713 rewriter.
create<ModuleOp>(spvModuleOp.getLoc(), spvModuleOp.getName());
1717 rewriter.
eraseBlock(&newModuleOp.getBodyRegion().back());
1718 rewriter.
eraseOp(spvModuleOp);
1727 class VectorShufflePattern
1732 matchAndRewrite(spirv::VectorShuffleOp op, OpAdaptor adaptor,
1735 auto components = adaptor.getComponents();
1736 auto vector1 = adaptor.getVector1();
1737 auto vector2 = adaptor.getVector2();
1738 int vector1Size = cast<VectorType>(vector1.getType()).getNumElements();
1739 int vector2Size = cast<VectorType>(vector2.getType()).getNumElements();
1740 if (vector1Size == vector2Size) {
1742 op, vector1, vector2,
1743 LLVM::convertArrayToIndices<int32_t>(components));
1747 auto dstType = getTypeConverter()->convertType(op.getType());
1750 auto scalarType = cast<VectorType>(dstType).getElementType();
1751 auto componentsArray = components.getValue();
1754 Value targetOp = rewriter.
create<LLVM::PoisonOp>(loc, dstType);
1755 for (
unsigned i = 0; i < componentsArray.size(); i++) {
1756 if (!isa<IntegerAttr>(componentsArray[i]))
1757 return op.
emitError(
"unable to support non-constant component");
1759 int indexVal = cast<IntegerAttr>(componentsArray[i]).getInt();
1764 Value baseVector = vector1;
1765 if (indexVal >= vector1Size) {
1766 offsetVal = vector1Size;
1767 baseVector = vector2;
1770 Value dstIndex = rewriter.
create<LLVM::ConstantOp>(
1776 auto extractOp = rewriter.
create<LLVM::ExtractElementOp>(
1777 loc, scalarType, baseVector, index);
1778 targetOp = rewriter.
create<LLVM::InsertElementOp>(loc, dstType, targetOp,
1779 extractOp, dstIndex);
1792 spirv::ClientAPI clientAPI) {
1809 spirv::ClientAPI clientAPI) {
1812 DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>,
1813 DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>,
1814 DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>,
1815 DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>,
1816 DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>,
1817 DirectConversionPattern<spirv::FMulOp, LLVM::FMulOp>,
1818 DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>,
1819 DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>,
1820 DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>,
1821 DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>,
1822 DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>,
1823 DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>,
1824 DirectConversionPattern<spirv::UModOp, LLVM::URemOp>,
1827 BitFieldInsertPattern, BitFieldUExtractPattern, BitFieldSExtractPattern,
1828 DirectConversionPattern<spirv::BitCountOp, LLVM::CtPopOp>,
1829 DirectConversionPattern<spirv::BitReverseOp, LLVM::BitReverseOp>,
1830 DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>,
1831 DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>,
1832 DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>,
1833 NotPattern<spirv::NotOp>,
1836 BitcastConversionPattern,
1837 DirectConversionPattern<spirv::ConvertFToSOp, LLVM::FPToSIOp>,
1838 DirectConversionPattern<spirv::ConvertFToUOp, LLVM::FPToUIOp>,
1839 DirectConversionPattern<spirv::ConvertSToFOp, LLVM::SIToFPOp>,
1840 DirectConversionPattern<spirv::ConvertUToFOp, LLVM::UIToFPOp>,
1841 IndirectCastPattern<spirv::FConvertOp, LLVM::FPExtOp, LLVM::FPTruncOp>,
1842 IndirectCastPattern<spirv::SConvertOp, LLVM::SExtOp, LLVM::TruncOp>,
1843 IndirectCastPattern<spirv::UConvertOp, LLVM::ZExtOp, LLVM::TruncOp>,
1846 IComparePattern<spirv::IEqualOp, LLVM::ICmpPredicate::eq>,
1847 IComparePattern<spirv::INotEqualOp, LLVM::ICmpPredicate::ne>,
1848 FComparePattern<spirv::FOrdEqualOp, LLVM::FCmpPredicate::oeq>,
1849 FComparePattern<spirv::FOrdGreaterThanOp, LLVM::FCmpPredicate::ogt>,
1850 FComparePattern<spirv::FOrdGreaterThanEqualOp, LLVM::FCmpPredicate::oge>,
1851 FComparePattern<spirv::FOrdLessThanEqualOp, LLVM::FCmpPredicate::ole>,
1852 FComparePattern<spirv::FOrdLessThanOp, LLVM::FCmpPredicate::olt>,
1853 FComparePattern<spirv::FOrdNotEqualOp, LLVM::FCmpPredicate::one>,
1854 FComparePattern<spirv::FUnordEqualOp, LLVM::FCmpPredicate::ueq>,
1855 FComparePattern<spirv::FUnordGreaterThanOp, LLVM::FCmpPredicate::ugt>,
1856 FComparePattern<spirv::FUnordGreaterThanEqualOp,
1857 LLVM::FCmpPredicate::uge>,
1858 FComparePattern<spirv::FUnordLessThanEqualOp, LLVM::FCmpPredicate::ule>,
1859 FComparePattern<spirv::FUnordLessThanOp, LLVM::FCmpPredicate::ult>,
1860 FComparePattern<spirv::FUnordNotEqualOp, LLVM::FCmpPredicate::une>,
1861 IComparePattern<spirv::SGreaterThanOp, LLVM::ICmpPredicate::sgt>,
1862 IComparePattern<spirv::SGreaterThanEqualOp, LLVM::ICmpPredicate::sge>,
1863 IComparePattern<spirv::SLessThanEqualOp, LLVM::ICmpPredicate::sle>,
1864 IComparePattern<spirv::SLessThanOp, LLVM::ICmpPredicate::slt>,
1865 IComparePattern<spirv::UGreaterThanOp, LLVM::ICmpPredicate::ugt>,
1866 IComparePattern<spirv::UGreaterThanEqualOp, LLVM::ICmpPredicate::uge>,
1867 IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>,
1868 IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>,
1871 ConstantScalarAndVectorPattern,
1874 BranchConversionPattern, BranchConditionalConversionPattern,
1875 FunctionCallPattern, LoopPattern, SelectionPattern,
1876 ErasePattern<spirv::MergeOp>,
1879 ErasePattern<spirv::EntryPointOp>, ExecutionModePattern,
1882 DirectConversionPattern<spirv::GLCeilOp, LLVM::FCeilOp>,
1883 DirectConversionPattern<spirv::GLCosOp, LLVM::CosOp>,
1884 DirectConversionPattern<spirv::GLExpOp, LLVM::ExpOp>,
1885 DirectConversionPattern<spirv::GLFAbsOp, LLVM::FAbsOp>,
1886 DirectConversionPattern<spirv::GLFloorOp, LLVM::FFloorOp>,
1887 DirectConversionPattern<spirv::GLFMaxOp, LLVM::MaxNumOp>,
1888 DirectConversionPattern<spirv::GLFMinOp, LLVM::MinNumOp>,
1889 DirectConversionPattern<spirv::GLLogOp, LLVM::LogOp>,
1890 DirectConversionPattern<spirv::GLSinOp, LLVM::SinOp>,
1891 DirectConversionPattern<spirv::GLSMaxOp, LLVM::SMaxOp>,
1892 DirectConversionPattern<spirv::GLSMinOp, LLVM::SMinOp>,
1893 DirectConversionPattern<spirv::GLSqrtOp, LLVM::SqrtOp>,
1894 InverseSqrtPattern, TanPattern, TanhPattern,
1897 DirectConversionPattern<spirv::LogicalAndOp, LLVM::AndOp>,
1898 DirectConversionPattern<spirv::LogicalOrOp, LLVM::OrOp>,
1899 IComparePattern<spirv::LogicalEqualOp, LLVM::ICmpPredicate::eq>,
1900 IComparePattern<spirv::LogicalNotEqualOp, LLVM::ICmpPredicate::ne>,
1901 NotPattern<spirv::LogicalNotOp>,
1904 AccessChainPattern, AddressOfPattern, LoadStorePattern<spirv::LoadOp>,
1905 LoadStorePattern<spirv::StoreOp>, VariablePattern,
1908 CompositeExtractPattern, CompositeInsertPattern,
1909 DirectConversionPattern<spirv::SelectOp, LLVM::SelectOp>,
1910 DirectConversionPattern<spirv::UndefOp, LLVM::UndefOp>,
1911 VectorShufflePattern,
1914 ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>,
1915 ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>,
1916 ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>,
1919 ReturnPattern, ReturnValuePattern,
1922 ControlBarrierPattern<spirv::ControlBarrierOp>,
1923 ControlBarrierPattern<spirv::INTELControlBarrierArriveOp>,
1924 ControlBarrierPattern<spirv::INTELControlBarrierWaitOp>,
1927 GroupReducePattern<spirv::GroupIAddOp>,
1928 GroupReducePattern<spirv::GroupFAddOp>,
1929 GroupReducePattern<spirv::GroupFMinOp>,
1930 GroupReducePattern<spirv::GroupUMinOp>,
1931 GroupReducePattern<spirv::GroupSMinOp,
true>,
1932 GroupReducePattern<spirv::GroupFMaxOp>,
1933 GroupReducePattern<spirv::GroupUMaxOp>,
1934 GroupReducePattern<spirv::GroupSMaxOp,
true>,
1935 GroupReducePattern<spirv::GroupNonUniformIAddOp,
false,
1937 GroupReducePattern<spirv::GroupNonUniformFAddOp,
false,
1939 GroupReducePattern<spirv::GroupNonUniformIMulOp,
false,
1941 GroupReducePattern<spirv::GroupNonUniformFMulOp,
false,
1943 GroupReducePattern<spirv::GroupNonUniformSMinOp,
true,
1945 GroupReducePattern<spirv::GroupNonUniformUMinOp,
false,
1947 GroupReducePattern<spirv::GroupNonUniformFMinOp,
false,
1949 GroupReducePattern<spirv::GroupNonUniformSMaxOp,
true,
1951 GroupReducePattern<spirv::GroupNonUniformUMaxOp,
false,
1953 GroupReducePattern<spirv::GroupNonUniformFMaxOp,
false,
1955 GroupReducePattern<spirv::GroupNonUniformBitwiseAndOp,
false,
1957 GroupReducePattern<spirv::GroupNonUniformBitwiseOrOp,
false,
1959 GroupReducePattern<spirv::GroupNonUniformBitwiseXorOp,
false,
1961 GroupReducePattern<spirv::GroupNonUniformLogicalAndOp,
false,
1963 GroupReducePattern<spirv::GroupNonUniformLogicalOrOp,
false,
1965 GroupReducePattern<spirv::GroupNonUniformLogicalXorOp,
false,
1980 patterns.add<ModuleConversionPattern>(
patterns.getContext(), typeConverter);
1991 auto spvModules = module.getOps<spirv::ModuleOp>();
1992 for (
auto spvModule : spvModules) {
1993 spvModule.walk([&](spirv::GlobalVariableOp op) {
1994 IntegerAttr descriptorSet =
1996 IntegerAttr binding = op->getAttrOfType<IntegerAttr>(
kBinding);
1999 if (descriptorSet && binding) {
2002 auto moduleAndName =
2003 spvModule.getName().has_value()
2004 ? spvModule.getName()->str() +
"_" + op.getSymName().str()
2005 : op.getSymName().str();
2007 llvm::formatv(
"{0}_descriptor_set{1}_binding{2}", moduleAndName,
2008 std::to_string(descriptorSet.getInt()),
2009 std::to_string(binding.getInt()));
2015 op.emitError(
"unable to replace all symbol uses for ") << name;
static LLVM::CallOp createSPIRVBuiltinCall(Location loc, ConversionPatternRewriter &rewriter, LLVM::LLVMFuncOp func, ValueRange args)
static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable, StringRef name, ArrayRef< Type > paramTypes, Type resultType, bool isMemNone, bool isConvergent)
static MLIRContext * getContext(OpFoldResult val)
static Value optionallyTruncateOrExtend(Location loc, Value value, Type llvmType, PatternRewriter &rewriter)
Utility function for bitfield ops:
static Value createFPConstant(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter, double value)
Creates llvm.mlir.constant with a floating-point scalar or vector value.
static constexpr StringRef kDescriptorSet
static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter, unsigned value)
Creates LLVM dialect constant with the given value.
static Type convertPointerType(spirv::PointerType type, const TypeConverter &converter, spirv::ClientAPI clientAPI)
Converts SPIR-V pointer type to LLVM pointer.
static Value processCountOrOffset(Location loc, Value value, Type srcType, Type dstType, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Utility function for bitfield ops: BitFieldInsert, BitFieldSExtract and BitFieldUExtract.
static unsigned getBitWidth(Type type)
Returns the bit width of integer, float or vector of float or integer values.
static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter, const TypeConverter &typeConverter, unsigned alignment, bool isVolatile, bool isNonTemporal)
Utility for spirv.Load and spirv.Store conversion.
static Type convertStructTypePacked(spirv::StructType type, const TypeConverter &converter)
Converts SPIR-V struct with no offset to packed LLVM struct.
static std::optional< Type > convertRuntimeArrayType(spirv::RuntimeArrayType type, TypeConverter &converter)
Converts SPIR-V runtime array to LLVM array.
static bool isSignedIntegerOrVector(Type type)
Returns true if the given type is a signed integer or vector type.
static bool isUnsignedIntegerOrVector(Type type)
Returns true if the given type is an unsigned integer or vector type.
static std::optional< Type > convertArrayType(spirv::ArrayType type, TypeConverter &converter)
Converts SPIR-V array type to LLVM array.
static constexpr StringRef kBinding
Hook for descriptor set and binding number encoding.
static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder)
Creates IntegerAttribute with all bits set for given type.
static Value optionallyBroadcast(Location loc, Value value, Type srcType, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value. If srcType is a scalar, the value remains unchanged.
static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter)
Creates llvm.mlir.constant with all bits set for the given type.
static unsigned getLLVMTypeBitWidth(Type type)
Returns the bit width of LLVMType integer or vector.
#define DISPATCH(functionControl, llvmAttr)
static std::optional< uint64_t > getIntegerOrVectorElementWidth(Type type)
Returns the width of an integer or of the element type of an integer vector, if applicable.
static Type convertStructTypeWithOffset(spirv::StructType type, const TypeConverter &converter)
Converts SPIR-V struct with a regular (according to VulkanLayoutUtils) offset to LLVM struct.
static Type convertStructType(spirv::StructType type, const TypeConverter &converter)
Converts SPIR-V struct to LLVM struct.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
OpListType::iterator iterator
Operation * getTerminator()
Get the terminator operation of this block.
OpListType & getOperations()
BlockArgListType getArguments()
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI32IntegerAttr(int32_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * getBlock() const
Returns the current block of the builder.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
operand_range getOperands()
Returns an iterator on the underlying Value's.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, StringAttr newSymbol, Operation *from)
Attempt to replace all uses of the given symbol 'oldSymbol' with the provided symbol 'newSymbol' that...
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
static void setSymbolName(Operation *symbol, StringAttr name)
Sets the name of the given symbol operation.
This class provides all of the information necessary to convert a type signature.
void addConversion(FnT &&callback)
Register a conversion function.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static spirv::StructType decorateType(spirv::StructType structType)
Returns a new StructType with layout decoration.
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Builder from ArrayRef<T>.
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
unsigned getNumElements() const
StorageClass getStorageClass() const
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
TypeRange getElementTypes() const
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
SmallVector< IntT > convertArrayToIndices(ArrayRef< Attribute > attrs)
Convert an array of integer attributes to a vector of integers that can be used as indices in LLVM op...
Include the generated interface declarations.
unsigned storageClassToAddressSpace(spirv::ClientAPI clientAPI, spirv::StorageClass storageClass)
void populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter, spirv::ClientAPI clientAPIForAddressSpaceMapping=spirv::ClientAPI::Unknown)
Populates type conversions with additional SPIR-V types.
void populateSPIRVToLLVMFunctionConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given list with patterns for function conversion from SPIR-V to LLVM.
const FrozenRewritePatternSet & patterns
void populateSPIRVToLLVMConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, spirv::ClientAPI clientAPIForAddressSpaceMapping=spirv::ClientAPI::Unknown)
Populates the given list with patterns that convert from SPIR-V to LLVM.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void encodeBindAttribute(ModuleOp module)
Encodes global variable's descriptor set and binding into its name if they both exist.
void populateSPIRVToLLVMModuleConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given patterns for module conversion from SPIR-V to LLVM.