24#define GEN_PASS_DEF_ARITHTOLLVMCONVERSIONPASS
25#include "mlir/Conversion/Passes.h.inc"
45template <
typename SourceOp,
typename TargetOp,
bool HasRoundingMode,
46 template <
typename,
typename>
typename AttrConvert =
48 bool FailOnUnsupportedFP =
false>
49struct ConstrainedVectorConvertToLLVMPattern
51 FailOnUnsupportedFP> {
52 using VectorConvertToLLVMPattern<
53 SourceOp, TargetOp, AttrConvert,
54 FailOnUnsupportedFP>::VectorConvertToLLVMPattern;
57 matchAndRewrite(SourceOp op,
typename SourceOp::Adaptor adaptor,
58 ConversionPatternRewriter &rewriter)
const override {
59 if (HasRoundingMode !=
static_cast<bool>(op.getRoundingModeAttr()))
61 return VectorConvertToLLVMPattern<
62 SourceOp, TargetOp, AttrConvert,
63 FailOnUnsupportedFP>::matchAndRewrite(op, adaptor, rewriter);
69struct IdentityBitcastLowering final
70 :
public OpConversionPattern<arith::BitcastOp> {
74 matchAndRewrite(arith::BitcastOp op, OpAdaptor adaptor,
75 ConversionPatternRewriter &rewriter)
const final {
76 Value src = adaptor.getIn();
77 Type resultType = getTypeConverter()->convertType(op.getType());
78 if (src.
getType() != resultType)
79 return rewriter.notifyMatchFailure(op,
"Types are different");
81 rewriter.replaceOp(op, src);
91 ConstrainedVectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
95using ConstrainedAddFOpLowering = ConstrainedVectorConvertToLLVMPattern<
96 arith::AddFOp, LLVM::ConstrainedFAddIntr,
true,
102using BitcastOpLowering =
104using DivFOpLowering =
105 ConstrainedVectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp,
109using ConstrainedDivFOpLowering = ConstrainedVectorConvertToLLVMPattern<
110 arith::DivFOp, LLVM::ConstrainedFDivIntr,
true,
112using DivSIOpLowering =
114using DivUIOpLowering =
119using ExtSIOpLowering =
121using ExtUIOpLowering =
124using FPToSIOpLowering =
128using FPToUIOpLowering =
132using MaximumFOpLowering =
136using MaxNumFOpLowering =
140using MaxSIOpLowering =
142using MaxUIOpLowering =
144using MinimumFOpLowering =
148using MinNumFOpLowering =
152using MinSIOpLowering =
154using MinUIOpLowering =
156using MulFOpLowering =
157 ConstrainedVectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
161using ConstrainedMulFOpLowering = ConstrainedVectorConvertToLLVMPattern<
162 arith::MulFOp, LLVM::ConstrainedFMulIntr,
true,
164using MulIOpLowering =
167using NegFOpLowering =
172using RemFOpLowering =
176using RemSIOpLowering =
178using RemUIOpLowering =
180using SelectOpLowering =
182using ShLIOpLowering =
185using ShRSIOpLowering =
187using ShRUIOpLowering =
189using SIToFPOpLowering =
191using SubFOpLowering =
192 ConstrainedVectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
196using ConstrainedSubFOpLowering = ConstrainedVectorConvertToLLVMPattern<
197 arith::SubFOp, LLVM::ConstrainedFSubIntr,
true,
199using SubIOpLowering =
202using TruncFOpLowering =
203 ConstrainedVectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp,
207using ConstrainedTruncFOpLowering = ConstrainedVectorConvertToLLVMPattern<
208 arith::TruncFOp, LLVM::ConstrainedFPTruncIntr,
true,
210using TruncIOpLowering =
213using UIToFPOpLowering =
228 matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
229 ConversionPatternRewriter &rewriter)
const override;
236template <
typename OpTy,
typename ExtCastTy>
238 using ConvertOpToLLVMPattern<OpTy>::ConvertOpToLLVMPattern;
241 matchAndRewrite(OpTy op,
typename OpTy::Adaptor adaptor,
242 ConversionPatternRewriter &rewriter)
const override;
245using IndexCastOpSILowering =
246 IndexCastOpLowering<arith::IndexCastOp, LLVM::SExtOp>;
247using IndexCastOpUILowering =
248 IndexCastOpLowering<arith::IndexCastUIOp, LLVM::ZExtOp>;
250struct AddUIExtendedOpLowering
255 matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor,
256 ConversionPatternRewriter &rewriter)
const override;
259struct SubUIExtendedOpLowering
264 matchAndRewrite(arith::SubUIExtendedOp op, OpAdaptor adaptor,
265 ConversionPatternRewriter &rewriter)
const override;
268template <
typename ArithMulOp,
bool IsSigned>
270 using ConvertOpToLLVMPattern<ArithMulOp>::ConvertOpToLLVMPattern;
273 matchAndRewrite(ArithMulOp op,
typename ArithMulOp::Adaptor adaptor,
274 ConversionPatternRewriter &rewriter)
const override;
277using MulSIExtendedOpLowering =
278 MulIExtendedOpLowering<arith::MulSIExtendedOp, true>;
279using MulUIExtendedOpLowering =
280 MulIExtendedOpLowering<arith::MulUIExtendedOp, false>;
286 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
287 ConversionPatternRewriter &rewriter)
const override;
294 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
295 ConversionPatternRewriter &rewriter)
const override;
307 matchAndRewrite(arith::ConvertFOp op, OpAdaptor adaptor,
308 ConversionPatternRewriter &rewriter)
const override {
310 *getTypeConverter()))
311 return rewriter.notifyMatchFailure(op,
"unsupported floating point type");
317 assert((srcType.isBF16() && dstType.isF16()) ||
318 (srcType.isF16() && dstType.isBF16()) &&
319 "only bf16 <-> f16 conversions are supported");
321 Type convertedType = getTypeConverter()->convertType(op.getType());
323 return rewriter.notifyMatchFailure(op,
"failed to convert result type");
325 Value input = adaptor.getIn();
326 Location loc = op.getLoc();
328 if (!isa<LLVM::LLVMArrayType>(input.
getType())) {
329 rewriter.replaceOp(op,
330 emitConversion(rewriter, loc, input, convertedType));
334 if (!isa<VectorType>(op.getType()))
335 return rewriter.notifyMatchFailure(op,
"expected vector result type");
338 op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
339 [&](Type llvm1DVectorTy,
ValueRange operands) -> Value {
340 return emitConversion(rewriter, loc, operands.front(),
347 static Value emitConversion(ConversionPatternRewriter &rewriter, Location loc,
348 Value input, Type targetType) {
349 Type f32Scalar = Float32Type::get(rewriter.getContext());
350 Type f32Ty = f32Scalar;
351 if (
auto vecTy = dyn_cast<VectorType>(targetType))
352 f32Ty = VectorType::get(vecTy.getShape(), f32Scalar);
354 Value ext = LLVM::FPExtOp::create(rewriter, loc, f32Ty, input);
355 return LLVM::FPTruncOp::create(rewriter, loc, targetType, ext);
364 matchAndRewrite(arith::SelectOp op, Adaptor adaptor,
365 ConversionPatternRewriter &rewriter)
const override;
375ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
376 ConversionPatternRewriter &rewriter)
const {
378 adaptor.getOperands(), op->getAttrs(),
380 *getTypeConverter(), rewriter);
387template <
typename OpTy,
typename ExtCastTy>
388LogicalResult IndexCastOpLowering<OpTy, ExtCastTy>::matchAndRewrite(
389 OpTy op,
typename OpTy::Adaptor adaptor,
390 ConversionPatternRewriter &rewriter)
const {
391 Type resultType = op.getResult().getType();
392 Type targetElementType =
394 Type sourceElementType =
399 if (targetBits == sourceBits) {
400 rewriter.replaceOp(op, adaptor.getIn());
407 if (isa<MemRefType>(op.getIn().getType())) {
408 rewriter.replaceOp(op, adaptor.getIn());
412 bool isNonNeg =
false;
413 if constexpr (std::is_same_v<ExtCastTy, LLVM::ZExtOp>)
414 isNonNeg = op.getNonNeg();
417 Type operandType = adaptor.getIn().getType();
418 if (!isa<LLVM::LLVMArrayType>(operandType)) {
419 Type targetType = this->typeConverter->convertType(resultType);
420 if (targetBits < sourceBits) {
421 rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, targetType,
424 auto extOp = rewriter.replaceOpWithNewOp<ExtCastTy>(op, targetType,
426 if constexpr (std::is_same_v<ExtCastTy, LLVM::ZExtOp>)
427 extOp.setNonNeg(isNonNeg);
432 if (!isa<VectorType>(resultType))
433 return rewriter.notifyMatchFailure(op,
"expected vector result type");
436 op.getOperation(), adaptor.getOperands(), *(this->getTypeConverter()),
437 [&](Type llvm1DVectorTy,
ValueRange operands) -> Value {
438 typename OpTy::Adaptor adaptor(operands);
439 if (targetBits < sourceBits) {
440 return LLVM::TruncOp::create(rewriter, op.getLoc(), llvm1DVectorTy,
443 auto extOp = ExtCastTy::create(rewriter, op.getLoc(), llvm1DVectorTy,
445 if constexpr (std::is_same_v<ExtCastTy, LLVM::ZExtOp>) {
447 extOp.setNonNeg(
true);
458LogicalResult AddUIExtendedOpLowering::matchAndRewrite(
459 arith::AddUIExtendedOp op, OpAdaptor adaptor,
460 ConversionPatternRewriter &rewriter)
const {
461 Type operandType = adaptor.getLhs().getType();
462 Type sumResultType = op.getSum().getType();
463 Type overflowResultType = op.getOverflow().getType();
465 if (!LLVM::isCompatibleType(operandType))
468 MLIRContext *ctx = rewriter.getContext();
469 Location loc = op.getLoc();
472 if (!isa<LLVM::LLVMArrayType>(operandType)) {
473 Type newOverflowType = typeConverter->convertType(overflowResultType);
475 LLVM::LLVMStructType::getLiteral(ctx, {sumResultType, newOverflowType});
476 Value addOverflow = LLVM::UAddWithOverflowOp::create(
477 rewriter, loc, structType, adaptor.getLhs(), adaptor.getRhs());
479 LLVM::ExtractValueOp::create(rewriter, loc, addOverflow, 0);
480 Value overflowExtracted =
481 LLVM::ExtractValueOp::create(rewriter, loc, addOverflow, 1);
482 rewriter.replaceOp(op, {sumExtracted, overflowExtracted});
486 if (!isa<VectorType>(sumResultType))
487 return rewriter.notifyMatchFailure(loc,
"expected vector result types");
489 return rewriter.notifyMatchFailure(loc,
490 "ND vector types are not supported yet");
497LogicalResult SubUIExtendedOpLowering::matchAndRewrite(
498 arith::SubUIExtendedOp op, OpAdaptor adaptor,
499 ConversionPatternRewriter &rewriter)
const {
500 Type operandType = adaptor.getLhs().getType();
501 Type diffResultType = op.getDiff().getType();
502 Type borrowResultType = op.getBorrow().getType();
504 if (!LLVM::isCompatibleType(operandType))
507 MLIRContext *ctx = rewriter.getContext();
508 Location loc = op.getLoc();
511 if (!isa<LLVM::LLVMArrayType>(operandType)) {
512 Type newBorrowType = typeConverter->convertType(borrowResultType);
514 LLVM::LLVMStructType::getLiteral(ctx, {diffResultType, newBorrowType});
515 Value subOverflow = LLVM::USubWithOverflowOp::create(
516 rewriter, loc, structType, adaptor.getLhs(), adaptor.getRhs());
517 Value diffExtracted =
518 LLVM::ExtractValueOp::create(rewriter, loc, subOverflow, 0);
519 Value borrowExtracted =
520 LLVM::ExtractValueOp::create(rewriter, loc, subOverflow, 1);
521 rewriter.replaceOp(op, {diffExtracted, borrowExtracted});
525 if (!isa<VectorType>(diffResultType))
526 return rewriter.notifyMatchFailure(loc,
"expected vector result types");
528 return rewriter.notifyMatchFailure(loc,
529 "ND vector types are not supported yet");
536template <
typename ArithMulOp,
bool IsSigned>
537LogicalResult MulIExtendedOpLowering<ArithMulOp, IsSigned>::matchAndRewrite(
538 ArithMulOp op,
typename ArithMulOp::Adaptor adaptor,
539 ConversionPatternRewriter &rewriter)
const {
540 Type resultType = adaptor.getLhs().getType();
542 if (!LLVM::isCompatibleType(resultType))
545 Location loc = op.getLoc();
551 if (!isa<LLVM::LLVMArrayType>(resultType)) {
553 TypedAttr shiftValAttr;
555 if (
auto intTy = dyn_cast<IntegerType>(resultType)) {
556 unsigned resultBitwidth = intTy.getWidth();
557 auto attrTy = rewriter.getIntegerType(resultBitwidth * 2);
558 shiftValAttr = rewriter.getIntegerAttr(attrTy, resultBitwidth);
560 auto vecTy = cast<VectorType>(resultType);
561 unsigned resultBitwidth = vecTy.getElementTypeBitWidth();
562 auto attrTy = VectorType::get(
563 vecTy.getShape(), rewriter.getIntegerType(resultBitwidth * 2));
564 shiftValAttr = SplatElementsAttr::get(
565 attrTy, APInt(resultBitwidth * 2, resultBitwidth));
567 Type wideType = shiftValAttr.getType();
568 assert(LLVM::isCompatibleType(wideType) &&
569 "LLVM dialect should support all signless integer types");
571 using LLVMExtOp = std::conditional_t<IsSigned, LLVM::SExtOp, LLVM::ZExtOp>;
572 Value lhsExt = LLVMExtOp::create(rewriter, loc, wideType, adaptor.getLhs());
573 Value rhsExt = LLVMExtOp::create(rewriter, loc, wideType, adaptor.getRhs());
574 Value mulExt = LLVM::MulOp::create(rewriter, loc, wideType, lhsExt, rhsExt);
577 Value low = LLVM::TruncOp::create(rewriter, loc, resultType, mulExt);
578 Value shiftVal = LLVM::ConstantOp::create(rewriter, loc, shiftValAttr);
579 Value highExt = LLVM::LShrOp::create(rewriter, loc, mulExt, shiftVal);
580 Value high = LLVM::TruncOp::create(rewriter, loc, resultType, highExt);
582 rewriter.replaceOp(op, {low, high});
586 if (!isa<VectorType>(resultType))
587 return rewriter.notifyMatchFailure(op,
"expected vector result type");
589 return rewriter.notifyMatchFailure(op,
590 "ND vector types are not supported yet");
599template <
typename LLVMPredType,
typename PredType>
601 return static_cast<LLVMPredType
>(pred);
605CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
606 ConversionPatternRewriter &rewriter)
const {
607 Type operandType = adaptor.getLhs().getType();
608 Type resultType = op.getResult().getType();
611 if (!isa<LLVM::LLVMArrayType>(operandType)) {
612 rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
613 op, typeConverter->convertType(resultType),
615 adaptor.getLhs(), adaptor.getRhs());
619 if (!isa<VectorType>(resultType))
620 return rewriter.notifyMatchFailure(op,
"expected vector result type");
623 op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
624 [&](Type llvm1DVectorTy,
ValueRange operands) {
625 OpAdaptor adaptor(operands);
626 return LLVM::ICmpOp::create(
627 rewriter, op.getLoc(), llvm1DVectorTy,
629 adaptor.getLhs(), adaptor.getRhs());
639CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
640 ConversionPatternRewriter &rewriter)
const {
642 op.getLhs().getType()))
643 return rewriter.notifyMatchFailure(op,
"unsupported floating point type");
645 Type operandType = adaptor.getLhs().getType();
646 Type resultType = op.getResult().getType();
647 LLVM::FastmathFlags fmf =
648 arith::convertArithFastMathFlagsToLLVM(op.getFastmath());
651 if (!isa<LLVM::LLVMArrayType>(operandType)) {
652 rewriter.replaceOpWithNewOp<LLVM::FCmpOp>(
653 op, typeConverter->convertType(resultType),
655 adaptor.getLhs(), adaptor.getRhs(), fmf);
659 if (!isa<VectorType>(resultType))
660 return rewriter.notifyMatchFailure(op,
"expected vector result type");
663 op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
664 [&](Type llvm1DVectorTy,
ValueRange operands) {
665 OpAdaptor adaptor(operands);
666 return LLVM::FCmpOp::create(
667 rewriter, op.getLoc(), llvm1DVectorTy,
669 adaptor.getLhs(), adaptor.getRhs(), fmf);
681LogicalResult SelectOpOneToNLowering::matchAndRewrite(
682 arith::SelectOp op, Adaptor adaptor,
683 ConversionPatternRewriter &rewriter)
const {
685 if (llvm::hasSingleElement(adaptor.getTrueValue()))
686 return rewriter.notifyMatchFailure(
687 op,
"not a 1:N conversion, 1:1 pattern will match");
688 if (!op.getCondition().getType().isInteger(1))
689 return rewriter.notifyMatchFailure(op,
690 "non-i1 conditions are not supported");
691 SmallVector<Value> results;
692 for (
auto [trueValue, falseValue] :
693 llvm::zip_equal(adaptor.getTrueValue(), adaptor.getFalseValue()))
694 results.push_back(arith::SelectOp::create(
695 rewriter, op.getLoc(), op.getCondition(), trueValue, falseValue));
696 rewriter.replaceOpWithMultiple(op, {results});
705struct ArithToLLVMConversionPass
706 :
public impl::ArithToLLVMConversionPassBase<ArithToLLVMConversionPass> {
709 void runOnOperation()
override {
715 options.overrideIndexBitwidth(indexBitwidth);
718 arith::populateCeilFloorDivExpandOpsPatterns(patterns);
719 arith::populateArithToLLVMConversionPatterns(converter, patterns);
721 if (
failed(applyPartialConversion(getOperation(),
target,
722 std::move(patterns))))
734struct ArithToLLVMDialectInterface :
public ConvertToLLVMPatternInterface {
735 ArithToLLVMDialectInterface(Dialect *dialect)
736 : ConvertToLLVMPatternInterface(dialect) {}
738 void loadDependentDialects(MLIRContext *context)
const final {
739 context->loadDialect<LLVM::LLVMDialect>();
744 void populateConvertToLLVMConversionPatterns(
745 ConversionTarget &
target, LLVMTypeConverter &typeConverter,
746 RewritePatternSet &patterns)
const final {
747 arith::populateCeilFloorDivExpandOpsPatterns(patterns);
748 arith::populateArithToLLVMConversionPatterns(typeConverter, patterns);
756 dialect->addInterfaces<ArithToLLVMDialectInterface>();
769 patterns.
add<IdentityBitcastLowering>(converter, patterns.
getContext(),
775 ConstrainedAddFOpLowering,
778 AddUIExtendedOpLowering,
779 SubUIExtendedOpLowering,
785 ConstrainedDivFOpLowering,
794 IndexCastOpSILowering,
795 IndexCastOpUILowering,
805 ConstrainedMulFOpLowering,
807 MulSIExtendedOpLowering,
808 MulUIExtendedOpLowering,
815 SelectOpOneToNLowering,
821 ConstrainedSubFOpLowering,
824 ConstrainedTruncFOpLowering,
static LLVMPredType convertCmpPredicate(PredType pred)
static llvm::ManagedStatic< PassManagerOptions > options
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
typename SourceOp::template GenericAdaptor< ArrayRef< ValueRange > > OneToNOpAdaptor
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
MLIRContext is the top-level object for a collection of MLIR operations.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Type getType() const
Return the type of this value.
Basic lowering implementation to rewrite Ops with just one result to the LLVM Dialect.
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
bool isUnsupportedFloatingPointType(const TypeConverter &typeConverter, Type type)
Return "true" if the given type is an unsupported floating point type.
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, Attribute propertiesAttr, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands.
bool opHasUnsupportedFloatingPointTypes(Operation *op, const TypeConverter &typeConverter)
Return "true" if the given op has any unsupported floating point types (either operands or results).
void populateArithToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
void registerConvertArithToLLVMInterface(DialectRegistry ®istry)
Include the generated interface declarations.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.