23#include "llvm/ADT/TypeSwitch.h"
24#include "llvm/Support/FormatVariadic.h"
26#define DEBUG_TYPE "spirv-to-llvm-pattern"
38 if (
auto vecType = dyn_cast<VectorType>(type))
39 return vecType.getElementType().isSignedInteger();
47 if (
auto vecType = dyn_cast<VectorType>(type))
48 return vecType.getElementType().isUnsignedInteger();
55 if (
auto intType = dyn_cast<IntegerType>(type))
56 return intType.getWidth();
57 if (
auto vecType = dyn_cast<VectorType>(type))
58 if (
auto intType = dyn_cast<IntegerType>(vecType.getElementType()))
59 return intType.getWidth();
66 "bitwidth is not supported for this type");
69 auto vecType = dyn_cast<VectorType>(type);
70 auto elementType = vecType.getElementType();
71 assert(elementType.isIntOrFloat() &&
72 "only integers and floats have a bitwidth");
73 return elementType.getIntOrFloatBitWidth();
78 if (
auto vecTy = dyn_cast<VectorType>(type))
79 type = vecTy.getElementType();
80 return cast<IntegerType>(type).getWidth();
85 if (
auto vecType = dyn_cast<VectorType>(type)) {
86 auto integerType = cast<IntegerType>(vecType.getElementType());
89 auto integerType = cast<IntegerType>(type);
96 if (isa<VectorType>(srcType)) {
97 return LLVM::ConstantOp::create(
98 rewriter, loc, dstType,
102 return LLVM::ConstantOp::create(rewriter, loc, dstType,
109 if (
auto vecType = dyn_cast<VectorType>(srcType)) {
110 auto floatType = cast<FloatType>(vecType.getElementType());
111 return LLVM::ConstantOp::create(
112 rewriter, loc, dstType,
116 auto floatType = cast<FloatType>(srcType);
117 return LLVM::ConstantOp::create(rewriter, loc, dstType,
130 auto srcType = value.
getType();
136 if (valueBitWidth < targetBitWidth)
137 return LLVM::ZExtOp::create(rewriter, loc, llvmType, value);
142 if (valueBitWidth > targetBitWidth)
143 return LLVM::TruncOp::create(rewriter, loc, llvmType, value);
150 ConversionPatternRewriter &rewriter) {
151 auto vectorType = VectorType::get(numElements, toBroadcast.
getType());
152 auto llvmVectorType = typeConverter.convertType(vectorType);
153 auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32));
154 Value broadcasted = LLVM::PoisonOp::create(rewriter, loc, llvmVectorType);
155 for (
unsigned i = 0; i < numElements; ++i) {
156 auto index = LLVM::ConstantOp::create(rewriter, loc, llvmI32Type,
157 rewriter.getI32IntegerAttr(i));
158 broadcasted = LLVM::InsertElementOp::create(
159 rewriter, loc, llvmVectorType, broadcasted, toBroadcast,
index);
167 ConversionPatternRewriter &rewriter) {
168 if (
auto vectorType = dyn_cast<VectorType>(srcType)) {
169 unsigned numElements = vectorType.getNumElements();
170 return broadcast(loc, value, numElements, typeConverter, rewriter);
187 ConversionPatternRewriter &rewriter) {
201 if (failed(converter.convertTypes(type.
getElementTypes(), elementsVector)))
203 return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
211 if (failed(converter.convertTypes(type.
getElementTypes(), elementsVector)))
213 return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
220 return LLVM::ConstantOp::create(
221 rewriter, loc, IntegerType::get(rewriter.
getContext(), 32),
227 ConversionPatternRewriter &rewriter,
229 unsigned alignment,
bool isVolatile,
230 bool isNonTemporal) {
231 if (
auto loadOp = dyn_cast<spirv::LoadOp>(op)) {
232 auto dstType = typeConverter.convertType(loadOp.getType());
234 return rewriter.notifyMatchFailure(op,
"type conversion failed");
235 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
236 loadOp, dstType, spirv::LoadOpAdaptor(operands).getPtr(), alignment,
237 isVolatile, isNonTemporal);
240 auto storeOp = cast<spirv::StoreOp>(op);
241 spirv::StoreOpAdaptor adaptor(operands);
242 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValue(),
243 adaptor.getPtr(), alignment,
244 isVolatile, isNonTemporal);
259 auto sizeInBytes = cast<spirv::SPIRVType>(elementType).getSizeInBytes();
260 if (stride != 0 && (!sizeInBytes || *sizeInBytes != stride))
263 auto llvmElementType = converter.convertType(elementType);
265 return LLVM::LLVMArrayType::get(llvmElementType, numElements);
272 spirv::ClientAPI clientAPI) {
273 unsigned addressSpace =
275 return LLVM::LLVMPointerType::get(type.getContext(), addressSpace);
286 return LLVM::LLVMArrayType::get(elementType, 0);
295 if (!memberDecorations.empty())
310 using SPIRVToLLVMConversion<spirv::AccessChainOp>::SPIRVToLLVMConversion;
313 matchAndRewrite(spirv::AccessChainOp op, OpAdaptor adaptor,
314 ConversionPatternRewriter &rewriter)
const override {
316 getTypeConverter()->convertType(op.getComponentPtr().getType());
318 return rewriter.notifyMatchFailure(op,
"type conversion failed");
320 auto indices = llvm::to_vector<4>(adaptor.getIndices());
321 Type indexType = op.getIndices().front().getType();
322 auto llvmIndexType = getTypeConverter()->convertType(indexType);
324 return rewriter.notifyMatchFailure(op,
"type conversion failed");
326 LLVM::ConstantOp::create(rewriter, op.getLoc(), llvmIndexType,
327 rewriter.getIntegerAttr(indexType, 0));
330 auto elementType = getTypeConverter()->convertType(
331 cast<spirv::PointerType>(op.getBasePtr().getType()).getPointeeType());
333 return rewriter.notifyMatchFailure(op,
"type conversion failed");
334 rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, dstType, elementType,
335 adaptor.getBasePtr(),
indices);
342 using SPIRVToLLVMConversion<spirv::AddressOfOp>::SPIRVToLLVMConversion;
345 matchAndRewrite(spirv::AddressOfOp op, OpAdaptor adaptor,
346 ConversionPatternRewriter &rewriter)
const override {
347 auto dstType = getTypeConverter()->convertType(op.getPointer().getType());
349 return rewriter.notifyMatchFailure(op,
"type conversion failed");
350 rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, dstType,
356class BitFieldInsertPattern
359 using SPIRVToLLVMConversion<spirv::BitFieldInsertOp>::SPIRVToLLVMConversion;
362 matchAndRewrite(spirv::BitFieldInsertOp op, OpAdaptor adaptor,
363 ConversionPatternRewriter &rewriter)
const override {
364 auto srcType = op.getType();
365 auto dstType = getTypeConverter()->convertType(srcType);
367 return rewriter.notifyMatchFailure(op,
"type conversion failed");
368 Location loc = op.getLoc();
372 *getTypeConverter(), rewriter);
374 *getTypeConverter(), rewriter);
378 Value maskShiftedByCount =
379 LLVM::ShlOp::create(rewriter, loc, dstType, minusOne, count);
380 Value negated = LLVM::XOrOp::create(rewriter, loc, dstType,
381 maskShiftedByCount, minusOne);
382 Value maskShiftedByCountAndOffset =
383 LLVM::ShlOp::create(rewriter, loc, dstType, negated, offset);
384 Value mask = LLVM::XOrOp::create(rewriter, loc, dstType,
385 maskShiftedByCountAndOffset, minusOne);
390 LLVM::AndOp::create(rewriter, loc, dstType, op.getBase(), mask);
391 Value insertShiftedByOffset =
392 LLVM::ShlOp::create(rewriter, loc, dstType, op.getInsert(), offset);
393 rewriter.replaceOpWithNewOp<LLVM::OrOp>(op, dstType, baseAndMask,
394 insertShiftedByOffset);
400class ConstantScalarAndVectorPattern
403 using SPIRVToLLVMConversion<spirv::ConstantOp>::SPIRVToLLVMConversion;
406 matchAndRewrite(spirv::ConstantOp constOp, OpAdaptor adaptor,
407 ConversionPatternRewriter &rewriter)
const override {
408 auto srcType = constOp.getType();
409 if (!isa<VectorType>(srcType) && !srcType.isIntOrFloat())
412 auto dstType = getTypeConverter()->convertType(srcType);
414 return rewriter.notifyMatchFailure(constOp,
"type conversion failed");
423 auto signlessType = rewriter.getIntegerType(
getBitWidth(srcType));
425 if (isa<VectorType>(srcType)) {
426 auto dstElementsAttr = cast<DenseIntElementsAttr>(constOp.getValue());
427 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
429 dstElementsAttr.mapValues(
430 signlessType, [&](
const APInt &value) {
return value; }));
433 auto srcAttr = cast<IntegerAttr>(constOp.getValue());
434 auto dstAttr = rewriter.getIntegerAttr(signlessType, srcAttr.getValue());
435 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, dstAttr);
438 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
439 constOp, dstType, adaptor.getOperands(), constOp->getAttrs());
444class BitFieldSExtractPattern
447 using SPIRVToLLVMConversion<spirv::BitFieldSExtractOp>::SPIRVToLLVMConversion;
450 matchAndRewrite(spirv::BitFieldSExtractOp op, OpAdaptor adaptor,
451 ConversionPatternRewriter &rewriter)
const override {
452 auto srcType = op.getType();
453 auto dstType = getTypeConverter()->convertType(srcType);
455 return rewriter.notifyMatchFailure(op,
"type conversion failed");
456 Location loc = op.getLoc();
460 *getTypeConverter(), rewriter);
462 *getTypeConverter(), rewriter);
465 IntegerType integerType;
466 if (
auto vecType = dyn_cast<VectorType>(srcType))
467 integerType = cast<IntegerType>(vecType.getElementType());
469 integerType = cast<IntegerType>(srcType);
471 auto baseSize = rewriter.getIntegerAttr(integerType,
getBitWidth(srcType));
473 isa<VectorType>(srcType)
474 ? LLVM::ConstantOp::create(
475 rewriter, loc, dstType,
476 SplatElementsAttr::get(cast<ShapedType>(srcType), baseSize))
477 : LLVM::ConstantOp::create(rewriter, loc, dstType, baseSize);
481 Value countPlusOffset =
482 LLVM::AddOp::create(rewriter, loc, dstType, count, offset);
483 Value amountToShiftLeft =
484 LLVM::SubOp::create(rewriter, loc, dstType, size, countPlusOffset);
485 Value baseShiftedLeft = LLVM::ShlOp::create(
486 rewriter, loc, dstType, op.getBase(), amountToShiftLeft);
489 Value amountToShiftRight =
490 LLVM::AddOp::create(rewriter, loc, dstType, offset, amountToShiftLeft);
491 rewriter.replaceOpWithNewOp<LLVM::AShrOp>(op, dstType, baseShiftedLeft,
497class BitFieldUExtractPattern
500 using SPIRVToLLVMConversion<spirv::BitFieldUExtractOp>::SPIRVToLLVMConversion;
503 matchAndRewrite(spirv::BitFieldUExtractOp op, OpAdaptor adaptor,
504 ConversionPatternRewriter &rewriter)
const override {
505 auto srcType = op.getType();
506 auto dstType = getTypeConverter()->convertType(srcType);
508 return rewriter.notifyMatchFailure(op,
"type conversion failed");
509 Location loc = op.getLoc();
513 *getTypeConverter(), rewriter);
515 *getTypeConverter(), rewriter);
519 Value maskShiftedByCount =
520 LLVM::ShlOp::create(rewriter, loc, dstType, minusOne, count);
521 Value mask = LLVM::XOrOp::create(rewriter, loc, dstType, maskShiftedByCount,
526 LLVM::LShrOp::create(rewriter, loc, dstType, op.getBase(), offset);
527 rewriter.replaceOpWithNewOp<LLVM::AndOp>(op, dstType, shiftedBase, mask);
534 using SPIRVToLLVMConversion<spirv::BranchOp>::SPIRVToLLVMConversion;
537 matchAndRewrite(spirv::BranchOp branchOp, OpAdaptor adaptor,
538 ConversionPatternRewriter &rewriter)
const override {
539 rewriter.replaceOpWithNewOp<LLVM::BrOp>(branchOp, adaptor.getOperands(),
540 branchOp.getTarget());
545class BranchConditionalConversionPattern
548 using SPIRVToLLVMConversion<
549 spirv::BranchConditionalOp>::SPIRVToLLVMConversion;
552 matchAndRewrite(spirv::BranchConditionalOp op, OpAdaptor adaptor,
553 ConversionPatternRewriter &rewriter)
const override {
556 if (
auto weights = op.getBranchWeights()) {
557 SmallVector<int32_t> weightValues;
558 for (
auto weight : weights->getAsRange<IntegerAttr>())
559 weightValues.push_back(weight.getInt());
563 rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
564 op, op.getCondition(), op.getTrueBlockArguments(),
565 op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(),
574class CompositeExtractPattern
577 using SPIRVToLLVMConversion<spirv::CompositeExtractOp>::SPIRVToLLVMConversion;
580 matchAndRewrite(spirv::CompositeExtractOp op, OpAdaptor adaptor,
581 ConversionPatternRewriter &rewriter)
const override {
582 auto dstType = this->getTypeConverter()->convertType(op.getType());
584 return rewriter.notifyMatchFailure(op,
"type conversion failed");
586 Type containerType = op.getComposite().getType();
587 if (isa<VectorType>(containerType)) {
588 Location loc = op.getLoc();
589 IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]);
591 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
592 op, dstType, adaptor.getComposite(), index);
596 rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(
597 op, adaptor.getComposite(),
606class CompositeInsertPattern
609 using SPIRVToLLVMConversion<spirv::CompositeInsertOp>::SPIRVToLLVMConversion;
612 matchAndRewrite(spirv::CompositeInsertOp op, OpAdaptor adaptor,
613 ConversionPatternRewriter &rewriter)
const override {
614 auto dstType = this->getTypeConverter()->convertType(op.getType());
616 return rewriter.notifyMatchFailure(op,
"type conversion failed");
618 Type containerType = op.getComposite().getType();
619 if (isa<VectorType>(containerType)) {
620 Location loc = op.getLoc();
621 IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]);
623 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
624 op, dstType, adaptor.getComposite(), adaptor.getObject(), index);
628 rewriter.replaceOpWithNewOp<LLVM::InsertValueOp>(
629 op, adaptor.getComposite(), adaptor.getObject(),
637template <
typename SPIRVOp,
typename LLVMOp>
640 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
643 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
644 ConversionPatternRewriter &rewriter)
const override {
645 auto dstType = this->getTypeConverter()->convertType(op.getType());
647 return rewriter.notifyMatchFailure(op,
"type conversion failed");
648 rewriter.template replaceOpWithNewOp<LLVMOp>(
649 op, dstType, adaptor.getOperands(), op->getAttrs());
656class ExecutionModePattern
659 using SPIRVToLLVMConversion<spirv::ExecutionModeOp>::SPIRVToLLVMConversion;
662 matchAndRewrite(spirv::ExecutionModeOp op, OpAdaptor adaptor,
663 ConversionPatternRewriter &rewriter)
const override {
667 ModuleOp module = op->getParentOfType<ModuleOp>();
668 spirv::ExecutionModeAttr executionModeAttr = op.getExecutionModeAttr();
669 std::string moduleName;
670 if (module.getName().has_value())
671 moduleName =
"_" +
module.getName()->str();
674 std::string executionModeInfoName = llvm::formatv(
675 "__spv_{0}_{1}_execution_mode_info_{2}", moduleName, op.getFn().str(),
676 static_cast<uint32_t
>(executionModeAttr.getValue()));
678 MLIRContext *context = rewriter.getContext();
679 OpBuilder::InsertionGuard guard(rewriter);
680 rewriter.setInsertionPointToStart(module.getBody());
687 auto llvmI32Type = IntegerType::get(context, 32);
688 SmallVector<Type, 2> fields;
689 fields.push_back(llvmI32Type);
691 if (!values.empty()) {
692 auto arrayType = LLVM::LLVMArrayType::get(llvmI32Type, values.size());
693 fields.push_back(arrayType);
695 auto structType = LLVM::LLVMStructType::getLiteral(context, fields);
698 auto global = LLVM::GlobalOp::create(
699 rewriter, UnknownLoc::get(context), structType,
true,
700 LLVM::Linkage::External, executionModeInfoName, Attribute(),
702 Location loc = global.getLoc();
703 Region ®ion = global.getInitializerRegion();
704 Block *block = rewriter.createBlock(®ion);
707 rewriter.setInsertionPointToStart(block);
708 Value structValue = LLVM::PoisonOp::create(rewriter, loc, structType);
709 Value executionMode = LLVM::ConstantOp::create(
710 rewriter, loc, llvmI32Type,
711 rewriter.getI32IntegerAttr(
712 static_cast<uint32_t
>(executionModeAttr.getValue())));
713 SmallVector<int64_t> position{0};
714 structValue = LLVM::InsertValueOp::create(rewriter, loc, structValue,
715 executionMode, position);
718 for (
unsigned i = 0, e = values.size(); i < e; ++i) {
719 auto attr = values.getValue()[i];
720 Value entry = LLVM::ConstantOp::create(rewriter, loc, llvmI32Type, attr);
721 structValue = LLVM::InsertValueOp::create(
722 rewriter, loc, structValue, entry, ArrayRef<int64_t>({1, i}));
724 LLVM::ReturnOp::create(rewriter, loc, ArrayRef<Value>({structValue}));
725 rewriter.eraseOp(op);
734class GlobalVariablePattern
737 template <
typename... Args>
738 GlobalVariablePattern(spirv::ClientAPI clientAPI, Args &&...args)
739 : SPIRVToLLVMConversion<spirv::GlobalVariableOp>(
740 std::forward<Args>(args)...),
741 clientAPI(clientAPI) {}
744 matchAndRewrite(spirv::GlobalVariableOp op, OpAdaptor adaptor,
745 ConversionPatternRewriter &rewriter)
const override {
748 if (op.getInitializer())
751 auto srcType = cast<spirv::PointerType>(op.getType());
752 auto dstType = getTypeConverter()->convertType(srcType.getPointeeType());
754 return rewriter.notifyMatchFailure(op,
"type conversion failed");
759 auto storageClass = srcType.getStorageClass();
760 switch (storageClass) {
761 case spirv::StorageClass::Input:
762 case spirv::StorageClass::Private:
763 case spirv::StorageClass::Output:
764 case spirv::StorageClass::StorageBuffer:
765 case spirv::StorageClass::UniformConstant:
774 bool isConstant = (storageClass == spirv::StorageClass::Input) ||
775 (storageClass == spirv::StorageClass::UniformConstant);
781 auto linkage = storageClass == spirv::StorageClass::Private
782 ? LLVM::Linkage::Private
783 : LLVM::Linkage::External;
784 StringAttr locationAttrName = op.getLocationAttrName();
785 IntegerAttr locationAttr = op.getLocationAttr();
786 auto newGlobalOp = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
787 op, dstType, isConstant, linkage, op.getSymName(), Attribute(),
792 newGlobalOp->setAttr(locationAttrName, locationAttr);
798 spirv::ClientAPI clientAPI;
803template <
typename SPIRVOp,
typename LLVMExtOp,
typename LLVMTruncOp>
806 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
809 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
810 ConversionPatternRewriter &rewriter)
const override {
812 Type fromType = op.getOperand().getType();
813 Type toType = op.getType();
815 auto dstType = this->getTypeConverter()->convertType(toType);
817 return rewriter.notifyMatchFailure(op,
"type conversion failed");
820 rewriter.template replaceOpWithNewOp<LLVMExtOp>(op, dstType,
821 adaptor.getOperands());
825 rewriter.template replaceOpWithNewOp<LLVMTruncOp>(op, dstType,
826 adaptor.getOperands());
833class FunctionCallPattern
836 using SPIRVToLLVMConversion<spirv::FunctionCallOp>::SPIRVToLLVMConversion;
839 matchAndRewrite(spirv::FunctionCallOp callOp, OpAdaptor adaptor,
840 ConversionPatternRewriter &rewriter)
const override {
841 if (callOp.getNumResults() == 0) {
842 auto newOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
843 callOp,
TypeRange(), adaptor.getOperands(), callOp->getAttrs());
844 newOp.getProperties().operandSegmentSizes = {
845 static_cast<int32_t
>(adaptor.getOperands().size()), 0};
846 newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});
851 auto dstType = getTypeConverter()->convertType(callOp.getType(0));
853 return rewriter.notifyMatchFailure(callOp,
"type conversion failed");
854 auto newOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
855 callOp, dstType, adaptor.getOperands(), callOp->getAttrs());
856 newOp.getProperties().operandSegmentSizes = {
857 static_cast<int32_t
>(adaptor.getOperands().size()), 0};
858 newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});
864template <
typename SPIRVOp, LLVM::FCmpPredicate predicate>
867 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
870 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
871 ConversionPatternRewriter &rewriter)
const override {
873 auto dstType = this->getTypeConverter()->convertType(op.getType());
875 return rewriter.notifyMatchFailure(op,
"type conversion failed");
877 rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>(
878 op, dstType, predicate, op.getOperand1(), op.getOperand2());
884template <
typename SPIRVOp, LLVM::ICmpPredicate predicate>
887 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
890 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
891 ConversionPatternRewriter &rewriter)
const override {
893 auto dstType = this->getTypeConverter()->convertType(op.getType());
895 return rewriter.notifyMatchFailure(op,
"type conversion failed");
897 rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>(
898 op, dstType, predicate, op.getOperand1(), op.getOperand2());
903class InverseSqrtPattern
906 using SPIRVToLLVMConversion<spirv::GLInverseSqrtOp>::SPIRVToLLVMConversion;
909 matchAndRewrite(spirv::GLInverseSqrtOp op, OpAdaptor adaptor,
910 ConversionPatternRewriter &rewriter)
const override {
911 auto srcType = op.getType();
912 auto dstType = getTypeConverter()->convertType(srcType);
914 return rewriter.notifyMatchFailure(op,
"type conversion failed");
916 Location loc = op.getLoc();
918 Value sqrt = LLVM::SqrtOp::create(rewriter, loc, dstType, op.getOperand());
919 rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, dstType, one, sqrt);
925template <
typename SPIRVOp>
928 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
931 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
932 ConversionPatternRewriter &rewriter)
const override {
933 if (!op.getMemoryAccess()) {
935 *this->getTypeConverter(), 0,
939 auto memoryAccess = *op.getMemoryAccess();
940 switch (memoryAccess) {
941 case spirv::MemoryAccess::Aligned:
942 case spirv::MemoryAccess::None:
943 case spirv::MemoryAccess::Nontemporal:
944 case spirv::MemoryAccess::Volatile: {
946 memoryAccess == spirv::MemoryAccess::Aligned ? *op.getAlignment() : 0;
947 bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal;
948 bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile;
950 *this->getTypeConverter(), alignment,
951 isVolatile, isNonTemporal);
961template <
typename SPIRVOp>
964 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
967 matchAndRewrite(SPIRVOp notOp,
typename SPIRVOp::Adaptor adaptor,
968 ConversionPatternRewriter &rewriter)
const override {
969 auto srcType = notOp.getType();
970 auto dstType = this->getTypeConverter()->convertType(srcType);
972 return rewriter.notifyMatchFailure(notOp,
"type conversion failed");
974 Location loc = notOp.getLoc();
977 isa<VectorType>(srcType)
978 ? LLVM::ConstantOp::create(
979 rewriter, loc, dstType,
980 SplatElementsAttr::get(cast<VectorType>(srcType), minusOne))
981 : LLVM::ConstantOp::create(rewriter, loc, dstType, minusOne);
982 rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType,
983 notOp.getOperand(), mask);
989template <
typename SPIRVOp>
992 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
995 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
996 ConversionPatternRewriter &rewriter)
const override {
997 rewriter.eraseOp(op);
1004 using SPIRVToLLVMConversion<spirv::ReturnOp>::SPIRVToLLVMConversion;
1007 matchAndRewrite(spirv::ReturnOp returnOp, OpAdaptor adaptor,
1008 ConversionPatternRewriter &rewriter)
const override {
1009 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp, ArrayRef<Type>(),
1017 using SPIRVToLLVMConversion<spirv::ReturnValueOp>::SPIRVToLLVMConversion;
1020 matchAndRewrite(spirv::ReturnValueOp returnValueOp, OpAdaptor adaptor,
1021 ConversionPatternRewriter &rewriter)
const override {
1022 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnValueOp, ArrayRef<Type>(),
1023 adaptor.getOperands());
1032 bool convergent =
true) {
1033 auto func = dyn_cast_or_null<LLVM::LLVMFuncOp>(
1039 func = LLVM::LLVMFuncOp::create(
1040 b, symbolTable->
getLoc(), name,
1041 LLVM::LLVMFunctionType::get(resultType, paramTypes));
1042 func.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
1043 func.setConvergent(convergent);
1044 func.setNoUnwind(
true);
1045 func.setWillReturn(
true);
1050 LLVM::LLVMFuncOp
func,
1052 auto call = LLVM::CallOp::create(builder, loc,
func, args);
1053 call.setCConv(
func.getCConv());
1054 call.setConvergentAttr(
func.getConvergentAttr());
1055 call.setNoUnwindAttr(
func.getNoUnwindAttr());
1056 call.setWillReturnAttr(
func.getWillReturnAttr());
1060template <
typename BarrierOpTy>
1063 using OpAdaptor =
typename SPIRVToLLVMConversion<BarrierOpTy>::OpAdaptor;
1065 using SPIRVToLLVMConversion<BarrierOpTy>::SPIRVToLLVMConversion;
1067 static constexpr StringRef getFuncName();
1070 matchAndRewrite(BarrierOpTy controlBarrierOp, OpAdaptor adaptor,
1071 ConversionPatternRewriter &rewriter)
const override {
1072 constexpr StringRef funcName = getFuncName();
1073 Operation *symbolTable =
1074 controlBarrierOp->template getParentWithTrait<OpTrait::SymbolTable>();
1076 Type i32 = rewriter.getI32Type();
1078 Type voidTy = rewriter.getType<LLVM::LLVMVoidType>();
1079 LLVM::LLVMFuncOp func =
1082 Location loc = controlBarrierOp->getLoc();
1083 Value execution = LLVM::ConstantOp::create(
1084 rewriter, loc, i32,
static_cast<int32_t
>(adaptor.getExecutionScope()));
1085 Value memory = LLVM::ConstantOp::create(
1086 rewriter, loc, i32,
static_cast<int32_t
>(adaptor.getMemoryScope()));
1087 Value semantics = LLVM::ConstantOp::create(
1088 rewriter, loc, i32,
static_cast<int32_t
>(adaptor.getMemorySemantics()));
1091 {execution, memory, semantics});
1093 rewriter.replaceOp(controlBarrierOp, call);
1100StringRef getTypeMangling(
Type type,
bool isSigned) {
1102 .Case([](Float16Type) {
return "Dh"; })
1103 .Case([](Float32Type) {
return "f"; })
1104 .Case([](Float64Type) {
return "d"; })
1105 .Case([isSigned](IntegerType intTy) {
1106 switch (intTy.getWidth()) {
1110 return (isSigned) ?
"a" :
"c";
1112 return (isSigned) ?
"s" :
"t";
1114 return (isSigned) ?
"i" :
"j";
1116 return (isSigned) ?
"l" :
"m";
1118 llvm_unreachable(
"Unsupported integer width");
1121 .DefaultUnreachable(
"No mangling defined");
1124template <
typename ReduceOp>
1125constexpr StringLiteral getGroupFuncName();
1128constexpr StringLiteral getGroupFuncName<spirv::GroupIAddOp>() {
1129 return "_Z17__spirv_GroupIAddii";
1132constexpr StringLiteral getGroupFuncName<spirv::GroupFAddOp>() {
1133 return "_Z17__spirv_GroupFAddii";
1136constexpr StringLiteral getGroupFuncName<spirv::GroupSMinOp>() {
1137 return "_Z17__spirv_GroupSMinii";
1140constexpr StringLiteral getGroupFuncName<spirv::GroupUMinOp>() {
1141 return "_Z17__spirv_GroupUMinii";
1144constexpr StringLiteral getGroupFuncName<spirv::GroupFMinOp>() {
1145 return "_Z17__spirv_GroupFMinii";
1148constexpr StringLiteral getGroupFuncName<spirv::GroupSMaxOp>() {
1149 return "_Z17__spirv_GroupSMaxii";
1152constexpr StringLiteral getGroupFuncName<spirv::GroupUMaxOp>() {
1153 return "_Z17__spirv_GroupUMaxii";
1156constexpr StringLiteral getGroupFuncName<spirv::GroupFMaxOp>() {
1157 return "_Z17__spirv_GroupFMaxii";
1160constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIAddOp>() {
1161 return "_Z27__spirv_GroupNonUniformIAddii";
1164constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFAddOp>() {
1165 return "_Z27__spirv_GroupNonUniformFAddii";
1168constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIMulOp>() {
1169 return "_Z27__spirv_GroupNonUniformIMulii";
1172constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMulOp>() {
1173 return "_Z27__spirv_GroupNonUniformFMulii";
1176constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMinOp>() {
1177 return "_Z27__spirv_GroupNonUniformSMinii";
1180constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMinOp>() {
1181 return "_Z27__spirv_GroupNonUniformUMinii";
1184constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMinOp>() {
1185 return "_Z27__spirv_GroupNonUniformFMinii";
1188constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMaxOp>() {
1189 return "_Z27__spirv_GroupNonUniformSMaxii";
1192constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMaxOp>() {
1193 return "_Z27__spirv_GroupNonUniformUMaxii";
1196constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMaxOp>() {
1197 return "_Z27__spirv_GroupNonUniformFMaxii";
1200constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseAndOp>() {
1201 return "_Z33__spirv_GroupNonUniformBitwiseAndii";
1204constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseOrOp>() {
1205 return "_Z32__spirv_GroupNonUniformBitwiseOrii";
1208constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseXorOp>() {
1209 return "_Z33__spirv_GroupNonUniformBitwiseXorii";
1212constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalAndOp>() {
1213 return "_Z33__spirv_GroupNonUniformLogicalAndii";
1216constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalOrOp>() {
1217 return "_Z32__spirv_GroupNonUniformLogicalOrii";
1220constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalXorOp>() {
1221 return "_Z33__spirv_GroupNonUniformLogicalXorii";
1225template <
typename ReduceOp,
bool Signed = false,
bool NonUniform = false>
1228 using SPIRVToLLVMConversion<ReduceOp>::SPIRVToLLVMConversion;
1231 matchAndRewrite(ReduceOp op,
typename ReduceOp::Adaptor adaptor,
1232 ConversionPatternRewriter &rewriter)
const override {
1234 Type retTy = op.getResult().getType();
1238 SmallString<36> funcName = getGroupFuncName<ReduceOp>();
1239 funcName += getTypeMangling(retTy,
false);
1241 Type i32Ty = rewriter.getI32Type();
1242 SmallVector<Type> paramTypes{i32Ty, i32Ty, retTy};
1243 if constexpr (NonUniform) {
1244 if (adaptor.getClusterSize()) {
1246 paramTypes.push_back(i32Ty);
1250 Operation *symbolTable =
1251 op->template getParentWithTrait<OpTrait::SymbolTable>();
1253 LLVM::LLVMFuncOp func =
1256 Location loc = op.getLoc();
1257 Value scope = LLVM::ConstantOp::create(
1258 rewriter, loc, i32Ty,
1259 static_cast<int32_t
>(adaptor.getExecutionScope()));
1260 Value groupOp = LLVM::ConstantOp::create(
1261 rewriter, loc, i32Ty,
1262 static_cast<int32_t
>(adaptor.getGroupOperation()));
1263 SmallVector<Value> operands{scope, groupOp};
1264 operands.append(adaptor.getOperands().begin(), adaptor.getOperands().end());
1267 rewriter.replaceOp(op, call);
1274ControlBarrierPattern<spirv::ControlBarrierOp>::getFuncName() {
1275 return "_Z22__spirv_ControlBarrieriii";
1280ControlBarrierPattern<spirv::INTELControlBarrierArriveOp>::getFuncName() {
1281 return "_Z33__spirv_ControlBarrierArriveINTELiii";
1286ControlBarrierPattern<spirv::INTELControlBarrierWaitOp>::getFuncName() {
1287 return "_Z31__spirv_ControlBarrierWaitINTELiii";
1340 using SPIRVToLLVMConversion<spirv::LoopOp>::SPIRVToLLVMConversion;
1343 matchAndRewrite(spirv::LoopOp loopOp, OpAdaptor adaptor,
1344 ConversionPatternRewriter &rewriter)
const override {
1346 if (loopOp.getLoopControl() != spirv::LoopControl::None)
1350 if (loopOp.getBody().empty()) {
1351 rewriter.eraseOp(loopOp);
1355 Location loc = loopOp.getLoc();
1359 Block *currentBlock = rewriter.getBlock();
1361 Block *endBlock = rewriter.splitBlock(currentBlock, position);
1365 Block *entryBlock = loopOp.getEntryBlock();
1367 auto brOp = dyn_cast<spirv::BranchOp>(entryBlock->
getOperations().front());
1370 Block *headerBlock = loopOp.getHeaderBlock();
1371 rewriter.setInsertionPointToEnd(currentBlock);
1372 LLVM::BrOp::create(rewriter, loc, brOp.getBlockArguments(), headerBlock);
1373 rewriter.eraseBlock(entryBlock);
1376 Block *mergeBlock = loopOp.getMergeBlock();
1379 rewriter.setInsertionPointToEnd(mergeBlock);
1380 LLVM::BrOp::create(rewriter, loc, terminatorOperands, endBlock);
1382 rewriter.inlineRegionBefore(loopOp.getBody(), endBlock);
1393 using SPIRVToLLVMConversion<spirv::SelectionOp>::SPIRVToLLVMConversion;
1396 matchAndRewrite(spirv::SelectionOp op, OpAdaptor adaptor,
1397 ConversionPatternRewriter &rewriter)
const override {
1401 if (op.getSelectionControl() != spirv::SelectionControl::None)
1408 if (op.getBody().getBlocks().size() <= 2) {
1409 rewriter.eraseOp(op);
1413 Location loc = op.getLoc();
1417 auto *currentBlock = rewriter.getInsertionBlock();
1418 rewriter.setInsertionPointAfter(op);
1419 auto position = rewriter.getInsertionPoint();
1420 auto *continueBlock = rewriter.splitBlock(currentBlock, position);
1426 auto *headerBlock = op.getHeaderBlock();
1428 auto condBrOp = dyn_cast<spirv::BranchConditionalOp>(
1434 auto *mergeBlock = op.getMergeBlock();
1437 rewriter.setInsertionPointToEnd(mergeBlock);
1438 LLVM::BrOp::create(rewriter, loc, terminatorOperands, continueBlock);
1441 Block *trueBlock = condBrOp.getTrueBlock();
1442 Block *falseBlock = condBrOp.getFalseBlock();
1443 rewriter.setInsertionPointToEnd(currentBlock);
1444 LLVM::CondBrOp::create(rewriter, loc, condBrOp.getCondition(), trueBlock,
1445 condBrOp.getTrueTargetOperands(), falseBlock,
1446 condBrOp.getFalseTargetOperands());
1448 rewriter.eraseBlock(headerBlock);
1449 rewriter.inlineRegionBefore(op.getBody(), continueBlock);
1450 rewriter.replaceOp(op, continueBlock->getArguments());
1459template <
typename SPIRVOp,
typename LLVMOp>
1462 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
1465 matchAndRewrite(SPIRVOp op,
typename SPIRVOp::Adaptor adaptor,
1466 ConversionPatternRewriter &rewriter)
const override {
1468 auto dstType = this->getTypeConverter()->convertType(op.getType());
1470 return rewriter.notifyMatchFailure(op,
"type conversion failed");
1472 Type op1Type = op.getOperand1().getType();
1473 Type op2Type = op.getOperand2().getType();
1475 if (op1Type == op2Type) {
1476 rewriter.template replaceOpWithNewOp<LLVMOp>(op, dstType,
1477 adaptor.getOperands());
1481 std::optional<uint64_t> dstTypeWidth =
1483 std::optional<uint64_t> op2TypeWidth =
1486 if (!dstTypeWidth || !op2TypeWidth)
1489 Location loc = op.getLoc();
1491 if (op2TypeWidth < dstTypeWidth) {
1494 LLVM::ZExtOp::create(rewriter, loc, dstType, adaptor.getOperand2());
1497 LLVM::SExtOp::create(rewriter, loc, dstType, adaptor.getOperand2());
1499 }
else if (op2TypeWidth == dstTypeWidth) {
1500 extended = adaptor.getOperand2();
1506 LLVMOp::create(rewriter, loc, dstType, adaptor.getOperand1(), extended);
1507 rewriter.replaceOp(op,
result);
1514 using SPIRVToLLVMConversion<spirv::GLTanOp>::SPIRVToLLVMConversion;
1517 matchAndRewrite(spirv::GLTanOp tanOp, OpAdaptor adaptor,
1518 ConversionPatternRewriter &rewriter)
const override {
1519 auto dstType = getTypeConverter()->convertType(tanOp.getType());
1521 return rewriter.notifyMatchFailure(tanOp,
"type conversion failed");
1523 rewriter.replaceOpWithNewOp<LLVM::TanOp>(tanOp, dstType,
1524 adaptor.getOperands());
1531 using SPIRVToLLVMConversion<spirv::GLTanhOp>::SPIRVToLLVMConversion;
1534 matchAndRewrite(spirv::GLTanhOp tanhOp, OpAdaptor adaptor,
1535 ConversionPatternRewriter &rewriter)
const override {
1536 auto srcType = tanhOp.getType();
1537 auto dstType = getTypeConverter()->convertType(srcType);
1539 return rewriter.notifyMatchFailure(tanhOp,
"type conversion failed");
1541 rewriter.replaceOpWithNewOp<LLVM::TanhOp>(tanhOp, dstType,
1542 adaptor.getOperands());
1552 using SPIRVToLLVMConversion<spirv::GLSAbsOp>::SPIRVToLLVMConversion;
1555 matchAndRewrite(spirv::GLSAbsOp op, OpAdaptor adaptor,
1556 ConversionPatternRewriter &rewriter)
const override {
1557 Type dstType = getTypeConverter()->convertType(op.getType());
1559 return rewriter.notifyMatchFailure(op,
"type conversion failed");
1561 rewriter.replaceOpWithNewOp<LLVM::AbsOp>(op, dstType, adaptor.getOperand(),
1569 using SPIRVToLLVMConversion<spirv::VariableOp>::SPIRVToLLVMConversion;
1572 matchAndRewrite(spirv::VariableOp varOp, OpAdaptor adaptor,
1573 ConversionPatternRewriter &rewriter)
const override {
1574 auto srcType = varOp.getType();
1576 auto pointerTo = cast<spirv::PointerType>(srcType).getPointeeType();
1577 auto init = varOp.getInitializer();
1578 if (init && !pointerTo.isIntOrFloat() && !isa<VectorType>(pointerTo))
1581 auto dstType = getTypeConverter()->convertType(srcType);
1583 return rewriter.notifyMatchFailure(varOp,
"type conversion failed");
1585 Location loc = varOp.getLoc();
1588 auto elementType = getTypeConverter()->convertType(pointerTo);
1590 return rewriter.notifyMatchFailure(varOp,
"type conversion failed");
1591 rewriter.replaceOpWithNewOp<LLVM::AllocaOp>(varOp, dstType, elementType,
1595 auto elementType = getTypeConverter()->convertType(pointerTo);
1597 return rewriter.notifyMatchFailure(varOp,
"type conversion failed");
1599 LLVM::AllocaOp::create(rewriter, loc, dstType, elementType, size);
1600 LLVM::StoreOp::create(rewriter, loc, adaptor.getInitializer(), allocated);
1601 rewriter.replaceOp(varOp, allocated);
1610class BitcastConversionPattern
1613 using SPIRVToLLVMConversion<spirv::BitcastOp>::SPIRVToLLVMConversion;
1616 matchAndRewrite(spirv::BitcastOp bitcastOp, OpAdaptor adaptor,
1617 ConversionPatternRewriter &rewriter)
const override {
1618 auto dstType = getTypeConverter()->convertType(bitcastOp.getType());
1620 return rewriter.notifyMatchFailure(bitcastOp,
"type conversion failed");
1623 if (isa<LLVM::LLVMPointerType>(dstType)) {
1624 rewriter.replaceOp(bitcastOp, adaptor.getOperand());
1628 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
1629 bitcastOp, dstType, adaptor.getOperands(), bitcastOp->getAttrs());
1640 using SPIRVToLLVMConversion<spirv::FuncOp>::SPIRVToLLVMConversion;
1643 matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor,
1644 ConversionPatternRewriter &rewriter)
const override {
1648 auto funcType = funcOp.getFunctionType();
1649 TypeConverter::SignatureConversion signatureConverter(
1650 funcType.getNumInputs());
1651 auto llvmType =
static_cast<const LLVMTypeConverter *
>(getTypeConverter())
1652 ->convertFunctionSignature(
1654 false, signatureConverter);
1659 Location loc = funcOp.getLoc();
1660 StringRef name = funcOp.getName();
1661 auto newFuncOp = LLVM::LLVMFuncOp::create(rewriter, loc, name, llvmType);
1664 MLIRContext *context = funcOp.getContext();
1665 switch (funcOp.getFunctionControl()) {
1666 case spirv::FunctionControl::Inline:
1667 newFuncOp.setAlwaysInline(
true);
1669 case spirv::FunctionControl::DontInline:
1670 newFuncOp.setNoInline(
true);
1673#define DISPATCH(functionControl, llvmAttr) \
1674 case functionControl: \
1675 newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr})); \
1678 DISPATCH(spirv::FunctionControl::Pure,
1679 StringAttr::get(context,
"readonly"));
1680 DISPATCH(spirv::FunctionControl::Const,
1681 StringAttr::get(context,
"readnone"));
1691 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
1693 if (
failed(rewriter.convertRegionTypes(
1694 &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter))) {
1697 rewriter.eraseOp(funcOp);
1708 using SPIRVToLLVMConversion<spirv::ModuleOp>::SPIRVToLLVMConversion;
1711 matchAndRewrite(spirv::ModuleOp spvModuleOp, OpAdaptor adaptor,
1712 ConversionPatternRewriter &rewriter)
const override {
1715 ModuleOp::create(rewriter, spvModuleOp.getLoc(), spvModuleOp.getName());
1716 rewriter.inlineRegionBefore(spvModuleOp.getRegion(), newModuleOp.getBody());
1719 rewriter.eraseBlock(&newModuleOp.getBodyRegion().back());
1720 rewriter.eraseOp(spvModuleOp);
1729class VectorShufflePattern
1732 using SPIRVToLLVMConversion<spirv::VectorShuffleOp>::SPIRVToLLVMConversion;
1734 matchAndRewrite(spirv::VectorShuffleOp op, OpAdaptor adaptor,
1735 ConversionPatternRewriter &rewriter)
const override {
1736 Location loc = op.getLoc();
1737 auto components = adaptor.getComponents();
1738 auto vector1 = adaptor.getVector1();
1739 auto vector2 = adaptor.getVector2();
1740 int vector1Size = cast<VectorType>(vector1.getType()).getNumElements();
1741 int vector2Size = cast<VectorType>(vector2.getType()).getNumElements();
1742 if (vector1Size == vector2Size) {
1743 rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(
1744 op, vector1, vector2,
1749 auto dstType = getTypeConverter()->convertType(op.getType());
1751 return rewriter.notifyMatchFailure(op,
"type conversion failed");
1752 auto scalarType = cast<VectorType>(dstType).getElementType();
1753 auto componentsArray = components.getValue();
1754 auto *context = rewriter.getContext();
1755 auto llvmI32Type = IntegerType::get(context, 32);
1756 Value targetOp = LLVM::PoisonOp::create(rewriter, loc, dstType);
1757 for (
unsigned i = 0; i < componentsArray.size(); i++) {
1758 if (!isa<IntegerAttr>(componentsArray[i]))
1759 return op.emitError(
"unable to support non-constant component");
1761 int indexVal = cast<IntegerAttr>(componentsArray[i]).getInt();
1766 Value baseVector = vector1;
1767 if (indexVal >= vector1Size) {
1768 offsetVal = vector1Size;
1769 baseVector = vector2;
1772 Value dstIndex = LLVM::ConstantOp::create(
1773 rewriter, loc, llvmI32Type,
1774 rewriter.getIntegerAttr(rewriter.getI32Type(), i));
1775 Value index = LLVM::ConstantOp::create(
1776 rewriter, loc, llvmI32Type,
1777 rewriter.getIntegerAttr(rewriter.getI32Type(), indexVal - offsetVal));
1779 auto extractOp = LLVM::ExtractElementOp::create(rewriter, loc, scalarType,
1781 targetOp = LLVM::InsertElementOp::create(rewriter, loc, dstType, targetOp,
1782 extractOp, dstIndex);
1784 rewriter.replaceOp(op, targetOp);
1795 spirv::ClientAPI clientAPI) {
1812 spirv::ClientAPI clientAPI) {
1815 DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>,
1816 DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>,
1817 DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>,
1818 DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>,
1819 DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>,
1820 DirectConversionPattern<spirv::FMulOp, LLVM::FMulOp>,
1821 DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>,
1822 DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>,
1823 DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>,
1824 DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>,
1825 DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>,
1826 DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>,
1827 DirectConversionPattern<spirv::UModOp, LLVM::URemOp>,
1830 BitFieldInsertPattern, BitFieldUExtractPattern, BitFieldSExtractPattern,
1831 DirectConversionPattern<spirv::BitCountOp, LLVM::CtPopOp>,
1832 DirectConversionPattern<spirv::BitReverseOp, LLVM::BitReverseOp>,
1833 DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>,
1834 DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>,
1835 DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>,
1836 NotPattern<spirv::NotOp>,
1839 BitcastConversionPattern,
1840 DirectConversionPattern<spirv::ConvertFToSOp, LLVM::FPToSIOp>,
1841 DirectConversionPattern<spirv::ConvertFToUOp, LLVM::FPToUIOp>,
1842 DirectConversionPattern<spirv::ConvertSToFOp, LLVM::SIToFPOp>,
1843 DirectConversionPattern<spirv::ConvertUToFOp, LLVM::UIToFPOp>,
1844 IndirectCastPattern<spirv::FConvertOp, LLVM::FPExtOp, LLVM::FPTruncOp>,
1845 IndirectCastPattern<spirv::SConvertOp, LLVM::SExtOp, LLVM::TruncOp>,
1846 IndirectCastPattern<spirv::UConvertOp, LLVM::ZExtOp, LLVM::TruncOp>,
1849 IComparePattern<spirv::IEqualOp, LLVM::ICmpPredicate::eq>,
1850 IComparePattern<spirv::INotEqualOp, LLVM::ICmpPredicate::ne>,
1851 FComparePattern<spirv::FOrdEqualOp, LLVM::FCmpPredicate::oeq>,
1852 FComparePattern<spirv::FOrdGreaterThanOp, LLVM::FCmpPredicate::ogt>,
1853 FComparePattern<spirv::FOrdGreaterThanEqualOp, LLVM::FCmpPredicate::oge>,
1854 FComparePattern<spirv::FOrdLessThanEqualOp, LLVM::FCmpPredicate::ole>,
1855 FComparePattern<spirv::FOrdLessThanOp, LLVM::FCmpPredicate::olt>,
1856 FComparePattern<spirv::FOrdNotEqualOp, LLVM::FCmpPredicate::one>,
1857 FComparePattern<spirv::FUnordEqualOp, LLVM::FCmpPredicate::ueq>,
1858 FComparePattern<spirv::FUnordGreaterThanOp, LLVM::FCmpPredicate::ugt>,
1859 FComparePattern<spirv::FUnordGreaterThanEqualOp,
1860 LLVM::FCmpPredicate::uge>,
1861 FComparePattern<spirv::FUnordLessThanEqualOp, LLVM::FCmpPredicate::ule>,
1862 FComparePattern<spirv::FUnordLessThanOp, LLVM::FCmpPredicate::ult>,
1863 FComparePattern<spirv::FUnordNotEqualOp, LLVM::FCmpPredicate::une>,
1864 IComparePattern<spirv::SGreaterThanOp, LLVM::ICmpPredicate::sgt>,
1865 IComparePattern<spirv::SGreaterThanEqualOp, LLVM::ICmpPredicate::sge>,
1866 IComparePattern<spirv::SLessThanEqualOp, LLVM::ICmpPredicate::sle>,
1867 IComparePattern<spirv::SLessThanOp, LLVM::ICmpPredicate::slt>,
1868 IComparePattern<spirv::UGreaterThanOp, LLVM::ICmpPredicate::ugt>,
1869 IComparePattern<spirv::UGreaterThanEqualOp, LLVM::ICmpPredicate::uge>,
1870 IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>,
1871 IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>,
1874 ConstantScalarAndVectorPattern,
1877 BranchConversionPattern, BranchConditionalConversionPattern,
1878 FunctionCallPattern, LoopPattern, SelectionPattern,
1879 ErasePattern<spirv::MergeOp>,
1882 ErasePattern<spirv::EntryPointOp>, ExecutionModePattern,
1885 DirectConversionPattern<spirv::GLCeilOp, LLVM::FCeilOp>,
1886 DirectConversionPattern<spirv::GLCosOp, LLVM::CosOp>,
1887 DirectConversionPattern<spirv::GLExpOp, LLVM::ExpOp>,
1888 DirectConversionPattern<spirv::GLExp2Op, LLVM::Exp2Op>,
1889 DirectConversionPattern<spirv::GLFAbsOp, LLVM::FAbsOp>,
1890 DirectConversionPattern<spirv::GLFloorOp, LLVM::FFloorOp>,
1891 DirectConversionPattern<spirv::GLFmaOp, LLVM::FMAOp>,
1892 DirectConversionPattern<spirv::GLFMaxOp, LLVM::MaxNumOp>,
1893 DirectConversionPattern<spirv::GLFMinOp, LLVM::MinNumOp>,
1894 DirectConversionPattern<spirv::GLLogOp, LLVM::LogOp>,
1895 DirectConversionPattern<spirv::GLLog2Op, LLVM::Log2Op>,
1896 DirectConversionPattern<spirv::GLPowOp, LLVM::PowOp>,
1897 DirectConversionPattern<spirv::GLRoundOp, LLVM::RoundOp>,
1898 DirectConversionPattern<spirv::GLRoundEvenOp, LLVM::RoundEvenOp>,
1899 DirectConversionPattern<spirv::GLSinOp, LLVM::SinOp>,
1900 DirectConversionPattern<spirv::GLSinhOp, LLVM::SinhOp>,
1901 DirectConversionPattern<spirv::GLCoshOp, LLVM::CoshOp>,
1902 DirectConversionPattern<spirv::GLSMaxOp, LLVM::SMaxOp>,
1903 DirectConversionPattern<spirv::GLSMinOp, LLVM::SMinOp>,
1904 DirectConversionPattern<spirv::GLSqrtOp, LLVM::SqrtOp>,
1905 DirectConversionPattern<spirv::GLUMaxOp, LLVM::UMaxOp>,
1906 DirectConversionPattern<spirv::GLUMinOp, LLVM::UMinOp>,
1907 InverseSqrtPattern, SAbsPattern, TanPattern, TanhPattern,
1910 DirectConversionPattern<spirv::LogicalAndOp, LLVM::AndOp>,
1911 DirectConversionPattern<spirv::LogicalOrOp, LLVM::OrOp>,
1912 IComparePattern<spirv::LogicalEqualOp, LLVM::ICmpPredicate::eq>,
1913 IComparePattern<spirv::LogicalNotEqualOp, LLVM::ICmpPredicate::ne>,
1914 NotPattern<spirv::LogicalNotOp>,
1917 AccessChainPattern, AddressOfPattern, LoadStorePattern<spirv::LoadOp>,
1918 LoadStorePattern<spirv::StoreOp>, VariablePattern,
1921 CompositeExtractPattern, CompositeInsertPattern,
1922 DirectConversionPattern<spirv::SelectOp, LLVM::SelectOp>,
1923 DirectConversionPattern<spirv::UndefOp, LLVM::UndefOp>,
1924 VectorShufflePattern,
1927 ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>,
1928 ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>,
1929 ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>,
1932 ReturnPattern, ReturnValuePattern,
1935 ControlBarrierPattern<spirv::ControlBarrierOp>,
1936 ControlBarrierPattern<spirv::INTELControlBarrierArriveOp>,
1937 ControlBarrierPattern<spirv::INTELControlBarrierWaitOp>,
1940 GroupReducePattern<spirv::GroupIAddOp>,
1941 GroupReducePattern<spirv::GroupFAddOp>,
1942 GroupReducePattern<spirv::GroupFMinOp>,
1943 GroupReducePattern<spirv::GroupUMinOp>,
1944 GroupReducePattern<spirv::GroupSMinOp,
true>,
1945 GroupReducePattern<spirv::GroupFMaxOp>,
1946 GroupReducePattern<spirv::GroupUMaxOp>,
1947 GroupReducePattern<spirv::GroupSMaxOp,
true>,
1948 GroupReducePattern<spirv::GroupNonUniformIAddOp,
false,
1950 GroupReducePattern<spirv::GroupNonUniformFAddOp,
false,
1952 GroupReducePattern<spirv::GroupNonUniformIMulOp,
false,
1954 GroupReducePattern<spirv::GroupNonUniformFMulOp,
false,
1956 GroupReducePattern<spirv::GroupNonUniformSMinOp,
true,
1958 GroupReducePattern<spirv::GroupNonUniformUMinOp,
false,
1960 GroupReducePattern<spirv::GroupNonUniformFMinOp,
false,
1962 GroupReducePattern<spirv::GroupNonUniformSMaxOp,
true,
1964 GroupReducePattern<spirv::GroupNonUniformUMaxOp,
false,
1966 GroupReducePattern<spirv::GroupNonUniformFMaxOp,
false,
1968 GroupReducePattern<spirv::GroupNonUniformBitwiseAndOp,
false,
1970 GroupReducePattern<spirv::GroupNonUniformBitwiseOrOp,
false,
1972 GroupReducePattern<spirv::GroupNonUniformBitwiseXorOp,
false,
1974 GroupReducePattern<spirv::GroupNonUniformLogicalAndOp,
false,
1976 GroupReducePattern<spirv::GroupNonUniformLogicalOrOp,
false,
1978 GroupReducePattern<spirv::GroupNonUniformLogicalXorOp,
false,
1982 patterns.
add<GlobalVariablePattern>(clientAPI, patterns.
getContext(),
1988 patterns.
add<FuncConversionPattern>(patterns.
getContext(), typeConverter);
1993 patterns.
add<ModuleConversionPattern>(patterns.
getContext(), typeConverter);
2004 auto spvModules =
module.getOps<spirv::ModuleOp>();
2005 for (
auto spvModule : spvModules) {
2006 spvModule.walk([&](spirv::GlobalVariableOp op) {
2007 IntegerAttr descriptorSet =
2009 IntegerAttr binding = op->getAttrOfType<IntegerAttr>(
kBinding);
2012 if (descriptorSet && binding) {
2015 auto moduleAndName =
2016 spvModule.getName().has_value()
2017 ? spvModule.getName()->str() +
"_" + op.getSymName().str()
2018 : op.getSymName().str();
2020 llvm::formatv(
"{0}_descriptor_set{1}_binding{2}", moduleAndName,
2021 std::to_string(descriptorSet.getInt()),
2022 std::to_string(binding.getInt()));
2023 auto nameAttr = StringAttr::get(op->getContext(), name);
2028 op.emitError(
"unable to replace all symbol uses for ") << name;
static LLVM::CallOp createSPIRVBuiltinCall(Location loc, ConversionPatternRewriter &rewriter, LLVM::LLVMFuncOp func, ValueRange args)
static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable, StringRef name, ArrayRef< Type > paramTypes, Type resultType, bool isMemNone, bool isConvergent)
static Value optionallyTruncateOrExtend(Location loc, Value value, Type llvmType, PatternRewriter &rewriter)
Utility function for bitfield ops:
static Value createFPConstant(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter, double value)
Creates llvm.mlir.constant with a floating-point scalar or vector value.
static constexpr StringRef kDescriptorSet
static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter, unsigned value)
Creates LLVM dialect constant with the given value.
static Type convertPointerType(spirv::PointerType type, const TypeConverter &converter, spirv::ClientAPI clientAPI)
Converts SPIR-V pointer type to LLVM pointer.
static Value processCountOrOffset(Location loc, Value value, Type srcType, Type dstType, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Utility function for bitfield ops: BitFieldInsert, BitFieldSExtract and BitFieldUExtract.
static unsigned getBitWidth(Type type)
Returns the bit width of integer, float or vector of float or integer values.
static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter, const TypeConverter &typeConverter, unsigned alignment, bool isVolatile, bool isNonTemporal)
Utility for spirv.Load and spirv.Store conversion.
static Type convertStructTypePacked(spirv::StructType type, const TypeConverter &converter)
Converts SPIR-V struct with no offset to packed LLVM struct.
static bool isSignedIntegerOrVector(Type type)
Returns true if the given type is a signed integer or vector type.
static bool isUnsignedIntegerOrVector(Type type)
Returns true if the given type is an unsigned integer or vector type.
static std::optional< Type > convertRuntimeArrayType(spirv::RuntimeArrayType type, TypeConverter &converter)
Converts SPIR-V runtime array to LLVM array.
static constexpr StringRef kBinding
Hook for descriptor set and binding number encoding.
static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder)
Creates IntegerAttribute with all bits set for given type.
static Value optionallyBroadcast(Location loc, Value value, Type srcType, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value. If srcType is a scalar, the value remains unchanged.
static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter)
Creates llvm.mlir.constant with all bits set for the given type.
static unsigned getLLVMTypeBitWidth(Type type)
Returns the bit width of LLVMType integer or vector.
static std::optional< uint64_t > getIntegerOrVectorElementWidth(Type type)
Returns the width of an integer or of the element type of an integer vector, if applicable.
#define DISPATCH(functionControl, llvmAttr)
static Type convertStructTypeWithOffset(spirv::StructType type, const TypeConverter &converter)
Converts SPIR-V struct with a regular (according to VulkanLayoutUtils) offset to LLVM struct.
static Type convertStructType(spirv::StructType type, const TypeConverter &converter)
Converts SPIR-V struct to LLVM struct.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static std::optional< Type > convertArrayType(spirv::ArrayType type, TypeConverter &converter)
Converts SPIR-V array type to LLVM array.
OpListType::iterator iterator
OpListType & getOperations()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
MLIRContext * getContext() const
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Location getLoc()
The source location the operation was defined or derived from.
operand_range getOperands()
Returns an iterator on the underlying Value's.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, StringAttr newSymbol, Operation *from)
Attempt to replace all uses of the given symbol 'oldSymbol' with the provided symbol 'newSymbol' that...
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
static void setSymbolName(Operation *symbol, StringAttr name)
Sets the name of the given symbol operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static spirv::StructType decorateType(spirv::StructType structType)
Returns a new StructType with layout decoration.
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
unsigned getNumElements() const
StorageClass getStorageClass() const
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
TypeRange getElementTypes() const
SmallVector< IntT > convertArrayToIndices(ArrayRef< Attribute > attrs)
Convert an array of integer attributes to a vector of integers that can be used as indices in LLVM op...
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Include the generated interface declarations.
unsigned storageClassToAddressSpace(spirv::ClientAPI clientAPI, spirv::StorageClass storageClass)
void populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter, spirv::ClientAPI clientAPIForAddressSpaceMapping=spirv::ClientAPI::Unknown)
Populates type conversions with additional SPIR-V types.
void populateSPIRVToLLVMFunctionConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given list with patterns for function conversion from SPIR-V to LLVM.
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
void populateSPIRVToLLVMConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, spirv::ClientAPI clientAPIForAddressSpaceMapping=spirv::ClientAPI::Unknown)
Populates the given list with patterns that convert from SPIR-V to LLVM.
void encodeBindAttribute(ModuleOp module)
Encodes global variable's descriptor set and binding into its name if they both exist.
void populateSPIRVToLLVMModuleConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given patterns for module conversion from SPIR-V to LLVM.