24#define GEN_PASS_DEF_ARITHTOLLVMCONVERSIONPASS
25#include "mlir/Conversion/Passes.h.inc"
45template <
typename SourceOp,
typename TargetOp,
bool HasRoundingMode,
46 template <
typename,
typename>
typename AttrConvert =
48 bool FailOnUnsupportedFP =
false>
49struct ConstrainedVectorConvertToLLVMPattern
51 FailOnUnsupportedFP> {
52 using VectorConvertToLLVMPattern<
53 SourceOp, TargetOp, AttrConvert,
54 FailOnUnsupportedFP>::VectorConvertToLLVMPattern;
57 matchAndRewrite(SourceOp op,
typename SourceOp::Adaptor adaptor,
58 ConversionPatternRewriter &rewriter)
const override {
59 if (HasRoundingMode !=
static_cast<bool>(op.getRoundingModeAttr()))
61 return VectorConvertToLLVMPattern<
62 SourceOp, TargetOp, AttrConvert,
63 FailOnUnsupportedFP>::matchAndRewrite(op, adaptor, rewriter);
69struct IdentityBitcastLowering final
70 :
public OpConversionPattern<arith::BitcastOp> {
74 matchAndRewrite(arith::BitcastOp op, OpAdaptor adaptor,
75 ConversionPatternRewriter &rewriter)
const final {
76 Value src = adaptor.getIn();
77 Type resultType = getTypeConverter()->convertType(op.getType());
78 if (src.
getType() != resultType)
79 return rewriter.notifyMatchFailure(op,
"Types are different");
81 rewriter.replaceOp(op, src);
91 ConstrainedVectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
95using ConstrainedAddFOpLowering = ConstrainedVectorConvertToLLVMPattern<
96 arith::AddFOp, LLVM::ConstrainedFAddIntr,
true,
102using BitcastOpLowering =
104using DivFOpLowering =
105 ConstrainedVectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp,
109using ConstrainedDivFOpLowering = ConstrainedVectorConvertToLLVMPattern<
110 arith::DivFOp, LLVM::ConstrainedFDivIntr,
true,
112using DivSIOpLowering =
114using DivUIOpLowering =
119using ExtSIOpLowering =
121using ExtUIOpLowering =
124using FPToSIOpLowering =
128using FPToUIOpLowering =
132using MaximumFOpLowering =
136using MaxNumFOpLowering =
140using MaxSIOpLowering =
142using MaxUIOpLowering =
144using MinimumFOpLowering =
148using MinNumFOpLowering =
152using MinSIOpLowering =
154using MinUIOpLowering =
156using MulFOpLowering =
157 ConstrainedVectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
161using ConstrainedMulFOpLowering = ConstrainedVectorConvertToLLVMPattern<
162 arith::MulFOp, LLVM::ConstrainedFMulIntr,
true,
164using MulIOpLowering =
167using NegFOpLowering =
172using RemFOpLowering =
176using RemSIOpLowering =
178using RemUIOpLowering =
180using SelectOpLowering =
182using ShLIOpLowering =
185using ShRSIOpLowering =
187using ShRUIOpLowering =
189using SIToFPOpLowering =
191using SubFOpLowering =
192 ConstrainedVectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
196using ConstrainedSubFOpLowering = ConstrainedVectorConvertToLLVMPattern<
197 arith::SubFOp, LLVM::ConstrainedFSubIntr,
true,
199using SubIOpLowering =
202using TruncFOpLowering =
203 ConstrainedVectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp,
207using ConstrainedTruncFOpLowering = ConstrainedVectorConvertToLLVMPattern<
208 arith::TruncFOp, LLVM::ConstrainedFPTruncIntr,
true,
210using TruncIOpLowering =
213using UIToFPOpLowering =
228 matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
229 ConversionPatternRewriter &rewriter)
const override;
236template <
typename OpTy,
typename ExtCastTy>
238 using ConvertOpToLLVMPattern<OpTy>::ConvertOpToLLVMPattern;
241 matchAndRewrite(OpTy op,
typename OpTy::Adaptor adaptor,
242 ConversionPatternRewriter &rewriter)
const override;
245using IndexCastOpSILowering =
246 IndexCastOpLowering<arith::IndexCastOp, LLVM::SExtOp>;
247using IndexCastOpUILowering =
248 IndexCastOpLowering<arith::IndexCastUIOp, LLVM::ZExtOp>;
250struct AddUIExtendedOpLowering
255 matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor,
256 ConversionPatternRewriter &rewriter)
const override;
259template <
typename ArithMulOp,
bool IsSigned>
261 using ConvertOpToLLVMPattern<ArithMulOp>::ConvertOpToLLVMPattern;
264 matchAndRewrite(ArithMulOp op,
typename ArithMulOp::Adaptor adaptor,
265 ConversionPatternRewriter &rewriter)
const override;
268using MulSIExtendedOpLowering =
269 MulIExtendedOpLowering<arith::MulSIExtendedOp, true>;
270using MulUIExtendedOpLowering =
271 MulIExtendedOpLowering<arith::MulUIExtendedOp, false>;
277 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
278 ConversionPatternRewriter &rewriter)
const override;
285 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
286 ConversionPatternRewriter &rewriter)
const override;
298 matchAndRewrite(arith::ConvertFOp op, OpAdaptor adaptor,
299 ConversionPatternRewriter &rewriter)
const override {
301 *getTypeConverter()))
302 return rewriter.notifyMatchFailure(op,
"unsupported floating point type");
308 assert((srcType.isBF16() && dstType.isF16()) ||
309 (srcType.isF16() && dstType.isBF16()) &&
310 "only bf16 <-> f16 conversions are supported");
312 Type convertedType = getTypeConverter()->convertType(op.getType());
314 return rewriter.notifyMatchFailure(op,
"failed to convert result type");
316 Value input = adaptor.getIn();
317 Location loc = op.getLoc();
319 if (!isa<LLVM::LLVMArrayType>(input.
getType())) {
320 rewriter.replaceOp(op,
321 emitConversion(rewriter, loc, input, convertedType));
325 if (!isa<VectorType>(op.getType()))
326 return rewriter.notifyMatchFailure(op,
"expected vector result type");
329 op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
330 [&](Type llvm1DVectorTy,
ValueRange operands) -> Value {
331 return emitConversion(rewriter, loc, operands.front(),
338 static Value emitConversion(ConversionPatternRewriter &rewriter, Location loc,
339 Value input, Type targetType) {
340 Type f32Scalar = Float32Type::get(rewriter.getContext());
341 Type f32Ty = f32Scalar;
342 if (
auto vecTy = dyn_cast<VectorType>(targetType))
343 f32Ty = VectorType::get(vecTy.getShape(), f32Scalar);
345 Value ext = LLVM::FPExtOp::create(rewriter, loc, f32Ty, input);
346 return LLVM::FPTruncOp::create(rewriter, loc, targetType, ext);
355 matchAndRewrite(arith::SelectOp op, Adaptor adaptor,
356 ConversionPatternRewriter &rewriter)
const override;
366ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
367 ConversionPatternRewriter &rewriter)
const {
369 adaptor.getOperands(), op->getAttrs(),
371 *getTypeConverter(), rewriter);
378template <
typename OpTy,
typename ExtCastTy>
379LogicalResult IndexCastOpLowering<OpTy, ExtCastTy>::matchAndRewrite(
380 OpTy op,
typename OpTy::Adaptor adaptor,
381 ConversionPatternRewriter &rewriter)
const {
382 Type resultType = op.getResult().getType();
383 Type targetElementType =
390 if (targetBits == sourceBits) {
391 rewriter.replaceOp(op, adaptor.getIn());
398 if (isa<MemRefType>(op.getIn().getType())) {
399 rewriter.replaceOp(op, adaptor.getIn());
403 bool isNonNeg =
false;
404 if constexpr (std::is_same_v<ExtCastTy, LLVM::ZExtOp>)
405 isNonNeg = op.getNonNeg();
408 Type operandType = adaptor.getIn().getType();
409 if (!isa<LLVM::LLVMArrayType>(operandType)) {
410 Type targetType = this->typeConverter->convertType(resultType);
411 if (targetBits < sourceBits) {
412 rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, targetType,
415 auto extOp = rewriter.replaceOpWithNewOp<ExtCastTy>(op, targetType,
417 if constexpr (std::is_same_v<ExtCastTy, LLVM::ZExtOp>)
418 extOp.setNonNeg(isNonNeg);
423 if (!isa<VectorType>(resultType))
424 return rewriter.notifyMatchFailure(op,
"expected vector result type");
427 op.getOperation(), adaptor.getOperands(), *(this->getTypeConverter()),
429 typename OpTy::Adaptor adaptor(operands);
430 if (targetBits < sourceBits) {
431 return LLVM::TruncOp::create(rewriter, op.getLoc(), llvm1DVectorTy,
434 auto extOp = ExtCastTy::create(rewriter, op.getLoc(), llvm1DVectorTy,
436 if constexpr (std::is_same_v<ExtCastTy, LLVM::ZExtOp>) {
438 extOp.setNonNeg(
true);
449LogicalResult AddUIExtendedOpLowering::matchAndRewrite(
450 arith::AddUIExtendedOp op, OpAdaptor adaptor,
451 ConversionPatternRewriter &rewriter)
const {
452 Type operandType = adaptor.getLhs().getType();
453 Type sumResultType = op.getSum().getType();
454 Type overflowResultType = op.getOverflow().getType();
459 MLIRContext *ctx = rewriter.getContext();
460 Location loc = op.getLoc();
463 if (!isa<LLVM::LLVMArrayType>(operandType)) {
464 Type newOverflowType = typeConverter->convertType(overflowResultType);
466 LLVM::LLVMStructType::getLiteral(ctx, {sumResultType, newOverflowType});
467 Value addOverflow = LLVM::UAddWithOverflowOp::create(
468 rewriter, loc, structType, adaptor.getLhs(), adaptor.getRhs());
470 LLVM::ExtractValueOp::create(rewriter, loc, addOverflow, 0);
471 Value overflowExtracted =
472 LLVM::ExtractValueOp::create(rewriter, loc, addOverflow, 1);
473 rewriter.replaceOp(op, {sumExtracted, overflowExtracted});
477 if (!isa<VectorType>(sumResultType))
478 return rewriter.notifyMatchFailure(loc,
"expected vector result types");
480 return rewriter.notifyMatchFailure(loc,
481 "ND vector types are not supported yet");
488template <
typename ArithMulOp,
bool IsSigned>
489LogicalResult MulIExtendedOpLowering<ArithMulOp, IsSigned>::matchAndRewrite(
490 ArithMulOp op,
typename ArithMulOp::Adaptor adaptor,
491 ConversionPatternRewriter &rewriter)
const {
492 Type resultType = adaptor.getLhs().getType();
497 Location loc = op.getLoc();
503 if (!isa<LLVM::LLVMArrayType>(resultType)) {
505 TypedAttr shiftValAttr;
507 if (
auto intTy = dyn_cast<IntegerType>(resultType)) {
508 unsigned resultBitwidth = intTy.getWidth();
509 auto attrTy = rewriter.getIntegerType(resultBitwidth * 2);
510 shiftValAttr = rewriter.getIntegerAttr(attrTy, resultBitwidth);
512 auto vecTy = cast<VectorType>(resultType);
513 unsigned resultBitwidth = vecTy.getElementTypeBitWidth();
514 auto attrTy = VectorType::get(
515 vecTy.getShape(), rewriter.getIntegerType(resultBitwidth * 2));
516 shiftValAttr = SplatElementsAttr::get(
517 attrTy, APInt(resultBitwidth * 2, resultBitwidth));
519 Type wideType = shiftValAttr.getType();
521 "LLVM dialect should support all signless integer types");
523 using LLVMExtOp = std::conditional_t<IsSigned, LLVM::SExtOp, LLVM::ZExtOp>;
524 Value lhsExt = LLVMExtOp::create(rewriter, loc, wideType, adaptor.getLhs());
525 Value rhsExt = LLVMExtOp::create(rewriter, loc, wideType, adaptor.getRhs());
526 Value mulExt = LLVM::MulOp::create(rewriter, loc, wideType, lhsExt, rhsExt);
529 Value low = LLVM::TruncOp::create(rewriter, loc, resultType, mulExt);
530 Value shiftVal = LLVM::ConstantOp::create(rewriter, loc, shiftValAttr);
531 Value highExt = LLVM::LShrOp::create(rewriter, loc, mulExt, shiftVal);
532 Value high = LLVM::TruncOp::create(rewriter, loc, resultType, highExt);
534 rewriter.replaceOp(op, {low, high});
538 if (!isa<VectorType>(resultType))
539 return rewriter.notifyMatchFailure(op,
"expected vector result type");
541 return rewriter.notifyMatchFailure(op,
542 "ND vector types are not supported yet");
551template <
typename LLVMPredType,
typename PredType>
553 return static_cast<LLVMPredType
>(pred);
557CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
558 ConversionPatternRewriter &rewriter)
const {
559 Type operandType = adaptor.getLhs().getType();
560 Type resultType = op.getResult().getType();
563 if (!isa<LLVM::LLVMArrayType>(operandType)) {
564 rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
565 op, typeConverter->convertType(resultType),
567 adaptor.getLhs(), adaptor.getRhs());
571 if (!isa<VectorType>(resultType))
572 return rewriter.notifyMatchFailure(op,
"expected vector result type");
575 op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
576 [&](Type llvm1DVectorTy,
ValueRange operands) {
577 OpAdaptor adaptor(operands);
578 return LLVM::ICmpOp::create(
579 rewriter, op.getLoc(), llvm1DVectorTy,
581 adaptor.getLhs(), adaptor.getRhs());
591CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
592 ConversionPatternRewriter &rewriter)
const {
594 op.getLhs().getType()))
595 return rewriter.notifyMatchFailure(op,
"unsupported floating point type");
597 Type operandType = adaptor.getLhs().getType();
598 Type resultType = op.getResult().getType();
599 LLVM::FastmathFlags fmf =
600 arith::convertArithFastMathFlagsToLLVM(op.getFastmath());
603 if (!isa<LLVM::LLVMArrayType>(operandType)) {
604 rewriter.replaceOpWithNewOp<LLVM::FCmpOp>(
605 op, typeConverter->convertType(resultType),
607 adaptor.getLhs(), adaptor.getRhs(), fmf);
611 if (!isa<VectorType>(resultType))
612 return rewriter.notifyMatchFailure(op,
"expected vector result type");
615 op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
616 [&](Type llvm1DVectorTy,
ValueRange operands) {
617 OpAdaptor adaptor(operands);
618 return LLVM::FCmpOp::create(
619 rewriter, op.getLoc(), llvm1DVectorTy,
621 adaptor.getLhs(), adaptor.getRhs(), fmf);
633LogicalResult SelectOpOneToNLowering::matchAndRewrite(
634 arith::SelectOp op, Adaptor adaptor,
635 ConversionPatternRewriter &rewriter)
const {
637 if (llvm::hasSingleElement(adaptor.getTrueValue()))
638 return rewriter.notifyMatchFailure(
639 op,
"not a 1:N conversion, 1:1 pattern will match");
640 if (!op.getCondition().getType().isInteger(1))
641 return rewriter.notifyMatchFailure(op,
642 "non-i1 conditions are not supported");
643 SmallVector<Value> results;
644 for (
auto [trueValue, falseValue] :
645 llvm::zip_equal(adaptor.getTrueValue(), adaptor.getFalseValue()))
646 results.push_back(arith::SelectOp::create(
647 rewriter, op.getLoc(), op.getCondition(), trueValue, falseValue));
648 rewriter.replaceOpWithMultiple(op, {results});
657struct ArithToLLVMConversionPass
658 :
public impl::ArithToLLVMConversionPassBase<ArithToLLVMConversionPass> {
661 void runOnOperation()
override {
667 options.overrideIndexBitwidth(indexBitwidth);
670 arith::populateCeilFloorDivExpandOpsPatterns(patterns);
671 arith::populateArithToLLVMConversionPatterns(converter, patterns);
673 if (
failed(applyPartialConversion(getOperation(),
target,
674 std::move(patterns))))
686struct ArithToLLVMDialectInterface :
public ConvertToLLVMPatternInterface {
687 ArithToLLVMDialectInterface(Dialect *dialect)
688 : ConvertToLLVMPatternInterface(dialect) {}
690 void loadDependentDialects(MLIRContext *context)
const final {
691 context->loadDialect<LLVM::LLVMDialect>();
696 void populateConvertToLLVMConversionPatterns(
697 ConversionTarget &
target, LLVMTypeConverter &typeConverter,
698 RewritePatternSet &patterns)
const final {
699 arith::populateCeilFloorDivExpandOpsPatterns(patterns);
700 arith::populateArithToLLVMConversionPatterns(typeConverter, patterns);
708 dialect->addInterfaces<ArithToLLVMDialectInterface>();
721 patterns.
add<IdentityBitcastLowering>(converter, patterns.
getContext(),
727 ConstrainedAddFOpLowering,
730 AddUIExtendedOpLowering,
736 ConstrainedDivFOpLowering,
745 IndexCastOpSILowering,
746 IndexCastOpUILowering,
756 ConstrainedMulFOpLowering,
758 MulSIExtendedOpLowering,
759 MulUIExtendedOpLowering,
766 SelectOpOneToNLowering,
772 ConstrainedSubFOpLowering,
775 ConstrainedTruncFOpLowering,
static LLVMPredType convertCmpPredicate(PredType pred)
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
typename SourceOp::template GenericAdaptor< ArrayRef< ValueRange > > OneToNOpAdaptor
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
MLIRContext is the top-level object for a collection of MLIR operations.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Basic lowering implementation to rewrite Ops with just one result to the LLVM Dialect.
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
bool isUnsupportedFloatingPointType(const TypeConverter &typeConverter, Type type)
Return "true" if the given type is an unsupported floating point type.
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, Attribute propertiesAttr, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands.
bool opHasUnsupportedFloatingPointTypes(Operation *op, const TypeConverter &typeConverter)
Return "true" if the given op has any unsupported floating point types (either operands or results).
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
void populateArithToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
void registerConvertArithToLLVMInterface(DialectRegistry ®istry)
Include the generated interface declarations.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.