25 #include "llvm/Support/Debug.h"
26 #include "llvm/Support/FormatVariadic.h"
28 #define DEBUG_TYPE "spirv-to-llvm-pattern"
46 if (
auto vecType = dyn_cast<VectorType>(type))
47 return vecType.getElementType().isSignedInteger();
55 if (
auto vecType = dyn_cast<VectorType>(type))
56 return vecType.getElementType().isUnsignedInteger();
63 if (
auto intType = dyn_cast<IntegerType>(type))
64 return intType.getWidth();
65 if (
auto vecType = dyn_cast<VectorType>(type))
66 if (
auto intType = dyn_cast<IntegerType>(vecType.getElementType()))
67 return intType.getWidth();
74 "bitwidth is not supported for this type");
77 auto vecType = dyn_cast<VectorType>(type);
78 auto elementType = vecType.getElementType();
79 assert(elementType.isIntOrFloat() &&
80 "only integers and floats have a bitwidth");
81 return elementType.getIntOrFloatBitWidth();
94 if (
auto vecType = dyn_cast<VectorType>(type)) {
95 auto integerType = cast<IntegerType>(vecType.getElementType());
98 auto integerType = cast<IntegerType>(type);
105 if (isa<VectorType>(srcType)) {
106 return rewriter.
create<LLVM::ConstantOp>(
111 return rewriter.
create<LLVM::ConstantOp>(
118 if (
auto vecType = dyn_cast<VectorType>(srcType)) {
119 auto floatType = cast<FloatType>(vecType.getElementType());
120 return rewriter.
create<LLVM::ConstantOp>(
125 auto floatType = cast<FloatType>(srcType);
126 return rewriter.
create<LLVM::ConstantOp>(
139 auto srcType = value.
getType();
145 if (valueBitWidth < targetBitWidth)
146 return rewriter.
create<LLVM::ZExtOp>(loc, llvmType, value);
151 if (valueBitWidth > targetBitWidth)
152 return rewriter.
create<LLVM::TruncOp>(loc, llvmType, value);
161 auto llvmVectorType = typeConverter.
convertType(vectorType);
163 Value broadcasted = rewriter.
create<LLVM::UndefOp>(loc, llvmVectorType);
164 for (
unsigned i = 0; i < numElements; ++i) {
165 auto index = rewriter.
create<LLVM::ConstantOp>(
167 broadcasted = rewriter.
create<LLVM::InsertElementOp>(
168 loc, llvmVectorType, broadcasted, toBroadcast, index);
177 if (
auto vectorType = dyn_cast<VectorType>(srcType)) {
178 unsigned numElements = vectorType.getNumElements();
179 return broadcast(loc, value, numElements, typeConverter, rewriter);
229 return rewriter.
create<LLVM::ConstantOp>(
238 unsigned alignment,
bool isVolatile,
239 bool isNonTemporal) {
240 if (
auto loadOp = dyn_cast<spirv::LoadOp>(op)) {
241 auto dstType = typeConverter.
convertType(loadOp.getType());
245 loadOp, dstType, spirv::LoadOpAdaptor(operands).getPtr(), alignment,
246 isVolatile, isNonTemporal);
249 auto storeOp = cast<spirv::StoreOp>(op);
250 spirv::StoreOpAdaptor adaptor(operands);
252 adaptor.getPtr(), alignment,
253 isVolatile, isNonTemporal);
268 auto sizeInBytes = cast<spirv::SPIRVType>(elementType).getSizeInBytes();
269 if (stride != 0 && (!sizeInBytes || *sizeInBytes != stride))
272 auto llvmElementType = converter.
convertType(elementType);
281 switch (storageClass) {
282 #define STORAGE_SPACE_MAP(storage, space) \
283 case spirv::StorageClass::storage: \
293 #undef STORAGE_SPACE_MAP
300 spirv::StorageClass storageClass) {
302 #define CLIENT_MAP(client, storage) \
303 case spirv::ClientAPI::client: \
304 return mapTo##client##AddressSpace(storage);
316 spirv::ClientAPI clientAPI) {
338 if (!memberDecorations.empty())
356 matchAndRewrite(spirv::AccessChainOp op, OpAdaptor adaptor,
358 auto dstType = typeConverter.convertType(op.getComponentPtr().getType());
362 auto indices = llvm::to_vector<4>(adaptor.getIndices());
363 Type indexType = op.getIndices().front().getType();
364 auto llvmIndexType = typeConverter.convertType(indexType);
369 indices.insert(indices.begin(), zero);
372 typeConverter.convertType(
373 cast<spirv::PointerType>(op.getBasePtr().getType())
375 adaptor.getBasePtr(), indices);
385 matchAndRewrite(spirv::AddressOfOp op, OpAdaptor adaptor,
387 auto dstType = typeConverter.convertType(op.getPointer().getType());
396 class BitFieldInsertPattern
402 matchAndRewrite(spirv::BitFieldInsertOp op, OpAdaptor adaptor,
404 auto srcType = op.getType();
405 auto dstType = typeConverter.convertType(srcType);
412 typeConverter, rewriter);
414 typeConverter, rewriter);
418 Value maskShiftedByCount =
419 rewriter.
create<LLVM::ShlOp>(loc, dstType, minusOne, count);
420 Value negated = rewriter.
create<LLVM::XOrOp>(loc, dstType,
421 maskShiftedByCount, minusOne);
422 Value maskShiftedByCountAndOffset =
423 rewriter.
create<LLVM::ShlOp>(loc, dstType, negated, offset);
425 loc, dstType, maskShiftedByCountAndOffset, minusOne);
430 rewriter.
create<LLVM::AndOp>(loc, dstType, op.getBase(), mask);
431 Value insertShiftedByOffset =
432 rewriter.
create<LLVM::ShlOp>(loc, dstType, op.getInsert(), offset);
434 insertShiftedByOffset);
440 class ConstantScalarAndVectorPattern
446 matchAndRewrite(spirv::ConstantOp constOp, OpAdaptor adaptor,
448 auto srcType = constOp.getType();
449 if (!isa<VectorType>(srcType) && !srcType.isIntOrFloat())
452 auto dstType = typeConverter.convertType(srcType);
465 if (isa<VectorType>(srcType)) {
466 auto dstElementsAttr = cast<DenseIntElementsAttr>(constOp.getValue());
469 dstElementsAttr.mapValues(
470 signlessType, [&](
const APInt &value) {
return value; }));
473 auto srcAttr = cast<IntegerAttr>(constOp.getValue());
474 auto dstAttr = rewriter.
getIntegerAttr(signlessType, srcAttr.getValue());
479 constOp, dstType, adaptor.getOperands(), constOp->getAttrs());
484 class BitFieldSExtractPattern
490 matchAndRewrite(spirv::BitFieldSExtractOp op, OpAdaptor adaptor,
492 auto srcType = op.getType();
493 auto dstType = typeConverter.convertType(srcType);
500 typeConverter, rewriter);
502 typeConverter, rewriter);
505 IntegerType integerType;
506 if (
auto vecType = dyn_cast<VectorType>(srcType))
507 integerType = cast<IntegerType>(vecType.getElementType());
509 integerType = cast<IntegerType>(srcType);
513 isa<VectorType>(srcType)
514 ? rewriter.
create<LLVM::ConstantOp>(
517 : rewriter.
create<LLVM::ConstantOp>(loc, dstType, baseSize);
521 Value countPlusOffset =
522 rewriter.
create<LLVM::AddOp>(loc, dstType, count, offset);
523 Value amountToShiftLeft =
524 rewriter.
create<LLVM::SubOp>(loc, dstType, size, countPlusOffset);
525 Value baseShiftedLeft = rewriter.
create<LLVM::ShlOp>(
526 loc, dstType, op.getBase(), amountToShiftLeft);
529 Value amountToShiftRight =
530 rewriter.
create<LLVM::AddOp>(loc, dstType, offset, amountToShiftLeft);
537 class BitFieldUExtractPattern
543 matchAndRewrite(spirv::BitFieldUExtractOp op, OpAdaptor adaptor,
545 auto srcType = op.getType();
546 auto dstType = typeConverter.convertType(srcType);
553 typeConverter, rewriter);
555 typeConverter, rewriter);
559 Value maskShiftedByCount =
560 rewriter.
create<LLVM::ShlOp>(loc, dstType, minusOne, count);
561 Value mask = rewriter.
create<LLVM::XOrOp>(loc, dstType, maskShiftedByCount,
566 rewriter.
create<LLVM::LShrOp>(loc, dstType, op.getBase(), offset);
577 matchAndRewrite(spirv::BranchOp branchOp, OpAdaptor adaptor,
580 branchOp.getTarget());
585 class BranchConditionalConversionPattern
592 matchAndRewrite(spirv::BranchConditionalOp op, OpAdaptor adaptor,
596 if (
auto weights = op.getBranchWeights()) {
598 for (
auto weight : weights->getAsRange<IntegerAttr>())
599 weightValues.push_back(weight.getInt());
604 op, op.getCondition(), op.getTrueBlockArguments(),
605 op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(),
614 class CompositeExtractPattern
620 matchAndRewrite(spirv::CompositeExtractOp op, OpAdaptor adaptor,
622 auto dstType = this->typeConverter.convertType(op.getType());
626 Type containerType = op.getComposite().getType();
627 if (isa<VectorType>(containerType)) {
629 IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]);
632 op, dstType, adaptor.getComposite(), index);
637 op, adaptor.getComposite(),
646 class CompositeInsertPattern
652 matchAndRewrite(spirv::CompositeInsertOp op, OpAdaptor adaptor,
654 auto dstType = this->typeConverter.convertType(op.getType());
658 Type containerType = op.getComposite().getType();
659 if (isa<VectorType>(containerType)) {
661 IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]);
664 op, dstType, adaptor.getComposite(), adaptor.getObject(), index);
669 op, adaptor.getComposite(), adaptor.getObject(),
677 template <
typename SPIRVOp,
typename LLVMOp>
683 matchAndRewrite(SPIRVOp operation,
typename SPIRVOp::Adaptor adaptor,
685 auto dstType = this->typeConverter.convertType(operation.getType());
688 rewriter.template replaceOpWithNewOp<LLVMOp>(
689 operation, dstType, adaptor.getOperands(), operation->getAttrs());
696 class ExecutionModePattern
702 matchAndRewrite(spirv::ExecutionModeOp op, OpAdaptor adaptor,
708 spirv::ExecutionModeAttr executionModeAttr = op.getExecutionModeAttr();
709 std::string moduleName;
710 if (module.getName().has_value())
711 moduleName =
"_" + module.
getName()->str();
714 std::string executionModeInfoName = llvm::formatv(
715 "__spv_{0}_{1}_execution_mode_info_{2}", moduleName, op.getFn().str(),
716 static_cast<uint32_t
>(executionModeAttr.getValue()));
729 fields.push_back(llvmI32Type);
730 ArrayAttr values = op.getValues();
731 if (!values.empty()) {
733 fields.push_back(arrayType);
738 auto global = rewriter.
create<LLVM::GlobalOp>(
740 LLVM::Linkage::External, executionModeInfoName,
Attribute(),
743 Region ®ion = global.getInitializerRegion();
748 Value structValue = rewriter.
create<LLVM::UndefOp>(loc, structType);
749 Value executionMode = rewriter.
create<LLVM::ConstantOp>(
752 static_cast<uint32_t
>(executionModeAttr.getValue())));
753 structValue = rewriter.
create<LLVM::InsertValueOp>(loc, structValue,
757 for (
unsigned i = 0, e = values.size(); i < e; ++i) {
758 auto attr = values.getValue()[i];
759 Value entry = rewriter.
create<LLVM::ConstantOp>(loc, llvmI32Type, attr);
760 structValue = rewriter.
create<LLVM::InsertValueOp>(
773 class GlobalVariablePattern
776 template <
typename... Args>
777 GlobalVariablePattern(spirv::ClientAPI clientAPI, Args &&...args)
779 std::forward<Args>(args)...),
780 clientAPI(clientAPI) {}
783 matchAndRewrite(spirv::GlobalVariableOp op, OpAdaptor adaptor,
787 if (op.getInitializer())
790 auto srcType = cast<spirv::PointerType>(op.getType());
791 auto dstType = typeConverter.convertType(srcType.getPointeeType());
798 auto storageClass = srcType.getStorageClass();
799 switch (storageClass) {
800 case spirv::StorageClass::Input:
801 case spirv::StorageClass::Private:
802 case spirv::StorageClass::Output:
803 case spirv::StorageClass::StorageBuffer:
804 case spirv::StorageClass::UniformConstant:
813 bool isConstant = (storageClass == spirv::StorageClass::Input) ||
814 (storageClass == spirv::StorageClass::UniformConstant);
820 auto linkage = storageClass == spirv::StorageClass::Private
821 ? LLVM::Linkage::Private
822 : LLVM::Linkage::External;
824 op, dstType, isConstant, linkage, op.getSymName(),
Attribute(),
828 if (op.getLocationAttr())
829 newGlobalOp->setAttr(op.getLocationAttrName(), op.getLocationAttr());
835 spirv::ClientAPI clientAPI;
840 template <
typename SPIRVOp,
typename LLVMExtOp,
typename LLVMTruncOp>
846 matchAndRewrite(SPIRVOp operation,
typename SPIRVOp::Adaptor adaptor,
849 Type fromType = operation.getOperand().getType();
850 Type toType = operation.getType();
852 auto dstType = this->typeConverter.convertType(toType);
857 rewriter.template replaceOpWithNewOp<LLVMExtOp>(operation, dstType,
858 adaptor.getOperands());
862 rewriter.template replaceOpWithNewOp<LLVMTruncOp>(operation, dstType,
863 adaptor.getOperands());
870 class FunctionCallPattern
876 matchAndRewrite(spirv::FunctionCallOp callOp, OpAdaptor adaptor,
878 if (callOp.getNumResults() == 0) {
880 callOp, std::nullopt, adaptor.getOperands(), callOp->getAttrs());
885 auto dstType = typeConverter.convertType(callOp.getType(0));
887 callOp, dstType, adaptor.getOperands(), callOp->getAttrs());
893 template <
typename SPIRVOp, LLVM::FCmpPredicate predicate>
899 matchAndRewrite(SPIRVOp operation,
typename SPIRVOp::Adaptor adaptor,
902 auto dstType = this->typeConverter.convertType(operation.getType());
906 rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>(
907 operation, dstType, predicate, operation.getOperand1(),
908 operation.getOperand2());
914 template <
typename SPIRVOp, LLVM::ICmpPredicate predicate>
920 matchAndRewrite(SPIRVOp operation,
typename SPIRVOp::Adaptor adaptor,
923 auto dstType = this->typeConverter.convertType(operation.getType());
927 rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>(
928 operation, dstType, predicate, operation.getOperand1(),
929 operation.getOperand2());
934 class InverseSqrtPattern
940 matchAndRewrite(spirv::GLInverseSqrtOp op, OpAdaptor adaptor,
942 auto srcType = op.getType();
943 auto dstType = typeConverter.convertType(srcType);
956 template <
typename SPIRVOp>
962 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
964 if (!op.getMemoryAccess()) {
966 this->typeConverter, 0,
970 auto memoryAccess = *op.getMemoryAccess();
971 switch (memoryAccess) {
972 case spirv::MemoryAccess::Aligned:
974 case spirv::MemoryAccess::Nontemporal:
975 case spirv::MemoryAccess::Volatile: {
977 memoryAccess == spirv::MemoryAccess::Aligned ? *op.getAlignment() : 0;
978 bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal;
979 bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile;
981 this->typeConverter, alignment, isVolatile,
992 template <
typename SPIRVOp>
998 matchAndRewrite(SPIRVOp notOp,
typename SPIRVOp::Adaptor adaptor,
1000 auto srcType = notOp.getType();
1001 auto dstType = this->typeConverter.convertType(srcType);
1008 isa<VectorType>(srcType)
1009 ? rewriter.
create<LLVM::ConstantOp>(
1012 : rewriter.
create<LLVM::ConstantOp>(loc, dstType, minusOne);
1013 rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType,
1014 notOp.getOperand(), mask);
1020 template <
typename SPIRVOp>
1026 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
1038 matchAndRewrite(spirv::ReturnOp returnOp, OpAdaptor adaptor,
1051 matchAndRewrite(spirv::ReturnValueOp returnValueOp, OpAdaptor adaptor,
1054 adaptor.getOperands());
1112 matchAndRewrite(spirv::LoopOp loopOp, OpAdaptor adaptor,
1128 Block *entryBlock = loopOp.getEntryBlock();
1130 auto brOp = dyn_cast<spirv::BranchOp>(entryBlock->
getOperations().front());
1133 Block *headerBlock = loopOp.getHeaderBlock();
1135 rewriter.
create<LLVM::BrOp>(loc, brOp.getBlockArguments(), headerBlock);
1139 Block *mergeBlock = loopOp.getMergeBlock();
1143 rewriter.
create<LLVM::BrOp>(loc, terminatorOperands, endBlock);
1159 matchAndRewrite(spirv::SelectionOp op, OpAdaptor adaptor,
1171 if (op.getBody().getBlocks().size() <= 2) {
1183 auto *continueBlock = rewriter.
splitBlock(currentBlock, position);
1189 auto *headerBlock = op.getHeaderBlock();
1191 auto condBrOp = dyn_cast<spirv::BranchConditionalOp>(
1198 auto *mergeBlock = op.getMergeBlock();
1202 rewriter.
create<LLVM::BrOp>(loc, terminatorOperands, continueBlock);
1205 Block *trueBlock = condBrOp.getTrueBlock();
1206 Block *falseBlock = condBrOp.getFalseBlock();
1208 rewriter.
create<LLVM::CondBrOp>(loc, condBrOp.getCondition(), trueBlock,
1209 condBrOp.getTrueTargetOperands(),
1211 condBrOp.getFalseTargetOperands());
1214 rewriter.
replaceOp(op, continueBlock->getArguments());
1223 template <
typename SPIRVOp,
typename LLVMOp>
1229 matchAndRewrite(SPIRVOp operation,
typename SPIRVOp::Adaptor adaptor,
1232 auto dstType = this->typeConverter.convertType(operation.getType());
1236 Type op1Type = operation.getOperand1().getType();
1237 Type op2Type = operation.getOperand2().getType();
1239 if (op1Type == op2Type) {
1240 rewriter.template replaceOpWithNewOp<LLVMOp>(operation, dstType,
1241 adaptor.getOperands());
1245 std::optional<uint64_t> dstTypeWidth =
1247 std::optional<uint64_t> op2TypeWidth =
1250 if (!dstTypeWidth || !op2TypeWidth)
1255 if (op2TypeWidth < dstTypeWidth) {
1257 extended = rewriter.template create<LLVM::ZExtOp>(
1258 loc, dstType, adaptor.getOperand2());
1260 extended = rewriter.template create<LLVM::SExtOp>(
1261 loc, dstType, adaptor.getOperand2());
1263 }
else if (op2TypeWidth == dstTypeWidth) {
1264 extended = adaptor.getOperand2();
1269 Value result = rewriter.template create<LLVMOp>(
1270 loc, dstType, adaptor.getOperand1(), extended);
1281 matchAndRewrite(spirv::GLTanOp tanOp, OpAdaptor adaptor,
1283 auto dstType = typeConverter.convertType(tanOp.getType());
1288 Value sin = rewriter.
create<LLVM::SinOp>(loc, dstType, tanOp.getOperand());
1289 Value cos = rewriter.
create<LLVM::CosOp>(loc, dstType, tanOp.getOperand());
1306 matchAndRewrite(spirv::GLTanhOp tanhOp, OpAdaptor adaptor,
1308 auto srcType = tanhOp.getType();
1309 auto dstType = typeConverter.convertType(srcType);
1316 rewriter.
create<LLVM::FMulOp>(loc, dstType, two, tanhOp.getOperand());
1317 Value exponential = rewriter.
create<LLVM::ExpOp>(loc, dstType, multiplied);
1320 rewriter.
create<LLVM::FSubOp>(loc, dstType, exponential, one);
1322 rewriter.
create<LLVM::FAddOp>(loc, dstType, exponential, one);
1334 matchAndRewrite(spirv::VariableOp varOp, OpAdaptor adaptor,
1336 auto srcType = varOp.getType();
1338 auto pointerTo = cast<spirv::PointerType>(srcType).getPointeeType();
1339 auto init = varOp.getInitializer();
1340 if (init && !pointerTo.isIntOrFloat() && !isa<VectorType>(pointerTo))
1343 auto dstType = typeConverter.convertType(srcType);
1351 varOp, dstType, typeConverter.convertType(pointerTo), size);
1354 Value allocated = rewriter.
create<LLVM::AllocaOp>(
1355 loc, dstType, typeConverter.convertType(pointerTo), size);
1356 rewriter.
create<LLVM::StoreOp>(loc, adaptor.getInitializer(), allocated);
1366 class BitcastConversionPattern
1372 matchAndRewrite(spirv::BitcastOp bitcastOp, OpAdaptor adaptor,
1374 auto dstType = typeConverter.convertType(bitcastOp.getType());
1379 if (isa<LLVM::LLVMPointerType>(dstType)) {
1380 rewriter.
replaceOp(bitcastOp, adaptor.getOperand());
1385 bitcastOp, dstType, adaptor.getOperands(), bitcastOp->getAttrs());
1399 matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor,
1404 auto funcType = funcOp.getFunctionType();
1406 funcType.getNumInputs());
1407 auto llvmType = typeConverter.convertFunctionSignature(
1408 funcType,
false,
false,
1409 signatureConverter);
1415 StringRef name = funcOp.getName();
1416 auto newFuncOp = rewriter.
create<LLVM::LLVMFuncOp>(loc, name, llvmType);
1420 switch (funcOp.getFunctionControl()) {
1421 #define DISPATCH(functionControl, llvmAttr) \
1422 case functionControl: \
1423 newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr})); \
1426 DISPATCH(spirv::FunctionControl::Inline,
1428 DISPATCH(spirv::FunctionControl::DontInline,
1430 DISPATCH(spirv::FunctionControl::Pure,
1432 DISPATCH(spirv::FunctionControl::Const,
1446 &signatureConverter))) {
1463 matchAndRewrite(spirv::ModuleOp spvModuleOp, OpAdaptor adaptor,
1467 rewriter.
create<ModuleOp>(spvModuleOp.getLoc(), spvModuleOp.getName());
1471 rewriter.
eraseBlock(&newModuleOp.getBodyRegion().back());
1472 rewriter.
eraseOp(spvModuleOp);
1481 class VectorShufflePattern
1486 matchAndRewrite(spirv::VectorShuffleOp op, OpAdaptor adaptor,
1489 auto components = adaptor.getComponents();
1490 auto vector1 = adaptor.getVector1();
1491 auto vector2 = adaptor.getVector2();
1492 int vector1Size = cast<VectorType>(vector1.getType()).getNumElements();
1493 int vector2Size = cast<VectorType>(vector2.getType()).getNumElements();
1494 if (vector1Size == vector2Size) {
1496 op, vector1, vector2,
1497 LLVM::convertArrayToIndices<int32_t>(components));
1501 auto dstType = typeConverter.convertType(op.getType());
1502 auto scalarType = cast<VectorType>(dstType).getElementType();
1503 auto componentsArray = components.getValue();
1506 Value targetOp = rewriter.
create<LLVM::UndefOp>(loc, dstType);
1507 for (
unsigned i = 0; i < componentsArray.size(); i++) {
1508 if (!isa<IntegerAttr>(componentsArray[i]))
1509 return op.
emitError(
"unable to support non-constant component");
1511 int indexVal = cast<IntegerAttr>(componentsArray[i]).getInt();
1516 Value baseVector = vector1;
1517 if (indexVal >= vector1Size) {
1518 offsetVal = vector1Size;
1519 baseVector = vector2;
1522 Value dstIndex = rewriter.
create<LLVM::ConstantOp>(
1528 auto extractOp = rewriter.
create<LLVM::ExtractElementOp>(
1529 loc, scalarType, baseVector, index);
1530 targetOp = rewriter.
create<LLVM::InsertElementOp>(loc, dstType, targetOp,
1531 extractOp, dstIndex);
1544 spirv::ClientAPI clientAPI) {
1561 spirv::ClientAPI clientAPI) {
1564 DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>,
1565 DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>,
1566 DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>,
1567 DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>,
1568 DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>,
1569 DirectConversionPattern<spirv::FMulOp, LLVM::FMulOp>,
1570 DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>,
1571 DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>,
1572 DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>,
1573 DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>,
1574 DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>,
1575 DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>,
1576 DirectConversionPattern<spirv::UModOp, LLVM::URemOp>,
1579 BitFieldInsertPattern, BitFieldUExtractPattern, BitFieldSExtractPattern,
1580 DirectConversionPattern<spirv::BitCountOp, LLVM::CtPopOp>,
1581 DirectConversionPattern<spirv::BitReverseOp, LLVM::BitReverseOp>,
1582 DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>,
1583 DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>,
1584 DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>,
1585 NotPattern<spirv::NotOp>,
1588 BitcastConversionPattern,
1589 DirectConversionPattern<spirv::ConvertFToSOp, LLVM::FPToSIOp>,
1590 DirectConversionPattern<spirv::ConvertFToUOp, LLVM::FPToUIOp>,
1591 DirectConversionPattern<spirv::ConvertSToFOp, LLVM::SIToFPOp>,
1592 DirectConversionPattern<spirv::ConvertUToFOp, LLVM::UIToFPOp>,
1593 IndirectCastPattern<spirv::FConvertOp, LLVM::FPExtOp, LLVM::FPTruncOp>,
1594 IndirectCastPattern<spirv::SConvertOp, LLVM::SExtOp, LLVM::TruncOp>,
1595 IndirectCastPattern<spirv::UConvertOp, LLVM::ZExtOp, LLVM::TruncOp>,
1598 IComparePattern<spirv::IEqualOp, LLVM::ICmpPredicate::eq>,
1599 IComparePattern<spirv::INotEqualOp, LLVM::ICmpPredicate::ne>,
1600 FComparePattern<spirv::FOrdEqualOp, LLVM::FCmpPredicate::oeq>,
1601 FComparePattern<spirv::FOrdGreaterThanOp, LLVM::FCmpPredicate::ogt>,
1602 FComparePattern<spirv::FOrdGreaterThanEqualOp, LLVM::FCmpPredicate::oge>,
1603 FComparePattern<spirv::FOrdLessThanEqualOp, LLVM::FCmpPredicate::ole>,
1604 FComparePattern<spirv::FOrdLessThanOp, LLVM::FCmpPredicate::olt>,
1605 FComparePattern<spirv::FOrdNotEqualOp, LLVM::FCmpPredicate::one>,
1606 FComparePattern<spirv::FUnordEqualOp, LLVM::FCmpPredicate::ueq>,
1607 FComparePattern<spirv::FUnordGreaterThanOp, LLVM::FCmpPredicate::ugt>,
1608 FComparePattern<spirv::FUnordGreaterThanEqualOp,
1609 LLVM::FCmpPredicate::uge>,
1610 FComparePattern<spirv::FUnordLessThanEqualOp, LLVM::FCmpPredicate::ule>,
1611 FComparePattern<spirv::FUnordLessThanOp, LLVM::FCmpPredicate::ult>,
1612 FComparePattern<spirv::FUnordNotEqualOp, LLVM::FCmpPredicate::une>,
1613 IComparePattern<spirv::SGreaterThanOp, LLVM::ICmpPredicate::sgt>,
1614 IComparePattern<spirv::SGreaterThanEqualOp, LLVM::ICmpPredicate::sge>,
1615 IComparePattern<spirv::SLessThanEqualOp, LLVM::ICmpPredicate::sle>,
1616 IComparePattern<spirv::SLessThanOp, LLVM::ICmpPredicate::slt>,
1617 IComparePattern<spirv::UGreaterThanOp, LLVM::ICmpPredicate::ugt>,
1618 IComparePattern<spirv::UGreaterThanEqualOp, LLVM::ICmpPredicate::uge>,
1619 IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>,
1620 IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>,
1623 ConstantScalarAndVectorPattern,
1626 BranchConversionPattern, BranchConditionalConversionPattern,
1627 FunctionCallPattern, LoopPattern, SelectionPattern,
1628 ErasePattern<spirv::MergeOp>,
1631 ErasePattern<spirv::EntryPointOp>, ExecutionModePattern,
1634 DirectConversionPattern<spirv::GLCeilOp, LLVM::FCeilOp>,
1635 DirectConversionPattern<spirv::GLCosOp, LLVM::CosOp>,
1636 DirectConversionPattern<spirv::GLExpOp, LLVM::ExpOp>,
1637 DirectConversionPattern<spirv::GLFAbsOp, LLVM::FAbsOp>,
1638 DirectConversionPattern<spirv::GLFloorOp, LLVM::FFloorOp>,
1639 DirectConversionPattern<spirv::GLFMaxOp, LLVM::MaxNumOp>,
1640 DirectConversionPattern<spirv::GLFMinOp, LLVM::MinNumOp>,
1641 DirectConversionPattern<spirv::GLLogOp, LLVM::LogOp>,
1642 DirectConversionPattern<spirv::GLSinOp, LLVM::SinOp>,
1643 DirectConversionPattern<spirv::GLSMaxOp, LLVM::SMaxOp>,
1644 DirectConversionPattern<spirv::GLSMinOp, LLVM::SMinOp>,
1645 DirectConversionPattern<spirv::GLSqrtOp, LLVM::SqrtOp>,
1646 InverseSqrtPattern, TanPattern, TanhPattern,
1649 DirectConversionPattern<spirv::LogicalAndOp, LLVM::AndOp>,
1650 DirectConversionPattern<spirv::LogicalOrOp, LLVM::OrOp>,
1651 IComparePattern<spirv::LogicalEqualOp, LLVM::ICmpPredicate::eq>,
1652 IComparePattern<spirv::LogicalNotEqualOp, LLVM::ICmpPredicate::ne>,
1653 NotPattern<spirv::LogicalNotOp>,
1656 AccessChainPattern, AddressOfPattern, LoadStorePattern<spirv::LoadOp>,
1657 LoadStorePattern<spirv::StoreOp>, VariablePattern,
1660 CompositeExtractPattern, CompositeInsertPattern,
1661 DirectConversionPattern<spirv::SelectOp, LLVM::SelectOp>,
1662 DirectConversionPattern<spirv::UndefOp, LLVM::UndefOp>,
1663 VectorShufflePattern,
1666 ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>,
1667 ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>,
1668 ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>,
1671 ReturnPattern, ReturnValuePattern>(patterns.
getContext(), typeConverter);
1673 patterns.
add<GlobalVariablePattern>(clientAPI, patterns.
getContext(),
1679 patterns.
add<FuncConversionPattern>(patterns.
getContext(), typeConverter);
1684 patterns.
add<ModuleConversionPattern>(patterns.
getContext(), typeConverter);
1695 auto spvModules = module.getOps<spirv::ModuleOp>();
1696 for (
auto spvModule : spvModules) {
1697 spvModule.walk([&](spirv::GlobalVariableOp op) {
1698 IntegerAttr descriptorSet =
1703 if (descriptorSet && binding) {
1706 auto moduleAndName =
1707 spvModule.getName().has_value()
1708 ? spvModule.getName()->str() +
"_" + op.getSymName().str()
1709 : op.getSymName().str();
1711 llvm::formatv(
"{0}_descriptor_set{1}_binding{2}", moduleAndName,
1712 std::to_string(descriptorSet.getInt()),
1713 std::to_string(binding.getInt()));
1719 op.
emitError(
"unable to replace all symbol uses for ") << name;
static MLIRContext * getContext(OpFoldResult val)
static Value optionallyTruncateOrExtend(Location loc, Value value, Type llvmType, PatternRewriter &rewriter)
Utility function for bitfield ops:
static Value createFPConstant(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter, double value)
Creates llvm.mlir.constant with a floating-point scalar or vector value.
static constexpr StringRef kDescriptorSet
static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter, unsigned value)
Creates LLVM dialect constant with the given value.
static Type convertStructTypeWithOffset(spirv::StructType type, LLVMTypeConverter &converter)
Converts SPIR-V struct with a regular (according to VulkanLayoutUtils) offset to LLVM struct.
static Value processCountOrOffset(Location loc, Value value, Type srcType, Type dstType, LLVMTypeConverter &converter, ConversionPatternRewriter &rewriter)
Utility function for bitfield ops: BitFieldInsert, BitFieldSExtract and BitFieldUExtract.
static unsigned mapToOpenCLAddressSpace(spirv::StorageClass storageClass)
static unsigned getBitWidth(Type type)
Returns the bit width of integer, float or vector of float or integer values.
constexpr unsigned defaultAddressSpace
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static std::optional< Type > convertRuntimeArrayType(spirv::RuntimeArrayType type, TypeConverter &converter)
Converts SPIR-V runtime array to LLVM array.
#define STORAGE_SPACE_MAP(storage, space)
static bool isSignedIntegerOrVector(Type type)
Returns true if the given type is a signed integer or vector type.
static bool isUnsignedIntegerOrVector(Type type)
Returns true if the given type is an unsigned integer or vector type.
static Type convertStructTypePacked(spirv::StructType type, LLVMTypeConverter &converter)
Converts SPIR-V struct with no offset to packed LLVM struct.
static std::optional< Type > convertArrayType(spirv::ArrayType type, TypeConverter &converter)
Converts SPIR-V array type to LLVM array.
static constexpr StringRef kBinding
Hook for descriptor set and binding number encoding.
static unsigned mapToAddressSpace(spirv::ClientAPI clientAPI, spirv::StorageClass storageClass)
static Type convertStructType(spirv::StructType type, LLVMTypeConverter &converter)
Converts SPIR-V struct to LLVM struct.
static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder)
Creates IntegerAttribute with all bits set for given type.
static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter, LLVMTypeConverter &typeConverter, unsigned alignment, bool isVolatile, bool isNonTemporal)
Utility for spirv.Load and spirv.Store conversion.
static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter)
Creates llvm.mlir.constant with all bits set for the given type.
static unsigned getLLVMTypeBitWidth(Type type)
Returns the bit width of LLVMType integer or vector.
static Value optionallyBroadcast(Location loc, Value value, Type srcType, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value. If srcType is a scalar, the value remains unchanged.
#define DISPATCH(functionControl, llvmAttr)
static std::optional< uint64_t > getIntegerOrVectorElementWidth(Type type)
Returns the width of an integer or of the element type of an integer vector, if applicable.
#define CLIENT_MAP(client, storage)
static Type convertPointerType(spirv::PointerType type, LLVMTypeConverter &converter, spirv::ClientAPI clientAPI)
Converts SPIR-V pointer type to LLVM pointer.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
OpListType::iterator iterator
Operation * getTerminator()
Get the terminator operation of this block.
OpListType & getOperations()
BlockArgListType getArguments()
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Convert the types of block arguments within the given region.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before) override
PatternRewriter hook for moving blocks out of a region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Block * splitBlock(Block *block, Block::iterator before) override
PatternRewriter hook for splitting a block into two parts.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Conversion from types to the LLVM IR dialect.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
static LLVMStructType getLiteral(MLIRContext *context, ArrayRef< Type > types, bool isPacked=false)
Gets or creates a literal struct with the given body in the provided context.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Block * getBlock() const
Returns the current block of the builder.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
AttrClass getAttrOfType(StringAttr name)
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
Attribute removeAttr(StringAttr name)
Remove the attribute with the specified name if it exists.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, StringAttr newSymbol, Operation *from)
Attempt to replace all uses of the given symbol 'oldSymbol' with the provided symbol 'newSymbol' that...
static void setSymbolName(Operation *symbol, StringAttr name)
Sets the name of the given symbol operation.
This class provides all of the information necessary to convert a type signature.
void addConversion(FnT &&callback)
Register a conversion function.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static spirv::StructType decorateType(spirv::StructType structType)
Returns a new StructType with layout decoration.
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Builder from ArrayRef<T>.
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
unsigned getNumElements() const
StorageClass getStorageClass() const
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
TypeRange getElementTypes() const
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect.
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
SmallVector< IntT > convertArrayToIndices(ArrayRef< Attribute > attrs)
Convert an array of integer attributes to a vector of integers that can be used as indices in LLVM op...
Type getVectorElementType(Type type)
Returns the element type of any vector type compatible with the LLVM dialect.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
void populateSPIRVToLLVMFunctionConversionPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given list with patterns for function conversion from SPIR-V to LLVM.
void populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter, spirv::ClientAPI clientAPIForAddressSpaceMapping=spirv::ClientAPI::Unknown)
Populates type conversions with additional SPIR-V types.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
void populateSPIRVToLLVMConversionPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, spirv::ClientAPI clientAPIForAddressSpaceMapping=spirv::ClientAPI::Unknown)
Populates the given list with patterns that convert from SPIR-V to LLVM.
void populateSPIRVToLLVMModuleConversionPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given patterns for module conversion from SPIR-V to LLVM.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void encodeBindAttribute(ModuleOp module)
Encodes global variable's descriptor set and binding into its name if they both exist.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.