25 #include "llvm/Support/Debug.h"
26 #include "llvm/Support/FormatVariadic.h"
28 #define DEBUG_TYPE "spirv-to-llvm-pattern"
40 if (
auto vecType = type.
dyn_cast<VectorType>())
41 return vecType.getElementType().isSignedInteger();
49 if (
auto vecType = type.
dyn_cast<VectorType>())
50 return vecType.getElementType().isUnsignedInteger();
57 "bitwidth is not supported for this type");
60 auto vecType = type.
dyn_cast<VectorType>();
61 auto elementType = vecType.getElementType();
62 assert(elementType.isIntOrFloat() &&
63 "only integers and floats have a bitwidth");
64 return elementType.getIntOrFloatBitWidth();
77 if (
auto vecType = type.
dyn_cast<VectorType>()) {
78 auto integerType = vecType.getElementType().cast<IntegerType>();
81 auto integerType = type.
cast<IntegerType>();
88 if (srcType.
isa<VectorType>()) {
89 return rewriter.
create<LLVM::ConstantOp>(
94 return rewriter.
create<LLVM::ConstantOp>(
101 if (
auto vecType = srcType.
dyn_cast<VectorType>()) {
103 return rewriter.
create<LLVM::ConstantOp>(
109 return rewriter.
create<LLVM::ConstantOp>(
122 auto srcType = value.
getType();
128 if (valueBitWidth < targetBitWidth)
129 return rewriter.
create<LLVM::ZExtOp>(loc, llvmType, value);
134 if (valueBitWidth > targetBitWidth)
135 return rewriter.
create<LLVM::TruncOp>(loc, llvmType, value);
143 auto vectorType = VectorType::get(numElements, toBroadcast.
getType());
144 auto llvmVectorType = typeConverter.
convertType(vectorType);
146 Value broadcasted = rewriter.
create<LLVM::UndefOp>(loc, llvmVectorType);
147 for (
unsigned i = 0; i < numElements; ++i) {
148 auto index = rewriter.
create<LLVM::ConstantOp>(
150 broadcasted = rewriter.
create<LLVM::InsertElementOp>(
151 loc, llvmVectorType, broadcasted, toBroadcast, index);
160 if (
auto vectorType = srcType.
dyn_cast<VectorType>()) {
161 unsigned numElements = vectorType.getNumElements();
162 return broadcast(loc, value, numElements, typeConverter, rewriter);
187 static std::optional<Type>
193 auto elementsVector = llvm::to_vector<8>(
195 return converter.convertType(elementType);
204 auto elementsVector = llvm::to_vector<8>(
206 return converter.convertType(elementType);
215 return rewriter.
create<LLVM::ConstantOp>(
216 loc, IntegerType::get(rewriter.
getContext(), 32),
224 unsigned alignment,
bool isVolatile,
225 bool isNonTemporal) {
226 if (
auto loadOp = dyn_cast<spirv::LoadOp>(op)) {
227 auto dstType = typeConverter.
convertType(loadOp.getType());
231 loadOp, dstType, spirv::LoadOpAdaptor(operands).getPtr(), alignment,
232 isVolatile, isNonTemporal);
235 auto storeOp = cast<spirv::StoreOp>(op);
236 spirv::StoreOpAdaptor adaptor(operands);
238 adaptor.getPtr(), alignment,
239 isVolatile, isNonTemporal);
255 if (stride != 0 && (!sizeInBytes || *sizeInBytes != stride))
258 auto llvmElementType = converter.
convertType(elementType);
260 return LLVM::LLVMArrayType::get(llvmElementType, numElements);
279 return LLVM::LLVMArrayType::get(elementType, 0);
288 if (!memberDecorations.empty())
306 matchAndRewrite(spirv::AccessChainOp op, OpAdaptor adaptor,
308 auto dstType = typeConverter.convertType(op.getComponentPtr().getType());
312 auto indices = llvm::to_vector<4>(adaptor.getIndices());
313 Type indexType = op.getIndices().front().getType();
314 auto llvmIndexType = typeConverter.convertType(indexType);
318 op.getLoc(), llvmIndexType, rewriter.
getIntegerAttr(indexType, 0));
319 indices.insert(indices.begin(), zero);
322 typeConverter.convertType(op.getBasePtr()
326 adaptor.getBasePtr(), indices);
336 matchAndRewrite(spirv::AddressOfOp op, OpAdaptor adaptor,
338 auto dstType = typeConverter.convertType(op.getPointer().getType());
346 class BitFieldInsertPattern
352 matchAndRewrite(spirv::BitFieldInsertOp op, OpAdaptor adaptor,
354 auto srcType = op.getType();
355 auto dstType = typeConverter.convertType(srcType);
362 typeConverter, rewriter);
364 typeConverter, rewriter);
368 Value maskShiftedByCount =
369 rewriter.
create<LLVM::ShlOp>(loc, dstType, minusOne, count);
370 Value negated = rewriter.
create<LLVM::XOrOp>(loc, dstType,
371 maskShiftedByCount, minusOne);
372 Value maskShiftedByCountAndOffset =
373 rewriter.
create<LLVM::ShlOp>(loc, dstType, negated, offset);
375 loc, dstType, maskShiftedByCountAndOffset, minusOne);
380 rewriter.
create<LLVM::AndOp>(loc, dstType, op.getBase(), mask);
381 Value insertShiftedByOffset =
382 rewriter.
create<LLVM::ShlOp>(loc, dstType, op.getInsert(), offset);
384 insertShiftedByOffset);
390 class ConstantScalarAndVectorPattern
396 matchAndRewrite(spirv::ConstantOp constOp, OpAdaptor adaptor,
398 auto srcType = constOp.getType();
399 if (!srcType.isa<VectorType>() && !srcType.isIntOrFloat())
402 auto dstType = typeConverter.convertType(srcType);
415 if (srcType.isa<VectorType>()) {
419 dstElementsAttr.mapValues(
420 signlessType, [&](
const APInt &value) {
return value; }));
423 auto srcAttr = constOp.getValue().cast<IntegerAttr>();
424 auto dstAttr = rewriter.
getIntegerAttr(signlessType, srcAttr.getValue());
429 constOp, dstType, adaptor.getOperands(), constOp->getAttrs());
434 class BitFieldSExtractPattern
440 matchAndRewrite(spirv::BitFieldSExtractOp op, OpAdaptor adaptor,
442 auto srcType = op.getType();
443 auto dstType = typeConverter.convertType(srcType);
450 typeConverter, rewriter);
452 typeConverter, rewriter);
455 IntegerType integerType;
456 if (
auto vecType = srcType.dyn_cast<VectorType>())
457 integerType = vecType.getElementType().cast<IntegerType>();
459 integerType = srcType.cast<IntegerType>();
463 srcType.
isa<VectorType>()
464 ? rewriter.
create<LLVM::ConstantOp>(
467 : rewriter.
create<LLVM::ConstantOp>(loc, dstType, baseSize);
471 Value countPlusOffset =
472 rewriter.
create<LLVM::AddOp>(loc, dstType, count, offset);
473 Value amountToShiftLeft =
474 rewriter.
create<LLVM::SubOp>(loc, dstType, size, countPlusOffset);
475 Value baseShiftedLeft = rewriter.
create<LLVM::ShlOp>(
476 loc, dstType, op.getBase(), amountToShiftLeft);
479 Value amountToShiftRight =
480 rewriter.
create<LLVM::AddOp>(loc, dstType, offset, amountToShiftLeft);
487 class BitFieldUExtractPattern
493 matchAndRewrite(spirv::BitFieldUExtractOp op, OpAdaptor adaptor,
495 auto srcType = op.getType();
496 auto dstType = typeConverter.convertType(srcType);
503 typeConverter, rewriter);
505 typeConverter, rewriter);
509 Value maskShiftedByCount =
510 rewriter.
create<LLVM::ShlOp>(loc, dstType, minusOne, count);
511 Value mask = rewriter.
create<LLVM::XOrOp>(loc, dstType, maskShiftedByCount,
516 rewriter.
create<LLVM::LShrOp>(loc, dstType, op.getBase(), offset);
527 matchAndRewrite(spirv::BranchOp branchOp, OpAdaptor adaptor,
530 branchOp.getTarget());
535 class BranchConditionalConversionPattern
542 matchAndRewrite(spirv::BranchConditionalOp op, OpAdaptor adaptor,
545 ElementsAttr branchWeights =
nullptr;
546 if (
auto weights = op.getBranchWeights()) {
547 VectorType weightType = VectorType::get(2, rewriter.
getI32Type());
552 op, op.getCondition(), op.getTrueBlockArguments(),
553 op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(),
562 class CompositeExtractPattern
568 matchAndRewrite(spirv::CompositeExtractOp op, OpAdaptor adaptor,
570 auto dstType = this->typeConverter.convertType(op.getType());
574 Type containerType = op.getComposite().getType();
575 if (containerType.
isa<VectorType>()) {
577 IntegerAttr value = op.getIndices()[0].
cast<IntegerAttr>();
580 op, dstType, adaptor.getComposite(), index);
593 class CompositeInsertPattern
599 matchAndRewrite(spirv::CompositeInsertOp op, OpAdaptor adaptor,
601 auto dstType = this->typeConverter.convertType(op.getType());
605 Type containerType = op.getComposite().getType();
606 if (containerType.
isa<VectorType>()) {
608 IntegerAttr value = op.getIndices()[0].
cast<IntegerAttr>();
611 op, dstType, adaptor.getComposite(), adaptor.getObject(), index);
616 op, adaptor.getComposite(), adaptor.getObject(),
624 template <
typename SPIRVOp,
typename LLVMOp>
630 matchAndRewrite(SPIRVOp operation,
typename SPIRVOp::Adaptor adaptor,
632 auto dstType = this->typeConverter.convertType(operation.getType());
635 rewriter.template replaceOpWithNewOp<LLVMOp>(
636 operation, dstType, adaptor.getOperands(), operation->getAttrs());
643 class ExecutionModePattern
649 matchAndRewrite(spirv::ExecutionModeOp op, OpAdaptor adaptor,
654 ModuleOp module = op->getParentOfType<ModuleOp>();
655 spirv::ExecutionModeAttr executionModeAttr = op.getExecutionModeAttr();
656 std::string moduleName;
657 if (module.getName().has_value())
658 moduleName =
"_" + module.getName()->str();
661 std::string executionModeInfoName = llvm::formatv(
662 "__spv_{0}_{1}_execution_mode_info_{2}", moduleName, op.getFn().str(),
663 static_cast<uint32_t
>(executionModeAttr.getValue()));
674 auto llvmI32Type = IntegerType::get(context, 32);
676 fields.push_back(llvmI32Type);
677 ArrayAttr values = op.getValues();
678 if (!values.empty()) {
679 auto arrayType = LLVM::LLVMArrayType::get(llvmI32Type, values.size());
680 fields.push_back(arrayType);
685 auto global = rewriter.
create<LLVM::GlobalOp>(
686 UnknownLoc::get(context), structType,
true,
687 LLVM::Linkage::External, executionModeInfoName,
Attribute(),
690 Region ®ion = global.getInitializerRegion();
695 Value structValue = rewriter.
create<LLVM::UndefOp>(loc, structType);
696 Value executionMode = rewriter.
create<LLVM::ConstantOp>(
699 static_cast<uint32_t
>(executionModeAttr.getValue())));
700 structValue = rewriter.
create<LLVM::InsertValueOp>(loc, structValue,
704 for (
unsigned i = 0, e = values.size(); i < e; ++i) {
705 auto attr = values.getValue()[i];
706 Value entry = rewriter.
create<LLVM::ConstantOp>(loc, llvmI32Type, attr);
707 structValue = rewriter.
create<LLVM::InsertValueOp>(
720 class GlobalVariablePattern
726 matchAndRewrite(spirv::GlobalVariableOp op, OpAdaptor adaptor,
730 if (op.getInitializer())
734 auto dstType = typeConverter.convertType(srcType.getPointeeType());
741 auto storageClass = srcType.getStorageClass();
742 switch (storageClass) {
743 case spirv::StorageClass::Input:
744 case spirv::StorageClass::Private:
745 case spirv::StorageClass::Output:
746 case spirv::StorageClass::StorageBuffer:
747 case spirv::StorageClass::UniformConstant:
756 bool isConstant = (storageClass == spirv::StorageClass::Input) ||
757 (storageClass == spirv::StorageClass::UniformConstant);
763 auto linkage = storageClass == spirv::StorageClass::Private
764 ? LLVM::Linkage::Private
765 : LLVM::Linkage::External;
767 op, dstType, isConstant, linkage, op.getSymName(),
Attribute(),
771 if (op.getLocationAttr())
772 newGlobalOp->setAttr(op.getLocationAttrName(), op.getLocationAttr());
780 template <
typename SPIRVOp,
typename LLVMExtOp,
typename LLVMTruncOp>
786 matchAndRewrite(SPIRVOp operation,
typename SPIRVOp::Adaptor adaptor,
789 Type fromType = operation.getOperand().getType();
790 Type toType = operation.getType();
792 auto dstType = this->typeConverter.convertType(toType);
797 rewriter.template replaceOpWithNewOp<LLVMExtOp>(operation, dstType,
798 adaptor.getOperands());
802 rewriter.template replaceOpWithNewOp<LLVMTruncOp>(operation, dstType,
803 adaptor.getOperands());
810 class FunctionCallPattern
816 matchAndRewrite(spirv::FunctionCallOp callOp, OpAdaptor adaptor,
818 if (callOp.getNumResults() == 0) {
820 callOp, std::nullopt, adaptor.getOperands(), callOp->getAttrs());
825 auto dstType = typeConverter.convertType(callOp.getType(0));
827 callOp, dstType, adaptor.getOperands(), callOp->getAttrs());
833 template <
typename SPIRVOp, LLVM::FCmpPredicate predicate>
839 matchAndRewrite(SPIRVOp operation,
typename SPIRVOp::Adaptor adaptor,
842 auto dstType = this->typeConverter.convertType(operation.getType());
846 rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>(
847 operation, dstType, predicate, operation.getOperand1(),
848 operation.getOperand2());
854 template <
typename SPIRVOp, LLVM::ICmpPredicate predicate>
860 matchAndRewrite(SPIRVOp operation,
typename SPIRVOp::Adaptor adaptor,
863 auto dstType = this->typeConverter.convertType(operation.getType());
867 rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>(
868 operation, dstType, predicate, operation.getOperand1(),
869 operation.getOperand2());
874 class InverseSqrtPattern
880 matchAndRewrite(spirv::GLInverseSqrtOp op, OpAdaptor adaptor,
882 auto srcType = op.getType();
883 auto dstType = typeConverter.convertType(srcType);
889 Value sqrt = rewriter.
create<LLVM::SqrtOp>(loc, dstType, op.getOperand());
896 template <
typename SPIRVOp>
902 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
904 if (!op.getMemoryAccess()) {
906 this->typeConverter, 0,
910 auto memoryAccess = *op.getMemoryAccess();
911 switch (memoryAccess) {
912 case spirv::MemoryAccess::Aligned:
914 case spirv::MemoryAccess::Nontemporal:
915 case spirv::MemoryAccess::Volatile: {
917 memoryAccess == spirv::MemoryAccess::Aligned ? *op.getAlignment() : 0;
918 bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal;
919 bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile;
921 this->typeConverter, alignment, isVolatile,
932 template <
typename SPIRVOp>
938 matchAndRewrite(SPIRVOp notOp,
typename SPIRVOp::Adaptor adaptor,
940 auto srcType = notOp.getType();
941 auto dstType = this->typeConverter.convertType(srcType);
947 auto mask = srcType.template isa<VectorType>()
948 ? rewriter.
create<LLVM::ConstantOp>(
951 srcType.template cast<VectorType>(), minusOne))
952 : rewriter.
create<LLVM::ConstantOp>(loc, dstType, minusOne);
953 rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType,
954 notOp.getOperand(), mask);
960 template <
typename SPIRVOp>
966 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
978 matchAndRewrite(spirv::ReturnOp returnOp, OpAdaptor adaptor,
991 matchAndRewrite(spirv::ReturnValueOp returnValueOp, OpAdaptor adaptor,
994 adaptor.getOperands());
1052 matchAndRewrite(spirv::LoopOp loopOp, OpAdaptor adaptor,
1068 Block *entryBlock = loopOp.getEntryBlock();
1070 auto brOp = dyn_cast<spirv::BranchOp>(entryBlock->
getOperations().front());
1073 Block *headerBlock = loopOp.getHeaderBlock();
1075 rewriter.
create<LLVM::BrOp>(loc, brOp.getBlockArguments(), headerBlock);
1079 Block *mergeBlock = loopOp.getMergeBlock();
1083 rewriter.
create<LLVM::BrOp>(loc, terminatorOperands, endBlock);
1099 matchAndRewrite(spirv::SelectionOp op, OpAdaptor adaptor,
1111 if (op.getBody().getBlocks().size() <= 2) {
1123 auto *continueBlock = rewriter.
splitBlock(currentBlock, position);
1129 auto *headerBlock = op.getHeaderBlock();
1131 auto condBrOp = dyn_cast<spirv::BranchConditionalOp>(
1138 auto *mergeBlock = op.getMergeBlock();
1142 rewriter.
create<LLVM::BrOp>(loc, terminatorOperands, continueBlock);
1145 Block *trueBlock = condBrOp.getTrueBlock();
1146 Block *falseBlock = condBrOp.getFalseBlock();
1148 rewriter.
create<LLVM::CondBrOp>(loc, condBrOp.getCondition(), trueBlock,
1149 condBrOp.getTrueTargetOperands(), falseBlock,
1150 condBrOp.getFalseTargetOperands());
1153 rewriter.
replaceOp(op, continueBlock->getArguments());
1162 template <
typename SPIRVOp,
typename LLVMOp>
1168 matchAndRewrite(SPIRVOp operation,
typename SPIRVOp::Adaptor adaptor,
1171 auto dstType = this->typeConverter.convertType(operation.getType());
1175 Type op1Type = operation.getOperand1().getType();
1176 Type op2Type = operation.getOperand2().getType();
1178 if (op1Type == op2Type) {
1179 rewriter.template replaceOpWithNewOp<LLVMOp>(operation, dstType,
1180 adaptor.getOperands());
1187 extended = rewriter.template create<LLVM::ZExtOp>(loc, dstType,
1188 adaptor.getOperand2());
1190 extended = rewriter.template create<LLVM::SExtOp>(loc, dstType,
1191 adaptor.getOperand2());
1193 Value result = rewriter.template create<LLVMOp>(
1194 loc, dstType, adaptor.getOperand1(), extended);
1205 matchAndRewrite(spirv::GLTanOp tanOp, OpAdaptor adaptor,
1207 auto dstType = typeConverter.convertType(tanOp.getType());
1212 Value sin = rewriter.
create<LLVM::SinOp>(loc, dstType, tanOp.getOperand());
1213 Value cos = rewriter.
create<LLVM::CosOp>(loc, dstType, tanOp.getOperand());
1230 matchAndRewrite(spirv::GLTanhOp tanhOp, OpAdaptor adaptor,
1232 auto srcType = tanhOp.getType();
1233 auto dstType = typeConverter.convertType(srcType);
1240 rewriter.
create<LLVM::FMulOp>(loc, dstType, two, tanhOp.getOperand());
1241 Value exponential = rewriter.
create<LLVM::ExpOp>(loc, dstType, multiplied);
1244 rewriter.
create<LLVM::FSubOp>(loc, dstType, exponential, one);
1246 rewriter.
create<LLVM::FAddOp>(loc, dstType, exponential, one);
1258 matchAndRewrite(spirv::VariableOp varOp, OpAdaptor adaptor,
1260 auto srcType = varOp.getType();
1263 auto init = varOp.getInitializer();
1264 if (init && !pointerTo.isIntOrFloat() && !pointerTo.isa<VectorType>())
1267 auto dstType = typeConverter.convertType(srcType);
1275 varOp, dstType, typeConverter.convertType(pointerTo), size);
1278 Value allocated = rewriter.
create<LLVM::AllocaOp>(
1279 loc, dstType, typeConverter.convertType(pointerTo), size);
1280 rewriter.
create<LLVM::StoreOp>(loc, adaptor.getInitializer(), allocated);
1290 class BitcastConversionPattern
1296 matchAndRewrite(spirv::BitcastOp bitcastOp, OpAdaptor adaptor,
1298 auto dstType = typeConverter.convertType(bitcastOp.getType());
1302 if (typeConverter.useOpaquePointers() &&
1303 dstType.isa<LLVM::LLVMPointerType>()) {
1304 rewriter.
replaceOp(bitcastOp, adaptor.getOperand());
1309 bitcastOp, dstType, adaptor.getOperands(), bitcastOp->getAttrs());
1323 matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor,
1328 auto funcType = funcOp.getFunctionType();
1330 funcType.getNumInputs());
1331 auto llvmType = typeConverter.convertFunctionSignature(
1332 funcType,
false, signatureConverter);
1338 StringRef name = funcOp.getName();
1339 auto newFuncOp = rewriter.
create<LLVM::LLVMFuncOp>(loc, name, llvmType);
1343 switch (funcOp.getFunctionControl()) {
1344 #define DISPATCH(functionControl, llvmAttr) \
1345 case functionControl: \
1346 newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr})); \
1349 DISPATCH(spirv::FunctionControl::Inline,
1350 StringAttr::get(context,
"alwaysinline"));
1351 DISPATCH(spirv::FunctionControl::DontInline,
1352 StringAttr::get(context,
"noinline"));
1353 DISPATCH(spirv::FunctionControl::Pure,
1354 StringAttr::get(context,
"readonly"));
1355 DISPATCH(spirv::FunctionControl::Const,
1356 StringAttr::get(context,
"readnone"));
1369 &signatureConverter))) {
1386 matchAndRewrite(spirv::ModuleOp spvModuleOp, OpAdaptor adaptor,
1390 rewriter.
create<ModuleOp>(spvModuleOp.getLoc(), spvModuleOp.getName());
1394 rewriter.
eraseBlock(&newModuleOp.getBodyRegion().back());
1395 rewriter.
eraseOp(spvModuleOp);
1404 class VectorShufflePattern
1409 matchAndRewrite(spirv::VectorShuffleOp op, OpAdaptor adaptor,
1412 auto components = adaptor.getComponents();
1413 auto vector1 = adaptor.getVector1();
1414 auto vector2 = adaptor.getVector2();
1416 int vector2Size = vector2.getType().cast<VectorType>().
getNumElements();
1417 if (vector1Size == vector2Size) {
1419 op, vector1, vector2,
1420 LLVM::convertArrayToIndices<int32_t>(components));
1424 auto dstType = typeConverter.convertType(op.getType());
1426 auto componentsArray = components.getValue();
1428 auto llvmI32Type = IntegerType::get(context, 32);
1429 Value targetOp = rewriter.
create<LLVM::UndefOp>(loc, dstType);
1430 for (
unsigned i = 0; i < componentsArray.size(); i++) {
1431 if (!componentsArray[i].isa<IntegerAttr>())
1432 return op.
emitError(
"unable to support non-constant component");
1434 int indexVal = componentsArray[i].cast<IntegerAttr>().getInt();
1439 Value baseVector = vector1;
1440 if (indexVal >= vector1Size) {
1441 offsetVal = vector1Size;
1442 baseVector = vector2;
1445 Value dstIndex = rewriter.
create<LLVM::ConstantOp>(
1451 auto extractOp = rewriter.
create<LLVM::ExtractElementOp>(
1452 loc, scalarType, baseVector, index);
1453 targetOp = rewriter.
create<LLVM::InsertElementOp>(loc, dstType, targetOp,
1454 extractOp, dstIndex);
1485 DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>,
1486 DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>,
1487 DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>,
1488 DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>,
1489 DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>,
1490 DirectConversionPattern<spirv::FMulOp, LLVM::FMulOp>,
1491 DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>,
1492 DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>,
1493 DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>,
1494 DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>,
1495 DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>,
1496 DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>,
1497 DirectConversionPattern<spirv::UModOp, LLVM::URemOp>,
1500 BitFieldInsertPattern, BitFieldUExtractPattern, BitFieldSExtractPattern,
1501 DirectConversionPattern<spirv::BitCountOp, LLVM::CtPopOp>,
1502 DirectConversionPattern<spirv::BitReverseOp, LLVM::BitReverseOp>,
1503 DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>,
1504 DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>,
1505 DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>,
1506 NotPattern<spirv::NotOp>,
1509 BitcastConversionPattern,
1510 DirectConversionPattern<spirv::ConvertFToSOp, LLVM::FPToSIOp>,
1511 DirectConversionPattern<spirv::ConvertFToUOp, LLVM::FPToUIOp>,
1512 DirectConversionPattern<spirv::ConvertSToFOp, LLVM::SIToFPOp>,
1513 DirectConversionPattern<spirv::ConvertUToFOp, LLVM::UIToFPOp>,
1514 IndirectCastPattern<spirv::FConvertOp, LLVM::FPExtOp, LLVM::FPTruncOp>,
1515 IndirectCastPattern<spirv::SConvertOp, LLVM::SExtOp, LLVM::TruncOp>,
1516 IndirectCastPattern<spirv::UConvertOp, LLVM::ZExtOp, LLVM::TruncOp>,
1519 IComparePattern<spirv::IEqualOp, LLVM::ICmpPredicate::eq>,
1520 IComparePattern<spirv::INotEqualOp, LLVM::ICmpPredicate::ne>,
1521 FComparePattern<spirv::FOrdEqualOp, LLVM::FCmpPredicate::oeq>,
1522 FComparePattern<spirv::FOrdGreaterThanOp, LLVM::FCmpPredicate::ogt>,
1523 FComparePattern<spirv::FOrdGreaterThanEqualOp, LLVM::FCmpPredicate::oge>,
1524 FComparePattern<spirv::FOrdLessThanEqualOp, LLVM::FCmpPredicate::ole>,
1525 FComparePattern<spirv::FOrdLessThanOp, LLVM::FCmpPredicate::olt>,
1526 FComparePattern<spirv::FOrdNotEqualOp, LLVM::FCmpPredicate::one>,
1527 FComparePattern<spirv::FUnordEqualOp, LLVM::FCmpPredicate::ueq>,
1528 FComparePattern<spirv::FUnordGreaterThanOp, LLVM::FCmpPredicate::ugt>,
1529 FComparePattern<spirv::FUnordGreaterThanEqualOp,
1530 LLVM::FCmpPredicate::uge>,
1531 FComparePattern<spirv::FUnordLessThanEqualOp, LLVM::FCmpPredicate::ule>,
1532 FComparePattern<spirv::FUnordLessThanOp, LLVM::FCmpPredicate::ult>,
1533 FComparePattern<spirv::FUnordNotEqualOp, LLVM::FCmpPredicate::une>,
1534 IComparePattern<spirv::SGreaterThanOp, LLVM::ICmpPredicate::sgt>,
1535 IComparePattern<spirv::SGreaterThanEqualOp, LLVM::ICmpPredicate::sge>,
1536 IComparePattern<spirv::SLessThanEqualOp, LLVM::ICmpPredicate::sle>,
1537 IComparePattern<spirv::SLessThanOp, LLVM::ICmpPredicate::slt>,
1538 IComparePattern<spirv::UGreaterThanOp, LLVM::ICmpPredicate::ugt>,
1539 IComparePattern<spirv::UGreaterThanEqualOp, LLVM::ICmpPredicate::uge>,
1540 IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>,
1541 IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>,
1544 ConstantScalarAndVectorPattern,
1547 BranchConversionPattern, BranchConditionalConversionPattern,
1548 FunctionCallPattern, LoopPattern, SelectionPattern,
1549 ErasePattern<spirv::MergeOp>,
1552 ErasePattern<spirv::EntryPointOp>, ExecutionModePattern,
1555 DirectConversionPattern<spirv::GLCeilOp, LLVM::FCeilOp>,
1556 DirectConversionPattern<spirv::GLCosOp, LLVM::CosOp>,
1557 DirectConversionPattern<spirv::GLExpOp, LLVM::ExpOp>,
1558 DirectConversionPattern<spirv::GLFAbsOp, LLVM::FAbsOp>,
1559 DirectConversionPattern<spirv::GLFloorOp, LLVM::FFloorOp>,
1560 DirectConversionPattern<spirv::GLFMaxOp, LLVM::MaxNumOp>,
1561 DirectConversionPattern<spirv::GLFMinOp, LLVM::MinNumOp>,
1562 DirectConversionPattern<spirv::GLLogOp, LLVM::LogOp>,
1563 DirectConversionPattern<spirv::GLSinOp, LLVM::SinOp>,
1564 DirectConversionPattern<spirv::GLSMaxOp, LLVM::SMaxOp>,
1565 DirectConversionPattern<spirv::GLSMinOp, LLVM::SMinOp>,
1566 DirectConversionPattern<spirv::GLSqrtOp, LLVM::SqrtOp>,
1567 InverseSqrtPattern, TanPattern, TanhPattern,
1570 DirectConversionPattern<spirv::LogicalAndOp, LLVM::AndOp>,
1571 DirectConversionPattern<spirv::LogicalOrOp, LLVM::OrOp>,
1572 IComparePattern<spirv::LogicalEqualOp, LLVM::ICmpPredicate::eq>,
1573 IComparePattern<spirv::LogicalNotEqualOp, LLVM::ICmpPredicate::ne>,
1574 NotPattern<spirv::LogicalNotOp>,
1577 AccessChainPattern, AddressOfPattern, GlobalVariablePattern,
1578 LoadStorePattern<spirv::LoadOp>, LoadStorePattern<spirv::StoreOp>,
1582 CompositeExtractPattern, CompositeInsertPattern,
1583 DirectConversionPattern<spirv::SelectOp, LLVM::SelectOp>,
1584 DirectConversionPattern<spirv::UndefOp, LLVM::UndefOp>,
1585 VectorShufflePattern,
1588 ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>,
1589 ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>,
1590 ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>,
1593 ReturnPattern, ReturnValuePattern>(patterns.
getContext(), typeConverter);
1598 patterns.
add<FuncConversionPattern>(patterns.
getContext(), typeConverter);
1603 patterns.
add<ModuleConversionPattern>(patterns.
getContext(), typeConverter);
1614 auto spvModules = module.getOps<spirv::ModuleOp>();
1615 for (
auto spvModule : spvModules) {
1616 spvModule.walk([&](spirv::GlobalVariableOp op) {
1617 IntegerAttr descriptorSet =
1619 IntegerAttr binding = op->getAttrOfType<IntegerAttr>(
kBinding);
1622 if (descriptorSet && binding) {
1625 auto moduleAndName =
1626 spvModule.getName().has_value()
1627 ? spvModule.getName()->str() +
"_" + op.getSymName().str()
1628 : op.getSymName().str();
1630 llvm::formatv(
"{0}_descriptor_set{1}_binding{2}", moduleAndName,
1631 std::to_string(descriptorSet.getInt()),
1632 std::to_string(binding.getInt()));
1633 auto nameAttr = StringAttr::get(op->getContext(), name);
1638 op.emitError(
"unable to replace all symbol uses for ") << name;
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static Value optionallyTruncateOrExtend(Location loc, Value value, Type llvmType, PatternRewriter &rewriter)
Utility function for bitfield ops:
static Value createFPConstant(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter, double value)
Creates llvm.mlir.constant with a floating-point scalar or vector value.
static constexpr StringRef kDescriptorSet
static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter, unsigned value)
Creates LLVM dialect constant with the given value.
static Value processCountOrOffset(Location loc, Value value, Type srcType, Type dstType, LLVMTypeConverter &converter, ConversionPatternRewriter &rewriter)
Utility function for bitfield ops: BitFieldInsert, BitFieldSExtract and BitFieldUExtract.
static unsigned getBitWidth(Type type)
Returns the bit width of integer, float or vector of float or integer values.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static Type convertPointerType(spirv::PointerType type, LLVMTypeConverter &converter)
Converts SPIR-V pointer type to LLVM pointer.
static std::optional< Type > convertRuntimeArrayType(spirv::RuntimeArrayType type, TypeConverter &converter)
Converts SPIR-V runtime array to LLVM array.
static bool isSignedIntegerOrVector(Type type)
Returns true if the given type is a signed integer or vector type.
static bool isUnsignedIntegerOrVector(Type type)
Returns true if the given type is an unsigned integer or vector type.
static Type convertStructTypePacked(spirv::StructType type, LLVMTypeConverter &converter)
Converts SPIR-V struct with no offset to packed LLVM struct.
static std::optional< Type > convertArrayType(spirv::ArrayType type, TypeConverter &converter)
Converts SPIR-V array type to LLVM array.
static std::optional< Type > convertStructType(spirv::StructType type, LLVMTypeConverter &converter)
Converts SPIR-V struct to LLVM struct.
static constexpr StringRef kBinding
Hook for descriptor set and binding number encoding.
static std::optional< Type > convertStructTypeWithOffset(spirv::StructType type, LLVMTypeConverter &converter)
Converts SPIR-V struct with a regular (according to VulkanLayoutUtils) offset to LLVM struct.
static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder)
Creates IntegerAttribute with all bits set for given type.
static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter, LLVMTypeConverter &typeConverter, unsigned alignment, bool isVolatile, bool isNonTemporal)
Utility for spirv.Load and spirv.Store conversion.
static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter)
Creates llvm.mlir.constant with all bits set for the given type.
static unsigned getLLVMTypeBitWidth(Type type)
Returns the bit width of LLVMType integer or vector.
static Value optionallyBroadcast(Location loc, Value value, Type srcType, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value. If srcType is a scalar, the value remains unchanged.
#define DISPATCH(functionControl, llvmAttr)
static int64_t getNumElements(ShapedType type)
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
OpListType::iterator iterator
Operation * getTerminator()
Get the terminator operation of this block.
OpListType & getOperations()
BlockArgListType getArguments()
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before) override
PatternRewriter hook for moving blocks out of a region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
FailureOr< Block * > convertRegionTypes(Region *region, TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Convert the types of block arguments within the given region.
Block * splitBlock(Block *block, Block::iterator before) override
PatternRewriter hook for splitting a block into two parts.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
Conversion from types to the LLVM IR dialect.
LLVM::LLVMPointerType getPointerType(Type elementType, unsigned addressSpace=0)
Creates an LLVM pointer type with the given element type and address space.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results)
Convert the given type.
static LLVMStructType getLiteral(MLIRContext *context, ArrayRef< Type > types, bool isPacked=false)
Gets or creates a literal struct with the given body in the provided context.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Block * getBlock() const
Returns the current block of the builder.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Operation is the basic unit of execution within MLIR.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_range getOperands()
Returns an iterator on the underlying Value's.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, StringAttr newSymbol, Operation *from)
Attempt to replace all uses of the given symbol 'oldSymbol' with the provided symbol 'newSymbol' that...
static void setSymbolName(Operation *symbol, StringAttr name)
Sets the name of the given symbol operation.
This class provides all of the information necessary to convert a type signature.
void addConversion(FnT &&callback)
Register a conversion function.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results)
Convert the given type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static spirv::StructType decorateType(spirv::StructType structType)
Returns a new StructType with layout decoration.
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
unsigned getNumElements() const
Type getPointeeType() const
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
ElementTypeRange getElementTypes() const
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect.
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
SmallVector< IntT > convertArrayToIndices(ArrayRef< Attribute > attrs)
Convert an array of integer attributes to a vector of integers that can be used as indices in LLVM op...
Type getVectorElementType(Type type)
Returns the element type of any vector type compatible with the LLVM dialect.
This header declares functions that assit transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
void populateSPIRVToLLVMFunctionConversionPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given list with patterns for function conversion from SPIR-V to LLVM.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
void populateSPIRVToLLVMModuleConversionPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given patterns for module conversion from SPIR-V to LLVM.
void populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter)
Populates type conversions with additional SPIR-V types.
void encodeBindAttribute(ModuleOp module)
Encodes global variable's descriptor set and binding into its name if they both exist.
void populateSPIRVToLLVMConversionPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given list with patterns that convert from SPIR-V to LLVM.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.