20 #include <type_traits>
23 #define GEN_PASS_DEF_ARITHTOLLVMCONVERSIONPASS
24 #include "mlir/Conversion/Passes.h.inc"
37 template <
typename SourceOp,
typename TargetOp,
bool Constrained,
38 template <
typename,
typename>
typename AttrConvert =
40 struct ConstrainedVectorConvertToLLVMPattern
46 matchAndRewrite(SourceOp op,
typename SourceOp::Adaptor adaptor,
48 if (Constrained !=
static_cast<bool>(op.getRoundingModeAttr()))
51 AttrConvert>::matchAndRewrite(op, adaptor,
58 struct IdentityBitcastLowering final
63 matchAndRewrite(arith::BitcastOp op, OpAdaptor adaptor,
65 Value src = adaptor.getIn();
66 Type resultType = getTypeConverter()->convertType(op.getType());
67 if (src.
getType() != resultType)
68 return rewriter.notifyMatchFailure(op,
"Types are different");
70 rewriter.replaceOp(op, src);
79 using AddFOpLowering =
81 arith::AttrConvertFastMathToLLVM>;
82 using AddIOpLowering =
84 arith::AttrConvertOverflowToLLVM>;
86 using BitcastOpLowering =
88 using DivFOpLowering =
90 arith::AttrConvertFastMathToLLVM>;
91 using DivSIOpLowering =
93 using DivUIOpLowering =
96 using ExtSIOpLowering =
98 using ExtUIOpLowering =
100 using FPToSIOpLowering =
102 using FPToUIOpLowering =
104 using MaximumFOpLowering =
106 arith::AttrConvertFastMathToLLVM>;
107 using MaxNumFOpLowering =
109 arith::AttrConvertFastMathToLLVM>;
110 using MaxSIOpLowering =
112 using MaxUIOpLowering =
114 using MinimumFOpLowering =
116 arith::AttrConvertFastMathToLLVM>;
117 using MinNumFOpLowering =
119 arith::AttrConvertFastMathToLLVM>;
120 using MinSIOpLowering =
122 using MinUIOpLowering =
124 using MulFOpLowering =
126 arith::AttrConvertFastMathToLLVM>;
127 using MulIOpLowering =
129 arith::AttrConvertOverflowToLLVM>;
130 using NegFOpLowering =
132 arith::AttrConvertFastMathToLLVM>;
134 using RemFOpLowering =
136 arith::AttrConvertFastMathToLLVM>;
137 using RemSIOpLowering =
139 using RemUIOpLowering =
141 using SelectOpLowering =
143 using ShLIOpLowering =
145 arith::AttrConvertOverflowToLLVM>;
146 using ShRSIOpLowering =
148 using ShRUIOpLowering =
150 using SIToFPOpLowering =
152 using SubFOpLowering =
154 arith::AttrConvertFastMathToLLVM>;
155 using SubIOpLowering =
157 arith::AttrConvertOverflowToLLVM>;
158 using TruncFOpLowering =
159 ConstrainedVectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp,
161 using ConstrainedTruncFOpLowering = ConstrainedVectorConvertToLLVMPattern<
162 arith::TruncFOp, LLVM::ConstrainedFPTruncIntr,
true,
163 arith::AttrConverterConstrainedFPToLLVM>;
164 using TruncIOpLowering =
166 arith::AttrConvertOverflowToLLVM>;
167 using UIToFPOpLowering =
180 matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
188 template <
typename OpTy,
typename ExtCastTy>
193 matchAndRewrite(OpTy op,
typename OpTy::Adaptor adaptor,
197 using IndexCastOpSILowering =
198 IndexCastOpLowering<arith::IndexCastOp, LLVM::SExtOp>;
199 using IndexCastOpUILowering =
200 IndexCastOpLowering<arith::IndexCastUIOp, LLVM::ZExtOp>;
202 struct AddUIExtendedOpLowering
207 matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor,
211 template <
typename ArithMulOp,
bool IsSigned>
216 matchAndRewrite(ArithMulOp op,
typename ArithMulOp::Adaptor adaptor,
220 using MulSIExtendedOpLowering =
221 MulIExtendedOpLowering<arith::MulSIExtendedOp, true>;
222 using MulUIExtendedOpLowering =
223 MulIExtendedOpLowering<arith::MulUIExtendedOp, false>;
229 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
237 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
247 matchAndRewrite(arith::SelectOp op, Adaptor adaptor,
258 ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
261 adaptor.getOperands(), op->getAttrs(),
262 *getTypeConverter(), rewriter);
269 template <
typename OpTy,
typename ExtCastTy>
270 LogicalResult IndexCastOpLowering<OpTy, ExtCastTy>::matchAndRewrite(
271 OpTy op,
typename OpTy::Adaptor adaptor,
273 Type resultType = op.getResult().getType();
274 Type targetElementType =
276 Type sourceElementType =
281 if (targetBits == sourceBits) {
287 Type operandType = adaptor.getIn().getType();
288 if (!isa<LLVM::LLVMArrayType>(operandType)) {
289 Type targetType = this->typeConverter->convertType(resultType);
290 if (targetBits < sourceBits)
298 if (!isa<VectorType>(resultType))
302 op.getOperation(), adaptor.getOperands(), *(this->getTypeConverter()),
304 typename OpTy::Adaptor adaptor(operands);
305 if (targetBits < sourceBits) {
306 return LLVM::TruncOp::create(rewriter, op.getLoc(), llvm1DVectorTy,
309 return ExtCastTy::create(rewriter, op.getLoc(), llvm1DVectorTy,
319 LogicalResult AddUIExtendedOpLowering::matchAndRewrite(
320 arith::AddUIExtendedOp op, OpAdaptor adaptor,
322 Type operandType = adaptor.getLhs().getType();
323 Type sumResultType = op.getSum().getType();
324 Type overflowResultType = op.getOverflow().getType();
333 if (!isa<LLVM::LLVMArrayType>(operandType)) {
334 Type newOverflowType = typeConverter->convertType(overflowResultType);
336 LLVM::LLVMStructType::getLiteral(ctx, {sumResultType, newOverflowType});
337 Value addOverflow = LLVM::UAddWithOverflowOp::create(
338 rewriter, loc, structType, adaptor.getLhs(), adaptor.getRhs());
340 LLVM::ExtractValueOp::create(rewriter, loc, addOverflow, 0);
341 Value overflowExtracted =
342 LLVM::ExtractValueOp::create(rewriter, loc, addOverflow, 1);
343 rewriter.
replaceOp(op, {sumExtracted, overflowExtracted});
347 if (!isa<VectorType>(sumResultType))
351 "ND vector types are not supported yet");
358 template <
typename ArithMulOp,
bool IsSigned>
359 LogicalResult MulIExtendedOpLowering<ArithMulOp, IsSigned>::matchAndRewrite(
360 ArithMulOp op,
typename ArithMulOp::Adaptor adaptor,
362 Type resultType = adaptor.getLhs().getType();
373 if (!isa<LLVM::LLVMArrayType>(resultType)) {
375 TypedAttr shiftValAttr;
377 if (
auto intTy = dyn_cast<IntegerType>(resultType)) {
378 unsigned resultBitwidth = intTy.getWidth();
382 auto vecTy = cast<VectorType>(resultType);
383 unsigned resultBitwidth = vecTy.getElementTypeBitWidth();
387 attrTy, APInt(resultBitwidth * 2, resultBitwidth));
389 Type wideType = shiftValAttr.getType();
391 "LLVM dialect should support all signless integer types");
393 using LLVMExtOp = std::conditional_t<IsSigned, LLVM::SExtOp, LLVM::ZExtOp>;
394 Value lhsExt = LLVMExtOp::create(rewriter, loc, wideType, adaptor.getLhs());
395 Value rhsExt = LLVMExtOp::create(rewriter, loc, wideType, adaptor.getRhs());
396 Value mulExt = LLVM::MulOp::create(rewriter, loc, wideType, lhsExt, rhsExt);
399 Value low = LLVM::TruncOp::create(rewriter, loc, resultType, mulExt);
400 Value shiftVal = LLVM::ConstantOp::create(rewriter, loc, shiftValAttr);
401 Value highExt = LLVM::LShrOp::create(rewriter, loc, mulExt, shiftVal);
402 Value high = LLVM::TruncOp::create(rewriter, loc, resultType, highExt);
408 if (!isa<VectorType>(resultType))
412 "ND vector types are not supported yet");
421 template <
typename LLVMPredType,
typename PredType>
423 return static_cast<LLVMPredType
>(pred);
427 CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
429 Type operandType = adaptor.getLhs().getType();
430 Type resultType = op.getResult().getType();
433 if (!isa<LLVM::LLVMArrayType>(operandType)) {
435 op, typeConverter->convertType(resultType),
436 convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()),
437 adaptor.getLhs(), adaptor.getRhs());
441 if (!isa<VectorType>(resultType))
445 op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
447 OpAdaptor adaptor(operands);
448 return LLVM::ICmpOp::create(
449 rewriter, op.getLoc(), llvm1DVectorTy,
450 convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()),
451 adaptor.getLhs(), adaptor.getRhs());
461 CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
463 Type operandType = adaptor.getLhs().getType();
464 Type resultType = op.getResult().getType();
465 LLVM::FastmathFlags fmf =
469 if (!isa<LLVM::LLVMArrayType>(operandType)) {
471 op, typeConverter->convertType(resultType),
472 convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()),
473 adaptor.getLhs(), adaptor.getRhs(), fmf);
477 if (!isa<VectorType>(resultType))
481 op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
483 OpAdaptor adaptor(operands);
484 return LLVM::FCmpOp::create(
485 rewriter, op.getLoc(), llvm1DVectorTy,
486 convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()),
487 adaptor.getLhs(), adaptor.getRhs(), fmf);
499 LogicalResult SelectOpOneToNLowering::matchAndRewrite(
500 arith::SelectOp op, Adaptor adaptor,
503 if (llvm::hasSingleElement(adaptor.getTrueValue()))
505 op,
"not a 1:N conversion, 1:1 pattern will match");
506 if (!op.getCondition().getType().isInteger(1))
508 "non-i1 conditions are not supported");
510 for (
auto [trueValue, falseValue] :
511 llvm::zip_equal(adaptor.getTrueValue(), adaptor.getFalseValue()))
512 results.push_back(arith::SelectOp::create(
513 rewriter, op.getLoc(), op.getCondition(), trueValue, falseValue));
523 struct ArithToLLVMConversionPass
524 :
public impl::ArithToLLVMConversionPassBase<ArithToLLVMConversionPass> {
527 void runOnOperation()
override {
533 options.overrideIndexBitwidth(indexBitwidth);
555 context->loadDialect<LLVM::LLVMDialect>();
572 dialect->addInterfaces<ArithToLLVMDialectInterface>();
593 AddUIExtendedOpLowering,
606 IndexCastOpSILowering,
607 IndexCastOpUILowering,
618 MulSIExtendedOpLowering,
619 MulUIExtendedOpLowering,
626 SelectOpOneToNLowering,
634 ConstrainedTruncFOpLowering,
static LLVMPredType convertCmpPredicate(PredType pred)
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static Value handleMultidimensionalVectors(ImplicitLocOpBuilder &builder, ValueRange operands, int64_t vectorWidth, llvm::function_ref< Value(ValueRange)> compute)
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void replaceOpWithMultiple(Operation *op, SmallVector< SmallVector< Value >> &&newValues)
Replace the given operation with the new value ranges.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
typename SourceOp::template GenericAdaptor< ArrayRef< ValueRange > > OneToNOpAdaptor
Base class for dialect interfaces providing translation to LLVM IR.
virtual void populateConvertToLLVMConversionPatterns(ConversionTarget &target, LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) const =0
Hook for derived dialect interface to provide conversion patterns and mark dialect legal for the conv...
virtual void loadDependentDialects(MLIRContext *context) const
Hook for derived dialect interface to load the dialects they target.
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Basic lowering implementation to rewrite Ops with just one result to the LLVM Dialect.
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, IntegerOverflowFlags overflowFlags=IntegerOverflowFlags::none)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands.
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns)
Add patterns to expand Arith ceil/floor division ops.
void populateArithToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
void registerConvertArithToLLVMInterface(DialectRegistry ®istry)
LLVM::FastmathFlags convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF)
Maps arithmetic fastmath enum values to LLVM enum values.
Include the generated interface declarations.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.