24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/FormatVariadic.h"
27 #define DEBUG_TYPE "spirv-to-llvm-pattern"
45 if (
auto vecType = dyn_cast<VectorType>(type))
46 return vecType.getElementType().isSignedInteger();
54 if (
auto vecType = dyn_cast<VectorType>(type))
55 return vecType.getElementType().isUnsignedInteger();
62 if (
auto intType = dyn_cast<IntegerType>(type))
63 return intType.getWidth();
64 if (
auto vecType = dyn_cast<VectorType>(type))
65 if (
auto intType = dyn_cast<IntegerType>(vecType.getElementType()))
66 return intType.getWidth();
73 "bitwidth is not supported for this type");
76 auto vecType = dyn_cast<VectorType>(type);
77 auto elementType = vecType.getElementType();
78 assert(elementType.isIntOrFloat() &&
79 "only integers and floats have a bitwidth");
80 return elementType.getIntOrFloatBitWidth();
93 if (
auto vecType = dyn_cast<VectorType>(type)) {
94 auto integerType = cast<IntegerType>(vecType.getElementType());
97 auto integerType = cast<IntegerType>(type);
104 if (isa<VectorType>(srcType)) {
105 return rewriter.
create<LLVM::ConstantOp>(
110 return rewriter.
create<LLVM::ConstantOp>(
117 if (
auto vecType = dyn_cast<VectorType>(srcType)) {
118 auto floatType = cast<FloatType>(vecType.getElementType());
119 return rewriter.
create<LLVM::ConstantOp>(
124 auto floatType = cast<FloatType>(srcType);
125 return rewriter.
create<LLVM::ConstantOp>(
138 auto srcType = value.
getType();
144 if (valueBitWidth < targetBitWidth)
145 return rewriter.
create<LLVM::ZExtOp>(loc, llvmType, value);
150 if (valueBitWidth > targetBitWidth)
151 return rewriter.
create<LLVM::TruncOp>(loc, llvmType, value);
160 auto llvmVectorType = typeConverter.
convertType(vectorType);
162 Value broadcasted = rewriter.
create<LLVM::UndefOp>(loc, llvmVectorType);
163 for (
unsigned i = 0; i < numElements; ++i) {
164 auto index = rewriter.
create<LLVM::ConstantOp>(
166 broadcasted = rewriter.
create<LLVM::InsertElementOp>(
167 loc, llvmVectorType, broadcasted, toBroadcast, index);
176 if (
auto vectorType = dyn_cast<VectorType>(srcType)) {
177 unsigned numElements = vectorType.getNumElements();
178 return broadcast(loc, value, numElements, typeConverter, rewriter);
228 return rewriter.
create<LLVM::ConstantOp>(
237 unsigned alignment,
bool isVolatile,
238 bool isNonTemporal) {
239 if (
auto loadOp = dyn_cast<spirv::LoadOp>(op)) {
240 auto dstType = typeConverter.
convertType(loadOp.getType());
244 loadOp, dstType, spirv::LoadOpAdaptor(operands).getPtr(), alignment,
245 isVolatile, isNonTemporal);
248 auto storeOp = cast<spirv::StoreOp>(op);
249 spirv::StoreOpAdaptor adaptor(operands);
251 adaptor.getPtr(), alignment,
252 isVolatile, isNonTemporal);
267 auto sizeInBytes = cast<spirv::SPIRVType>(elementType).getSizeInBytes();
268 if (stride != 0 && (!sizeInBytes || *sizeInBytes != stride))
271 auto llvmElementType = converter.
convertType(elementType);
280 switch (storageClass) {
281 #define STORAGE_SPACE_MAP(storage, space) \
282 case spirv::StorageClass::storage: \
292 #undef STORAGE_SPACE_MAP
299 spirv::StorageClass storageClass) {
301 #define CLIENT_MAP(client, storage) \
302 case spirv::ClientAPI::client: \
303 return mapTo##client##AddressSpace(storage);
315 spirv::ClientAPI clientAPI) {
337 if (!memberDecorations.empty())
355 matchAndRewrite(spirv::AccessChainOp op, OpAdaptor adaptor,
357 auto dstType = typeConverter.convertType(op.getComponentPtr().getType());
361 auto indices = llvm::to_vector<4>(adaptor.getIndices());
362 Type indexType = op.getIndices().front().getType();
363 auto llvmIndexType = typeConverter.convertType(indexType);
368 indices.insert(indices.begin(), zero);
370 auto elementType = typeConverter.convertType(
371 cast<spirv::PointerType>(op.getBasePtr().getType()).getPointeeType());
375 adaptor.getBasePtr(), indices);
385 matchAndRewrite(spirv::AddressOfOp op, OpAdaptor adaptor,
387 auto dstType = typeConverter.convertType(op.getPointer().getType());
396 class BitFieldInsertPattern
402 matchAndRewrite(spirv::BitFieldInsertOp op, OpAdaptor adaptor,
404 auto srcType = op.getType();
405 auto dstType = typeConverter.convertType(srcType);
412 typeConverter, rewriter);
414 typeConverter, rewriter);
418 Value maskShiftedByCount =
419 rewriter.
create<LLVM::ShlOp>(loc, dstType, minusOne, count);
420 Value negated = rewriter.
create<LLVM::XOrOp>(loc, dstType,
421 maskShiftedByCount, minusOne);
422 Value maskShiftedByCountAndOffset =
423 rewriter.
create<LLVM::ShlOp>(loc, dstType, negated, offset);
425 loc, dstType, maskShiftedByCountAndOffset, minusOne);
430 rewriter.
create<LLVM::AndOp>(loc, dstType, op.getBase(), mask);
431 Value insertShiftedByOffset =
432 rewriter.
create<LLVM::ShlOp>(loc, dstType, op.getInsert(), offset);
434 insertShiftedByOffset);
440 class ConstantScalarAndVectorPattern
446 matchAndRewrite(spirv::ConstantOp constOp, OpAdaptor adaptor,
448 auto srcType = constOp.getType();
449 if (!isa<VectorType>(srcType) && !srcType.isIntOrFloat())
452 auto dstType = typeConverter.convertType(srcType);
465 if (isa<VectorType>(srcType)) {
466 auto dstElementsAttr = cast<DenseIntElementsAttr>(constOp.getValue());
469 dstElementsAttr.mapValues(
470 signlessType, [&](
const APInt &value) {
return value; }));
473 auto srcAttr = cast<IntegerAttr>(constOp.getValue());
474 auto dstAttr = rewriter.
getIntegerAttr(signlessType, srcAttr.getValue());
479 constOp, dstType, adaptor.getOperands(), constOp->getAttrs());
484 class BitFieldSExtractPattern
490 matchAndRewrite(spirv::BitFieldSExtractOp op, OpAdaptor adaptor,
492 auto srcType = op.getType();
493 auto dstType = typeConverter.convertType(srcType);
500 typeConverter, rewriter);
502 typeConverter, rewriter);
505 IntegerType integerType;
506 if (
auto vecType = dyn_cast<VectorType>(srcType))
507 integerType = cast<IntegerType>(vecType.getElementType());
509 integerType = cast<IntegerType>(srcType);
513 isa<VectorType>(srcType)
514 ? rewriter.
create<LLVM::ConstantOp>(
517 : rewriter.
create<LLVM::ConstantOp>(loc, dstType, baseSize);
521 Value countPlusOffset =
522 rewriter.
create<LLVM::AddOp>(loc, dstType, count, offset);
523 Value amountToShiftLeft =
524 rewriter.
create<LLVM::SubOp>(loc, dstType, size, countPlusOffset);
525 Value baseShiftedLeft = rewriter.
create<LLVM::ShlOp>(
526 loc, dstType, op.getBase(), amountToShiftLeft);
529 Value amountToShiftRight =
530 rewriter.
create<LLVM::AddOp>(loc, dstType, offset, amountToShiftLeft);
537 class BitFieldUExtractPattern
543 matchAndRewrite(spirv::BitFieldUExtractOp op, OpAdaptor adaptor,
545 auto srcType = op.getType();
546 auto dstType = typeConverter.convertType(srcType);
553 typeConverter, rewriter);
555 typeConverter, rewriter);
559 Value maskShiftedByCount =
560 rewriter.
create<LLVM::ShlOp>(loc, dstType, minusOne, count);
561 Value mask = rewriter.
create<LLVM::XOrOp>(loc, dstType, maskShiftedByCount,
566 rewriter.
create<LLVM::LShrOp>(loc, dstType, op.getBase(), offset);
577 matchAndRewrite(spirv::BranchOp branchOp, OpAdaptor adaptor,
580 branchOp.getTarget());
585 class BranchConditionalConversionPattern
592 matchAndRewrite(spirv::BranchConditionalOp op, OpAdaptor adaptor,
596 if (
auto weights = op.getBranchWeights()) {
598 for (
auto weight : weights->getAsRange<IntegerAttr>())
599 weightValues.push_back(weight.getInt());
604 op, op.getCondition(), op.getTrueBlockArguments(),
605 op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(),
614 class CompositeExtractPattern
620 matchAndRewrite(spirv::CompositeExtractOp op, OpAdaptor adaptor,
622 auto dstType = this->typeConverter.convertType(op.getType());
626 Type containerType = op.getComposite().getType();
627 if (isa<VectorType>(containerType)) {
629 IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]);
632 op, dstType, adaptor.getComposite(), index);
637 op, adaptor.getComposite(),
646 class CompositeInsertPattern
652 matchAndRewrite(spirv::CompositeInsertOp op, OpAdaptor adaptor,
654 auto dstType = this->typeConverter.convertType(op.getType());
658 Type containerType = op.getComposite().getType();
659 if (isa<VectorType>(containerType)) {
661 IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]);
664 op, dstType, adaptor.getComposite(), adaptor.getObject(), index);
669 op, adaptor.getComposite(), adaptor.getObject(),
677 template <
typename SPIRVOp,
typename LLVMOp>
683 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
685 auto dstType = this->typeConverter.convertType(op.getType());
688 rewriter.template replaceOpWithNewOp<LLVMOp>(
689 op, dstType, adaptor.getOperands(), op->
getAttrs());
696 class ExecutionModePattern
702 matchAndRewrite(spirv::ExecutionModeOp op, OpAdaptor adaptor,
708 spirv::ExecutionModeAttr executionModeAttr = op.getExecutionModeAttr();
709 std::string moduleName;
710 if (module.getName().has_value())
711 moduleName =
"_" + module.
getName()->str();
714 std::string executionModeInfoName = llvm::formatv(
715 "__spv_{0}_{1}_execution_mode_info_{2}", moduleName, op.getFn().str(),
716 static_cast<uint32_t
>(executionModeAttr.getValue()));
729 fields.push_back(llvmI32Type);
730 ArrayAttr values = op.getValues();
731 if (!values.empty()) {
733 fields.push_back(arrayType);
738 auto global = rewriter.
create<LLVM::GlobalOp>(
740 LLVM::Linkage::External, executionModeInfoName,
Attribute(),
743 Region ®ion = global.getInitializerRegion();
748 Value structValue = rewriter.
create<LLVM::UndefOp>(loc, structType);
749 Value executionMode = rewriter.
create<LLVM::ConstantOp>(
752 static_cast<uint32_t
>(executionModeAttr.getValue())));
753 structValue = rewriter.
create<LLVM::InsertValueOp>(loc, structValue,
757 for (
unsigned i = 0, e = values.size(); i < e; ++i) {
758 auto attr = values.getValue()[i];
759 Value entry = rewriter.
create<LLVM::ConstantOp>(loc, llvmI32Type, attr);
760 structValue = rewriter.
create<LLVM::InsertValueOp>(
773 class GlobalVariablePattern
776 template <
typename... Args>
777 GlobalVariablePattern(spirv::ClientAPI clientAPI, Args &&...args)
779 std::forward<Args>(args)...),
780 clientAPI(clientAPI) {}
783 matchAndRewrite(spirv::GlobalVariableOp op, OpAdaptor adaptor,
787 if (op.getInitializer())
790 auto srcType = cast<spirv::PointerType>(op.getType());
791 auto dstType = typeConverter.convertType(srcType.getPointeeType());
798 auto storageClass = srcType.getStorageClass();
799 switch (storageClass) {
800 case spirv::StorageClass::Input:
801 case spirv::StorageClass::Private:
802 case spirv::StorageClass::Output:
803 case spirv::StorageClass::StorageBuffer:
804 case spirv::StorageClass::UniformConstant:
813 bool isConstant = (storageClass == spirv::StorageClass::Input) ||
814 (storageClass == spirv::StorageClass::UniformConstant);
820 auto linkage = storageClass == spirv::StorageClass::Private
821 ? LLVM::Linkage::Private
822 : LLVM::Linkage::External;
824 op, dstType, isConstant, linkage, op.getSymName(),
Attribute(),
828 if (op.getLocationAttr())
829 newGlobalOp->setAttr(op.getLocationAttrName(), op.getLocationAttr());
835 spirv::ClientAPI clientAPI;
840 template <
typename SPIRVOp,
typename LLVMExtOp,
typename LLVMTruncOp>
846 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
850 Type toType = op.getType();
852 auto dstType = this->typeConverter.convertType(toType);
857 rewriter.template replaceOpWithNewOp<LLVMExtOp>(op, dstType,
858 adaptor.getOperands());
862 rewriter.template replaceOpWithNewOp<LLVMTruncOp>(op, dstType,
863 adaptor.getOperands());
870 class FunctionCallPattern
876 matchAndRewrite(spirv::FunctionCallOp callOp, OpAdaptor adaptor,
878 if (callOp.getNumResults() == 0) {
880 callOp, std::nullopt, adaptor.getOperands(), callOp->getAttrs());
885 auto dstType = typeConverter.convertType(callOp.getType(0));
889 callOp, dstType, adaptor.getOperands(), callOp->getAttrs());
895 template <
typename SPIRVOp, LLVM::FCmpPredicate predicate>
901 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
904 auto dstType = this->typeConverter.convertType(op.getType());
908 rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>(
909 op, dstType, predicate, op.getOperand1(), op.getOperand2());
915 template <
typename SPIRVOp, LLVM::ICmpPredicate predicate>
921 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
924 auto dstType = this->typeConverter.convertType(op.getType());
928 rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>(
929 op, dstType, predicate, op.getOperand1(), op.getOperand2());
934 class InverseSqrtPattern
940 matchAndRewrite(spirv::GLInverseSqrtOp op, OpAdaptor adaptor,
942 auto srcType = op.getType();
943 auto dstType = typeConverter.convertType(srcType);
956 template <
typename SPIRVOp>
962 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
964 if (!op.getMemoryAccess()) {
966 this->typeConverter, 0,
970 auto memoryAccess = *op.getMemoryAccess();
971 switch (memoryAccess) {
972 case spirv::MemoryAccess::Aligned:
974 case spirv::MemoryAccess::Nontemporal:
975 case spirv::MemoryAccess::Volatile: {
977 memoryAccess == spirv::MemoryAccess::Aligned ? *op.getAlignment() : 0;
978 bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal;
979 bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile;
981 this->typeConverter, alignment, isVolatile,
992 template <
typename SPIRVOp>
998 matchAndRewrite(SPIRVOp notOp,
typename SPIRVOp::Adaptor adaptor,
1000 auto srcType = notOp.getType();
1001 auto dstType = this->typeConverter.convertType(srcType);
1008 isa<VectorType>(srcType)
1009 ? rewriter.
create<LLVM::ConstantOp>(
1012 : rewriter.
create<LLVM::ConstantOp>(loc, dstType, minusOne);
1013 rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType,
1014 notOp.getOperand(), mask);
1020 template <
typename SPIRVOp>
1026 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
1038 matchAndRewrite(spirv::ReturnOp returnOp, OpAdaptor adaptor,
1051 matchAndRewrite(spirv::ReturnValueOp returnValueOp, OpAdaptor adaptor,
1054 adaptor.getOperands());
1112 matchAndRewrite(spirv::LoopOp loopOp, OpAdaptor adaptor,
1128 Block *entryBlock = loopOp.getEntryBlock();
1130 auto brOp = dyn_cast<spirv::BranchOp>(entryBlock->
getOperations().front());
1133 Block *headerBlock = loopOp.getHeaderBlock();
1135 rewriter.
create<LLVM::BrOp>(loc, brOp.getBlockArguments(), headerBlock);
1139 Block *mergeBlock = loopOp.getMergeBlock();
1143 rewriter.
create<LLVM::BrOp>(loc, terminatorOperands, endBlock);
1159 matchAndRewrite(spirv::SelectionOp op, OpAdaptor adaptor,
1171 if (op.getBody().getBlocks().size() <= 2) {
1183 auto *continueBlock = rewriter.
splitBlock(currentBlock, position);
1189 auto *headerBlock = op.getHeaderBlock();
1191 auto condBrOp = dyn_cast<spirv::BranchConditionalOp>(
1198 auto *mergeBlock = op.getMergeBlock();
1202 rewriter.
create<LLVM::BrOp>(loc, terminatorOperands, continueBlock);
1205 Block *trueBlock = condBrOp.getTrueBlock();
1206 Block *falseBlock = condBrOp.getFalseBlock();
1208 rewriter.
create<LLVM::CondBrOp>(loc, condBrOp.getCondition(), trueBlock,
1209 condBrOp.getTrueTargetOperands(),
1211 condBrOp.getFalseTargetOperands());
1214 rewriter.
replaceOp(op, continueBlock->getArguments());
1223 template <
typename SPIRVOp,
typename LLVMOp>
1229 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
1232 auto dstType = this->typeConverter.convertType(op.getType());
1236 Type op1Type = op.getOperand1().getType();
1237 Type op2Type = op.getOperand2().getType();
1239 if (op1Type == op2Type) {
1240 rewriter.template replaceOpWithNewOp<LLVMOp>(op, dstType,
1241 adaptor.getOperands());
1245 std::optional<uint64_t> dstTypeWidth =
1247 std::optional<uint64_t> op2TypeWidth =
1250 if (!dstTypeWidth || !op2TypeWidth)
1255 if (op2TypeWidth < dstTypeWidth) {
1257 extended = rewriter.template create<LLVM::ZExtOp>(
1258 loc, dstType, adaptor.getOperand2());
1260 extended = rewriter.template create<LLVM::SExtOp>(
1261 loc, dstType, adaptor.getOperand2());
1263 }
else if (op2TypeWidth == dstTypeWidth) {
1264 extended = adaptor.getOperand2();
1269 Value result = rewriter.template create<LLVMOp>(
1270 loc, dstType, adaptor.getOperand1(), extended);
1281 matchAndRewrite(spirv::GLTanOp tanOp, OpAdaptor adaptor,
1283 auto dstType = typeConverter.convertType(tanOp.getType());
1288 Value sin = rewriter.
create<LLVM::SinOp>(loc, dstType, tanOp.getOperand());
1289 Value cos = rewriter.
create<LLVM::CosOp>(loc, dstType, tanOp.getOperand());
1306 matchAndRewrite(spirv::GLTanhOp tanhOp, OpAdaptor adaptor,
1308 auto srcType = tanhOp.getType();
1309 auto dstType = typeConverter.convertType(srcType);
1316 rewriter.
create<LLVM::FMulOp>(loc, dstType, two, tanhOp.getOperand());
1317 Value exponential = rewriter.
create<LLVM::ExpOp>(loc, dstType, multiplied);
1320 rewriter.
create<LLVM::FSubOp>(loc, dstType, exponential, one);
1322 rewriter.
create<LLVM::FAddOp>(loc, dstType, exponential, one);
1334 matchAndRewrite(spirv::VariableOp varOp, OpAdaptor adaptor,
1336 auto srcType = varOp.getType();
1338 auto pointerTo = cast<spirv::PointerType>(srcType).getPointeeType();
1339 auto init = varOp.getInitializer();
1340 if (init && !pointerTo.isIntOrFloat() && !isa<VectorType>(pointerTo))
1343 auto dstType = typeConverter.convertType(srcType);
1350 auto elementType = typeConverter.convertType(pointerTo);
1357 auto elementType = typeConverter.convertType(pointerTo);
1361 rewriter.
create<LLVM::AllocaOp>(loc, dstType, elementType, size);
1362 rewriter.
create<LLVM::StoreOp>(loc, adaptor.getInitializer(), allocated);
1372 class BitcastConversionPattern
1378 matchAndRewrite(spirv::BitcastOp bitcastOp, OpAdaptor adaptor,
1380 auto dstType = typeConverter.convertType(bitcastOp.getType());
1385 if (isa<LLVM::LLVMPointerType>(dstType)) {
1386 rewriter.
replaceOp(bitcastOp, adaptor.getOperand());
1391 bitcastOp, dstType, adaptor.getOperands(), bitcastOp->getAttrs());
1405 matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor,
1410 auto funcType = funcOp.getFunctionType();
1412 funcType.getNumInputs());
1413 auto llvmType = typeConverter.convertFunctionSignature(
1414 funcType,
false,
false,
1415 signatureConverter);
1421 StringRef name = funcOp.getName();
1422 auto newFuncOp = rewriter.
create<LLVM::LLVMFuncOp>(loc, name, llvmType);
1426 switch (funcOp.getFunctionControl()) {
1427 case spirv::FunctionControl::Inline:
1428 newFuncOp.setAlwaysInline(
true);
1430 case spirv::FunctionControl::DontInline:
1431 newFuncOp.setNoInline(
true);
1434 #define DISPATCH(functionControl, llvmAttr) \
1435 case functionControl: \
1436 newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr})); \
1439 DISPATCH(spirv::FunctionControl::Pure,
1441 DISPATCH(spirv::FunctionControl::Const,
1455 &signatureConverter))) {
1472 matchAndRewrite(spirv::ModuleOp spvModuleOp, OpAdaptor adaptor,
1476 rewriter.
create<ModuleOp>(spvModuleOp.getLoc(), spvModuleOp.getName());
1480 rewriter.
eraseBlock(&newModuleOp.getBodyRegion().back());
1481 rewriter.
eraseOp(spvModuleOp);
1490 class VectorShufflePattern
1495 matchAndRewrite(spirv::VectorShuffleOp op, OpAdaptor adaptor,
1498 auto components = adaptor.getComponents();
1499 auto vector1 = adaptor.getVector1();
1500 auto vector2 = adaptor.getVector2();
1501 int vector1Size = cast<VectorType>(vector1.getType()).getNumElements();
1502 int vector2Size = cast<VectorType>(vector2.getType()).getNumElements();
1503 if (vector1Size == vector2Size) {
1505 op, vector1, vector2,
1506 LLVM::convertArrayToIndices<int32_t>(components));
1510 auto dstType = typeConverter.convertType(op.getType());
1513 auto scalarType = cast<VectorType>(dstType).getElementType();
1514 auto componentsArray = components.getValue();
1517 Value targetOp = rewriter.
create<LLVM::UndefOp>(loc, dstType);
1518 for (
unsigned i = 0; i < componentsArray.size(); i++) {
1519 if (!isa<IntegerAttr>(componentsArray[i]))
1520 return op.
emitError(
"unable to support non-constant component");
1522 int indexVal = cast<IntegerAttr>(componentsArray[i]).getInt();
1527 Value baseVector = vector1;
1528 if (indexVal >= vector1Size) {
1529 offsetVal = vector1Size;
1530 baseVector = vector2;
1533 Value dstIndex = rewriter.
create<LLVM::ConstantOp>(
1539 auto extractOp = rewriter.
create<LLVM::ExtractElementOp>(
1540 loc, scalarType, baseVector, index);
1541 targetOp = rewriter.
create<LLVM::InsertElementOp>(loc, dstType, targetOp,
1542 extractOp, dstIndex);
1555 spirv::ClientAPI clientAPI) {
1572 spirv::ClientAPI clientAPI) {
1575 DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>,
1576 DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>,
1577 DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>,
1578 DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>,
1579 DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>,
1580 DirectConversionPattern<spirv::FMulOp, LLVM::FMulOp>,
1581 DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>,
1582 DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>,
1583 DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>,
1584 DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>,
1585 DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>,
1586 DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>,
1587 DirectConversionPattern<spirv::UModOp, LLVM::URemOp>,
1590 BitFieldInsertPattern, BitFieldUExtractPattern, BitFieldSExtractPattern,
1591 DirectConversionPattern<spirv::BitCountOp, LLVM::CtPopOp>,
1592 DirectConversionPattern<spirv::BitReverseOp, LLVM::BitReverseOp>,
1593 DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>,
1594 DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>,
1595 DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>,
1596 NotPattern<spirv::NotOp>,
1599 BitcastConversionPattern,
1600 DirectConversionPattern<spirv::ConvertFToSOp, LLVM::FPToSIOp>,
1601 DirectConversionPattern<spirv::ConvertFToUOp, LLVM::FPToUIOp>,
1602 DirectConversionPattern<spirv::ConvertSToFOp, LLVM::SIToFPOp>,
1603 DirectConversionPattern<spirv::ConvertUToFOp, LLVM::UIToFPOp>,
1604 IndirectCastPattern<spirv::FConvertOp, LLVM::FPExtOp, LLVM::FPTruncOp>,
1605 IndirectCastPattern<spirv::SConvertOp, LLVM::SExtOp, LLVM::TruncOp>,
1606 IndirectCastPattern<spirv::UConvertOp, LLVM::ZExtOp, LLVM::TruncOp>,
1609 IComparePattern<spirv::IEqualOp, LLVM::ICmpPredicate::eq>,
1610 IComparePattern<spirv::INotEqualOp, LLVM::ICmpPredicate::ne>,
1611 FComparePattern<spirv::FOrdEqualOp, LLVM::FCmpPredicate::oeq>,
1612 FComparePattern<spirv::FOrdGreaterThanOp, LLVM::FCmpPredicate::ogt>,
1613 FComparePattern<spirv::FOrdGreaterThanEqualOp, LLVM::FCmpPredicate::oge>,
1614 FComparePattern<spirv::FOrdLessThanEqualOp, LLVM::FCmpPredicate::ole>,
1615 FComparePattern<spirv::FOrdLessThanOp, LLVM::FCmpPredicate::olt>,
1616 FComparePattern<spirv::FOrdNotEqualOp, LLVM::FCmpPredicate::one>,
1617 FComparePattern<spirv::FUnordEqualOp, LLVM::FCmpPredicate::ueq>,
1618 FComparePattern<spirv::FUnordGreaterThanOp, LLVM::FCmpPredicate::ugt>,
1619 FComparePattern<spirv::FUnordGreaterThanEqualOp,
1620 LLVM::FCmpPredicate::uge>,
1621 FComparePattern<spirv::FUnordLessThanEqualOp, LLVM::FCmpPredicate::ule>,
1622 FComparePattern<spirv::FUnordLessThanOp, LLVM::FCmpPredicate::ult>,
1623 FComparePattern<spirv::FUnordNotEqualOp, LLVM::FCmpPredicate::une>,
1624 IComparePattern<spirv::SGreaterThanOp, LLVM::ICmpPredicate::sgt>,
1625 IComparePattern<spirv::SGreaterThanEqualOp, LLVM::ICmpPredicate::sge>,
1626 IComparePattern<spirv::SLessThanEqualOp, LLVM::ICmpPredicate::sle>,
1627 IComparePattern<spirv::SLessThanOp, LLVM::ICmpPredicate::slt>,
1628 IComparePattern<spirv::UGreaterThanOp, LLVM::ICmpPredicate::ugt>,
1629 IComparePattern<spirv::UGreaterThanEqualOp, LLVM::ICmpPredicate::uge>,
1630 IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>,
1631 IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>,
1634 ConstantScalarAndVectorPattern,
1637 BranchConversionPattern, BranchConditionalConversionPattern,
1638 FunctionCallPattern, LoopPattern, SelectionPattern,
1639 ErasePattern<spirv::MergeOp>,
1642 ErasePattern<spirv::EntryPointOp>, ExecutionModePattern,
1645 DirectConversionPattern<spirv::GLCeilOp, LLVM::FCeilOp>,
1646 DirectConversionPattern<spirv::GLCosOp, LLVM::CosOp>,
1647 DirectConversionPattern<spirv::GLExpOp, LLVM::ExpOp>,
1648 DirectConversionPattern<spirv::GLFAbsOp, LLVM::FAbsOp>,
1649 DirectConversionPattern<spirv::GLFloorOp, LLVM::FFloorOp>,
1650 DirectConversionPattern<spirv::GLFMaxOp, LLVM::MaxNumOp>,
1651 DirectConversionPattern<spirv::GLFMinOp, LLVM::MinNumOp>,
1652 DirectConversionPattern<spirv::GLLogOp, LLVM::LogOp>,
1653 DirectConversionPattern<spirv::GLSinOp, LLVM::SinOp>,
1654 DirectConversionPattern<spirv::GLSMaxOp, LLVM::SMaxOp>,
1655 DirectConversionPattern<spirv::GLSMinOp, LLVM::SMinOp>,
1656 DirectConversionPattern<spirv::GLSqrtOp, LLVM::SqrtOp>,
1657 InverseSqrtPattern, TanPattern, TanhPattern,
1660 DirectConversionPattern<spirv::LogicalAndOp, LLVM::AndOp>,
1661 DirectConversionPattern<spirv::LogicalOrOp, LLVM::OrOp>,
1662 IComparePattern<spirv::LogicalEqualOp, LLVM::ICmpPredicate::eq>,
1663 IComparePattern<spirv::LogicalNotEqualOp, LLVM::ICmpPredicate::ne>,
1664 NotPattern<spirv::LogicalNotOp>,
1667 AccessChainPattern, AddressOfPattern, LoadStorePattern<spirv::LoadOp>,
1668 LoadStorePattern<spirv::StoreOp>, VariablePattern,
1671 CompositeExtractPattern, CompositeInsertPattern,
1672 DirectConversionPattern<spirv::SelectOp, LLVM::SelectOp>,
1673 DirectConversionPattern<spirv::UndefOp, LLVM::UndefOp>,
1674 VectorShufflePattern,
1677 ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>,
1678 ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>,
1679 ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>,
1682 ReturnPattern, ReturnValuePattern>(patterns.
getContext(), typeConverter);
1684 patterns.
add<GlobalVariablePattern>(clientAPI, patterns.
getContext(),
1690 patterns.
add<FuncConversionPattern>(patterns.
getContext(), typeConverter);
1695 patterns.
add<ModuleConversionPattern>(patterns.
getContext(), typeConverter);
1706 auto spvModules = module.getOps<spirv::ModuleOp>();
1707 for (
auto spvModule : spvModules) {
1708 spvModule.walk([&](spirv::GlobalVariableOp op) {
1709 IntegerAttr descriptorSet =
1714 if (descriptorSet && binding) {
1717 auto moduleAndName =
1718 spvModule.getName().has_value()
1719 ? spvModule.getName()->str() +
"_" + op.getSymName().str()
1720 : op.getSymName().str();
1722 llvm::formatv(
"{0}_descriptor_set{1}_binding{2}", moduleAndName,
1723 std::to_string(descriptorSet.getInt()),
1724 std::to_string(binding.getInt()));
1730 op.
emitError(
"unable to replace all symbol uses for ") << name;
static MLIRContext * getContext(OpFoldResult val)
static Value optionallyTruncateOrExtend(Location loc, Value value, Type llvmType, PatternRewriter &rewriter)
Utility function for bitfield ops:
static Value createFPConstant(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter, double value)
Creates llvm.mlir.constant with a floating-point scalar or vector value.
static constexpr StringRef kDescriptorSet
static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter, unsigned value)
Creates LLVM dialect constant with the given value.
static Type convertStructTypeWithOffset(spirv::StructType type, LLVMTypeConverter &converter)
Converts SPIR-V struct with a regular (according to VulkanLayoutUtils) offset to LLVM struct.
static Value processCountOrOffset(Location loc, Value value, Type srcType, Type dstType, LLVMTypeConverter &converter, ConversionPatternRewriter &rewriter)
Utility function for bitfield ops: BitFieldInsert, BitFieldSExtract and BitFieldUExtract.
static unsigned mapToOpenCLAddressSpace(spirv::StorageClass storageClass)
static unsigned getBitWidth(Type type)
Returns the bit width of integer, float or vector of float or integer values.
constexpr unsigned defaultAddressSpace
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static std::optional< Type > convertRuntimeArrayType(spirv::RuntimeArrayType type, TypeConverter &converter)
Converts SPIR-V runtime array to LLVM array.
#define STORAGE_SPACE_MAP(storage, space)
static bool isSignedIntegerOrVector(Type type)
Returns true if the given type is a signed integer or vector type.
static bool isUnsignedIntegerOrVector(Type type)
Returns true if the given type is an unsigned integer or vector type.
static Type convertStructTypePacked(spirv::StructType type, LLVMTypeConverter &converter)
Converts SPIR-V struct with no offset to packed LLVM struct.
static std::optional< Type > convertArrayType(spirv::ArrayType type, TypeConverter &converter)
Converts SPIR-V array type to LLVM array.
static constexpr StringRef kBinding
Hook for descriptor set and binding number encoding.
static unsigned mapToAddressSpace(spirv::ClientAPI clientAPI, spirv::StorageClass storageClass)
static Type convertStructType(spirv::StructType type, LLVMTypeConverter &converter)
Converts SPIR-V struct to LLVM struct.
static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder)
Creates IntegerAttribute with all bits set for given type.
static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter, LLVMTypeConverter &typeConverter, unsigned alignment, bool isVolatile, bool isNonTemporal)
Utility for spirv.Load and spirv.Store conversion.
static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter)
Creates llvm.mlir.constant with all bits set for the given type.
static unsigned getLLVMTypeBitWidth(Type type)
Returns the bit width of LLVMType integer or vector.
static Value optionallyBroadcast(Location loc, Value value, Type srcType, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value. If srcType is a scalar, the value remains unchanged.
#define DISPATCH(functionControl, llvmAttr)
static std::optional< uint64_t > getIntegerOrVectorElementWidth(Type type)
Returns the width of an integer or of the element type of an integer vector, if applicable.
#define CLIENT_MAP(client, storage)
static Type convertPointerType(spirv::PointerType type, LLVMTypeConverter &converter, spirv::ClientAPI clientAPI)
Converts SPIR-V pointer type to LLVM pointer.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
OpListType::iterator iterator
Operation * getTerminator()
Get the terminator operation of this block.
OpListType & getOperations()
BlockArgListType getArguments()
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Conversion from types to the LLVM IR dialect.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
static LLVMStructType getLiteral(MLIRContext *context, ArrayRef< Type > types, bool isPacked=false)
Gets or creates a literal struct with the given body in the provided context.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Block * getBlock() const
Returns the current block of the builder.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
AttrClass getAttrOfType(StringAttr name)
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
Attribute removeAttr(StringAttr name)
Remove the attribute with the specified name if it exists.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, StringAttr newSymbol, Operation *from)
Attempt to replace all uses of the given symbol 'oldSymbol' with the provided symbol 'newSymbol' that...
static void setSymbolName(Operation *symbol, StringAttr name)
Sets the name of the given symbol operation.
This class provides all of the information necessary to convert a type signature.
void addConversion(FnT &&callback)
Register a conversion function.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static spirv::StructType decorateType(spirv::StructType structType)
Returns a new StructType with layout decoration.
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Builder from ArrayRef<T>.
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
unsigned getNumElements() const
StorageClass getStorageClass() const
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
TypeRange getElementTypes() const
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect.
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
SmallVector< IntT > convertArrayToIndices(ArrayRef< Attribute > attrs)
Convert an array of integer attributes to a vector of integers that can be used as indices in LLVM op...
Type getVectorElementType(Type type)
Returns the element type of any vector type compatible with the LLVM dialect.
Include the generated interface declarations.
void populateSPIRVToLLVMFunctionConversionPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given list with patterns for function conversion from SPIR-V to LLVM.
void populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter, spirv::ClientAPI clientAPIForAddressSpaceMapping=spirv::ClientAPI::Unknown)
Populates type conversions with additional SPIR-V types.
void populateSPIRVToLLVMConversionPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, spirv::ClientAPI clientAPIForAddressSpaceMapping=spirv::ClientAPI::Unknown)
Populates the given list with patterns that convert from SPIR-V to LLVM.
void populateSPIRVToLLVMModuleConversionPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given patterns for module conversion from SPIR-V to LLVM.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void encodeBindAttribute(ModuleOp module)
Encodes global variable's descriptor set and binding into its name if they both exist.