23#include "llvm/ADT/TypeSwitch.h"
24#include "llvm/Support/FormatVariadic.h"
26#define DEBUG_TYPE "spirv-to-llvm-pattern"
38 if (
auto vecType = dyn_cast<VectorType>(type))
39 return vecType.getElementType().isSignedInteger();
47 if (
auto vecType = dyn_cast<VectorType>(type))
48 return vecType.getElementType().isUnsignedInteger();
55 if (
auto intType = dyn_cast<IntegerType>(type))
56 return intType.getWidth();
57 if (
auto vecType = dyn_cast<VectorType>(type))
58 if (
auto intType = dyn_cast<IntegerType>(vecType.getElementType()))
59 return intType.getWidth();
66 "bitwidth is not supported for this type");
69 auto vecType = dyn_cast<VectorType>(type);
70 auto elementType = vecType.getElementType();
71 assert(elementType.isIntOrFloat() &&
72 "only integers and floats have a bitwidth");
73 return elementType.getIntOrFloatBitWidth();
78 if (
auto vecTy = dyn_cast<VectorType>(type))
79 type = vecTy.getElementType();
80 return cast<IntegerType>(type).getWidth();
85 if (
auto vecType = dyn_cast<VectorType>(type)) {
86 auto integerType = cast<IntegerType>(vecType.getElementType());
89 auto integerType = cast<IntegerType>(type);
96 if (isa<VectorType>(srcType)) {
97 return LLVM::ConstantOp::create(
98 rewriter, loc, dstType,
102 return LLVM::ConstantOp::create(rewriter, loc, dstType,
109 if (
auto vecType = dyn_cast<VectorType>(srcType)) {
110 auto floatType = cast<FloatType>(vecType.getElementType());
111 return LLVM::ConstantOp::create(
112 rewriter, loc, dstType,
116 auto floatType = cast<FloatType>(srcType);
117 return LLVM::ConstantOp::create(rewriter, loc, dstType,
130 auto srcType = value.
getType();
136 if (valueBitWidth < targetBitWidth)
137 return LLVM::ZExtOp::create(rewriter, loc, llvmType, value);
142 if (valueBitWidth > targetBitWidth)
143 return LLVM::TruncOp::create(rewriter, loc, llvmType, value);
150 ConversionPatternRewriter &rewriter) {
151 auto vectorType = VectorType::get(numElements, toBroadcast.
getType());
152 auto llvmVectorType = typeConverter.convertType(vectorType);
153 auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32));
154 Value broadcasted = LLVM::PoisonOp::create(rewriter, loc, llvmVectorType);
155 for (
unsigned i = 0; i < numElements; ++i) {
156 auto index = LLVM::ConstantOp::create(rewriter, loc, llvmI32Type,
157 rewriter.getI32IntegerAttr(i));
158 broadcasted = LLVM::InsertElementOp::create(
159 rewriter, loc, llvmVectorType, broadcasted, toBroadcast,
index);
167 ConversionPatternRewriter &rewriter) {
168 if (
auto vectorType = dyn_cast<VectorType>(srcType)) {
169 unsigned numElements = vectorType.getNumElements();
170 return broadcast(loc, value, numElements, typeConverter, rewriter);
187 ConversionPatternRewriter &rewriter) {
201 if (failed(converter.convertTypes(type.
getElementTypes(), elementsVector)))
203 return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
211 if (failed(converter.convertTypes(type.
getElementTypes(), elementsVector)))
213 return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
220 return LLVM::ConstantOp::create(
221 rewriter, loc, IntegerType::get(rewriter.
getContext(), 32),
227 ConversionPatternRewriter &rewriter,
229 unsigned alignment,
bool isVolatile,
230 bool isNonTemporal) {
231 if (
auto loadOp = dyn_cast<spirv::LoadOp>(op)) {
232 auto dstType = typeConverter.convertType(loadOp.getType());
234 return rewriter.notifyMatchFailure(op,
"type conversion failed");
235 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
236 loadOp, dstType, spirv::LoadOpAdaptor(operands).getPtr(), alignment,
237 isVolatile, isNonTemporal);
240 auto storeOp = cast<spirv::StoreOp>(op);
241 spirv::StoreOpAdaptor adaptor(operands);
242 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValue(),
243 adaptor.getPtr(), alignment,
244 isVolatile, isNonTemporal);
259 auto sizeInBytes = cast<spirv::SPIRVType>(elementType).getSizeInBytes();
260 if (stride != 0 && (!sizeInBytes || *sizeInBytes != stride))
263 auto llvmElementType = converter.convertType(elementType);
265 return LLVM::LLVMArrayType::get(llvmElementType, numElements);
272 spirv::ClientAPI clientAPI) {
273 unsigned addressSpace =
275 return LLVM::LLVMPointerType::get(type.getContext(), addressSpace);
286 return LLVM::LLVMArrayType::get(elementType, 0);
295 if (!memberDecorations.empty())
310 using SPIRVToLLVMConversion<spirv::AccessChainOp>::SPIRVToLLVMConversion;
313 matchAndRewrite(spirv::AccessChainOp op, OpAdaptor adaptor,
314 ConversionPatternRewriter &rewriter)
const override {
316 getTypeConverter()->convertType(op.getComponentPtr().getType());
318 return rewriter.notifyMatchFailure(op,
"type conversion failed");
320 auto indices = llvm::to_vector<4>(adaptor.getIndices());
321 Type indexType = op.getIndices().front().getType();
322 auto llvmIndexType = getTypeConverter()->convertType(indexType);
324 return rewriter.notifyMatchFailure(op,
"type conversion failed");
326 LLVM::ConstantOp::create(rewriter, op.getLoc(), llvmIndexType,
327 rewriter.getIntegerAttr(indexType, 0));
330 auto elementType = getTypeConverter()->convertType(
331 cast<spirv::PointerType>(op.getBasePtr().getType()).getPointeeType());
333 return rewriter.notifyMatchFailure(op,
"type conversion failed");
334 rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, dstType, elementType,
335 adaptor.getBasePtr(),
indices);
342 using SPIRVToLLVMConversion<spirv::AddressOfOp>::SPIRVToLLVMConversion;
345 matchAndRewrite(spirv::AddressOfOp op, OpAdaptor adaptor,
346 ConversionPatternRewriter &rewriter)
const override {
347 auto dstType = getTypeConverter()->convertType(op.getPointer().getType());
349 return rewriter.notifyMatchFailure(op,
"type conversion failed");
350 rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, dstType,
356class BitFieldInsertPattern
359 using SPIRVToLLVMConversion<spirv::BitFieldInsertOp>::SPIRVToLLVMConversion;
362 matchAndRewrite(spirv::BitFieldInsertOp op, OpAdaptor adaptor,
363 ConversionPatternRewriter &rewriter)
const override {
364 auto srcType = op.getType();
365 auto dstType = getTypeConverter()->convertType(srcType);
367 return rewriter.notifyMatchFailure(op,
"type conversion failed");
368 Location loc = op.getLoc();
372 *getTypeConverter(), rewriter);
374 *getTypeConverter(), rewriter);
378 Value maskShiftedByCount =
379 LLVM::ShlOp::create(rewriter, loc, dstType, minusOne, count);
380 Value negated = LLVM::XOrOp::create(rewriter, loc, dstType,
381 maskShiftedByCount, minusOne);
382 Value maskShiftedByCountAndOffset =
383 LLVM::ShlOp::create(rewriter, loc, dstType, negated, offset);
384 Value mask = LLVM::XOrOp::create(rewriter, loc, dstType,
385 maskShiftedByCountAndOffset, minusOne);
390 LLVM::AndOp::create(rewriter, loc, dstType, op.getBase(), mask);
391 Value insertShiftedByOffset =
392 LLVM::ShlOp::create(rewriter, loc, dstType, op.getInsert(), offset);
393 rewriter.replaceOpWithNewOp<LLVM::OrOp>(op, dstType, baseAndMask,
394 insertShiftedByOffset);
400class ConstantScalarAndVectorPattern
403 using SPIRVToLLVMConversion<spirv::ConstantOp>::SPIRVToLLVMConversion;
406 matchAndRewrite(spirv::ConstantOp constOp, OpAdaptor adaptor,
407 ConversionPatternRewriter &rewriter)
const override {
408 auto srcType = constOp.getType();
409 if (!isa<VectorType>(srcType) && !srcType.isIntOrFloat())
412 auto dstType = getTypeConverter()->convertType(srcType);
414 return rewriter.notifyMatchFailure(constOp,
"type conversion failed");
423 auto signlessType = rewriter.getIntegerType(
getBitWidth(srcType));
425 if (isa<VectorType>(srcType)) {
426 auto dstElementsAttr = cast<DenseIntElementsAttr>(constOp.getValue());
427 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
429 dstElementsAttr.mapValues(
430 signlessType, [&](
const APInt &value) {
return value; }));
433 auto srcAttr = cast<IntegerAttr>(constOp.getValue());
434 auto dstAttr = rewriter.getIntegerAttr(signlessType, srcAttr.getValue());
435 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, dstAttr);
438 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
439 constOp, dstType, adaptor.getOperands(), constOp->getAttrs());
444class BitFieldSExtractPattern
447 using SPIRVToLLVMConversion<spirv::BitFieldSExtractOp>::SPIRVToLLVMConversion;
450 matchAndRewrite(spirv::BitFieldSExtractOp op, OpAdaptor adaptor,
451 ConversionPatternRewriter &rewriter)
const override {
452 auto srcType = op.getType();
453 auto dstType = getTypeConverter()->convertType(srcType);
455 return rewriter.notifyMatchFailure(op,
"type conversion failed");
456 Location loc = op.getLoc();
460 *getTypeConverter(), rewriter);
462 *getTypeConverter(), rewriter);
465 IntegerType integerType;
466 if (
auto vecType = dyn_cast<VectorType>(srcType))
467 integerType = cast<IntegerType>(vecType.getElementType());
469 integerType = cast<IntegerType>(srcType);
471 auto baseSize = rewriter.getIntegerAttr(integerType,
getBitWidth(srcType));
473 isa<VectorType>(srcType)
474 ? LLVM::ConstantOp::create(
475 rewriter, loc, dstType,
476 SplatElementsAttr::get(cast<ShapedType>(srcType), baseSize))
477 : LLVM::ConstantOp::create(rewriter, loc, dstType, baseSize);
481 Value countPlusOffset =
482 LLVM::AddOp::create(rewriter, loc, dstType, count, offset);
483 Value amountToShiftLeft =
484 LLVM::SubOp::create(rewriter, loc, dstType, size, countPlusOffset);
485 Value baseShiftedLeft = LLVM::ShlOp::create(
486 rewriter, loc, dstType, op.getBase(), amountToShiftLeft);
489 Value amountToShiftRight =
490 LLVM::AddOp::create(rewriter, loc, dstType, offset, amountToShiftLeft);
491 rewriter.replaceOpWithNewOp<LLVM::AShrOp>(op, dstType, baseShiftedLeft,
497class BitFieldUExtractPattern
500 using SPIRVToLLVMConversion<spirv::BitFieldUExtractOp>::SPIRVToLLVMConversion;
503 matchAndRewrite(spirv::BitFieldUExtractOp op, OpAdaptor adaptor,
504 ConversionPatternRewriter &rewriter)
const override {
505 auto srcType = op.getType();
506 auto dstType = getTypeConverter()->convertType(srcType);
508 return rewriter.notifyMatchFailure(op,
"type conversion failed");
509 Location loc = op.getLoc();
513 *getTypeConverter(), rewriter);
515 *getTypeConverter(), rewriter);
519 Value maskShiftedByCount =
520 LLVM::ShlOp::create(rewriter, loc, dstType, minusOne, count);
521 Value mask = LLVM::XOrOp::create(rewriter, loc, dstType, maskShiftedByCount,
526 LLVM::LShrOp::create(rewriter, loc, dstType, op.getBase(), offset);
527 rewriter.replaceOpWithNewOp<LLVM::AndOp>(op, dstType, shiftedBase, mask);
534 using SPIRVToLLVMConversion<spirv::BranchOp>::SPIRVToLLVMConversion;
537 matchAndRewrite(spirv::BranchOp branchOp, OpAdaptor adaptor,
538 ConversionPatternRewriter &rewriter)
const override {
539 rewriter.replaceOpWithNewOp<LLVM::BrOp>(branchOp, adaptor.getOperands(),
540 branchOp.getTarget());
545class BranchConditionalConversionPattern
548 using SPIRVToLLVMConversion<
549 spirv::BranchConditionalOp>::SPIRVToLLVMConversion;
552 matchAndRewrite(spirv::BranchConditionalOp op, OpAdaptor adaptor,
553 ConversionPatternRewriter &rewriter)
const override {
556 if (
auto weights = op.getBranchWeights()) {
557 SmallVector<int32_t> weightValues;
558 for (
auto weight : weights->getAsRange<IntegerAttr>())
559 weightValues.push_back(weight.getInt());
563 rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
564 op, op.getCondition(), op.getTrueBlockArguments(),
565 op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(),
574class CompositeExtractPattern
577 using SPIRVToLLVMConversion<spirv::CompositeExtractOp>::SPIRVToLLVMConversion;
580 matchAndRewrite(spirv::CompositeExtractOp op, OpAdaptor adaptor,
581 ConversionPatternRewriter &rewriter)
const override {
582 auto dstType = this->getTypeConverter()->convertType(op.getType());
584 return rewriter.notifyMatchFailure(op,
"type conversion failed");
586 Type containerType = op.getComposite().getType();
587 if (isa<VectorType>(containerType)) {
588 Location loc = op.getLoc();
589 IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]);
591 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
592 op, dstType, adaptor.getComposite(), index);
596 rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(
597 op, adaptor.getComposite(),
606class CompositeInsertPattern
609 using SPIRVToLLVMConversion<spirv::CompositeInsertOp>::SPIRVToLLVMConversion;
612 matchAndRewrite(spirv::CompositeInsertOp op, OpAdaptor adaptor,
613 ConversionPatternRewriter &rewriter)
const override {
614 auto dstType = this->getTypeConverter()->convertType(op.getType());
616 return rewriter.notifyMatchFailure(op,
"type conversion failed");
618 Type containerType = op.getComposite().getType();
619 if (isa<VectorType>(containerType)) {
620 Location loc = op.getLoc();
621 IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]);
623 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
624 op, dstType, adaptor.getComposite(), adaptor.getObject(), index);
628 rewriter.replaceOpWithNewOp<LLVM::InsertValueOp>(
629 op, adaptor.getComposite(), adaptor.getObject(),
637template <
typename SPIRVOp,
typename LLVMOp>
640 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
643 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
644 ConversionPatternRewriter &rewriter)
const override {
645 auto dstType = this->getTypeConverter()->convertType(op.getType());
647 return rewriter.notifyMatchFailure(op,
"type conversion failed");
648 rewriter.template replaceOpWithNewOp<LLVMOp>(
649 op, dstType, adaptor.getOperands(), op->getAttrs());
656class ExecutionModePattern
659 using SPIRVToLLVMConversion<spirv::ExecutionModeOp>::SPIRVToLLVMConversion;
662 matchAndRewrite(spirv::ExecutionModeOp op, OpAdaptor adaptor,
663 ConversionPatternRewriter &rewriter)
const override {
667 ModuleOp module = op->getParentOfType<ModuleOp>();
668 spirv::ExecutionModeAttr executionModeAttr = op.getExecutionModeAttr();
669 std::string moduleName;
670 if (module.getName().has_value())
671 moduleName =
"_" +
module.getName()->str();
674 std::string executionModeInfoName = llvm::formatv(
675 "__spv_{0}_{1}_execution_mode_info_{2}", moduleName, op.getFn().str(),
676 static_cast<uint32_t
>(executionModeAttr.getValue()));
678 MLIRContext *context = rewriter.getContext();
679 OpBuilder::InsertionGuard guard(rewriter);
680 rewriter.setInsertionPointToStart(module.getBody());
687 auto llvmI32Type = IntegerType::get(context, 32);
688 SmallVector<Type, 2> fields;
689 fields.push_back(llvmI32Type);
691 if (!values.empty()) {
692 auto arrayType = LLVM::LLVMArrayType::get(llvmI32Type, values.size());
693 fields.push_back(arrayType);
695 auto structType = LLVM::LLVMStructType::getLiteral(context, fields);
698 auto global = LLVM::GlobalOp::create(
699 rewriter, UnknownLoc::get(context), structType,
true,
700 LLVM::Linkage::External, executionModeInfoName, Attribute(),
702 Location loc = global.getLoc();
703 Region ®ion = global.getInitializerRegion();
704 Block *block = rewriter.createBlock(®ion);
707 rewriter.setInsertionPointToStart(block);
708 Value structValue = LLVM::PoisonOp::create(rewriter, loc, structType);
709 Value executionMode = LLVM::ConstantOp::create(
710 rewriter, loc, llvmI32Type,
711 rewriter.getI32IntegerAttr(
712 static_cast<uint32_t
>(executionModeAttr.getValue())));
713 SmallVector<int64_t> position{0};
714 structValue = LLVM::InsertValueOp::create(rewriter, loc, structValue,
715 executionMode, position);
718 for (
unsigned i = 0, e = values.size(); i < e; ++i) {
719 auto attr = values.getValue()[i];
720 Value entry = LLVM::ConstantOp::create(rewriter, loc, llvmI32Type, attr);
721 structValue = LLVM::InsertValueOp::create(
722 rewriter, loc, structValue, entry, ArrayRef<int64_t>({1, i}));
724 LLVM::ReturnOp::create(rewriter, loc, ArrayRef<Value>({structValue}));
725 rewriter.eraseOp(op);
734class GlobalVariablePattern
737 template <
typename... Args>
738 GlobalVariablePattern(spirv::ClientAPI clientAPI, Args &&...args)
739 : SPIRVToLLVMConversion<spirv::GlobalVariableOp>(
740 std::forward<Args>(args)...),
741 clientAPI(clientAPI) {}
744 matchAndRewrite(spirv::GlobalVariableOp op, OpAdaptor adaptor,
745 ConversionPatternRewriter &rewriter)
const override {
748 if (op.getInitializer())
751 auto srcType = cast<spirv::PointerType>(op.getType());
752 auto dstType = getTypeConverter()->convertType(srcType.getPointeeType());
754 return rewriter.notifyMatchFailure(op,
"type conversion failed");
759 auto storageClass = srcType.getStorageClass();
760 switch (storageClass) {
761 case spirv::StorageClass::Input:
762 case spirv::StorageClass::Private:
763 case spirv::StorageClass::Output:
764 case spirv::StorageClass::StorageBuffer:
765 case spirv::StorageClass::UniformConstant:
774 bool isConstant = (storageClass == spirv::StorageClass::Input) ||
775 (storageClass == spirv::StorageClass::UniformConstant);
781 auto linkage = storageClass == spirv::StorageClass::Private
782 ? LLVM::Linkage::Private
783 : LLVM::Linkage::External;
784 StringAttr locationAttrName = op.getLocationAttrName();
785 IntegerAttr locationAttr = op.getLocationAttr();
786 auto newGlobalOp = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
787 op, dstType, isConstant, linkage, op.getSymName(), Attribute(),
792 newGlobalOp->setAttr(locationAttrName, locationAttr);
798 spirv::ClientAPI clientAPI;
803template <
typename SPIRVOp,
typename LLVMExtOp,
typename LLVMTruncOp>
806 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
809 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
810 ConversionPatternRewriter &rewriter)
const override {
812 Type fromType = op.getOperand().getType();
813 Type toType = op.getType();
815 auto dstType = this->getTypeConverter()->convertType(toType);
817 return rewriter.notifyMatchFailure(op,
"type conversion failed");
820 rewriter.template replaceOpWithNewOp<LLVMExtOp>(op, dstType,
821 adaptor.getOperands());
825 rewriter.template replaceOpWithNewOp<LLVMTruncOp>(op, dstType,
826 adaptor.getOperands());
833class FunctionCallPattern
836 using SPIRVToLLVMConversion<spirv::FunctionCallOp>::SPIRVToLLVMConversion;
839 matchAndRewrite(spirv::FunctionCallOp callOp, OpAdaptor adaptor,
840 ConversionPatternRewriter &rewriter)
const override {
841 if (callOp.getNumResults() == 0) {
842 auto newOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
843 callOp,
TypeRange(), adaptor.getOperands(), callOp->getAttrs());
844 newOp.getProperties().operandSegmentSizes = {
845 static_cast<int32_t
>(adaptor.getOperands().size()), 0};
846 newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});
851 auto dstType = getTypeConverter()->convertType(callOp.getType(0));
853 return rewriter.notifyMatchFailure(callOp,
"type conversion failed");
854 auto newOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
855 callOp, dstType, adaptor.getOperands(), callOp->getAttrs());
856 newOp.getProperties().operandSegmentSizes = {
857 static_cast<int32_t
>(adaptor.getOperands().size()), 0};
858 newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});
864template <
typename SPIRVOp, LLVM::FCmpPredicate predicate>
867 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
870 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
871 ConversionPatternRewriter &rewriter)
const override {
873 auto dstType = this->getTypeConverter()->convertType(op.getType());
875 return rewriter.notifyMatchFailure(op,
"type conversion failed");
877 rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>(
878 op, dstType, predicate, op.getOperand1(), op.getOperand2());
884template <
typename SPIRVOp, LLVM::ICmpPredicate predicate>
887 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
890 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
891 ConversionPatternRewriter &rewriter)
const override {
893 auto dstType = this->getTypeConverter()->convertType(op.getType());
895 return rewriter.notifyMatchFailure(op,
"type conversion failed");
897 rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>(
898 op, dstType, predicate, op.getOperand1(), op.getOperand2());
903class InverseSqrtPattern
906 using SPIRVToLLVMConversion<spirv::GLInverseSqrtOp>::SPIRVToLLVMConversion;
909 matchAndRewrite(spirv::GLInverseSqrtOp op, OpAdaptor adaptor,
910 ConversionPatternRewriter &rewriter)
const override {
911 auto srcType = op.getType();
912 auto dstType = getTypeConverter()->convertType(srcType);
914 return rewriter.notifyMatchFailure(op,
"type conversion failed");
916 Location loc = op.getLoc();
918 Value sqrt = LLVM::SqrtOp::create(rewriter, loc, dstType, op.getOperand());
919 rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, dstType, one, sqrt);
925template <
typename SPIRVOp>
928 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
931 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
932 ConversionPatternRewriter &rewriter)
const override {
933 if (!op.getMemoryAccess()) {
935 *this->getTypeConverter(), 0,
939 auto memoryAccess = *op.getMemoryAccess();
940 switch (memoryAccess) {
941 case spirv::MemoryAccess::Aligned:
942 case spirv::MemoryAccess::None:
943 case spirv::MemoryAccess::Nontemporal:
944 case spirv::MemoryAccess::Volatile: {
946 memoryAccess == spirv::MemoryAccess::Aligned ? *op.getAlignment() : 0;
947 bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal;
948 bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile;
950 *this->getTypeConverter(), alignment,
951 isVolatile, isNonTemporal);
961template <
typename SPIRVOp>
964 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
967 matchAndRewrite(SPIRVOp notOp,
typename SPIRVOp::Adaptor adaptor,
968 ConversionPatternRewriter &rewriter)
const override {
969 auto srcType = notOp.getType();
970 auto dstType = this->getTypeConverter()->convertType(srcType);
972 return rewriter.notifyMatchFailure(notOp,
"type conversion failed");
974 Location loc = notOp.getLoc();
977 isa<VectorType>(srcType)
978 ? LLVM::ConstantOp::create(
979 rewriter, loc, dstType,
980 SplatElementsAttr::get(cast<VectorType>(srcType), minusOne))
981 : LLVM::ConstantOp::create(rewriter, loc, dstType, minusOne);
982 rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType,
983 notOp.getOperand(), mask);
989template <
typename SPIRVOp>
992 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
995 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
996 ConversionPatternRewriter &rewriter)
const override {
997 rewriter.eraseOp(op);
1004 using SPIRVToLLVMConversion<spirv::ReturnOp>::SPIRVToLLVMConversion;
1007 matchAndRewrite(spirv::ReturnOp returnOp, OpAdaptor adaptor,
1008 ConversionPatternRewriter &rewriter)
const override {
1009 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp, ArrayRef<Type>(),
1017 using SPIRVToLLVMConversion<spirv::ReturnValueOp>::SPIRVToLLVMConversion;
1020 matchAndRewrite(spirv::ReturnValueOp returnValueOp, OpAdaptor adaptor,
1021 ConversionPatternRewriter &rewriter)
const override {
1022 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnValueOp, ArrayRef<Type>(),
1023 adaptor.getOperands());
1032 bool convergent =
true) {
1033 auto func = dyn_cast_or_null<LLVM::LLVMFuncOp>(
1039 func = LLVM::LLVMFuncOp::create(
1040 b, symbolTable->
getLoc(), name,
1041 LLVM::LLVMFunctionType::get(resultType, paramTypes));
1042 func.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
1043 func.setConvergent(convergent);
1044 func.setNoUnwind(
true);
1045 func.setWillReturn(
true);
1050 LLVM::LLVMFuncOp
func,
1052 auto call = LLVM::CallOp::create(builder, loc,
func, args);
1053 call.setCConv(
func.getCConv());
1054 call.setConvergentAttr(
func.getConvergentAttr());
1055 call.setNoUnwindAttr(
func.getNoUnwindAttr());
1056 call.setWillReturnAttr(
func.getWillReturnAttr());
1060template <
typename BarrierOpTy>
1063 using OpAdaptor =
typename SPIRVToLLVMConversion<BarrierOpTy>::OpAdaptor;
1065 using SPIRVToLLVMConversion<BarrierOpTy>::SPIRVToLLVMConversion;
1067 static constexpr StringRef getFuncName();
1070 matchAndRewrite(BarrierOpTy controlBarrierOp, OpAdaptor adaptor,
1071 ConversionPatternRewriter &rewriter)
const override {
1072 constexpr StringRef funcName = getFuncName();
1073 Operation *symbolTable =
1074 controlBarrierOp->template getParentWithTrait<OpTrait::SymbolTable>();
1076 Type i32 = rewriter.getI32Type();
1078 Type voidTy = rewriter.getType<LLVM::LLVMVoidType>();
1079 LLVM::LLVMFuncOp func =
1082 Location loc = controlBarrierOp->getLoc();
1083 Value execution = LLVM::ConstantOp::create(
1084 rewriter, loc, i32,
static_cast<int32_t
>(adaptor.getExecutionScope()));
1085 Value memory = LLVM::ConstantOp::create(
1086 rewriter, loc, i32,
static_cast<int32_t
>(adaptor.getMemoryScope()));
1087 Value semantics = LLVM::ConstantOp::create(
1088 rewriter, loc, i32,
static_cast<int32_t
>(adaptor.getMemorySemantics()));
1091 {execution, memory, semantics});
1093 rewriter.replaceOp(controlBarrierOp, call);
1100StringRef getTypeMangling(
Type type,
bool isSigned) {
1102 .Case<Float16Type>([](
auto) {
return "Dh"; })
1103 .Case<Float32Type>([](
auto) {
return "f"; })
1104 .Case<Float64Type>([](
auto) {
return "d"; })
1105 .Case<IntegerType>([isSigned](IntegerType intTy) {
1106 switch (intTy.getWidth()) {
1110 return (isSigned) ?
"a" :
"c";
1112 return (isSigned) ?
"s" :
"t";
1114 return (isSigned) ?
"i" :
"j";
1116 return (isSigned) ?
"l" :
"m";
1118 llvm_unreachable(
"Unsupported integer width");
1121 .DefaultUnreachable(
"No mangling defined");
1124template <
typename ReduceOp>
1125constexpr StringLiteral getGroupFuncName();
1128constexpr StringLiteral getGroupFuncName<spirv::GroupIAddOp>() {
1129 return "_Z17__spirv_GroupIAddii";
1132constexpr StringLiteral getGroupFuncName<spirv::GroupFAddOp>() {
1133 return "_Z17__spirv_GroupFAddii";
1136constexpr StringLiteral getGroupFuncName<spirv::GroupSMinOp>() {
1137 return "_Z17__spirv_GroupSMinii";
1140constexpr StringLiteral getGroupFuncName<spirv::GroupUMinOp>() {
1141 return "_Z17__spirv_GroupUMinii";
1144constexpr StringLiteral getGroupFuncName<spirv::GroupFMinOp>() {
1145 return "_Z17__spirv_GroupFMinii";
1148constexpr StringLiteral getGroupFuncName<spirv::GroupSMaxOp>() {
1149 return "_Z17__spirv_GroupSMaxii";
1152constexpr StringLiteral getGroupFuncName<spirv::GroupUMaxOp>() {
1153 return "_Z17__spirv_GroupUMaxii";
1156constexpr StringLiteral getGroupFuncName<spirv::GroupFMaxOp>() {
1157 return "_Z17__spirv_GroupFMaxii";
1160constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIAddOp>() {
1161 return "_Z27__spirv_GroupNonUniformIAddii";
1164constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFAddOp>() {
1165 return "_Z27__spirv_GroupNonUniformFAddii";
1168constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIMulOp>() {
1169 return "_Z27__spirv_GroupNonUniformIMulii";
1172constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMulOp>() {
1173 return "_Z27__spirv_GroupNonUniformFMulii";
1176constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMinOp>() {
1177 return "_Z27__spirv_GroupNonUniformSMinii";
1180constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMinOp>() {
1181 return "_Z27__spirv_GroupNonUniformUMinii";
1184constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMinOp>() {
1185 return "_Z27__spirv_GroupNonUniformFMinii";
1188constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMaxOp>() {
1189 return "_Z27__spirv_GroupNonUniformSMaxii";
1192constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMaxOp>() {
1193 return "_Z27__spirv_GroupNonUniformUMaxii";
1196constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMaxOp>() {
1197 return "_Z27__spirv_GroupNonUniformFMaxii";
1200constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseAndOp>() {
1201 return "_Z33__spirv_GroupNonUniformBitwiseAndii";
1204constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseOrOp>() {
1205 return "_Z32__spirv_GroupNonUniformBitwiseOrii";
1208constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseXorOp>() {
1209 return "_Z33__spirv_GroupNonUniformBitwiseXorii";
1212constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalAndOp>() {
1213 return "_Z33__spirv_GroupNonUniformLogicalAndii";
1216constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalOrOp>() {
1217 return "_Z32__spirv_GroupNonUniformLogicalOrii";
1220constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalXorOp>() {
1221 return "_Z33__spirv_GroupNonUniformLogicalXorii";
1225template <
typename ReduceOp,
bool Signed = false,
bool NonUniform = false>
1228 using SPIRVToLLVMConversion<ReduceOp>::SPIRVToLLVMConversion;
1231 matchAndRewrite(ReduceOp op,
typename ReduceOp::Adaptor adaptor,
1232 ConversionPatternRewriter &rewriter)
const override {
1234 Type retTy = op.getResult().getType();
1238 SmallString<36> funcName = getGroupFuncName<ReduceOp>();
1239 funcName += getTypeMangling(retTy,
false);
1241 Type i32Ty = rewriter.getI32Type();
1242 SmallVector<Type> paramTypes{i32Ty, i32Ty, retTy};
1243 if constexpr (NonUniform) {
1244 if (adaptor.getClusterSize()) {
1246 paramTypes.push_back(i32Ty);
1250 Operation *symbolTable =
1251 op->template getParentWithTrait<OpTrait::SymbolTable>();
1253 LLVM::LLVMFuncOp func =
1256 Location loc = op.getLoc();
1257 Value scope = LLVM::ConstantOp::create(
1258 rewriter, loc, i32Ty,
1259 static_cast<int32_t
>(adaptor.getExecutionScope()));
1260 Value groupOp = LLVM::ConstantOp::create(
1261 rewriter, loc, i32Ty,
1262 static_cast<int32_t
>(adaptor.getGroupOperation()));
1263 SmallVector<Value> operands{scope, groupOp};
1264 operands.append(adaptor.getOperands().begin(), adaptor.getOperands().end());
1267 rewriter.replaceOp(op, call);
1274ControlBarrierPattern<spirv::ControlBarrierOp>::getFuncName() {
1275 return "_Z22__spirv_ControlBarrieriii";
1280ControlBarrierPattern<spirv::INTELControlBarrierArriveOp>::getFuncName() {
1281 return "_Z33__spirv_ControlBarrierArriveINTELiii";
1286ControlBarrierPattern<spirv::INTELControlBarrierWaitOp>::getFuncName() {
1287 return "_Z31__spirv_ControlBarrierWaitINTELiii";
1340 using SPIRVToLLVMConversion<spirv::LoopOp>::SPIRVToLLVMConversion;
1343 matchAndRewrite(spirv::LoopOp loopOp, OpAdaptor adaptor,
1344 ConversionPatternRewriter &rewriter)
const override {
1346 if (loopOp.getLoopControl() != spirv::LoopControl::None)
1350 if (loopOp.getBody().empty()) {
1351 rewriter.eraseOp(loopOp);
1355 Location loc = loopOp.getLoc();
1359 Block *currentBlock = rewriter.getBlock();
1361 Block *endBlock = rewriter.splitBlock(currentBlock, position);
1365 Block *entryBlock = loopOp.getEntryBlock();
1367 auto brOp = dyn_cast<spirv::BranchOp>(entryBlock->
getOperations().front());
1370 Block *headerBlock = loopOp.getHeaderBlock();
1371 rewriter.setInsertionPointToEnd(currentBlock);
1372 LLVM::BrOp::create(rewriter, loc, brOp.getBlockArguments(), headerBlock);
1373 rewriter.eraseBlock(entryBlock);
1376 Block *mergeBlock = loopOp.getMergeBlock();
1379 rewriter.setInsertionPointToEnd(mergeBlock);
1380 LLVM::BrOp::create(rewriter, loc, terminatorOperands, endBlock);
1382 rewriter.inlineRegionBefore(loopOp.getBody(), endBlock);
1393 using SPIRVToLLVMConversion<spirv::SelectionOp>::SPIRVToLLVMConversion;
1396 matchAndRewrite(spirv::SelectionOp op, OpAdaptor adaptor,
1397 ConversionPatternRewriter &rewriter)
const override {
1401 if (op.getSelectionControl() != spirv::SelectionControl::None)
1408 if (op.getBody().getBlocks().size() <= 2) {
1409 rewriter.eraseOp(op);
1413 Location loc = op.getLoc();
1417 auto *currentBlock = rewriter.getInsertionBlock();
1418 rewriter.setInsertionPointAfter(op);
1419 auto position = rewriter.getInsertionPoint();
1420 auto *continueBlock = rewriter.splitBlock(currentBlock, position);
1426 auto *headerBlock = op.getHeaderBlock();
1428 auto condBrOp = dyn_cast<spirv::BranchConditionalOp>(
1434 auto *mergeBlock = op.getMergeBlock();
1437 rewriter.setInsertionPointToEnd(mergeBlock);
1438 LLVM::BrOp::create(rewriter, loc, terminatorOperands, continueBlock);
1441 Block *trueBlock = condBrOp.getTrueBlock();
1442 Block *falseBlock = condBrOp.getFalseBlock();
1443 rewriter.setInsertionPointToEnd(currentBlock);
1444 LLVM::CondBrOp::create(rewriter, loc, condBrOp.getCondition(), trueBlock,
1445 condBrOp.getTrueTargetOperands(), falseBlock,
1446 condBrOp.getFalseTargetOperands());
1448 rewriter.eraseBlock(headerBlock);
1449 rewriter.inlineRegionBefore(op.getBody(), continueBlock);
1450 rewriter.replaceOp(op, continueBlock->getArguments());
1459template <
typename SPIRVOp,
typename LLVMOp>
1462 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
1465 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
1466 ConversionPatternRewriter &rewriter)
const override {
1468 auto dstType = this->getTypeConverter()->convertType(op.getType());
1470 return rewriter.notifyMatchFailure(op,
"type conversion failed");
1472 Type op1Type = op.getOperand1().getType();
1473 Type op2Type = op.getOperand2().getType();
1475 if (op1Type == op2Type) {
1476 rewriter.template replaceOpWithNewOp<LLVMOp>(op, dstType,
1477 adaptor.getOperands());
1481 std::optional<uint64_t> dstTypeWidth =
1483 std::optional<uint64_t> op2TypeWidth =
1486 if (!dstTypeWidth || !op2TypeWidth)
1489 Location loc = op.getLoc();
1491 if (op2TypeWidth < dstTypeWidth) {
1494 LLVM::ZExtOp::create(rewriter, loc, dstType, adaptor.getOperand2());
1497 LLVM::SExtOp::create(rewriter, loc, dstType, adaptor.getOperand2());
1499 }
else if (op2TypeWidth == dstTypeWidth) {
1500 extended = adaptor.getOperand2();
1506 LLVMOp::create(rewriter, loc, dstType, adaptor.getOperand1(), extended);
1507 rewriter.replaceOp(op,
result);
1514 using SPIRVToLLVMConversion<spirv::GLTanOp>::SPIRVToLLVMConversion;
1517 matchAndRewrite(spirv::GLTanOp tanOp, OpAdaptor adaptor,
1518 ConversionPatternRewriter &rewriter)
const override {
1519 auto dstType = getTypeConverter()->convertType(tanOp.getType());
1521 return rewriter.notifyMatchFailure(tanOp,
"type conversion failed");
1523 rewriter.replaceOpWithNewOp<LLVM::TanOp>(tanOp, dstType,
1524 adaptor.getOperands());
1531 using SPIRVToLLVMConversion<spirv::GLTanhOp>::SPIRVToLLVMConversion;
1534 matchAndRewrite(spirv::GLTanhOp tanhOp, OpAdaptor adaptor,
1535 ConversionPatternRewriter &rewriter)
const override {
1536 auto srcType = tanhOp.getType();
1537 auto dstType = getTypeConverter()->convertType(srcType);
1539 return rewriter.notifyMatchFailure(tanhOp,
"type conversion failed");
1541 rewriter.replaceOpWithNewOp<LLVM::TanhOp>(tanhOp, dstType,
1542 adaptor.getOperands());
1549 using SPIRVToLLVMConversion<spirv::VariableOp>::SPIRVToLLVMConversion;
1552 matchAndRewrite(spirv::VariableOp varOp, OpAdaptor adaptor,
1553 ConversionPatternRewriter &rewriter)
const override {
1554 auto srcType = varOp.getType();
1556 auto pointerTo = cast<spirv::PointerType>(srcType).getPointeeType();
1557 auto init = varOp.getInitializer();
1558 if (init && !pointerTo.isIntOrFloat() && !isa<VectorType>(pointerTo))
1561 auto dstType = getTypeConverter()->convertType(srcType);
1563 return rewriter.notifyMatchFailure(varOp,
"type conversion failed");
1565 Location loc = varOp.getLoc();
1568 auto elementType = getTypeConverter()->convertType(pointerTo);
1570 return rewriter.notifyMatchFailure(varOp,
"type conversion failed");
1571 rewriter.replaceOpWithNewOp<LLVM::AllocaOp>(varOp, dstType, elementType,
1575 auto elementType = getTypeConverter()->convertType(pointerTo);
1577 return rewriter.notifyMatchFailure(varOp,
"type conversion failed");
1579 LLVM::AllocaOp::create(rewriter, loc, dstType, elementType, size);
1580 LLVM::StoreOp::create(rewriter, loc, adaptor.getInitializer(), allocated);
1581 rewriter.replaceOp(varOp, allocated);
1590class BitcastConversionPattern
1593 using SPIRVToLLVMConversion<spirv::BitcastOp>::SPIRVToLLVMConversion;
1596 matchAndRewrite(spirv::BitcastOp bitcastOp, OpAdaptor adaptor,
1597 ConversionPatternRewriter &rewriter)
const override {
1598 auto dstType = getTypeConverter()->convertType(bitcastOp.getType());
1600 return rewriter.notifyMatchFailure(bitcastOp,
"type conversion failed");
1603 if (isa<LLVM::LLVMPointerType>(dstType)) {
1604 rewriter.replaceOp(bitcastOp, adaptor.getOperand());
1608 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
1609 bitcastOp, dstType, adaptor.getOperands(), bitcastOp->getAttrs());
1620 using SPIRVToLLVMConversion<spirv::FuncOp>::SPIRVToLLVMConversion;
1623 matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor,
1624 ConversionPatternRewriter &rewriter)
const override {
1628 auto funcType = funcOp.getFunctionType();
1629 TypeConverter::SignatureConversion signatureConverter(
1630 funcType.getNumInputs());
1631 auto llvmType =
static_cast<const LLVMTypeConverter *
>(getTypeConverter())
1632 ->convertFunctionSignature(
1634 false, signatureConverter);
1639 Location loc = funcOp.getLoc();
1640 StringRef name = funcOp.getName();
1641 auto newFuncOp = LLVM::LLVMFuncOp::create(rewriter, loc, name, llvmType);
1644 MLIRContext *context = funcOp.getContext();
1645 switch (funcOp.getFunctionControl()) {
1646 case spirv::FunctionControl::Inline:
1647 newFuncOp.setAlwaysInline(
true);
1649 case spirv::FunctionControl::DontInline:
1650 newFuncOp.setNoInline(
true);
1653#define DISPATCH(functionControl, llvmAttr) \
1654 case functionControl: \
1655 newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr})); \
1658 DISPATCH(spirv::FunctionControl::Pure,
1659 StringAttr::get(context,
"readonly"));
1660 DISPATCH(spirv::FunctionControl::Const,
1661 StringAttr::get(context,
"readnone"));
1671 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
1673 if (
failed(rewriter.convertRegionTypes(
1674 &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter))) {
1677 rewriter.eraseOp(funcOp);
1688 using SPIRVToLLVMConversion<spirv::ModuleOp>::SPIRVToLLVMConversion;
1691 matchAndRewrite(spirv::ModuleOp spvModuleOp, OpAdaptor adaptor,
1692 ConversionPatternRewriter &rewriter)
const override {
1695 ModuleOp::create(rewriter, spvModuleOp.getLoc(), spvModuleOp.getName());
1696 rewriter.inlineRegionBefore(spvModuleOp.getRegion(), newModuleOp.getBody());
1699 rewriter.eraseBlock(&newModuleOp.getBodyRegion().back());
1700 rewriter.eraseOp(spvModuleOp);
1709class VectorShufflePattern
1712 using SPIRVToLLVMConversion<spirv::VectorShuffleOp>::SPIRVToLLVMConversion;
1714 matchAndRewrite(spirv::VectorShuffleOp op, OpAdaptor adaptor,
1715 ConversionPatternRewriter &rewriter)
const override {
1716 Location loc = op.getLoc();
1717 auto components = adaptor.getComponents();
1718 auto vector1 = adaptor.getVector1();
1719 auto vector2 = adaptor.getVector2();
1720 int vector1Size = cast<VectorType>(vector1.getType()).getNumElements();
1721 int vector2Size = cast<VectorType>(vector2.getType()).getNumElements();
1722 if (vector1Size == vector2Size) {
1723 rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(
1724 op, vector1, vector2,
1729 auto dstType = getTypeConverter()->convertType(op.getType());
1731 return rewriter.notifyMatchFailure(op,
"type conversion failed");
1732 auto scalarType = cast<VectorType>(dstType).getElementType();
1733 auto componentsArray = components.getValue();
1734 auto *context = rewriter.getContext();
1735 auto llvmI32Type = IntegerType::get(context, 32);
1736 Value targetOp = LLVM::PoisonOp::create(rewriter, loc, dstType);
1737 for (
unsigned i = 0; i < componentsArray.size(); i++) {
1738 if (!isa<IntegerAttr>(componentsArray[i]))
1739 return op.emitError(
"unable to support non-constant component");
1741 int indexVal = cast<IntegerAttr>(componentsArray[i]).getInt();
1746 Value baseVector = vector1;
1747 if (indexVal >= vector1Size) {
1748 offsetVal = vector1Size;
1749 baseVector = vector2;
1752 Value dstIndex = LLVM::ConstantOp::create(
1753 rewriter, loc, llvmI32Type,
1754 rewriter.getIntegerAttr(rewriter.getI32Type(), i));
1755 Value index = LLVM::ConstantOp::create(
1756 rewriter, loc, llvmI32Type,
1757 rewriter.getIntegerAttr(rewriter.getI32Type(), indexVal - offsetVal));
1759 auto extractOp = LLVM::ExtractElementOp::create(rewriter, loc, scalarType,
1761 targetOp = LLVM::InsertElementOp::create(rewriter, loc, dstType, targetOp,
1762 extractOp, dstIndex);
1764 rewriter.replaceOp(op, targetOp);
1775 spirv::ClientAPI clientAPI) {
1792 spirv::ClientAPI clientAPI) {
1795 DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>,
1796 DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>,
1797 DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>,
1798 DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>,
1799 DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>,
1800 DirectConversionPattern<spirv::FMulOp, LLVM::FMulOp>,
1801 DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>,
1802 DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>,
1803 DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>,
1804 DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>,
1805 DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>,
1806 DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>,
1807 DirectConversionPattern<spirv::UModOp, LLVM::URemOp>,
1810 BitFieldInsertPattern, BitFieldUExtractPattern, BitFieldSExtractPattern,
1811 DirectConversionPattern<spirv::BitCountOp, LLVM::CtPopOp>,
1812 DirectConversionPattern<spirv::BitReverseOp, LLVM::BitReverseOp>,
1813 DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>,
1814 DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>,
1815 DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>,
1816 NotPattern<spirv::NotOp>,
1819 BitcastConversionPattern,
1820 DirectConversionPattern<spirv::ConvertFToSOp, LLVM::FPToSIOp>,
1821 DirectConversionPattern<spirv::ConvertFToUOp, LLVM::FPToUIOp>,
1822 DirectConversionPattern<spirv::ConvertSToFOp, LLVM::SIToFPOp>,
1823 DirectConversionPattern<spirv::ConvertUToFOp, LLVM::UIToFPOp>,
1824 IndirectCastPattern<spirv::FConvertOp, LLVM::FPExtOp, LLVM::FPTruncOp>,
1825 IndirectCastPattern<spirv::SConvertOp, LLVM::SExtOp, LLVM::TruncOp>,
1826 IndirectCastPattern<spirv::UConvertOp, LLVM::ZExtOp, LLVM::TruncOp>,
1829 IComparePattern<spirv::IEqualOp, LLVM::ICmpPredicate::eq>,
1830 IComparePattern<spirv::INotEqualOp, LLVM::ICmpPredicate::ne>,
1831 FComparePattern<spirv::FOrdEqualOp, LLVM::FCmpPredicate::oeq>,
1832 FComparePattern<spirv::FOrdGreaterThanOp, LLVM::FCmpPredicate::ogt>,
1833 FComparePattern<spirv::FOrdGreaterThanEqualOp, LLVM::FCmpPredicate::oge>,
1834 FComparePattern<spirv::FOrdLessThanEqualOp, LLVM::FCmpPredicate::ole>,
1835 FComparePattern<spirv::FOrdLessThanOp, LLVM::FCmpPredicate::olt>,
1836 FComparePattern<spirv::FOrdNotEqualOp, LLVM::FCmpPredicate::one>,
1837 FComparePattern<spirv::FUnordEqualOp, LLVM::FCmpPredicate::ueq>,
1838 FComparePattern<spirv::FUnordGreaterThanOp, LLVM::FCmpPredicate::ugt>,
1839 FComparePattern<spirv::FUnordGreaterThanEqualOp,
1840 LLVM::FCmpPredicate::uge>,
1841 FComparePattern<spirv::FUnordLessThanEqualOp, LLVM::FCmpPredicate::ule>,
1842 FComparePattern<spirv::FUnordLessThanOp, LLVM::FCmpPredicate::ult>,
1843 FComparePattern<spirv::FUnordNotEqualOp, LLVM::FCmpPredicate::une>,
1844 IComparePattern<spirv::SGreaterThanOp, LLVM::ICmpPredicate::sgt>,
1845 IComparePattern<spirv::SGreaterThanEqualOp, LLVM::ICmpPredicate::sge>,
1846 IComparePattern<spirv::SLessThanEqualOp, LLVM::ICmpPredicate::sle>,
1847 IComparePattern<spirv::SLessThanOp, LLVM::ICmpPredicate::slt>,
1848 IComparePattern<spirv::UGreaterThanOp, LLVM::ICmpPredicate::ugt>,
1849 IComparePattern<spirv::UGreaterThanEqualOp, LLVM::ICmpPredicate::uge>,
1850 IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>,
1851 IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>,
1854 ConstantScalarAndVectorPattern,
1857 BranchConversionPattern, BranchConditionalConversionPattern,
1858 FunctionCallPattern, LoopPattern, SelectionPattern,
1859 ErasePattern<spirv::MergeOp>,
1862 ErasePattern<spirv::EntryPointOp>, ExecutionModePattern,
1865 DirectConversionPattern<spirv::GLCeilOp, LLVM::FCeilOp>,
1866 DirectConversionPattern<spirv::GLCosOp, LLVM::CosOp>,
1867 DirectConversionPattern<spirv::GLExpOp, LLVM::ExpOp>,
1868 DirectConversionPattern<spirv::GLFAbsOp, LLVM::FAbsOp>,
1869 DirectConversionPattern<spirv::GLFloorOp, LLVM::FFloorOp>,
1870 DirectConversionPattern<spirv::GLFMaxOp, LLVM::MaxNumOp>,
1871 DirectConversionPattern<spirv::GLFMinOp, LLVM::MinNumOp>,
1872 DirectConversionPattern<spirv::GLLogOp, LLVM::LogOp>,
1873 DirectConversionPattern<spirv::GLSinOp, LLVM::SinOp>,
1874 DirectConversionPattern<spirv::GLSMaxOp, LLVM::SMaxOp>,
1875 DirectConversionPattern<spirv::GLSMinOp, LLVM::SMinOp>,
1876 DirectConversionPattern<spirv::GLSqrtOp, LLVM::SqrtOp>,
1877 InverseSqrtPattern, TanPattern, TanhPattern,
1880 DirectConversionPattern<spirv::LogicalAndOp, LLVM::AndOp>,
1881 DirectConversionPattern<spirv::LogicalOrOp, LLVM::OrOp>,
1882 IComparePattern<spirv::LogicalEqualOp, LLVM::ICmpPredicate::eq>,
1883 IComparePattern<spirv::LogicalNotEqualOp, LLVM::ICmpPredicate::ne>,
1884 NotPattern<spirv::LogicalNotOp>,
1887 AccessChainPattern, AddressOfPattern, LoadStorePattern<spirv::LoadOp>,
1888 LoadStorePattern<spirv::StoreOp>, VariablePattern,
1891 CompositeExtractPattern, CompositeInsertPattern,
1892 DirectConversionPattern<spirv::SelectOp, LLVM::SelectOp>,
1893 DirectConversionPattern<spirv::UndefOp, LLVM::UndefOp>,
1894 VectorShufflePattern,
1897 ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>,
1898 ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>,
1899 ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>,
1902 ReturnPattern, ReturnValuePattern,
1905 ControlBarrierPattern<spirv::ControlBarrierOp>,
1906 ControlBarrierPattern<spirv::INTELControlBarrierArriveOp>,
1907 ControlBarrierPattern<spirv::INTELControlBarrierWaitOp>,
1910 GroupReducePattern<spirv::GroupIAddOp>,
1911 GroupReducePattern<spirv::GroupFAddOp>,
1912 GroupReducePattern<spirv::GroupFMinOp>,
1913 GroupReducePattern<spirv::GroupUMinOp>,
1914 GroupReducePattern<spirv::GroupSMinOp,
true>,
1915 GroupReducePattern<spirv::GroupFMaxOp>,
1916 GroupReducePattern<spirv::GroupUMaxOp>,
1917 GroupReducePattern<spirv::GroupSMaxOp,
true>,
1918 GroupReducePattern<spirv::GroupNonUniformIAddOp,
false,
1920 GroupReducePattern<spirv::GroupNonUniformFAddOp,
false,
1922 GroupReducePattern<spirv::GroupNonUniformIMulOp,
false,
1924 GroupReducePattern<spirv::GroupNonUniformFMulOp,
false,
1926 GroupReducePattern<spirv::GroupNonUniformSMinOp,
true,
1928 GroupReducePattern<spirv::GroupNonUniformUMinOp,
false,
1930 GroupReducePattern<spirv::GroupNonUniformFMinOp,
false,
1932 GroupReducePattern<spirv::GroupNonUniformSMaxOp,
true,
1934 GroupReducePattern<spirv::GroupNonUniformUMaxOp,
false,
1936 GroupReducePattern<spirv::GroupNonUniformFMaxOp,
false,
1938 GroupReducePattern<spirv::GroupNonUniformBitwiseAndOp,
false,
1940 GroupReducePattern<spirv::GroupNonUniformBitwiseOrOp,
false,
1942 GroupReducePattern<spirv::GroupNonUniformBitwiseXorOp,
false,
1944 GroupReducePattern<spirv::GroupNonUniformLogicalAndOp,
false,
1946 GroupReducePattern<spirv::GroupNonUniformLogicalOrOp,
false,
1948 GroupReducePattern<spirv::GroupNonUniformLogicalXorOp,
false,
1963 patterns.add<ModuleConversionPattern>(
patterns.getContext(), typeConverter);
1974 auto spvModules =
module.getOps<spirv::ModuleOp>();
1975 for (
auto spvModule : spvModules) {
1976 spvModule.walk([&](spirv::GlobalVariableOp op) {
1977 IntegerAttr descriptorSet =
1979 IntegerAttr binding = op->getAttrOfType<IntegerAttr>(
kBinding);
1982 if (descriptorSet && binding) {
1985 auto moduleAndName =
1986 spvModule.getName().has_value()
1987 ? spvModule.getName()->str() +
"_" + op.getSymName().str()
1988 : op.getSymName().str();
1990 llvm::formatv(
"{0}_descriptor_set{1}_binding{2}", moduleAndName,
1991 std::to_string(descriptorSet.getInt()),
1992 std::to_string(binding.getInt()));
1993 auto nameAttr = StringAttr::get(op->getContext(), name);
1998 op.emitError(
"unable to replace all symbol uses for ") << name;
static LLVM::CallOp createSPIRVBuiltinCall(Location loc, ConversionPatternRewriter &rewriter, LLVM::LLVMFuncOp func, ValueRange args)
static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable, StringRef name, ArrayRef< Type > paramTypes, Type resultType, bool isMemNone, bool isConvergent)
static Value optionallyTruncateOrExtend(Location loc, Value value, Type llvmType, PatternRewriter &rewriter)
Utility function for bitfield ops:
static Value createFPConstant(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter, double value)
Creates llvm.mlir.constant with a floating-point scalar or vector value.
static constexpr StringRef kDescriptorSet
static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter, unsigned value)
Creates LLVM dialect constant with the given value.
static Type convertPointerType(spirv::PointerType type, const TypeConverter &converter, spirv::ClientAPI clientAPI)
Converts SPIR-V pointer type to LLVM pointer.
static Value processCountOrOffset(Location loc, Value value, Type srcType, Type dstType, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Utility function for bitfield ops: BitFieldInsert, BitFieldSExtract and BitFieldUExtract.
static unsigned getBitWidth(Type type)
Returns the bit width of integer, float or vector of float or integer values.
static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter, const TypeConverter &typeConverter, unsigned alignment, bool isVolatile, bool isNonTemporal)
Utility for spirv.Load and spirv.Store conversion.
static Type convertStructTypePacked(spirv::StructType type, const TypeConverter &converter)
Converts SPIR-V struct with no offset to packed LLVM struct.
static bool isSignedIntegerOrVector(Type type)
Returns true if the given type is a signed integer or vector type.
static bool isUnsignedIntegerOrVector(Type type)
Returns true if the given type is an unsigned integer or vector type.
static std::optional< Type > convertRuntimeArrayType(spirv::RuntimeArrayType type, TypeConverter &converter)
Converts SPIR-V runtime array to LLVM array.
static constexpr StringRef kBinding
Hook for descriptor set and binding number encoding.
static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder)
Creates IntegerAttribute with all bits set for given type.
static Value optionallyBroadcast(Location loc, Value value, Type srcType, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value. If srcType is a scalar, the value remains unchanged.
static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter)
Creates llvm.mlir.constant with all bits set for the given type.
static unsigned getLLVMTypeBitWidth(Type type)
Returns the bit width of LLVMType integer or vector.
static std::optional< uint64_t > getIntegerOrVectorElementWidth(Type type)
Returns the width of an integer or of the element type of an integer vector, if applicable.
#define DISPATCH(functionControl, llvmAttr)
static Type convertStructTypeWithOffset(spirv::StructType type, const TypeConverter &converter)
Converts SPIR-V struct with a regular (according to VulkanLayoutUtils) offset to LLVM struct.
static Type convertStructType(spirv::StructType type, const TypeConverter &converter)
Converts SPIR-V struct to LLVM struct.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static std::optional< Type > convertArrayType(spirv::ArrayType type, TypeConverter &converter)
Converts SPIR-V array type to LLVM array.
OpListType::iterator iterator
OpListType & getOperations()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
MLIRContext * getContext() const
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Location getLoc()
The source location the operation was defined or derived from.
operand_range getOperands()
Returns an iterator on the underlying Value's.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, StringAttr newSymbol, Operation *from)
Attempt to replace all uses of the given symbol 'oldSymbol' with the provided symbol 'newSymbol' that...
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
static void setSymbolName(Operation *symbol, StringAttr name)
Sets the name of the given symbol operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static spirv::StructType decorateType(spirv::StructType structType)
Returns a new StructType with layout decoration.
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
unsigned getNumElements() const
StorageClass getStorageClass() const
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
TypeRange getElementTypes() const
SmallVector< IntT > convertArrayToIndices(ArrayRef< Attribute > attrs)
Convert an array of integer attributes to a vector of integers that can be used as indices in LLVM op...
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Include the generated interface declarations.
unsigned storageClassToAddressSpace(spirv::ClientAPI clientAPI, spirv::StorageClass storageClass)
void populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter, spirv::ClientAPI clientAPIForAddressSpaceMapping=spirv::ClientAPI::Unknown)
Populates type conversions with additional SPIR-V types.
void populateSPIRVToLLVMFunctionConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given list with patterns for function conversion from SPIR-V to LLVM.
const FrozenRewritePatternSet & patterns
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
void populateSPIRVToLLVMConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, spirv::ClientAPI clientAPIForAddressSpaceMapping=spirv::ClientAPI::Unknown)
Populates the given list with patterns that convert from SPIR-V to LLVM.
void encodeBindAttribute(ModuleOp module)
Encodes global variable's descriptor set and binding into its name if they both exist.
void populateSPIRVToLLVMModuleConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given patterns for module conversion from SPIR-V to LLVM.