21 #include <type_traits>
24 #define GEN_PASS_DEF_ARITHTOLLVMCONVERSIONPASS
25 #include "mlir/Conversion/Passes.h.inc"
38 template <
typename SourceOp,
typename TargetOp,
bool Constrained,
39 template <
typename,
typename>
typename AttrConvert =
41 struct ConstrainedVectorConvertToLLVMPattern
47 matchAndRewrite(SourceOp op,
typename SourceOp::Adaptor adaptor,
49 if (Constrained !=
static_cast<bool>(op.getRoundingModeAttr()))
52 AttrConvert>::matchAndRewrite(op, adaptor,
61 using AddFOpLowering =
63 arith::AttrConvertFastMathToLLVM>;
64 using AddIOpLowering =
66 arith::AttrConvertOverflowToLLVM>;
68 using BitcastOpLowering =
70 using DivFOpLowering =
72 arith::AttrConvertFastMathToLLVM>;
73 using DivSIOpLowering =
75 using DivUIOpLowering =
78 using ExtSIOpLowering =
80 using ExtUIOpLowering =
82 using FPToSIOpLowering =
84 using FPToUIOpLowering =
86 using MaximumFOpLowering =
88 arith::AttrConvertFastMathToLLVM>;
89 using MaxNumFOpLowering =
91 arith::AttrConvertFastMathToLLVM>;
92 using MaxSIOpLowering =
94 using MaxUIOpLowering =
96 using MinimumFOpLowering =
98 arith::AttrConvertFastMathToLLVM>;
99 using MinNumFOpLowering =
101 arith::AttrConvertFastMathToLLVM>;
102 using MinSIOpLowering =
104 using MinUIOpLowering =
106 using MulFOpLowering =
108 arith::AttrConvertFastMathToLLVM>;
109 using MulIOpLowering =
111 arith::AttrConvertOverflowToLLVM>;
112 using NegFOpLowering =
114 arith::AttrConvertFastMathToLLVM>;
116 using RemFOpLowering =
118 arith::AttrConvertFastMathToLLVM>;
119 using RemSIOpLowering =
121 using RemUIOpLowering =
123 using SelectOpLowering =
125 using ShLIOpLowering =
127 arith::AttrConvertOverflowToLLVM>;
128 using ShRSIOpLowering =
130 using ShRUIOpLowering =
132 using SIToFPOpLowering =
134 using SubFOpLowering =
136 arith::AttrConvertFastMathToLLVM>;
137 using SubIOpLowering =
139 arith::AttrConvertOverflowToLLVM>;
140 using TruncFOpLowering =
141 ConstrainedVectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp,
143 using ConstrainedTruncFOpLowering = ConstrainedVectorConvertToLLVMPattern<
144 arith::TruncFOp, LLVM::ConstrainedFPTruncIntr,
true,
145 arith::AttrConverterConstrainedFPToLLVM>;
146 using TruncIOpLowering =
148 using UIToFPOpLowering =
161 matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
169 template <
typename OpTy,
typename ExtCastTy>
174 matchAndRewrite(OpTy op,
typename OpTy::Adaptor adaptor,
178 using IndexCastOpSILowering =
179 IndexCastOpLowering<arith::IndexCastOp, LLVM::SExtOp>;
180 using IndexCastOpUILowering =
181 IndexCastOpLowering<arith::IndexCastUIOp, LLVM::ZExtOp>;
183 struct AddUIExtendedOpLowering
188 matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor,
192 template <
typename ArithMulOp,
bool IsSigned>
197 matchAndRewrite(ArithMulOp op,
typename ArithMulOp::Adaptor adaptor,
201 using MulSIExtendedOpLowering =
202 MulIExtendedOpLowering<arith::MulSIExtendedOp, true>;
203 using MulUIExtendedOpLowering =
204 MulIExtendedOpLowering<arith::MulUIExtendedOp, false>;
210 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
218 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
229 ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
232 adaptor.getOperands(), op->getAttrs(),
233 *getTypeConverter(), rewriter);
240 template <
typename OpTy,
typename ExtCastTy>
241 LogicalResult IndexCastOpLowering<OpTy, ExtCastTy>::matchAndRewrite(
242 OpTy op,
typename OpTy::Adaptor adaptor,
244 Type resultType = op.getResult().getType();
245 Type targetElementType =
247 Type sourceElementType =
252 if (targetBits == sourceBits) {
258 Type operandType = adaptor.getIn().getType();
259 if (!isa<LLVM::LLVMArrayType>(operandType)) {
260 Type targetType = this->typeConverter->convertType(resultType);
261 if (targetBits < sourceBits)
269 if (!isa<VectorType>(resultType))
273 op.getOperation(), adaptor.getOperands(), *(this->getTypeConverter()),
275 typename OpTy::Adaptor adaptor(operands);
276 if (targetBits < sourceBits) {
277 return rewriter.create<LLVM::TruncOp>(op.getLoc(), llvm1DVectorTy,
280 return rewriter.
create<ExtCastTy>(op.getLoc(), llvm1DVectorTy,
290 LogicalResult AddUIExtendedOpLowering::matchAndRewrite(
291 arith::AddUIExtendedOp op, OpAdaptor adaptor,
293 Type operandType = adaptor.getLhs().getType();
294 Type sumResultType = op.getSum().getType();
295 Type overflowResultType = op.getOverflow().getType();
304 if (!isa<LLVM::LLVMArrayType>(operandType)) {
305 Type newOverflowType = typeConverter->convertType(overflowResultType);
307 LLVM::LLVMStructType::getLiteral(ctx, {sumResultType, newOverflowType});
308 Value addOverflow = rewriter.
create<LLVM::UAddWithOverflowOp>(
309 loc, structType, adaptor.getLhs(), adaptor.getRhs());
311 rewriter.
create<LLVM::ExtractValueOp>(loc, addOverflow, 0);
312 Value overflowExtracted =
313 rewriter.
create<LLVM::ExtractValueOp>(loc, addOverflow, 1);
314 rewriter.
replaceOp(op, {sumExtracted, overflowExtracted});
318 if (!isa<VectorType>(sumResultType))
322 "ND vector types are not supported yet");
329 template <
typename ArithMulOp,
bool IsSigned>
330 LogicalResult MulIExtendedOpLowering<ArithMulOp, IsSigned>::matchAndRewrite(
331 ArithMulOp op,
typename ArithMulOp::Adaptor adaptor,
333 Type resultType = adaptor.getLhs().getType();
344 if (!isa<LLVM::LLVMArrayType>(resultType)) {
346 TypedAttr shiftValAttr;
348 if (
auto intTy = dyn_cast<IntegerType>(resultType)) {
349 unsigned resultBitwidth = intTy.getWidth();
353 auto vecTy = cast<VectorType>(resultType);
354 unsigned resultBitwidth = vecTy.getElementTypeBitWidth();
358 attrTy, APInt(resultBitwidth * 2, resultBitwidth));
360 Type wideType = shiftValAttr.getType();
362 "LLVM dialect should support all signless integer types");
364 using LLVMExtOp = std::conditional_t<IsSigned, LLVM::SExtOp, LLVM::ZExtOp>;
365 Value lhsExt = rewriter.
create<LLVMExtOp>(loc, wideType, adaptor.getLhs());
366 Value rhsExt = rewriter.
create<LLVMExtOp>(loc, wideType, adaptor.getRhs());
367 Value mulExt = rewriter.
create<LLVM::MulOp>(loc, wideType, lhsExt, rhsExt);
370 Value low = rewriter.
create<LLVM::TruncOp>(loc, resultType, mulExt);
371 Value shiftVal = rewriter.
create<LLVM::ConstantOp>(loc, shiftValAttr);
372 Value highExt = rewriter.
create<LLVM::LShrOp>(loc, mulExt, shiftVal);
373 Value high = rewriter.
create<LLVM::TruncOp>(loc, resultType, highExt);
379 if (!isa<VectorType>(resultType))
383 "ND vector types are not supported yet");
392 template <
typename LLVMPredType,
typename PredType>
394 return static_cast<LLVMPredType
>(pred);
398 CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
400 Type operandType = adaptor.getLhs().getType();
401 Type resultType = op.getResult().getType();
404 if (!isa<LLVM::LLVMArrayType>(operandType)) {
406 op, typeConverter->convertType(resultType),
407 convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()),
408 adaptor.getLhs(), adaptor.getRhs());
412 if (!isa<VectorType>(resultType))
416 op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
418 OpAdaptor adaptor(operands);
419 return rewriter.
create<LLVM::ICmpOp>(
420 op.getLoc(), llvm1DVectorTy,
421 convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()),
422 adaptor.getLhs(), adaptor.getRhs());
432 CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
434 Type operandType = adaptor.getLhs().getType();
435 Type resultType = op.getResult().getType();
436 LLVM::FastmathFlags fmf =
440 if (!isa<LLVM::LLVMArrayType>(operandType)) {
442 op, typeConverter->convertType(resultType),
443 convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()),
444 adaptor.getLhs(), adaptor.getRhs(), fmf);
448 if (!isa<VectorType>(resultType))
452 op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
454 OpAdaptor adaptor(operands);
455 return rewriter.
create<LLVM::FCmpOp>(
456 op.getLoc(), llvm1DVectorTy,
457 convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()),
458 adaptor.getLhs(), adaptor.getRhs(), fmf);
468 struct ArithToLLVMConversionPass
469 :
public impl::ArithToLLVMConversionPassBase<ArithToLLVMConversionPass> {
472 void runOnOperation()
override {
478 options.overrideIndexBitwidth(indexBitwidth);
500 context->loadDialect<LLVM::LLVMDialect>();
517 dialect->addInterfaces<ArithToLLVMDialectInterface>();
532 AddUIExtendedOpLowering,
545 IndexCastOpSILowering,
546 IndexCastOpUILowering,
557 MulSIExtendedOpLowering,
558 MulUIExtendedOpLowering,
572 ConstrainedTruncFOpLowering,
static LLVMPredType convertCmpPredicate(PredType pred)
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static Value handleMultidimensionalVectors(ImplicitLocOpBuilder &builder, ValueRange operands, int64_t vectorWidth, llvm::function_ref< Value(ValueRange)> compute)
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Base class for dialect interfaces providing translation to LLVM IR.
virtual void populateConvertToLLVMConversionPatterns(ConversionTarget &target, LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) const =0
Hook for derived dialect interface to provide conversion patterns and mark dialect legal for the conv...
virtual void loadDependentDialects(MLIRContext *context) const
Hook for derived dialect interface to load the dialects they target.
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Basic lowering implementation to rewrite Ops with just one result to the LLVM Dialect.
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, IntegerOverflowFlags overflowFlags=IntegerOverflowFlags::none)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands.
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns)
Add patterns to expand Arith ceil/floor division ops.
void populateArithToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
void registerConvertArithToLLVMInterface(DialectRegistry ®istry)
LLVM::FastmathFlags convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF)
Maps arithmetic fastmath enum values to LLVM enum values.
Include the generated interface declarations.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.