21 #include <type_traits>
24 #define GEN_PASS_DEF_ARITHTOLLVMCONVERSIONPASS
25 #include "mlir/Conversion/Passes.h.inc"
38 template <
typename SourceOp,
typename TargetOp,
bool Constrained,
39 template <
typename,
typename>
typename AttrConvert =
41 struct ConstrainedVectorConvertToLLVMPattern
47 matchAndRewrite(SourceOp op,
typename SourceOp::Adaptor adaptor,
49 if (Constrained !=
static_cast<bool>(op.getRoundingModeAttr()))
52 AttrConvert>::matchAndRewrite(op, adaptor,
59 struct IdentityBitcastLowering final
64 matchAndRewrite(arith::BitcastOp op, OpAdaptor adaptor,
66 Value src = adaptor.getIn();
67 Type resultType = getTypeConverter()->convertType(op.getType());
68 if (src.
getType() != resultType)
69 return rewriter.notifyMatchFailure(op,
"Types are different");
71 rewriter.replaceOp(op, src);
80 using AddFOpLowering =
82 arith::AttrConvertFastMathToLLVM>;
83 using AddIOpLowering =
85 arith::AttrConvertOverflowToLLVM>;
87 using BitcastOpLowering =
89 using DivFOpLowering =
91 arith::AttrConvertFastMathToLLVM>;
92 using DivSIOpLowering =
94 using DivUIOpLowering =
97 using ExtSIOpLowering =
99 using ExtUIOpLowering =
101 using FPToSIOpLowering =
103 using FPToUIOpLowering =
105 using MaximumFOpLowering =
107 arith::AttrConvertFastMathToLLVM>;
108 using MaxNumFOpLowering =
110 arith::AttrConvertFastMathToLLVM>;
111 using MaxSIOpLowering =
113 using MaxUIOpLowering =
115 using MinimumFOpLowering =
117 arith::AttrConvertFastMathToLLVM>;
118 using MinNumFOpLowering =
120 arith::AttrConvertFastMathToLLVM>;
121 using MinSIOpLowering =
123 using MinUIOpLowering =
125 using MulFOpLowering =
127 arith::AttrConvertFastMathToLLVM>;
128 using MulIOpLowering =
130 arith::AttrConvertOverflowToLLVM>;
131 using NegFOpLowering =
133 arith::AttrConvertFastMathToLLVM>;
135 using RemFOpLowering =
137 arith::AttrConvertFastMathToLLVM>;
138 using RemSIOpLowering =
140 using RemUIOpLowering =
142 using SelectOpLowering =
144 using ShLIOpLowering =
146 arith::AttrConvertOverflowToLLVM>;
147 using ShRSIOpLowering =
149 using ShRUIOpLowering =
151 using SIToFPOpLowering =
153 using SubFOpLowering =
155 arith::AttrConvertFastMathToLLVM>;
156 using SubIOpLowering =
158 arith::AttrConvertOverflowToLLVM>;
159 using TruncFOpLowering =
160 ConstrainedVectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp,
162 using ConstrainedTruncFOpLowering = ConstrainedVectorConvertToLLVMPattern<
163 arith::TruncFOp, LLVM::ConstrainedFPTruncIntr,
true,
164 arith::AttrConverterConstrainedFPToLLVM>;
165 using TruncIOpLowering =
167 using UIToFPOpLowering =
180 matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
188 template <
typename OpTy,
typename ExtCastTy>
193 matchAndRewrite(OpTy op,
typename OpTy::Adaptor adaptor,
197 using IndexCastOpSILowering =
198 IndexCastOpLowering<arith::IndexCastOp, LLVM::SExtOp>;
199 using IndexCastOpUILowering =
200 IndexCastOpLowering<arith::IndexCastUIOp, LLVM::ZExtOp>;
202 struct AddUIExtendedOpLowering
207 matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor,
211 template <
typename ArithMulOp,
bool IsSigned>
216 matchAndRewrite(ArithMulOp op,
typename ArithMulOp::Adaptor adaptor,
220 using MulSIExtendedOpLowering =
221 MulIExtendedOpLowering<arith::MulSIExtendedOp, true>;
222 using MulUIExtendedOpLowering =
223 MulIExtendedOpLowering<arith::MulUIExtendedOp, false>;
229 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
237 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
248 ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
251 adaptor.getOperands(), op->getAttrs(),
252 *getTypeConverter(), rewriter);
259 template <
typename OpTy,
typename ExtCastTy>
260 LogicalResult IndexCastOpLowering<OpTy, ExtCastTy>::matchAndRewrite(
261 OpTy op,
typename OpTy::Adaptor adaptor,
263 Type resultType = op.getResult().getType();
264 Type targetElementType =
266 Type sourceElementType =
271 if (targetBits == sourceBits) {
277 Type operandType = adaptor.getIn().getType();
278 if (!isa<LLVM::LLVMArrayType>(operandType)) {
279 Type targetType = this->typeConverter->convertType(resultType);
280 if (targetBits < sourceBits)
288 if (!isa<VectorType>(resultType))
292 op.getOperation(), adaptor.getOperands(), *(this->getTypeConverter()),
294 typename OpTy::Adaptor adaptor(operands);
295 if (targetBits < sourceBits) {
296 return rewriter.create<LLVM::TruncOp>(op.getLoc(), llvm1DVectorTy,
299 return rewriter.
create<ExtCastTy>(op.getLoc(), llvm1DVectorTy,
309 LogicalResult AddUIExtendedOpLowering::matchAndRewrite(
310 arith::AddUIExtendedOp op, OpAdaptor adaptor,
312 Type operandType = adaptor.getLhs().getType();
313 Type sumResultType = op.getSum().getType();
314 Type overflowResultType = op.getOverflow().getType();
323 if (!isa<LLVM::LLVMArrayType>(operandType)) {
324 Type newOverflowType = typeConverter->convertType(overflowResultType);
326 LLVM::LLVMStructType::getLiteral(ctx, {sumResultType, newOverflowType});
327 Value addOverflow = rewriter.
create<LLVM::UAddWithOverflowOp>(
328 loc, structType, adaptor.getLhs(), adaptor.getRhs());
330 rewriter.
create<LLVM::ExtractValueOp>(loc, addOverflow, 0);
331 Value overflowExtracted =
332 rewriter.
create<LLVM::ExtractValueOp>(loc, addOverflow, 1);
333 rewriter.
replaceOp(op, {sumExtracted, overflowExtracted});
337 if (!isa<VectorType>(sumResultType))
341 "ND vector types are not supported yet");
348 template <
typename ArithMulOp,
bool IsSigned>
349 LogicalResult MulIExtendedOpLowering<ArithMulOp, IsSigned>::matchAndRewrite(
350 ArithMulOp op,
typename ArithMulOp::Adaptor adaptor,
352 Type resultType = adaptor.getLhs().getType();
363 if (!isa<LLVM::LLVMArrayType>(resultType)) {
365 TypedAttr shiftValAttr;
367 if (
auto intTy = dyn_cast<IntegerType>(resultType)) {
368 unsigned resultBitwidth = intTy.getWidth();
372 auto vecTy = cast<VectorType>(resultType);
373 unsigned resultBitwidth = vecTy.getElementTypeBitWidth();
377 attrTy, APInt(resultBitwidth * 2, resultBitwidth));
379 Type wideType = shiftValAttr.getType();
381 "LLVM dialect should support all signless integer types");
383 using LLVMExtOp = std::conditional_t<IsSigned, LLVM::SExtOp, LLVM::ZExtOp>;
384 Value lhsExt = rewriter.
create<LLVMExtOp>(loc, wideType, adaptor.getLhs());
385 Value rhsExt = rewriter.
create<LLVMExtOp>(loc, wideType, adaptor.getRhs());
386 Value mulExt = rewriter.
create<LLVM::MulOp>(loc, wideType, lhsExt, rhsExt);
389 Value low = rewriter.
create<LLVM::TruncOp>(loc, resultType, mulExt);
390 Value shiftVal = rewriter.
create<LLVM::ConstantOp>(loc, shiftValAttr);
391 Value highExt = rewriter.
create<LLVM::LShrOp>(loc, mulExt, shiftVal);
392 Value high = rewriter.
create<LLVM::TruncOp>(loc, resultType, highExt);
398 if (!isa<VectorType>(resultType))
402 "ND vector types are not supported yet");
411 template <
typename LLVMPredType,
typename PredType>
413 return static_cast<LLVMPredType
>(pred);
417 CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
419 Type operandType = adaptor.getLhs().getType();
420 Type resultType = op.getResult().getType();
423 if (!isa<LLVM::LLVMArrayType>(operandType)) {
425 op, typeConverter->convertType(resultType),
426 convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()),
427 adaptor.getLhs(), adaptor.getRhs());
431 if (!isa<VectorType>(resultType))
435 op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
437 OpAdaptor adaptor(operands);
438 return rewriter.
create<LLVM::ICmpOp>(
439 op.getLoc(), llvm1DVectorTy,
440 convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()),
441 adaptor.getLhs(), adaptor.getRhs());
451 CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
453 Type operandType = adaptor.getLhs().getType();
454 Type resultType = op.getResult().getType();
455 LLVM::FastmathFlags fmf =
459 if (!isa<LLVM::LLVMArrayType>(operandType)) {
461 op, typeConverter->convertType(resultType),
462 convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()),
463 adaptor.getLhs(), adaptor.getRhs(), fmf);
467 if (!isa<VectorType>(resultType))
471 op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
473 OpAdaptor adaptor(operands);
474 return rewriter.
create<LLVM::FCmpOp>(
475 op.getLoc(), llvm1DVectorTy,
476 convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()),
477 adaptor.getLhs(), adaptor.getRhs(), fmf);
487 struct ArithToLLVMConversionPass
488 :
public impl::ArithToLLVMConversionPassBase<ArithToLLVMConversionPass> {
491 void runOnOperation()
override {
497 options.overrideIndexBitwidth(indexBitwidth);
519 context->loadDialect<LLVM::LLVMDialect>();
536 dialect->addInterfaces<ArithToLLVMDialectInterface>();
557 AddUIExtendedOpLowering,
570 IndexCastOpSILowering,
571 IndexCastOpUILowering,
582 MulSIExtendedOpLowering,
583 MulUIExtendedOpLowering,
597 ConstrainedTruncFOpLowering,
static LLVMPredType convertCmpPredicate(PredType pred)
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static Value handleMultidimensionalVectors(ImplicitLocOpBuilder &builder, ValueRange operands, int64_t vectorWidth, llvm::function_ref< Value(ValueRange)> compute)
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Base class for dialect interfaces providing translation to LLVM IR.
virtual void populateConvertToLLVMConversionPatterns(ConversionTarget &target, LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) const =0
Hook for derived dialect interface to provide conversion patterns and mark dialect legal for the conv...
virtual void loadDependentDialects(MLIRContext *context) const
Hook for derived dialect interface to load the dialects they target.
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Basic lowering implementation to rewrite Ops with just one result to the LLVM Dialect.
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, IntegerOverflowFlags overflowFlags=IntegerOverflowFlags::none)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands.
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns)
Add patterns to expand Arith ceil/floor division ops.
void populateArithToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
void registerConvertArithToLLVMInterface(DialectRegistry ®istry)
LLVM::FastmathFlags convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF)
Maps arithmetic fastmath enum values to LLVM enum values.
Include the generated interface declarations.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.