23#define GEN_PASS_DEF_ARITHTOLLVMCONVERSIONPASS
24#include "mlir/Conversion/Passes.h.inc"
37template <
typename SourceOp,
typename TargetOp,
bool Constrained,
38 template <
typename,
typename>
typename AttrConvert =
40 bool FailOnUnsupportedFP =
false>
41struct ConstrainedVectorConvertToLLVMPattern
43 FailOnUnsupportedFP> {
44 using VectorConvertToLLVMPattern<
45 SourceOp, TargetOp, AttrConvert,
46 FailOnUnsupportedFP>::VectorConvertToLLVMPattern;
49 matchAndRewrite(SourceOp op,
typename SourceOp::Adaptor adaptor,
50 ConversionPatternRewriter &rewriter)
const override {
51 if (Constrained !=
static_cast<bool>(op.getRoundingModeAttr()))
53 return VectorConvertToLLVMPattern<
54 SourceOp, TargetOp, AttrConvert,
55 FailOnUnsupportedFP>::matchAndRewrite(op, adaptor, rewriter);
61struct IdentityBitcastLowering final
62 :
public OpConversionPattern<arith::BitcastOp> {
66 matchAndRewrite(arith::BitcastOp op, OpAdaptor adaptor,
67 ConversionPatternRewriter &rewriter)
const final {
68 Value src = adaptor.getIn();
69 Type resultType = getTypeConverter()->convertType(op.getType());
70 if (src.
getType() != resultType)
71 return rewriter.notifyMatchFailure(op,
"Types are different");
73 rewriter.replaceOp(op, src);
90using BitcastOpLowering =
96using DivSIOpLowering =
98using DivUIOpLowering =
103using ExtSIOpLowering =
105using ExtUIOpLowering =
107using FPToSIOpLowering =
111using FPToUIOpLowering =
115using MaximumFOpLowering =
119using MaxNumFOpLowering =
123using MaxSIOpLowering =
125using MaxUIOpLowering =
127using MinimumFOpLowering =
131using MinNumFOpLowering =
135using MinSIOpLowering =
137using MinUIOpLowering =
139using MulFOpLowering =
143using MulIOpLowering =
146using NegFOpLowering =
151using RemFOpLowering =
155using RemSIOpLowering =
157using RemUIOpLowering =
159using SelectOpLowering =
161using ShLIOpLowering =
164using ShRSIOpLowering =
166using ShRUIOpLowering =
168using SIToFPOpLowering =
170using SubFOpLowering =
174using SubIOpLowering =
177using TruncFOpLowering =
178 ConstrainedVectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp,
181using ConstrainedTruncFOpLowering = ConstrainedVectorConvertToLLVMPattern<
182 arith::TruncFOp, LLVM::ConstrainedFPTruncIntr,
true,
184using TruncIOpLowering =
187using UIToFPOpLowering =
202 matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
203 ConversionPatternRewriter &rewriter)
const override;
210template <
typename OpTy,
typename ExtCastTy>
212 using ConvertOpToLLVMPattern<OpTy>::ConvertOpToLLVMPattern;
215 matchAndRewrite(OpTy op,
typename OpTy::Adaptor adaptor,
216 ConversionPatternRewriter &rewriter)
const override;
219using IndexCastOpSILowering =
220 IndexCastOpLowering<arith::IndexCastOp, LLVM::SExtOp>;
221using IndexCastOpUILowering =
222 IndexCastOpLowering<arith::IndexCastUIOp, LLVM::ZExtOp>;
224struct AddUIExtendedOpLowering
229 matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor,
230 ConversionPatternRewriter &rewriter)
const override;
233template <
typename ArithMulOp,
bool IsSigned>
235 using ConvertOpToLLVMPattern<ArithMulOp>::ConvertOpToLLVMPattern;
238 matchAndRewrite(ArithMulOp op,
typename ArithMulOp::Adaptor adaptor,
239 ConversionPatternRewriter &rewriter)
const override;
242using MulSIExtendedOpLowering =
243 MulIExtendedOpLowering<arith::MulSIExtendedOp, true>;
244using MulUIExtendedOpLowering =
245 MulIExtendedOpLowering<arith::MulUIExtendedOp, false>;
251 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
252 ConversionPatternRewriter &rewriter)
const override;
259 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
260 ConversionPatternRewriter &rewriter)
const override;
268 matchAndRewrite(arith::SelectOp op, Adaptor adaptor,
269 ConversionPatternRewriter &rewriter)
const override;
279ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
280 ConversionPatternRewriter &rewriter)
const {
282 adaptor.getOperands(), op->getAttrs(),
283 *getTypeConverter(), rewriter);
290template <
typename OpTy,
typename ExtCastTy>
291LogicalResult IndexCastOpLowering<OpTy, ExtCastTy>::matchAndRewrite(
292 OpTy op,
typename OpTy::Adaptor adaptor,
293 ConversionPatternRewriter &rewriter)
const {
294 Type resultType = op.getResult().getType();
302 if (targetBits == sourceBits) {
303 rewriter.replaceOp(op, adaptor.getIn());
308 Type operandType = adaptor.getIn().getType();
309 if (!isa<LLVM::LLVMArrayType>(operandType)) {
310 Type targetType = this->typeConverter->convertType(resultType);
311 if (targetBits < sourceBits)
312 rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, targetType,
315 rewriter.replaceOpWithNewOp<ExtCastTy>(op, targetType, adaptor.getIn());
319 if (!isa<VectorType>(resultType))
320 return rewriter.notifyMatchFailure(op,
"expected vector result type");
323 op.getOperation(), adaptor.getOperands(), *(this->getTypeConverter()),
325 typename OpTy::Adaptor adaptor(operands);
326 if (targetBits < sourceBits) {
327 return LLVM::TruncOp::create(rewriter, op.getLoc(), llvm1DVectorTy,
330 return ExtCastTy::create(rewriter, op.getLoc(), llvm1DVectorTy,
340LogicalResult AddUIExtendedOpLowering::matchAndRewrite(
341 arith::AddUIExtendedOp op, OpAdaptor adaptor,
342 ConversionPatternRewriter &rewriter)
const {
343 Type operandType = adaptor.getLhs().getType();
344 Type sumResultType = op.getSum().getType();
345 Type overflowResultType = op.getOverflow().getType();
354 if (!isa<LLVM::LLVMArrayType>(operandType)) {
355 Type newOverflowType = typeConverter->convertType(overflowResultType);
357 LLVM::LLVMStructType::getLiteral(ctx, {sumResultType, newOverflowType});
358 Value addOverflow = LLVM::UAddWithOverflowOp::create(
359 rewriter, loc, structType, adaptor.getLhs(), adaptor.getRhs());
361 LLVM::ExtractValueOp::create(rewriter, loc, addOverflow, 0);
362 Value overflowExtracted =
363 LLVM::ExtractValueOp::create(rewriter, loc, addOverflow, 1);
364 rewriter.replaceOp(op, {sumExtracted, overflowExtracted});
368 if (!isa<VectorType>(sumResultType))
369 return rewriter.notifyMatchFailure(loc,
"expected vector result types");
371 return rewriter.notifyMatchFailure(loc,
372 "ND vector types are not supported yet");
379template <
typename ArithMulOp,
bool IsSigned>
380LogicalResult MulIExtendedOpLowering<ArithMulOp, IsSigned>::matchAndRewrite(
381 ArithMulOp op,
typename ArithMulOp::Adaptor adaptor,
382 ConversionPatternRewriter &rewriter)
const {
383 Type resultType = adaptor.getLhs().getType();
388 Location loc = op.getLoc();
394 if (!isa<LLVM::LLVMArrayType>(resultType)) {
396 TypedAttr shiftValAttr;
398 if (
auto intTy = dyn_cast<IntegerType>(resultType)) {
399 unsigned resultBitwidth = intTy.getWidth();
400 auto attrTy = rewriter.getIntegerType(resultBitwidth * 2);
401 shiftValAttr = rewriter.getIntegerAttr(attrTy, resultBitwidth);
403 auto vecTy = cast<VectorType>(resultType);
404 unsigned resultBitwidth = vecTy.getElementTypeBitWidth();
405 auto attrTy = VectorType::get(
406 vecTy.getShape(), rewriter.getIntegerType(resultBitwidth * 2));
407 shiftValAttr = SplatElementsAttr::get(
408 attrTy, APInt(resultBitwidth * 2, resultBitwidth));
410 Type wideType = shiftValAttr.getType();
412 "LLVM dialect should support all signless integer types");
414 using LLVMExtOp = std::conditional_t<IsSigned, LLVM::SExtOp, LLVM::ZExtOp>;
415 Value lhsExt = LLVMExtOp::create(rewriter, loc, wideType, adaptor.getLhs());
416 Value rhsExt = LLVMExtOp::create(rewriter, loc, wideType, adaptor.getRhs());
417 Value mulExt = LLVM::MulOp::create(rewriter, loc, wideType, lhsExt, rhsExt);
420 Value low = LLVM::TruncOp::create(rewriter, loc, resultType, mulExt);
421 Value shiftVal = LLVM::ConstantOp::create(rewriter, loc, shiftValAttr);
422 Value highExt = LLVM::LShrOp::create(rewriter, loc, mulExt, shiftVal);
423 Value high = LLVM::TruncOp::create(rewriter, loc, resultType, highExt);
425 rewriter.replaceOp(op, {low, high});
429 if (!isa<VectorType>(resultType))
430 return rewriter.notifyMatchFailure(op,
"expected vector result type");
432 return rewriter.notifyMatchFailure(op,
433 "ND vector types are not supported yet");
442template <
typename LLVMPredType,
typename PredType>
444 return static_cast<LLVMPredType
>(pred);
448CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
449 ConversionPatternRewriter &rewriter)
const {
450 Type operandType = adaptor.getLhs().getType();
451 Type resultType = op.getResult().getType();
454 if (!isa<LLVM::LLVMArrayType>(operandType)) {
455 rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
456 op, typeConverter->convertType(resultType),
458 adaptor.getLhs(), adaptor.getRhs());
462 if (!isa<VectorType>(resultType))
463 return rewriter.notifyMatchFailure(op,
"expected vector result type");
466 op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
467 [&](Type llvm1DVectorTy,
ValueRange operands) {
468 OpAdaptor adaptor(operands);
469 return LLVM::ICmpOp::create(
470 rewriter, op.getLoc(), llvm1DVectorTy,
472 adaptor.getLhs(), adaptor.getRhs());
482CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
483 ConversionPatternRewriter &rewriter)
const {
484 Type operandType = adaptor.getLhs().getType();
485 Type resultType = op.getResult().getType();
486 LLVM::FastmathFlags fmf =
487 arith::convertArithFastMathFlagsToLLVM(op.getFastmath());
490 if (!isa<LLVM::LLVMArrayType>(operandType)) {
491 rewriter.replaceOpWithNewOp<LLVM::FCmpOp>(
492 op, typeConverter->convertType(resultType),
494 adaptor.getLhs(), adaptor.getRhs(), fmf);
498 if (!isa<VectorType>(resultType))
499 return rewriter.notifyMatchFailure(op,
"expected vector result type");
502 op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
503 [&](Type llvm1DVectorTy,
ValueRange operands) {
504 OpAdaptor adaptor(operands);
505 return LLVM::FCmpOp::create(
506 rewriter, op.getLoc(), llvm1DVectorTy,
508 adaptor.getLhs(), adaptor.getRhs(), fmf);
520LogicalResult SelectOpOneToNLowering::matchAndRewrite(
521 arith::SelectOp op, Adaptor adaptor,
522 ConversionPatternRewriter &rewriter)
const {
524 if (llvm::hasSingleElement(adaptor.getTrueValue()))
525 return rewriter.notifyMatchFailure(
526 op,
"not a 1:N conversion, 1:1 pattern will match");
527 if (!op.getCondition().getType().isInteger(1))
528 return rewriter.notifyMatchFailure(op,
529 "non-i1 conditions are not supported");
530 SmallVector<Value> results;
531 for (
auto [trueValue, falseValue] :
532 llvm::zip_equal(adaptor.getTrueValue(), adaptor.getFalseValue()))
533 results.push_back(arith::SelectOp::create(
534 rewriter, op.getLoc(), op.getCondition(), trueValue, falseValue));
535 rewriter.replaceOpWithMultiple(op, {results});
544struct ArithToLLVMConversionPass
545 :
public impl::ArithToLLVMConversionPassBase<ArithToLLVMConversionPass> {
548 void runOnOperation()
override {
554 options.overrideIndexBitwidth(indexBitwidth);
557 arith::populateCeilFloorDivExpandOpsPatterns(
patterns);
558 arith::populateArithToLLVMConversionPatterns(converter,
patterns);
560 if (
failed(applyPartialConversion(getOperation(),
target,
573struct ArithToLLVMDialectInterface :
public ConvertToLLVMPatternInterface {
575 void loadDependentDialects(MLIRContext *context)
const final {
576 context->loadDialect<LLVM::LLVMDialect>();
581 void populateConvertToLLVMConversionPatterns(
582 ConversionTarget &
target, LLVMTypeConverter &typeConverter,
583 RewritePatternSet &
patterns)
const final {
584 arith::populateCeilFloorDivExpandOpsPatterns(
patterns);
585 arith::populateArithToLLVMConversionPatterns(typeConverter,
patterns);
593 dialect->addInterfaces<ArithToLLVMDialectInterface>();
614 AddUIExtendedOpLowering,
627 IndexCastOpSILowering,
628 IndexCastOpUILowering,
639 MulSIExtendedOpLowering,
640 MulUIExtendedOpLowering,
647 SelectOpOneToNLowering,
655 ConstrainedTruncFOpLowering,
static LLVMPredType convertCmpPredicate(PredType pred)
static llvm::ManagedStatic< PassManagerOptions > options
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
typename SourceOp::template GenericAdaptor< ArrayRef< ValueRange > > OneToNOpAdaptor
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Basic lowering implementation to rewrite Ops with just one result to the LLVM Dialect.
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, IntegerOverflowFlags overflowFlags=IntegerOverflowFlags::none)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands.
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
void populateArithToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
void registerConvertArithToLLVMInterface(DialectRegistry ®istry)
Include the generated interface declarations.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns