25 #include "llvm/Support/Debug.h"
26 #include "llvm/Support/FormatVariadic.h"
28 #define DEBUG_TYPE "spirv-to-llvm-pattern"
40 if (
auto vecType = dyn_cast<VectorType>(type))
41 return vecType.getElementType().isSignedInteger();
49 if (
auto vecType = dyn_cast<VectorType>(type))
50 return vecType.getElementType().isUnsignedInteger();
57 if (
auto intType = dyn_cast<IntegerType>(type))
58 return intType.getWidth();
59 if (
auto vecType = dyn_cast<VectorType>(type))
60 if (
auto intType = dyn_cast<IntegerType>(vecType.getElementType()))
61 return intType.getWidth();
68 "bitwidth is not supported for this type");
71 auto vecType = dyn_cast<VectorType>(type);
72 auto elementType = vecType.getElementType();
73 assert(elementType.isIntOrFloat() &&
74 "only integers and floats have a bitwidth");
75 return elementType.getIntOrFloatBitWidth();
88 if (
auto vecType = dyn_cast<VectorType>(type)) {
89 auto integerType = cast<IntegerType>(vecType.getElementType());
92 auto integerType = cast<IntegerType>(type);
99 if (isa<VectorType>(srcType)) {
100 return rewriter.
create<LLVM::ConstantOp>(
105 return rewriter.
create<LLVM::ConstantOp>(
112 if (
auto vecType = dyn_cast<VectorType>(srcType)) {
113 auto floatType = cast<FloatType>(vecType.getElementType());
114 return rewriter.
create<LLVM::ConstantOp>(
119 auto floatType = cast<FloatType>(srcType);
120 return rewriter.
create<LLVM::ConstantOp>(
133 auto srcType = value.
getType();
139 if (valueBitWidth < targetBitWidth)
140 return rewriter.
create<LLVM::ZExtOp>(loc, llvmType, value);
145 if (valueBitWidth > targetBitWidth)
146 return rewriter.
create<LLVM::TruncOp>(loc, llvmType, value);
155 auto llvmVectorType = typeConverter.
convertType(vectorType);
157 Value broadcasted = rewriter.
create<LLVM::UndefOp>(loc, llvmVectorType);
158 for (
unsigned i = 0; i < numElements; ++i) {
159 auto index = rewriter.
create<LLVM::ConstantOp>(
161 broadcasted = rewriter.
create<LLVM::InsertElementOp>(
162 loc, llvmVectorType, broadcasted, toBroadcast, index);
171 if (
auto vectorType = dyn_cast<VectorType>(srcType)) {
172 unsigned numElements = vectorType.getNumElements();
173 return broadcast(loc, value, numElements, typeConverter, rewriter);
223 return rewriter.
create<LLVM::ConstantOp>(
232 unsigned alignment,
bool isVolatile,
233 bool isNonTemporal) {
234 if (
auto loadOp = dyn_cast<spirv::LoadOp>(op)) {
235 auto dstType = typeConverter.
convertType(loadOp.getType());
239 loadOp, dstType, spirv::LoadOpAdaptor(operands).getPtr(), alignment,
240 isVolatile, isNonTemporal);
243 auto storeOp = cast<spirv::StoreOp>(op);
244 spirv::StoreOpAdaptor adaptor(operands);
246 adaptor.getPtr(), alignment,
247 isVolatile, isNonTemporal);
262 auto sizeInBytes = cast<spirv::SPIRVType>(elementType).getSizeInBytes();
263 if (stride != 0 && (!sizeInBytes || *sizeInBytes != stride))
266 auto llvmElementType = converter.
convertType(elementType);
275 spirv::ClientAPI clientAPI) {
276 unsigned addressSpace =
298 if (!memberDecorations.empty())
316 matchAndRewrite(spirv::AccessChainOp op, OpAdaptor adaptor,
319 getTypeConverter()->convertType(op.getComponentPtr().getType());
323 auto indices = llvm::to_vector<4>(adaptor.getIndices());
324 Type indexType = op.getIndices().front().getType();
325 auto llvmIndexType = getTypeConverter()->convertType(indexType);
330 indices.insert(indices.begin(), zero);
332 auto elementType = getTypeConverter()->convertType(
333 cast<spirv::PointerType>(op.getBasePtr().getType()).getPointeeType());
337 adaptor.getBasePtr(), indices);
347 matchAndRewrite(spirv::AddressOfOp op, OpAdaptor adaptor,
349 auto dstType = getTypeConverter()->convertType(op.getPointer().getType());
358 class BitFieldInsertPattern
364 matchAndRewrite(spirv::BitFieldInsertOp op, OpAdaptor adaptor,
366 auto srcType = op.getType();
367 auto dstType = getTypeConverter()->convertType(srcType);
374 *getTypeConverter(), rewriter);
376 *getTypeConverter(), rewriter);
380 Value maskShiftedByCount =
381 rewriter.
create<LLVM::ShlOp>(loc, dstType, minusOne, count);
382 Value negated = rewriter.
create<LLVM::XOrOp>(loc, dstType,
383 maskShiftedByCount, minusOne);
384 Value maskShiftedByCountAndOffset =
385 rewriter.
create<LLVM::ShlOp>(loc, dstType, negated, offset);
387 loc, dstType, maskShiftedByCountAndOffset, minusOne);
392 rewriter.
create<LLVM::AndOp>(loc, dstType, op.getBase(), mask);
393 Value insertShiftedByOffset =
394 rewriter.
create<LLVM::ShlOp>(loc, dstType, op.getInsert(), offset);
396 insertShiftedByOffset);
402 class ConstantScalarAndVectorPattern
408 matchAndRewrite(spirv::ConstantOp constOp, OpAdaptor adaptor,
410 auto srcType = constOp.getType();
411 if (!isa<VectorType>(srcType) && !srcType.isIntOrFloat())
414 auto dstType = getTypeConverter()->convertType(srcType);
427 if (isa<VectorType>(srcType)) {
428 auto dstElementsAttr = cast<DenseIntElementsAttr>(constOp.getValue());
431 dstElementsAttr.mapValues(
432 signlessType, [&](
const APInt &value) {
return value; }));
435 auto srcAttr = cast<IntegerAttr>(constOp.getValue());
436 auto dstAttr = rewriter.
getIntegerAttr(signlessType, srcAttr.getValue());
441 constOp, dstType, adaptor.getOperands(), constOp->getAttrs());
446 class BitFieldSExtractPattern
452 matchAndRewrite(spirv::BitFieldSExtractOp op, OpAdaptor adaptor,
454 auto srcType = op.getType();
455 auto dstType = getTypeConverter()->convertType(srcType);
462 *getTypeConverter(), rewriter);
464 *getTypeConverter(), rewriter);
467 IntegerType integerType;
468 if (
auto vecType = dyn_cast<VectorType>(srcType))
469 integerType = cast<IntegerType>(vecType.getElementType());
471 integerType = cast<IntegerType>(srcType);
475 isa<VectorType>(srcType)
476 ? rewriter.
create<LLVM::ConstantOp>(
479 : rewriter.
create<LLVM::ConstantOp>(loc, dstType, baseSize);
483 Value countPlusOffset =
484 rewriter.
create<LLVM::AddOp>(loc, dstType, count, offset);
485 Value amountToShiftLeft =
486 rewriter.
create<LLVM::SubOp>(loc, dstType, size, countPlusOffset);
487 Value baseShiftedLeft = rewriter.
create<LLVM::ShlOp>(
488 loc, dstType, op.getBase(), amountToShiftLeft);
491 Value amountToShiftRight =
492 rewriter.
create<LLVM::AddOp>(loc, dstType, offset, amountToShiftLeft);
499 class BitFieldUExtractPattern
505 matchAndRewrite(spirv::BitFieldUExtractOp op, OpAdaptor adaptor,
507 auto srcType = op.getType();
508 auto dstType = getTypeConverter()->convertType(srcType);
515 *getTypeConverter(), rewriter);
517 *getTypeConverter(), rewriter);
521 Value maskShiftedByCount =
522 rewriter.
create<LLVM::ShlOp>(loc, dstType, minusOne, count);
523 Value mask = rewriter.
create<LLVM::XOrOp>(loc, dstType, maskShiftedByCount,
528 rewriter.
create<LLVM::LShrOp>(loc, dstType, op.getBase(), offset);
539 matchAndRewrite(spirv::BranchOp branchOp, OpAdaptor adaptor,
542 branchOp.getTarget());
547 class BranchConditionalConversionPattern
554 matchAndRewrite(spirv::BranchConditionalOp op, OpAdaptor adaptor,
558 if (
auto weights = op.getBranchWeights()) {
560 for (
auto weight : weights->getAsRange<IntegerAttr>())
561 weightValues.push_back(weight.getInt());
566 op, op.getCondition(), op.getTrueBlockArguments(),
567 op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(),
576 class CompositeExtractPattern
582 matchAndRewrite(spirv::CompositeExtractOp op, OpAdaptor adaptor,
584 auto dstType = this->getTypeConverter()->convertType(op.getType());
588 Type containerType = op.getComposite().getType();
589 if (isa<VectorType>(containerType)) {
591 IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]);
594 op, dstType, adaptor.getComposite(), index);
599 op, adaptor.getComposite(),
608 class CompositeInsertPattern
614 matchAndRewrite(spirv::CompositeInsertOp op, OpAdaptor adaptor,
616 auto dstType = this->getTypeConverter()->convertType(op.getType());
620 Type containerType = op.getComposite().getType();
621 if (isa<VectorType>(containerType)) {
623 IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]);
626 op, dstType, adaptor.getComposite(), adaptor.getObject(), index);
631 op, adaptor.getComposite(), adaptor.getObject(),
639 template <
typename SPIRVOp,
typename LLVMOp>
645 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
647 auto dstType = this->getTypeConverter()->convertType(op.getType());
650 rewriter.template replaceOpWithNewOp<LLVMOp>(
651 op, dstType, adaptor.getOperands(), op->
getAttrs());
658 class ExecutionModePattern
664 matchAndRewrite(spirv::ExecutionModeOp op, OpAdaptor adaptor,
670 spirv::ExecutionModeAttr executionModeAttr = op.getExecutionModeAttr();
671 std::string moduleName;
672 if (module.getName().has_value())
673 moduleName =
"_" + module.
getName()->str();
676 std::string executionModeInfoName = llvm::formatv(
677 "__spv_{0}_{1}_execution_mode_info_{2}", moduleName, op.getFn().str(),
678 static_cast<uint32_t
>(executionModeAttr.getValue()));
691 fields.push_back(llvmI32Type);
692 ArrayAttr values = op.getValues();
693 if (!values.empty()) {
695 fields.push_back(arrayType);
700 auto global = rewriter.
create<LLVM::GlobalOp>(
702 LLVM::Linkage::External, executionModeInfoName,
Attribute(),
705 Region ®ion = global.getInitializerRegion();
710 Value structValue = rewriter.
create<LLVM::UndefOp>(loc, structType);
711 Value executionMode = rewriter.
create<LLVM::ConstantOp>(
714 static_cast<uint32_t
>(executionModeAttr.getValue())));
715 structValue = rewriter.
create<LLVM::InsertValueOp>(loc, structValue,
719 for (
unsigned i = 0, e = values.size(); i < e; ++i) {
720 auto attr = values.getValue()[i];
721 Value entry = rewriter.
create<LLVM::ConstantOp>(loc, llvmI32Type, attr);
722 structValue = rewriter.
create<LLVM::InsertValueOp>(
735 class GlobalVariablePattern
738 template <
typename... Args>
739 GlobalVariablePattern(spirv::ClientAPI clientAPI, Args &&...args)
741 std::forward<Args>(args)...),
742 clientAPI(clientAPI) {}
745 matchAndRewrite(spirv::GlobalVariableOp op, OpAdaptor adaptor,
749 if (op.getInitializer())
752 auto srcType = cast<spirv::PointerType>(op.getType());
753 auto dstType = getTypeConverter()->convertType(srcType.getPointeeType());
760 auto storageClass = srcType.getStorageClass();
761 switch (storageClass) {
762 case spirv::StorageClass::Input:
763 case spirv::StorageClass::Private:
764 case spirv::StorageClass::Output:
765 case spirv::StorageClass::StorageBuffer:
766 case spirv::StorageClass::UniformConstant:
775 bool isConstant = (storageClass == spirv::StorageClass::Input) ||
776 (storageClass == spirv::StorageClass::UniformConstant);
782 auto linkage = storageClass == spirv::StorageClass::Private
783 ? LLVM::Linkage::Private
784 : LLVM::Linkage::External;
786 op, dstType, isConstant, linkage, op.getSymName(),
Attribute(),
790 if (op.getLocationAttr())
791 newGlobalOp->setAttr(op.getLocationAttrName(), op.getLocationAttr());
797 spirv::ClientAPI clientAPI;
802 template <
typename SPIRVOp,
typename LLVMExtOp,
typename LLVMTruncOp>
808 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
812 Type toType = op.getType();
814 auto dstType = this->getTypeConverter()->convertType(toType);
819 rewriter.template replaceOpWithNewOp<LLVMExtOp>(op, dstType,
820 adaptor.getOperands());
824 rewriter.template replaceOpWithNewOp<LLVMTruncOp>(op, dstType,
825 adaptor.getOperands());
832 class FunctionCallPattern
838 matchAndRewrite(spirv::FunctionCallOp callOp, OpAdaptor adaptor,
840 if (callOp.getNumResults() == 0) {
842 callOp, std::nullopt, adaptor.getOperands(), callOp->getAttrs());
843 newOp.getProperties().operandSegmentSizes = {
844 static_cast<int32_t
>(adaptor.getOperands().size()), 0};
850 auto dstType = getTypeConverter()->convertType(callOp.getType(0));
854 callOp, dstType, adaptor.getOperands(), callOp->getAttrs());
855 newOp.getProperties().operandSegmentSizes = {
856 static_cast<int32_t
>(adaptor.getOperands().size()), 0};
863 template <
typename SPIRVOp, LLVM::FCmpPredicate predicate>
869 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
872 auto dstType = this->getTypeConverter()->convertType(op.getType());
876 rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>(
877 op, dstType, predicate, op.getOperand1(), op.getOperand2());
883 template <
typename SPIRVOp, LLVM::ICmpPredicate predicate>
889 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
892 auto dstType = this->getTypeConverter()->convertType(op.getType());
896 rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>(
897 op, dstType, predicate, op.getOperand1(), op.getOperand2());
902 class InverseSqrtPattern
908 matchAndRewrite(spirv::GLInverseSqrtOp op, OpAdaptor adaptor,
910 auto srcType = op.getType();
911 auto dstType = getTypeConverter()->convertType(srcType);
924 template <
typename SPIRVOp>
930 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
932 if (!op.getMemoryAccess()) {
934 *this->getTypeConverter(), 0,
938 auto memoryAccess = *op.getMemoryAccess();
939 switch (memoryAccess) {
940 case spirv::MemoryAccess::Aligned:
942 case spirv::MemoryAccess::Nontemporal:
943 case spirv::MemoryAccess::Volatile: {
945 memoryAccess == spirv::MemoryAccess::Aligned ? *op.getAlignment() : 0;
946 bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal;
947 bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile;
949 *this->getTypeConverter(), alignment,
950 isVolatile, isNonTemporal);
960 template <
typename SPIRVOp>
966 matchAndRewrite(SPIRVOp notOp,
typename SPIRVOp::Adaptor adaptor,
968 auto srcType = notOp.getType();
969 auto dstType = this->getTypeConverter()->convertType(srcType);
976 isa<VectorType>(srcType)
977 ? rewriter.
create<LLVM::ConstantOp>(
980 : rewriter.
create<LLVM::ConstantOp>(loc, dstType, minusOne);
981 rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType,
982 notOp.getOperand(), mask);
988 template <
typename SPIRVOp>
994 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
1006 matchAndRewrite(spirv::ReturnOp returnOp, OpAdaptor adaptor,
1019 matchAndRewrite(spirv::ReturnValueOp returnValueOp, OpAdaptor adaptor,
1022 adaptor.getOperands());
1031 auto func = dyn_cast_or_null<LLVM::LLVMFuncOp>(
1037 func = b.create<LLVM::LLVMFuncOp>(
1038 symbolTable->
getLoc(), name,
1040 func.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
1041 func.setConvergent(
true);
1042 func.setNoUnwind(
true);
1043 func.setWillReturn(
true);
1048 LLVM::LLVMFuncOp func,
1050 auto call = builder.
create<LLVM::CallOp>(loc, func, args);
1051 call.setCConv(func.getCConv());
1052 call.setConvergentAttr(func.getConvergentAttr());
1053 call.setNoUnwindAttr(func.getNoUnwindAttr());
1054 call.setWillReturnAttr(func.getWillReturnAttr());
1058 class ControlBarrierPattern
1064 matchAndRewrite(spirv::ControlBarrierOp controlBarrierOp, OpAdaptor adaptor,
1066 constexpr StringLiteral funcName =
"_Z22__spirv_ControlBarrieriii";
1072 Type voidTy = rewriter.
getType<LLVM::LLVMVoidType>();
1073 LLVM::LLVMFuncOp func =
1076 Location loc = controlBarrierOp->getLoc();
1077 Value execution = rewriter.
create<LLVM::ConstantOp>(
1078 loc, i32,
static_cast<int32_t
>(adaptor.getExecutionScope()));
1080 loc, i32,
static_cast<int32_t
>(adaptor.getMemoryScope()));
1081 Value semantics = rewriter.
create<LLVM::ConstantOp>(
1082 loc, i32,
static_cast<int32_t
>(adaptor.getMemorySemantics()));
1085 {execution, memory, semantics});
1087 rewriter.
replaceOp(controlBarrierOp, call);
1145 matchAndRewrite(spirv::LoopOp loopOp, OpAdaptor adaptor,
1161 Block *entryBlock = loopOp.getEntryBlock();
1163 auto brOp = dyn_cast<spirv::BranchOp>(entryBlock->
getOperations().front());
1166 Block *headerBlock = loopOp.getHeaderBlock();
1168 rewriter.
create<LLVM::BrOp>(loc, brOp.getBlockArguments(), headerBlock);
1172 Block *mergeBlock = loopOp.getMergeBlock();
1176 rewriter.
create<LLVM::BrOp>(loc, terminatorOperands, endBlock);
1192 matchAndRewrite(spirv::SelectionOp op, OpAdaptor adaptor,
1204 if (op.getBody().getBlocks().size() <= 2) {
1216 auto *continueBlock = rewriter.
splitBlock(currentBlock, position);
1222 auto *headerBlock = op.getHeaderBlock();
1224 auto condBrOp = dyn_cast<spirv::BranchConditionalOp>(
1231 auto *mergeBlock = op.getMergeBlock();
1235 rewriter.
create<LLVM::BrOp>(loc, terminatorOperands, continueBlock);
1238 Block *trueBlock = condBrOp.getTrueBlock();
1239 Block *falseBlock = condBrOp.getFalseBlock();
1241 rewriter.
create<LLVM::CondBrOp>(loc, condBrOp.getCondition(), trueBlock,
1242 condBrOp.getTrueTargetOperands(),
1244 condBrOp.getFalseTargetOperands());
1247 rewriter.
replaceOp(op, continueBlock->getArguments());
1256 template <
typename SPIRVOp,
typename LLVMOp>
1262 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
1265 auto dstType = this->getTypeConverter()->convertType(op.getType());
1269 Type op1Type = op.getOperand1().getType();
1270 Type op2Type = op.getOperand2().getType();
1272 if (op1Type == op2Type) {
1273 rewriter.template replaceOpWithNewOp<LLVMOp>(op, dstType,
1274 adaptor.getOperands());
1278 std::optional<uint64_t> dstTypeWidth =
1280 std::optional<uint64_t> op2TypeWidth =
1283 if (!dstTypeWidth || !op2TypeWidth)
1288 if (op2TypeWidth < dstTypeWidth) {
1290 extended = rewriter.template create<LLVM::ZExtOp>(
1291 loc, dstType, adaptor.getOperand2());
1293 extended = rewriter.template create<LLVM::SExtOp>(
1294 loc, dstType, adaptor.getOperand2());
1296 }
else if (op2TypeWidth == dstTypeWidth) {
1297 extended = adaptor.getOperand2();
1302 Value result = rewriter.template create<LLVMOp>(
1303 loc, dstType, adaptor.getOperand1(), extended);
1314 matchAndRewrite(spirv::GLTanOp tanOp, OpAdaptor adaptor,
1316 auto dstType = getTypeConverter()->convertType(tanOp.getType());
1321 Value sin = rewriter.
create<LLVM::SinOp>(loc, dstType, tanOp.getOperand());
1322 Value cos = rewriter.
create<LLVM::CosOp>(loc, dstType, tanOp.getOperand());
1339 matchAndRewrite(spirv::GLTanhOp tanhOp, OpAdaptor adaptor,
1341 auto srcType = tanhOp.getType();
1342 auto dstType = getTypeConverter()->convertType(srcType);
1349 rewriter.
create<LLVM::FMulOp>(loc, dstType, two, tanhOp.getOperand());
1350 Value exponential = rewriter.
create<LLVM::ExpOp>(loc, dstType, multiplied);
1353 rewriter.
create<LLVM::FSubOp>(loc, dstType, exponential, one);
1355 rewriter.
create<LLVM::FAddOp>(loc, dstType, exponential, one);
1367 matchAndRewrite(spirv::VariableOp varOp, OpAdaptor adaptor,
1369 auto srcType = varOp.getType();
1371 auto pointerTo = cast<spirv::PointerType>(srcType).getPointeeType();
1372 auto init = varOp.getInitializer();
1373 if (init && !pointerTo.isIntOrFloat() && !isa<VectorType>(pointerTo))
1376 auto dstType = getTypeConverter()->convertType(srcType);
1383 auto elementType = getTypeConverter()->convertType(pointerTo);
1390 auto elementType = getTypeConverter()->convertType(pointerTo);
1394 rewriter.
create<LLVM::AllocaOp>(loc, dstType, elementType, size);
1395 rewriter.
create<LLVM::StoreOp>(loc, adaptor.getInitializer(), allocated);
1405 class BitcastConversionPattern
1411 matchAndRewrite(spirv::BitcastOp bitcastOp, OpAdaptor adaptor,
1413 auto dstType = getTypeConverter()->convertType(bitcastOp.getType());
1418 if (isa<LLVM::LLVMPointerType>(dstType)) {
1419 rewriter.
replaceOp(bitcastOp, adaptor.getOperand());
1424 bitcastOp, dstType, adaptor.getOperands(), bitcastOp->getAttrs());
1438 matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor,
1443 auto funcType = funcOp.getFunctionType();
1445 funcType.getNumInputs());
1447 ->convertFunctionSignature(
1449 false, signatureConverter);
1455 StringRef name = funcOp.getName();
1456 auto newFuncOp = rewriter.
create<LLVM::LLVMFuncOp>(loc, name, llvmType);
1460 switch (funcOp.getFunctionControl()) {
1461 case spirv::FunctionControl::Inline:
1462 newFuncOp.setAlwaysInline(
true);
1464 case spirv::FunctionControl::DontInline:
1465 newFuncOp.setNoInline(
true);
1468 #define DISPATCH(functionControl, llvmAttr) \
1469 case functionControl: \
1470 newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr})); \
1473 DISPATCH(spirv::FunctionControl::Pure,
1475 DISPATCH(spirv::FunctionControl::Const,
1489 &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter))) {
1506 matchAndRewrite(spirv::ModuleOp spvModuleOp, OpAdaptor adaptor,
1510 rewriter.
create<ModuleOp>(spvModuleOp.getLoc(), spvModuleOp.getName());
1514 rewriter.
eraseBlock(&newModuleOp.getBodyRegion().back());
1515 rewriter.
eraseOp(spvModuleOp);
1524 class VectorShufflePattern
1529 matchAndRewrite(spirv::VectorShuffleOp op, OpAdaptor adaptor,
1532 auto components = adaptor.getComponents();
1533 auto vector1 = adaptor.getVector1();
1534 auto vector2 = adaptor.getVector2();
1535 int vector1Size = cast<VectorType>(vector1.getType()).getNumElements();
1536 int vector2Size = cast<VectorType>(vector2.getType()).getNumElements();
1537 if (vector1Size == vector2Size) {
1539 op, vector1, vector2,
1540 LLVM::convertArrayToIndices<int32_t>(components));
1544 auto dstType = getTypeConverter()->convertType(op.getType());
1547 auto scalarType = cast<VectorType>(dstType).getElementType();
1548 auto componentsArray = components.getValue();
1551 Value targetOp = rewriter.
create<LLVM::UndefOp>(loc, dstType);
1552 for (
unsigned i = 0; i < componentsArray.size(); i++) {
1553 if (!isa<IntegerAttr>(componentsArray[i]))
1554 return op.
emitError(
"unable to support non-constant component");
1556 int indexVal = cast<IntegerAttr>(componentsArray[i]).getInt();
1561 Value baseVector = vector1;
1562 if (indexVal >= vector1Size) {
1563 offsetVal = vector1Size;
1564 baseVector = vector2;
1567 Value dstIndex = rewriter.
create<LLVM::ConstantOp>(
1573 auto extractOp = rewriter.
create<LLVM::ExtractElementOp>(
1574 loc, scalarType, baseVector, index);
1575 targetOp = rewriter.
create<LLVM::InsertElementOp>(loc, dstType, targetOp,
1576 extractOp, dstIndex);
1589 spirv::ClientAPI clientAPI) {
1606 spirv::ClientAPI clientAPI) {
1609 DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>,
1610 DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>,
1611 DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>,
1612 DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>,
1613 DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>,
1614 DirectConversionPattern<spirv::FMulOp, LLVM::FMulOp>,
1615 DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>,
1616 DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>,
1617 DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>,
1618 DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>,
1619 DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>,
1620 DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>,
1621 DirectConversionPattern<spirv::UModOp, LLVM::URemOp>,
1624 BitFieldInsertPattern, BitFieldUExtractPattern, BitFieldSExtractPattern,
1625 DirectConversionPattern<spirv::BitCountOp, LLVM::CtPopOp>,
1626 DirectConversionPattern<spirv::BitReverseOp, LLVM::BitReverseOp>,
1627 DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>,
1628 DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>,
1629 DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>,
1630 NotPattern<spirv::NotOp>,
1633 BitcastConversionPattern,
1634 DirectConversionPattern<spirv::ConvertFToSOp, LLVM::FPToSIOp>,
1635 DirectConversionPattern<spirv::ConvertFToUOp, LLVM::FPToUIOp>,
1636 DirectConversionPattern<spirv::ConvertSToFOp, LLVM::SIToFPOp>,
1637 DirectConversionPattern<spirv::ConvertUToFOp, LLVM::UIToFPOp>,
1638 IndirectCastPattern<spirv::FConvertOp, LLVM::FPExtOp, LLVM::FPTruncOp>,
1639 IndirectCastPattern<spirv::SConvertOp, LLVM::SExtOp, LLVM::TruncOp>,
1640 IndirectCastPattern<spirv::UConvertOp, LLVM::ZExtOp, LLVM::TruncOp>,
1643 IComparePattern<spirv::IEqualOp, LLVM::ICmpPredicate::eq>,
1644 IComparePattern<spirv::INotEqualOp, LLVM::ICmpPredicate::ne>,
1645 FComparePattern<spirv::FOrdEqualOp, LLVM::FCmpPredicate::oeq>,
1646 FComparePattern<spirv::FOrdGreaterThanOp, LLVM::FCmpPredicate::ogt>,
1647 FComparePattern<spirv::FOrdGreaterThanEqualOp, LLVM::FCmpPredicate::oge>,
1648 FComparePattern<spirv::FOrdLessThanEqualOp, LLVM::FCmpPredicate::ole>,
1649 FComparePattern<spirv::FOrdLessThanOp, LLVM::FCmpPredicate::olt>,
1650 FComparePattern<spirv::FOrdNotEqualOp, LLVM::FCmpPredicate::one>,
1651 FComparePattern<spirv::FUnordEqualOp, LLVM::FCmpPredicate::ueq>,
1652 FComparePattern<spirv::FUnordGreaterThanOp, LLVM::FCmpPredicate::ugt>,
1653 FComparePattern<spirv::FUnordGreaterThanEqualOp,
1654 LLVM::FCmpPredicate::uge>,
1655 FComparePattern<spirv::FUnordLessThanEqualOp, LLVM::FCmpPredicate::ule>,
1656 FComparePattern<spirv::FUnordLessThanOp, LLVM::FCmpPredicate::ult>,
1657 FComparePattern<spirv::FUnordNotEqualOp, LLVM::FCmpPredicate::une>,
1658 IComparePattern<spirv::SGreaterThanOp, LLVM::ICmpPredicate::sgt>,
1659 IComparePattern<spirv::SGreaterThanEqualOp, LLVM::ICmpPredicate::sge>,
1660 IComparePattern<spirv::SLessThanEqualOp, LLVM::ICmpPredicate::sle>,
1661 IComparePattern<spirv::SLessThanOp, LLVM::ICmpPredicate::slt>,
1662 IComparePattern<spirv::UGreaterThanOp, LLVM::ICmpPredicate::ugt>,
1663 IComparePattern<spirv::UGreaterThanEqualOp, LLVM::ICmpPredicate::uge>,
1664 IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>,
1665 IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>,
1668 ConstantScalarAndVectorPattern,
1671 BranchConversionPattern, BranchConditionalConversionPattern,
1672 FunctionCallPattern, LoopPattern, SelectionPattern,
1673 ErasePattern<spirv::MergeOp>,
1676 ErasePattern<spirv::EntryPointOp>, ExecutionModePattern,
1679 DirectConversionPattern<spirv::GLCeilOp, LLVM::FCeilOp>,
1680 DirectConversionPattern<spirv::GLCosOp, LLVM::CosOp>,
1681 DirectConversionPattern<spirv::GLExpOp, LLVM::ExpOp>,
1682 DirectConversionPattern<spirv::GLFAbsOp, LLVM::FAbsOp>,
1683 DirectConversionPattern<spirv::GLFloorOp, LLVM::FFloorOp>,
1684 DirectConversionPattern<spirv::GLFMaxOp, LLVM::MaxNumOp>,
1685 DirectConversionPattern<spirv::GLFMinOp, LLVM::MinNumOp>,
1686 DirectConversionPattern<spirv::GLLogOp, LLVM::LogOp>,
1687 DirectConversionPattern<spirv::GLSinOp, LLVM::SinOp>,
1688 DirectConversionPattern<spirv::GLSMaxOp, LLVM::SMaxOp>,
1689 DirectConversionPattern<spirv::GLSMinOp, LLVM::SMinOp>,
1690 DirectConversionPattern<spirv::GLSqrtOp, LLVM::SqrtOp>,
1691 InverseSqrtPattern, TanPattern, TanhPattern,
1694 DirectConversionPattern<spirv::LogicalAndOp, LLVM::AndOp>,
1695 DirectConversionPattern<spirv::LogicalOrOp, LLVM::OrOp>,
1696 IComparePattern<spirv::LogicalEqualOp, LLVM::ICmpPredicate::eq>,
1697 IComparePattern<spirv::LogicalNotEqualOp, LLVM::ICmpPredicate::ne>,
1698 NotPattern<spirv::LogicalNotOp>,
1701 AccessChainPattern, AddressOfPattern, LoadStorePattern<spirv::LoadOp>,
1702 LoadStorePattern<spirv::StoreOp>, VariablePattern,
1705 CompositeExtractPattern, CompositeInsertPattern,
1706 DirectConversionPattern<spirv::SelectOp, LLVM::SelectOp>,
1707 DirectConversionPattern<spirv::UndefOp, LLVM::UndefOp>,
1708 VectorShufflePattern,
1711 ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>,
1712 ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>,
1713 ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>,
1716 ReturnPattern, ReturnValuePattern,
1719 ControlBarrierPattern>(patterns.
getContext(), typeConverter);
1721 patterns.
add<GlobalVariablePattern>(clientAPI, patterns.
getContext(),
1727 patterns.
add<FuncConversionPattern>(patterns.
getContext(), typeConverter);
1732 patterns.
add<ModuleConversionPattern>(patterns.
getContext(), typeConverter);
1743 auto spvModules = module.getOps<spirv::ModuleOp>();
1744 for (
auto spvModule : spvModules) {
1745 spvModule.walk([&](spirv::GlobalVariableOp op) {
1746 IntegerAttr descriptorSet =
1751 if (descriptorSet && binding) {
1754 auto moduleAndName =
1755 spvModule.getName().has_value()
1756 ? spvModule.getName()->str() +
"_" + op.getSymName().str()
1757 : op.getSymName().str();
1759 llvm::formatv(
"{0}_descriptor_set{1}_binding{2}", moduleAndName,
1760 std::to_string(descriptorSet.getInt()),
1761 std::to_string(binding.getInt()));
1767 op.
emitError(
"unable to replace all symbol uses for ") << name;
static LLVM::CallOp createSPIRVBuiltinCall(Location loc, ConversionPatternRewriter &rewriter, LLVM::LLVMFuncOp func, ValueRange args)
static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable, StringRef name, ArrayRef< Type > paramTypes, Type resultType, bool isMemNone, bool isConvergent)
static MLIRContext * getContext(OpFoldResult val)
static Value optionallyTruncateOrExtend(Location loc, Value value, Type llvmType, PatternRewriter &rewriter)
Utility function for bitfield ops:
static Value createFPConstant(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter, double value)
Creates llvm.mlir.constant with a floating-point scalar or vector value.
static constexpr StringRef kDescriptorSet
static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter, unsigned value)
Creates LLVM dialect constant with the given value.
static Type convertPointerType(spirv::PointerType type, const TypeConverter &converter, spirv::ClientAPI clientAPI)
Converts SPIR-V pointer type to LLVM pointer.
static Value processCountOrOffset(Location loc, Value value, Type srcType, Type dstType, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Utility function for bitfield ops: BitFieldInsert, BitFieldSExtract and BitFieldUExtract.
static unsigned getBitWidth(Type type)
Returns the bit width of integer, float or vector of float or integer values.
static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter, const TypeConverter &typeConverter, unsigned alignment, bool isVolatile, bool isNonTemporal)
Utility for spirv.Load and spirv.Store conversion.
static Type convertStructTypePacked(spirv::StructType type, const TypeConverter &converter)
Converts SPIR-V struct with no offset to packed LLVM struct.
static std::optional< Type > convertRuntimeArrayType(spirv::RuntimeArrayType type, TypeConverter &converter)
Converts SPIR-V runtime array to LLVM array.
static bool isSignedIntegerOrVector(Type type)
Returns true if the given type is a signed integer or vector type.
static bool isUnsignedIntegerOrVector(Type type)
Returns true if the given type is an unsigned integer or vector type.
static std::optional< Type > convertArrayType(spirv::ArrayType type, TypeConverter &converter)
Converts SPIR-V array type to LLVM array.
static constexpr StringRef kBinding
Hook for descriptor set and binding number encoding.
static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder)
Creates IntegerAttribute with all bits set for given type.
static Value optionallyBroadcast(Location loc, Value value, Type srcType, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value. If srcType is a scalar, the value remains unchanged.
static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter)
Creates llvm.mlir.constant with all bits set for the given type.
static unsigned getLLVMTypeBitWidth(Type type)
Returns the bit width of LLVMType integer or vector.
#define DISPATCH(functionControl, llvmAttr)
static std::optional< uint64_t > getIntegerOrVectorElementWidth(Type type)
Returns the width of an integer or of the element type of an integer vector, if applicable.
static Type convertStructTypeWithOffset(spirv::StructType type, const TypeConverter &converter)
Converts SPIR-V struct with a regular (according to VulkanLayoutUtils) offset to LLVM struct.
static Type convertStructType(spirv::StructType type, const TypeConverter &converter)
Converts SPIR-V struct to LLVM struct.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
OpListType::iterator iterator
Operation * getTerminator()
Get the terminator operation of this block.
OpListType & getOperations()
BlockArgListType getArguments()
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI32IntegerAttr(int32_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Conversion from types to the LLVM IR dialect.
static LLVMStructType getLiteral(MLIRContext *context, ArrayRef< Type > types, bool isPacked=false)
Gets or creates a literal struct with the given body in the provided context.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Block * getBlock() const
Returns the current block of the builder.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
A trait used to provide symbol table functionalities to a region operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
AttrClass getAttrOfType(StringAttr name)
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
Attribute removeAttr(StringAttr name)
Remove the attribute with the specified name if it exists.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, StringAttr newSymbol, Operation *from)
Attempt to replace all uses of the given symbol 'oldSymbol' with the provided symbol 'newSymbol' that...
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
static void setSymbolName(Operation *symbol, StringAttr name)
Sets the name of the given symbol operation.
This class provides all of the information necessary to convert a type signature.
void addConversion(FnT &&callback)
Register a conversion function.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static spirv::StructType decorateType(spirv::StructType structType)
Returns a new StructType with layout decoration.
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Builder from ArrayRef<T>.
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
unsigned getNumElements() const
StorageClass getStorageClass() const
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
TypeRange getElementTypes() const
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect.
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
SmallVector< IntT > convertArrayToIndices(ArrayRef< Attribute > attrs)
Convert an array of integer attributes to a vector of integers that can be used as indices in LLVM op...
Type getVectorElementType(Type type)
Returns the element type of any vector type compatible with the LLVM dialect.
Include the generated interface declarations.
unsigned storageClassToAddressSpace(spirv::ClientAPI clientAPI, spirv::StorageClass storageClass)
void populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter, spirv::ClientAPI clientAPIForAddressSpaceMapping=spirv::ClientAPI::Unknown)
Populates type conversions with additional SPIR-V types.
void populateSPIRVToLLVMFunctionConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given list with patterns for function conversion from SPIR-V to LLVM.
void populateSPIRVToLLVMConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, spirv::ClientAPI clientAPIForAddressSpaceMapping=spirv::ClientAPI::Unknown)
Populates the given list with patterns that convert from SPIR-V to LLVM.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void encodeBindAttribute(ModuleOp module)
Encodes global variable's descriptor set and binding into its name if they both exist.
void populateSPIRVToLLVMModuleConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given patterns for module conversion from SPIR-V to LLVM.