20 #include <type_traits>
23 #define GEN_PASS_DEF_ARITHTOLLVMCONVERSIONPASS
24 #include "mlir/Conversion/Passes.h.inc"
37 template <
typename SourceOp,
typename TargetOp,
bool Constrained,
38 template <
typename,
typename>
typename AttrConvert =
40 struct ConstrainedVectorConvertToLLVMPattern
46 matchAndRewrite(SourceOp op,
typename SourceOp::Adaptor adaptor,
48 if (Constrained !=
static_cast<bool>(op.getRoundingModeAttr()))
51 AttrConvert>::matchAndRewrite(op, adaptor,
60 using AddFOpLowering =
62 arith::AttrConvertFastMathToLLVM>;
63 using AddIOpLowering =
65 arith::AttrConvertOverflowToLLVM>;
67 using BitcastOpLowering =
69 using DivFOpLowering =
71 arith::AttrConvertFastMathToLLVM>;
72 using DivSIOpLowering =
74 using DivUIOpLowering =
77 using ExtSIOpLowering =
79 using ExtUIOpLowering =
81 using FPToSIOpLowering =
83 using FPToUIOpLowering =
85 using MaximumFOpLowering =
87 arith::AttrConvertFastMathToLLVM>;
88 using MaxNumFOpLowering =
90 arith::AttrConvertFastMathToLLVM>;
91 using MaxSIOpLowering =
93 using MaxUIOpLowering =
95 using MinimumFOpLowering =
97 arith::AttrConvertFastMathToLLVM>;
98 using MinNumFOpLowering =
100 arith::AttrConvertFastMathToLLVM>;
101 using MinSIOpLowering =
103 using MinUIOpLowering =
105 using MulFOpLowering =
107 arith::AttrConvertFastMathToLLVM>;
108 using MulIOpLowering =
110 arith::AttrConvertOverflowToLLVM>;
111 using NegFOpLowering =
113 arith::AttrConvertFastMathToLLVM>;
115 using RemFOpLowering =
117 arith::AttrConvertFastMathToLLVM>;
118 using RemSIOpLowering =
120 using RemUIOpLowering =
122 using SelectOpLowering =
124 using ShLIOpLowering =
126 arith::AttrConvertOverflowToLLVM>;
127 using ShRSIOpLowering =
129 using ShRUIOpLowering =
131 using SIToFPOpLowering =
133 using SubFOpLowering =
135 arith::AttrConvertFastMathToLLVM>;
136 using SubIOpLowering =
138 arith::AttrConvertOverflowToLLVM>;
139 using TruncFOpLowering =
140 ConstrainedVectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp,
142 using ConstrainedTruncFOpLowering = ConstrainedVectorConvertToLLVMPattern<
143 arith::TruncFOp, LLVM::ConstrainedFPTruncIntr,
true,
144 arith::AttrConverterConstrainedFPToLLVM>;
145 using TruncIOpLowering =
147 using UIToFPOpLowering =
160 matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
168 template <
typename OpTy,
typename ExtCastTy>
173 matchAndRewrite(OpTy op,
typename OpTy::Adaptor adaptor,
177 using IndexCastOpSILowering =
178 IndexCastOpLowering<arith::IndexCastOp, LLVM::SExtOp>;
179 using IndexCastOpUILowering =
180 IndexCastOpLowering<arith::IndexCastUIOp, LLVM::ZExtOp>;
182 struct AddUIExtendedOpLowering
187 matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor,
191 template <
typename ArithMulOp,
bool IsSigned>
196 matchAndRewrite(ArithMulOp op,
typename ArithMulOp::Adaptor adaptor,
200 using MulSIExtendedOpLowering =
201 MulIExtendedOpLowering<arith::MulSIExtendedOp, true>;
202 using MulUIExtendedOpLowering =
203 MulIExtendedOpLowering<arith::MulUIExtendedOp, false>;
209 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
217 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
228 ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
231 adaptor.getOperands(), op->getAttrs(),
232 *getTypeConverter(), rewriter);
239 template <
typename OpTy,
typename ExtCastTy>
240 LogicalResult IndexCastOpLowering<OpTy, ExtCastTy>::matchAndRewrite(
241 OpTy op,
typename OpTy::Adaptor adaptor,
243 Type resultType = op.getResult().getType();
244 Type targetElementType =
246 Type sourceElementType =
251 if (targetBits == sourceBits) {
257 Type operandType = adaptor.getIn().getType();
258 if (!isa<LLVM::LLVMArrayType>(operandType)) {
259 Type targetType = this->typeConverter->convertType(resultType);
260 if (targetBits < sourceBits)
268 if (!isa<VectorType>(resultType))
272 op.getOperation(), adaptor.getOperands(), *(this->getTypeConverter()),
274 typename OpTy::Adaptor adaptor(operands);
275 if (targetBits < sourceBits) {
276 return rewriter.create<LLVM::TruncOp>(op.getLoc(), llvm1DVectorTy,
279 return rewriter.
create<ExtCastTy>(op.getLoc(), llvm1DVectorTy,
289 LogicalResult AddUIExtendedOpLowering::matchAndRewrite(
290 arith::AddUIExtendedOp op, OpAdaptor adaptor,
292 Type operandType = adaptor.getLhs().getType();
293 Type sumResultType = op.getSum().getType();
294 Type overflowResultType = op.getOverflow().getType();
303 if (!isa<LLVM::LLVMArrayType>(operandType)) {
304 Type newOverflowType = typeConverter->convertType(overflowResultType);
306 LLVM::LLVMStructType::getLiteral(ctx, {sumResultType, newOverflowType});
307 Value addOverflow = rewriter.
create<LLVM::UAddWithOverflowOp>(
308 loc, structType, adaptor.getLhs(), adaptor.getRhs());
310 rewriter.
create<LLVM::ExtractValueOp>(loc, addOverflow, 0);
311 Value overflowExtracted =
312 rewriter.
create<LLVM::ExtractValueOp>(loc, addOverflow, 1);
313 rewriter.
replaceOp(op, {sumExtracted, overflowExtracted});
317 if (!isa<VectorType>(sumResultType))
321 "ND vector types are not supported yet");
328 template <
typename ArithMulOp,
bool IsSigned>
329 LogicalResult MulIExtendedOpLowering<ArithMulOp, IsSigned>::matchAndRewrite(
330 ArithMulOp op,
typename ArithMulOp::Adaptor adaptor,
332 Type resultType = adaptor.getLhs().getType();
343 if (!isa<LLVM::LLVMArrayType>(resultType)) {
345 TypedAttr shiftValAttr;
347 if (
auto intTy = dyn_cast<IntegerType>(resultType)) {
348 unsigned resultBitwidth = intTy.getWidth();
352 auto vecTy = cast<VectorType>(resultType);
353 unsigned resultBitwidth = vecTy.getElementTypeBitWidth();
357 attrTy, APInt(resultBitwidth * 2, resultBitwidth));
359 Type wideType = shiftValAttr.getType();
361 "LLVM dialect should support all signless integer types");
363 using LLVMExtOp = std::conditional_t<IsSigned, LLVM::SExtOp, LLVM::ZExtOp>;
364 Value lhsExt = rewriter.
create<LLVMExtOp>(loc, wideType, adaptor.getLhs());
365 Value rhsExt = rewriter.
create<LLVMExtOp>(loc, wideType, adaptor.getRhs());
366 Value mulExt = rewriter.
create<LLVM::MulOp>(loc, wideType, lhsExt, rhsExt);
369 Value low = rewriter.
create<LLVM::TruncOp>(loc, resultType, mulExt);
370 Value shiftVal = rewriter.
create<LLVM::ConstantOp>(loc, shiftValAttr);
371 Value highExt = rewriter.
create<LLVM::LShrOp>(loc, mulExt, shiftVal);
372 Value high = rewriter.
create<LLVM::TruncOp>(loc, resultType, highExt);
378 if (!isa<VectorType>(resultType))
382 "ND vector types are not supported yet");
391 template <
typename LLVMPredType,
typename PredType>
393 return static_cast<LLVMPredType
>(pred);
397 CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
399 Type operandType = adaptor.getLhs().getType();
400 Type resultType = op.getResult().getType();
403 if (!isa<LLVM::LLVMArrayType>(operandType)) {
405 op, typeConverter->convertType(resultType),
406 convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()),
407 adaptor.getLhs(), adaptor.getRhs());
411 if (!isa<VectorType>(resultType))
415 op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
417 OpAdaptor adaptor(operands);
418 return rewriter.
create<LLVM::ICmpOp>(
419 op.getLoc(), llvm1DVectorTy,
420 convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()),
421 adaptor.getLhs(), adaptor.getRhs());
431 CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
433 Type operandType = adaptor.getLhs().getType();
434 Type resultType = op.getResult().getType();
435 LLVM::FastmathFlags fmf =
439 if (!isa<LLVM::LLVMArrayType>(operandType)) {
441 op, typeConverter->convertType(resultType),
442 convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()),
443 adaptor.getLhs(), adaptor.getRhs(), fmf);
447 if (!isa<VectorType>(resultType))
451 op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
453 OpAdaptor adaptor(operands);
454 return rewriter.
create<LLVM::FCmpOp>(
455 op.getLoc(), llvm1DVectorTy,
456 convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()),
457 adaptor.getLhs(), adaptor.getRhs(), fmf);
467 struct ArithToLLVMConversionPass
468 :
public impl::ArithToLLVMConversionPassBase<ArithToLLVMConversionPass> {
471 void runOnOperation()
override {
477 options.overrideIndexBitwidth(indexBitwidth);
483 std::move(patterns))))
498 context->loadDialect<LLVM::LLVMDialect>();
514 dialect->addInterfaces<ArithToLLVMDialectInterface>();
529 AddUIExtendedOpLowering,
542 IndexCastOpSILowering,
543 IndexCastOpUILowering,
554 MulSIExtendedOpLowering,
555 MulUIExtendedOpLowering,
569 ConstrainedTruncFOpLowering,
static LLVMPredType convertCmpPredicate(PredType pred)
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static Value handleMultidimensionalVectors(ImplicitLocOpBuilder &builder, ValueRange operands, int64_t vectorWidth, llvm::function_ref< Value(ValueRange)> compute)
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Base class for dialect interfaces providing translation to LLVM IR.
virtual void populateConvertToLLVMConversionPatterns(ConversionTarget &target, LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) const =0
Hook for derived dialect interface to provide conversion patterns and mark dialect legal for the conv...
virtual void loadDependentDialects(MLIRContext *context) const
Hook for derived dialect interface to load the dialects they target.
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Basic lowering implementation to rewrite Ops with just one result to the LLVM Dialect.
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, IntegerOverflowFlags overflowFlags=IntegerOverflowFlags::none)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands.
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
void populateArithToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
void registerConvertArithToLLVMInterface(DialectRegistry ®istry)
LLVM::FastmathFlags convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF)
Maps arithmetic fastmath enum values to LLVM enum values.
Include the generated interface declarations.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.