25 #include "llvm/Support/Debug.h"
26 #include "llvm/Support/FormatVariadic.h"
28 #define DEBUG_TYPE "spirv-to-llvm-pattern"
46 if (
auto vecType = dyn_cast<VectorType>(type))
47 return vecType.getElementType().isSignedInteger();
55 if (
auto vecType = dyn_cast<VectorType>(type))
56 return vecType.getElementType().isUnsignedInteger();
63 if (
auto intType = dyn_cast<IntegerType>(type))
64 return intType.getWidth();
65 if (
auto vecType = dyn_cast<VectorType>(type))
66 if (
auto intType = dyn_cast<IntegerType>(vecType.getElementType()))
67 return intType.getWidth();
74 "bitwidth is not supported for this type");
77 auto vecType = dyn_cast<VectorType>(type);
78 auto elementType = vecType.getElementType();
79 assert(elementType.isIntOrFloat() &&
80 "only integers and floats have a bitwidth");
81 return elementType.getIntOrFloatBitWidth();
94 if (
auto vecType = dyn_cast<VectorType>(type)) {
95 auto integerType = cast<IntegerType>(vecType.getElementType());
98 auto integerType = cast<IntegerType>(type);
105 if (isa<VectorType>(srcType)) {
106 return rewriter.
create<LLVM::ConstantOp>(
111 return rewriter.
create<LLVM::ConstantOp>(
118 if (
auto vecType = dyn_cast<VectorType>(srcType)) {
119 auto floatType = cast<FloatType>(vecType.getElementType());
120 return rewriter.
create<LLVM::ConstantOp>(
125 auto floatType = cast<FloatType>(srcType);
126 return rewriter.
create<LLVM::ConstantOp>(
139 auto srcType = value.
getType();
145 if (valueBitWidth < targetBitWidth)
146 return rewriter.
create<LLVM::ZExtOp>(loc, llvmType, value);
151 if (valueBitWidth > targetBitWidth)
152 return rewriter.
create<LLVM::TruncOp>(loc, llvmType, value);
161 auto llvmVectorType = typeConverter.
convertType(vectorType);
163 Value broadcasted = rewriter.
create<LLVM::UndefOp>(loc, llvmVectorType);
164 for (
unsigned i = 0; i < numElements; ++i) {
165 auto index = rewriter.
create<LLVM::ConstantOp>(
167 broadcasted = rewriter.
create<LLVM::InsertElementOp>(
168 loc, llvmVectorType, broadcasted, toBroadcast, index);
177 if (
auto vectorType = dyn_cast<VectorType>(srcType)) {
178 unsigned numElements = vectorType.getNumElements();
179 return broadcast(loc, value, numElements, typeConverter, rewriter);
229 return rewriter.
create<LLVM::ConstantOp>(
238 unsigned alignment,
bool isVolatile,
239 bool isNonTemporal) {
240 if (
auto loadOp = dyn_cast<spirv::LoadOp>(op)) {
241 auto dstType = typeConverter.
convertType(loadOp.getType());
245 loadOp, dstType, spirv::LoadOpAdaptor(operands).getPtr(), alignment,
246 isVolatile, isNonTemporal);
249 auto storeOp = cast<spirv::StoreOp>(op);
250 spirv::StoreOpAdaptor adaptor(operands);
252 adaptor.getPtr(), alignment,
253 isVolatile, isNonTemporal);
268 auto sizeInBytes = cast<spirv::SPIRVType>(elementType).getSizeInBytes();
269 if (stride != 0 && (!sizeInBytes || *sizeInBytes != stride))
272 auto llvmElementType = converter.
convertType(elementType);
281 switch (storageClass) {
282 #define STORAGE_SPACE_MAP(storage, space) \
283 case spirv::StorageClass::storage: \
293 #undef STORAGE_SPACE_MAP
300 spirv::StorageClass storageClass) {
302 #define CLIENT_MAP(client, storage) \
303 case spirv::ClientAPI::client: \
304 return mapTo##client##AddressSpace(storage);
316 spirv::ClientAPI clientAPI) {
338 if (!memberDecorations.empty())
356 matchAndRewrite(spirv::AccessChainOp op, OpAdaptor adaptor,
358 auto dstType = typeConverter.convertType(op.getComponentPtr().getType());
362 auto indices = llvm::to_vector<4>(adaptor.getIndices());
363 Type indexType = op.getIndices().front().getType();
364 auto llvmIndexType = typeConverter.convertType(indexType);
369 indices.insert(indices.begin(), zero);
371 auto elementType = typeConverter.convertType(
372 cast<spirv::PointerType>(op.getBasePtr().getType()).getPointeeType());
376 adaptor.getBasePtr(), indices);
386 matchAndRewrite(spirv::AddressOfOp op, OpAdaptor adaptor,
388 auto dstType = typeConverter.convertType(op.getPointer().getType());
397 class BitFieldInsertPattern
403 matchAndRewrite(spirv::BitFieldInsertOp op, OpAdaptor adaptor,
405 auto srcType = op.getType();
406 auto dstType = typeConverter.convertType(srcType);
413 typeConverter, rewriter);
415 typeConverter, rewriter);
419 Value maskShiftedByCount =
420 rewriter.
create<LLVM::ShlOp>(loc, dstType, minusOne, count);
421 Value negated = rewriter.
create<LLVM::XOrOp>(loc, dstType,
422 maskShiftedByCount, minusOne);
423 Value maskShiftedByCountAndOffset =
424 rewriter.
create<LLVM::ShlOp>(loc, dstType, negated, offset);
426 loc, dstType, maskShiftedByCountAndOffset, minusOne);
431 rewriter.
create<LLVM::AndOp>(loc, dstType, op.getBase(), mask);
432 Value insertShiftedByOffset =
433 rewriter.
create<LLVM::ShlOp>(loc, dstType, op.getInsert(), offset);
435 insertShiftedByOffset);
441 class ConstantScalarAndVectorPattern
447 matchAndRewrite(spirv::ConstantOp constOp, OpAdaptor adaptor,
449 auto srcType = constOp.getType();
450 if (!isa<VectorType>(srcType) && !srcType.isIntOrFloat())
453 auto dstType = typeConverter.convertType(srcType);
466 if (isa<VectorType>(srcType)) {
467 auto dstElementsAttr = cast<DenseIntElementsAttr>(constOp.getValue());
470 dstElementsAttr.mapValues(
471 signlessType, [&](
const APInt &value) {
return value; }));
474 auto srcAttr = cast<IntegerAttr>(constOp.getValue());
475 auto dstAttr = rewriter.
getIntegerAttr(signlessType, srcAttr.getValue());
480 constOp, dstType, adaptor.getOperands(), constOp->getAttrs());
485 class BitFieldSExtractPattern
491 matchAndRewrite(spirv::BitFieldSExtractOp op, OpAdaptor adaptor,
493 auto srcType = op.getType();
494 auto dstType = typeConverter.convertType(srcType);
501 typeConverter, rewriter);
503 typeConverter, rewriter);
506 IntegerType integerType;
507 if (
auto vecType = dyn_cast<VectorType>(srcType))
508 integerType = cast<IntegerType>(vecType.getElementType());
510 integerType = cast<IntegerType>(srcType);
514 isa<VectorType>(srcType)
515 ? rewriter.
create<LLVM::ConstantOp>(
518 : rewriter.
create<LLVM::ConstantOp>(loc, dstType, baseSize);
522 Value countPlusOffset =
523 rewriter.
create<LLVM::AddOp>(loc, dstType, count, offset);
524 Value amountToShiftLeft =
525 rewriter.
create<LLVM::SubOp>(loc, dstType, size, countPlusOffset);
526 Value baseShiftedLeft = rewriter.
create<LLVM::ShlOp>(
527 loc, dstType, op.getBase(), amountToShiftLeft);
530 Value amountToShiftRight =
531 rewriter.
create<LLVM::AddOp>(loc, dstType, offset, amountToShiftLeft);
538 class BitFieldUExtractPattern
544 matchAndRewrite(spirv::BitFieldUExtractOp op, OpAdaptor adaptor,
546 auto srcType = op.getType();
547 auto dstType = typeConverter.convertType(srcType);
554 typeConverter, rewriter);
556 typeConverter, rewriter);
560 Value maskShiftedByCount =
561 rewriter.
create<LLVM::ShlOp>(loc, dstType, minusOne, count);
562 Value mask = rewriter.
create<LLVM::XOrOp>(loc, dstType, maskShiftedByCount,
567 rewriter.
create<LLVM::LShrOp>(loc, dstType, op.getBase(), offset);
578 matchAndRewrite(spirv::BranchOp branchOp, OpAdaptor adaptor,
581 branchOp.getTarget());
586 class BranchConditionalConversionPattern
593 matchAndRewrite(spirv::BranchConditionalOp op, OpAdaptor adaptor,
597 if (
auto weights = op.getBranchWeights()) {
599 for (
auto weight : weights->getAsRange<IntegerAttr>())
600 weightValues.push_back(weight.getInt());
605 op, op.getCondition(), op.getTrueBlockArguments(),
606 op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(),
615 class CompositeExtractPattern
621 matchAndRewrite(spirv::CompositeExtractOp op, OpAdaptor adaptor,
623 auto dstType = this->typeConverter.convertType(op.getType());
627 Type containerType = op.getComposite().getType();
628 if (isa<VectorType>(containerType)) {
630 IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]);
633 op, dstType, adaptor.getComposite(), index);
638 op, adaptor.getComposite(),
647 class CompositeInsertPattern
653 matchAndRewrite(spirv::CompositeInsertOp op, OpAdaptor adaptor,
655 auto dstType = this->typeConverter.convertType(op.getType());
659 Type containerType = op.getComposite().getType();
660 if (isa<VectorType>(containerType)) {
662 IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]);
665 op, dstType, adaptor.getComposite(), adaptor.getObject(), index);
670 op, adaptor.getComposite(), adaptor.getObject(),
678 template <
typename SPIRVOp,
typename LLVMOp>
684 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
686 auto dstType = this->typeConverter.convertType(op.getType());
689 rewriter.template replaceOpWithNewOp<LLVMOp>(
690 op, dstType, adaptor.getOperands(), op->
getAttrs());
697 class ExecutionModePattern
703 matchAndRewrite(spirv::ExecutionModeOp op, OpAdaptor adaptor,
709 spirv::ExecutionModeAttr executionModeAttr = op.getExecutionModeAttr();
710 std::string moduleName;
711 if (module.getName().has_value())
712 moduleName =
"_" + module.
getName()->str();
715 std::string executionModeInfoName = llvm::formatv(
716 "__spv_{0}_{1}_execution_mode_info_{2}", moduleName, op.getFn().str(),
717 static_cast<uint32_t
>(executionModeAttr.getValue()));
730 fields.push_back(llvmI32Type);
731 ArrayAttr values = op.getValues();
732 if (!values.empty()) {
734 fields.push_back(arrayType);
739 auto global = rewriter.
create<LLVM::GlobalOp>(
741 LLVM::Linkage::External, executionModeInfoName,
Attribute(),
744 Region ®ion = global.getInitializerRegion();
749 Value structValue = rewriter.
create<LLVM::UndefOp>(loc, structType);
750 Value executionMode = rewriter.
create<LLVM::ConstantOp>(
753 static_cast<uint32_t
>(executionModeAttr.getValue())));
754 structValue = rewriter.
create<LLVM::InsertValueOp>(loc, structValue,
758 for (
unsigned i = 0, e = values.size(); i < e; ++i) {
759 auto attr = values.getValue()[i];
760 Value entry = rewriter.
create<LLVM::ConstantOp>(loc, llvmI32Type, attr);
761 structValue = rewriter.
create<LLVM::InsertValueOp>(
774 class GlobalVariablePattern
777 template <
typename... Args>
778 GlobalVariablePattern(spirv::ClientAPI clientAPI, Args &&...args)
780 std::forward<Args>(args)...),
781 clientAPI(clientAPI) {}
784 matchAndRewrite(spirv::GlobalVariableOp op, OpAdaptor adaptor,
788 if (op.getInitializer())
791 auto srcType = cast<spirv::PointerType>(op.getType());
792 auto dstType = typeConverter.convertType(srcType.getPointeeType());
799 auto storageClass = srcType.getStorageClass();
800 switch (storageClass) {
801 case spirv::StorageClass::Input:
802 case spirv::StorageClass::Private:
803 case spirv::StorageClass::Output:
804 case spirv::StorageClass::StorageBuffer:
805 case spirv::StorageClass::UniformConstant:
814 bool isConstant = (storageClass == spirv::StorageClass::Input) ||
815 (storageClass == spirv::StorageClass::UniformConstant);
821 auto linkage = storageClass == spirv::StorageClass::Private
822 ? LLVM::Linkage::Private
823 : LLVM::Linkage::External;
825 op, dstType, isConstant, linkage, op.getSymName(),
Attribute(),
829 if (op.getLocationAttr())
830 newGlobalOp->setAttr(op.getLocationAttrName(), op.getLocationAttr());
836 spirv::ClientAPI clientAPI;
841 template <
typename SPIRVOp,
typename LLVMExtOp,
typename LLVMTruncOp>
847 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
851 Type toType = op.getType();
853 auto dstType = this->typeConverter.convertType(toType);
858 rewriter.template replaceOpWithNewOp<LLVMExtOp>(op, dstType,
859 adaptor.getOperands());
863 rewriter.template replaceOpWithNewOp<LLVMTruncOp>(op, dstType,
864 adaptor.getOperands());
871 class FunctionCallPattern
877 matchAndRewrite(spirv::FunctionCallOp callOp, OpAdaptor adaptor,
879 if (callOp.getNumResults() == 0) {
881 callOp, std::nullopt, adaptor.getOperands(), callOp->getAttrs());
886 auto dstType = typeConverter.convertType(callOp.getType(0));
890 callOp, dstType, adaptor.getOperands(), callOp->getAttrs());
896 template <
typename SPIRVOp, LLVM::FCmpPredicate predicate>
902 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
905 auto dstType = this->typeConverter.convertType(op.getType());
909 rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>(
910 op, dstType, predicate, op.getOperand1(), op.getOperand2());
916 template <
typename SPIRVOp, LLVM::ICmpPredicate predicate>
922 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
925 auto dstType = this->typeConverter.convertType(op.getType());
929 rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>(
930 op, dstType, predicate, op.getOperand1(), op.getOperand2());
935 class InverseSqrtPattern
941 matchAndRewrite(spirv::GLInverseSqrtOp op, OpAdaptor adaptor,
943 auto srcType = op.getType();
944 auto dstType = typeConverter.convertType(srcType);
957 template <
typename SPIRVOp>
963 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
965 if (!op.getMemoryAccess()) {
967 this->typeConverter, 0,
971 auto memoryAccess = *op.getMemoryAccess();
972 switch (memoryAccess) {
973 case spirv::MemoryAccess::Aligned:
975 case spirv::MemoryAccess::Nontemporal:
976 case spirv::MemoryAccess::Volatile: {
978 memoryAccess == spirv::MemoryAccess::Aligned ? *op.getAlignment() : 0;
979 bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal;
980 bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile;
982 this->typeConverter, alignment, isVolatile,
993 template <
typename SPIRVOp>
999 matchAndRewrite(SPIRVOp notOp,
typename SPIRVOp::Adaptor adaptor,
1001 auto srcType = notOp.getType();
1002 auto dstType = this->typeConverter.convertType(srcType);
1009 isa<VectorType>(srcType)
1010 ? rewriter.
create<LLVM::ConstantOp>(
1013 : rewriter.
create<LLVM::ConstantOp>(loc, dstType, minusOne);
1014 rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType,
1015 notOp.getOperand(), mask);
1021 template <
typename SPIRVOp>
1027 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
1039 matchAndRewrite(spirv::ReturnOp returnOp, OpAdaptor adaptor,
1052 matchAndRewrite(spirv::ReturnValueOp returnValueOp, OpAdaptor adaptor,
1055 adaptor.getOperands());
1113 matchAndRewrite(spirv::LoopOp loopOp, OpAdaptor adaptor,
1129 Block *entryBlock = loopOp.getEntryBlock();
1131 auto brOp = dyn_cast<spirv::BranchOp>(entryBlock->
getOperations().front());
1134 Block *headerBlock = loopOp.getHeaderBlock();
1136 rewriter.
create<LLVM::BrOp>(loc, brOp.getBlockArguments(), headerBlock);
1140 Block *mergeBlock = loopOp.getMergeBlock();
1144 rewriter.
create<LLVM::BrOp>(loc, terminatorOperands, endBlock);
1160 matchAndRewrite(spirv::SelectionOp op, OpAdaptor adaptor,
1172 if (op.getBody().getBlocks().size() <= 2) {
1184 auto *continueBlock = rewriter.
splitBlock(currentBlock, position);
1190 auto *headerBlock = op.getHeaderBlock();
1192 auto condBrOp = dyn_cast<spirv::BranchConditionalOp>(
1199 auto *mergeBlock = op.getMergeBlock();
1203 rewriter.
create<LLVM::BrOp>(loc, terminatorOperands, continueBlock);
1206 Block *trueBlock = condBrOp.getTrueBlock();
1207 Block *falseBlock = condBrOp.getFalseBlock();
1209 rewriter.
create<LLVM::CondBrOp>(loc, condBrOp.getCondition(), trueBlock,
1210 condBrOp.getTrueTargetOperands(),
1212 condBrOp.getFalseTargetOperands());
1215 rewriter.
replaceOp(op, continueBlock->getArguments());
1224 template <
typename SPIRVOp,
typename LLVMOp>
1230 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
1233 auto dstType = this->typeConverter.convertType(op.getType());
1237 Type op1Type = op.getOperand1().getType();
1238 Type op2Type = op.getOperand2().getType();
1240 if (op1Type == op2Type) {
1241 rewriter.template replaceOpWithNewOp<LLVMOp>(op, dstType,
1242 adaptor.getOperands());
1246 std::optional<uint64_t> dstTypeWidth =
1248 std::optional<uint64_t> op2TypeWidth =
1251 if (!dstTypeWidth || !op2TypeWidth)
1256 if (op2TypeWidth < dstTypeWidth) {
1258 extended = rewriter.template create<LLVM::ZExtOp>(
1259 loc, dstType, adaptor.getOperand2());
1261 extended = rewriter.template create<LLVM::SExtOp>(
1262 loc, dstType, adaptor.getOperand2());
1264 }
else if (op2TypeWidth == dstTypeWidth) {
1265 extended = adaptor.getOperand2();
1270 Value result = rewriter.template create<LLVMOp>(
1271 loc, dstType, adaptor.getOperand1(), extended);
1282 matchAndRewrite(spirv::GLTanOp tanOp, OpAdaptor adaptor,
1284 auto dstType = typeConverter.convertType(tanOp.getType());
1289 Value sin = rewriter.
create<LLVM::SinOp>(loc, dstType, tanOp.getOperand());
1290 Value cos = rewriter.
create<LLVM::CosOp>(loc, dstType, tanOp.getOperand());
1307 matchAndRewrite(spirv::GLTanhOp tanhOp, OpAdaptor adaptor,
1309 auto srcType = tanhOp.getType();
1310 auto dstType = typeConverter.convertType(srcType);
1317 rewriter.
create<LLVM::FMulOp>(loc, dstType, two, tanhOp.getOperand());
1318 Value exponential = rewriter.
create<LLVM::ExpOp>(loc, dstType, multiplied);
1321 rewriter.
create<LLVM::FSubOp>(loc, dstType, exponential, one);
1323 rewriter.
create<LLVM::FAddOp>(loc, dstType, exponential, one);
1335 matchAndRewrite(spirv::VariableOp varOp, OpAdaptor adaptor,
1337 auto srcType = varOp.getType();
1339 auto pointerTo = cast<spirv::PointerType>(srcType).getPointeeType();
1340 auto init = varOp.getInitializer();
1341 if (init && !pointerTo.isIntOrFloat() && !isa<VectorType>(pointerTo))
1344 auto dstType = typeConverter.convertType(srcType);
1351 auto elementType = typeConverter.convertType(pointerTo);
1358 auto elementType = typeConverter.convertType(pointerTo);
1362 rewriter.
create<LLVM::AllocaOp>(loc, dstType, elementType, size);
1363 rewriter.
create<LLVM::StoreOp>(loc, adaptor.getInitializer(), allocated);
1373 class BitcastConversionPattern
1379 matchAndRewrite(spirv::BitcastOp bitcastOp, OpAdaptor adaptor,
1381 auto dstType = typeConverter.convertType(bitcastOp.getType());
1386 if (isa<LLVM::LLVMPointerType>(dstType)) {
1387 rewriter.
replaceOp(bitcastOp, adaptor.getOperand());
1392 bitcastOp, dstType, adaptor.getOperands(), bitcastOp->getAttrs());
1406 matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor,
1411 auto funcType = funcOp.getFunctionType();
1413 funcType.getNumInputs());
1414 auto llvmType = typeConverter.convertFunctionSignature(
1415 funcType,
false,
false,
1416 signatureConverter);
1422 StringRef name = funcOp.getName();
1423 auto newFuncOp = rewriter.
create<LLVM::LLVMFuncOp>(loc, name, llvmType);
1427 switch (funcOp.getFunctionControl()) {
1428 #define DISPATCH(functionControl, llvmAttr) \
1429 case functionControl: \
1430 newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr})); \
1433 DISPATCH(spirv::FunctionControl::Inline,
1435 DISPATCH(spirv::FunctionControl::DontInline,
1437 DISPATCH(spirv::FunctionControl::Pure,
1439 DISPATCH(spirv::FunctionControl::Const,
1453 &signatureConverter))) {
1470 matchAndRewrite(spirv::ModuleOp spvModuleOp, OpAdaptor adaptor,
1474 rewriter.
create<ModuleOp>(spvModuleOp.getLoc(), spvModuleOp.getName());
1478 rewriter.
eraseBlock(&newModuleOp.getBodyRegion().back());
1479 rewriter.
eraseOp(spvModuleOp);
1488 class VectorShufflePattern
1493 matchAndRewrite(spirv::VectorShuffleOp op, OpAdaptor adaptor,
1496 auto components = adaptor.getComponents();
1497 auto vector1 = adaptor.getVector1();
1498 auto vector2 = adaptor.getVector2();
1499 int vector1Size = cast<VectorType>(vector1.getType()).getNumElements();
1500 int vector2Size = cast<VectorType>(vector2.getType()).getNumElements();
1501 if (vector1Size == vector2Size) {
1503 op, vector1, vector2,
1504 LLVM::convertArrayToIndices<int32_t>(components));
1508 auto dstType = typeConverter.convertType(op.getType());
1511 auto scalarType = cast<VectorType>(dstType).getElementType();
1512 auto componentsArray = components.getValue();
1515 Value targetOp = rewriter.
create<LLVM::UndefOp>(loc, dstType);
1516 for (
unsigned i = 0; i < componentsArray.size(); i++) {
1517 if (!isa<IntegerAttr>(componentsArray[i]))
1518 return op.
emitError(
"unable to support non-constant component");
1520 int indexVal = cast<IntegerAttr>(componentsArray[i]).getInt();
1525 Value baseVector = vector1;
1526 if (indexVal >= vector1Size) {
1527 offsetVal = vector1Size;
1528 baseVector = vector2;
1531 Value dstIndex = rewriter.
create<LLVM::ConstantOp>(
1537 auto extractOp = rewriter.
create<LLVM::ExtractElementOp>(
1538 loc, scalarType, baseVector, index);
1539 targetOp = rewriter.
create<LLVM::InsertElementOp>(loc, dstType, targetOp,
1540 extractOp, dstIndex);
1553 spirv::ClientAPI clientAPI) {
1570 spirv::ClientAPI clientAPI) {
1573 DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>,
1574 DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>,
1575 DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>,
1576 DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>,
1577 DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>,
1578 DirectConversionPattern<spirv::FMulOp, LLVM::FMulOp>,
1579 DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>,
1580 DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>,
1581 DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>,
1582 DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>,
1583 DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>,
1584 DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>,
1585 DirectConversionPattern<spirv::UModOp, LLVM::URemOp>,
1588 BitFieldInsertPattern, BitFieldUExtractPattern, BitFieldSExtractPattern,
1589 DirectConversionPattern<spirv::BitCountOp, LLVM::CtPopOp>,
1590 DirectConversionPattern<spirv::BitReverseOp, LLVM::BitReverseOp>,
1591 DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>,
1592 DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>,
1593 DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>,
1594 NotPattern<spirv::NotOp>,
1597 BitcastConversionPattern,
1598 DirectConversionPattern<spirv::ConvertFToSOp, LLVM::FPToSIOp>,
1599 DirectConversionPattern<spirv::ConvertFToUOp, LLVM::FPToUIOp>,
1600 DirectConversionPattern<spirv::ConvertSToFOp, LLVM::SIToFPOp>,
1601 DirectConversionPattern<spirv::ConvertUToFOp, LLVM::UIToFPOp>,
1602 IndirectCastPattern<spirv::FConvertOp, LLVM::FPExtOp, LLVM::FPTruncOp>,
1603 IndirectCastPattern<spirv::SConvertOp, LLVM::SExtOp, LLVM::TruncOp>,
1604 IndirectCastPattern<spirv::UConvertOp, LLVM::ZExtOp, LLVM::TruncOp>,
1607 IComparePattern<spirv::IEqualOp, LLVM::ICmpPredicate::eq>,
1608 IComparePattern<spirv::INotEqualOp, LLVM::ICmpPredicate::ne>,
1609 FComparePattern<spirv::FOrdEqualOp, LLVM::FCmpPredicate::oeq>,
1610 FComparePattern<spirv::FOrdGreaterThanOp, LLVM::FCmpPredicate::ogt>,
1611 FComparePattern<spirv::FOrdGreaterThanEqualOp, LLVM::FCmpPredicate::oge>,
1612 FComparePattern<spirv::FOrdLessThanEqualOp, LLVM::FCmpPredicate::ole>,
1613 FComparePattern<spirv::FOrdLessThanOp, LLVM::FCmpPredicate::olt>,
1614 FComparePattern<spirv::FOrdNotEqualOp, LLVM::FCmpPredicate::one>,
1615 FComparePattern<spirv::FUnordEqualOp, LLVM::FCmpPredicate::ueq>,
1616 FComparePattern<spirv::FUnordGreaterThanOp, LLVM::FCmpPredicate::ugt>,
1617 FComparePattern<spirv::FUnordGreaterThanEqualOp,
1618 LLVM::FCmpPredicate::uge>,
1619 FComparePattern<spirv::FUnordLessThanEqualOp, LLVM::FCmpPredicate::ule>,
1620 FComparePattern<spirv::FUnordLessThanOp, LLVM::FCmpPredicate::ult>,
1621 FComparePattern<spirv::FUnordNotEqualOp, LLVM::FCmpPredicate::une>,
1622 IComparePattern<spirv::SGreaterThanOp, LLVM::ICmpPredicate::sgt>,
1623 IComparePattern<spirv::SGreaterThanEqualOp, LLVM::ICmpPredicate::sge>,
1624 IComparePattern<spirv::SLessThanEqualOp, LLVM::ICmpPredicate::sle>,
1625 IComparePattern<spirv::SLessThanOp, LLVM::ICmpPredicate::slt>,
1626 IComparePattern<spirv::UGreaterThanOp, LLVM::ICmpPredicate::ugt>,
1627 IComparePattern<spirv::UGreaterThanEqualOp, LLVM::ICmpPredicate::uge>,
1628 IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>,
1629 IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>,
1632 ConstantScalarAndVectorPattern,
1635 BranchConversionPattern, BranchConditionalConversionPattern,
1636 FunctionCallPattern, LoopPattern, SelectionPattern,
1637 ErasePattern<spirv::MergeOp>,
1640 ErasePattern<spirv::EntryPointOp>, ExecutionModePattern,
1643 DirectConversionPattern<spirv::GLCeilOp, LLVM::FCeilOp>,
1644 DirectConversionPattern<spirv::GLCosOp, LLVM::CosOp>,
1645 DirectConversionPattern<spirv::GLExpOp, LLVM::ExpOp>,
1646 DirectConversionPattern<spirv::GLFAbsOp, LLVM::FAbsOp>,
1647 DirectConversionPattern<spirv::GLFloorOp, LLVM::FFloorOp>,
1648 DirectConversionPattern<spirv::GLFMaxOp, LLVM::MaxNumOp>,
1649 DirectConversionPattern<spirv::GLFMinOp, LLVM::MinNumOp>,
1650 DirectConversionPattern<spirv::GLLogOp, LLVM::LogOp>,
1651 DirectConversionPattern<spirv::GLSinOp, LLVM::SinOp>,
1652 DirectConversionPattern<spirv::GLSMaxOp, LLVM::SMaxOp>,
1653 DirectConversionPattern<spirv::GLSMinOp, LLVM::SMinOp>,
1654 DirectConversionPattern<spirv::GLSqrtOp, LLVM::SqrtOp>,
1655 InverseSqrtPattern, TanPattern, TanhPattern,
1658 DirectConversionPattern<spirv::LogicalAndOp, LLVM::AndOp>,
1659 DirectConversionPattern<spirv::LogicalOrOp, LLVM::OrOp>,
1660 IComparePattern<spirv::LogicalEqualOp, LLVM::ICmpPredicate::eq>,
1661 IComparePattern<spirv::LogicalNotEqualOp, LLVM::ICmpPredicate::ne>,
1662 NotPattern<spirv::LogicalNotOp>,
1665 AccessChainPattern, AddressOfPattern, LoadStorePattern<spirv::LoadOp>,
1666 LoadStorePattern<spirv::StoreOp>, VariablePattern,
1669 CompositeExtractPattern, CompositeInsertPattern,
1670 DirectConversionPattern<spirv::SelectOp, LLVM::SelectOp>,
1671 DirectConversionPattern<spirv::UndefOp, LLVM::UndefOp>,
1672 VectorShufflePattern,
1675 ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>,
1676 ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>,
1677 ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>,
1680 ReturnPattern, ReturnValuePattern>(patterns.
getContext(), typeConverter);
1682 patterns.
add<GlobalVariablePattern>(clientAPI, patterns.
getContext(),
1688 patterns.
add<FuncConversionPattern>(patterns.
getContext(), typeConverter);
1693 patterns.
add<ModuleConversionPattern>(patterns.
getContext(), typeConverter);
1704 auto spvModules = module.getOps<spirv::ModuleOp>();
1705 for (
auto spvModule : spvModules) {
1706 spvModule.walk([&](spirv::GlobalVariableOp op) {
1707 IntegerAttr descriptorSet =
1712 if (descriptorSet && binding) {
1715 auto moduleAndName =
1716 spvModule.getName().has_value()
1717 ? spvModule.getName()->str() +
"_" + op.getSymName().str()
1718 : op.getSymName().str();
1720 llvm::formatv(
"{0}_descriptor_set{1}_binding{2}", moduleAndName,
1721 std::to_string(descriptorSet.getInt()),
1722 std::to_string(binding.getInt()));
1728 op.
emitError(
"unable to replace all symbol uses for ") << name;
static MLIRContext * getContext(OpFoldResult val)
static Value optionallyTruncateOrExtend(Location loc, Value value, Type llvmType, PatternRewriter &rewriter)
Utility function for bitfield ops:
static Value createFPConstant(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter, double value)
Creates llvm.mlir.constant with a floating-point scalar or vector value.
static constexpr StringRef kDescriptorSet
static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter, unsigned value)
Creates LLVM dialect constant with the given value.
static Type convertStructTypeWithOffset(spirv::StructType type, LLVMTypeConverter &converter)
Converts SPIR-V struct with a regular (according to VulkanLayoutUtils) offset to LLVM struct.
static Value processCountOrOffset(Location loc, Value value, Type srcType, Type dstType, LLVMTypeConverter &converter, ConversionPatternRewriter &rewriter)
Utility function for bitfield ops: BitFieldInsert, BitFieldSExtract and BitFieldUExtract.
static unsigned mapToOpenCLAddressSpace(spirv::StorageClass storageClass)
static unsigned getBitWidth(Type type)
Returns the bit width of integer, float or vector of float or integer values.
constexpr unsigned defaultAddressSpace
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static std::optional< Type > convertRuntimeArrayType(spirv::RuntimeArrayType type, TypeConverter &converter)
Converts SPIR-V runtime array to LLVM array.
#define STORAGE_SPACE_MAP(storage, space)
static bool isSignedIntegerOrVector(Type type)
Returns true if the given type is a signed integer or vector type.
static bool isUnsignedIntegerOrVector(Type type)
Returns true if the given type is an unsigned integer or vector type.
static Type convertStructTypePacked(spirv::StructType type, LLVMTypeConverter &converter)
Converts SPIR-V struct with no offset to packed LLVM struct.
static std::optional< Type > convertArrayType(spirv::ArrayType type, TypeConverter &converter)
Converts SPIR-V array type to LLVM array.
static constexpr StringRef kBinding
Hook for descriptor set and binding number encoding.
static unsigned mapToAddressSpace(spirv::ClientAPI clientAPI, spirv::StorageClass storageClass)
static Type convertStructType(spirv::StructType type, LLVMTypeConverter &converter)
Converts SPIR-V struct to LLVM struct.
static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder)
Creates IntegerAttribute with all bits set for given type.
static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter, LLVMTypeConverter &typeConverter, unsigned alignment, bool isVolatile, bool isNonTemporal)
Utility for spirv.Load and spirv.Store conversion.
static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter)
Creates llvm.mlir.constant with all bits set for the given type.
static unsigned getLLVMTypeBitWidth(Type type)
Returns the bit width of LLVMType integer or vector.
static Value optionallyBroadcast(Location loc, Value value, Type srcType, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value. If srcType is a scalar, the value remains unchanged.
#define DISPATCH(functionControl, llvmAttr)
static std::optional< uint64_t > getIntegerOrVectorElementWidth(Type type)
Returns the width of an integer or of the element type of an integer vector, if applicable.
#define CLIENT_MAP(client, storage)
static Type convertPointerType(spirv::PointerType type, LLVMTypeConverter &converter, spirv::ClientAPI clientAPI)
Converts SPIR-V pointer type to LLVM pointer.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
OpListType::iterator iterator
Operation * getTerminator()
Get the terminator operation of this block.
OpListType & getOperations()
BlockArgListType getArguments()
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Convert the types of block arguments within the given region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Conversion from types to the LLVM IR dialect.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
static LLVMStructType getLiteral(MLIRContext *context, ArrayRef< Type > types, bool isPacked=false)
Gets or creates a literal struct with the given body in the provided context.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Block * getBlock() const
Returns the current block of the builder.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
AttrClass getAttrOfType(StringAttr name)
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
Attribute removeAttr(StringAttr name)
Remove the attribute with the specified name if it exists.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, StringAttr newSymbol, Operation *from)
Attempt to replace all uses of the given symbol 'oldSymbol' with the provided symbol 'newSymbol' that...
static void setSymbolName(Operation *symbol, StringAttr name)
Sets the name of the given symbol operation.
This class provides all of the information necessary to convert a type signature.
void addConversion(FnT &&callback)
Register a conversion function.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static spirv::StructType decorateType(spirv::StructType structType)
Returns a new StructType with layout decoration.
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Builder from ArrayRef<T>.
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
unsigned getNumElements() const
StorageClass getStorageClass() const
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
TypeRange getElementTypes() const
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect.
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
SmallVector< IntT > convertArrayToIndices(ArrayRef< Attribute > attrs)
Convert an array of integer attributes to a vector of integers that can be used as indices in LLVM op...
Type getVectorElementType(Type type)
Returns the element type of any vector type compatible with the LLVM dialect.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
void populateSPIRVToLLVMFunctionConversionPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given list with patterns for function conversion from SPIR-V to LLVM.
void populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter, spirv::ClientAPI clientAPIForAddressSpaceMapping=spirv::ClientAPI::Unknown)
Populates type conversions with additional SPIR-V types.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
void populateSPIRVToLLVMConversionPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, spirv::ClientAPI clientAPIForAddressSpaceMapping=spirv::ClientAPI::Unknown)
Populates the given list with patterns that convert from SPIR-V to LLVM.
void populateSPIRVToLLVMModuleConversionPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given patterns for module conversion from SPIR-V to LLVM.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void encodeBindAttribute(ModuleOp module)
Encodes global variable's descriptor set and binding into its name if they both exist.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.