21 #include <type_traits>
24 #define GEN_PASS_DEF_ARITHTOLLVMCONVERSIONPASS
25 #include "mlir/Conversion/Passes.h.inc"
38 template <
typename SourceOp,
typename TargetOp,
bool Constrained,
39 template <
typename,
typename>
typename AttrConvert =
41 struct ConstrainedVectorConvertToLLVMPattern
47 matchAndRewrite(SourceOp op,
typename SourceOp::Adaptor adaptor,
49 if (Constrained !=
static_cast<bool>(op.getRoundingModeAttr()))
52 AttrConvert>::matchAndRewrite(op, adaptor,
59 struct IdentityBitcastLowering final
64 matchAndRewrite(arith::BitcastOp op, OpAdaptor adaptor,
66 Value src = adaptor.getIn();
67 Type resultType = getTypeConverter()->convertType(op.getType());
68 if (src.
getType() != resultType)
69 return rewriter.notifyMatchFailure(op,
"Types are different");
71 rewriter.replaceOp(op, src);
80 using AddFOpLowering =
82 arith::AttrConvertFastMathToLLVM>;
83 using AddIOpLowering =
85 arith::AttrConvertOverflowToLLVM>;
87 using BitcastOpLowering =
89 using DivFOpLowering =
91 arith::AttrConvertFastMathToLLVM>;
92 using DivSIOpLowering =
94 using DivUIOpLowering =
97 using ExtSIOpLowering =
99 using ExtUIOpLowering =
101 using FPToSIOpLowering =
103 using FPToUIOpLowering =
105 using MaximumFOpLowering =
107 arith::AttrConvertFastMathToLLVM>;
108 using MaxNumFOpLowering =
110 arith::AttrConvertFastMathToLLVM>;
111 using MaxSIOpLowering =
113 using MaxUIOpLowering =
115 using MinimumFOpLowering =
117 arith::AttrConvertFastMathToLLVM>;
118 using MinNumFOpLowering =
120 arith::AttrConvertFastMathToLLVM>;
121 using MinSIOpLowering =
123 using MinUIOpLowering =
125 using MulFOpLowering =
127 arith::AttrConvertFastMathToLLVM>;
128 using MulIOpLowering =
130 arith::AttrConvertOverflowToLLVM>;
131 using NegFOpLowering =
133 arith::AttrConvertFastMathToLLVM>;
135 using RemFOpLowering =
137 arith::AttrConvertFastMathToLLVM>;
138 using RemSIOpLowering =
140 using RemUIOpLowering =
142 using SelectOpLowering =
144 using ShLIOpLowering =
146 arith::AttrConvertOverflowToLLVM>;
147 using ShRSIOpLowering =
149 using ShRUIOpLowering =
151 using SIToFPOpLowering =
153 using SubFOpLowering =
155 arith::AttrConvertFastMathToLLVM>;
156 using SubIOpLowering =
158 arith::AttrConvertOverflowToLLVM>;
159 using TruncFOpLowering =
160 ConstrainedVectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp,
162 using ConstrainedTruncFOpLowering = ConstrainedVectorConvertToLLVMPattern<
163 arith::TruncFOp, LLVM::ConstrainedFPTruncIntr,
true,
164 arith::AttrConverterConstrainedFPToLLVM>;
165 using TruncIOpLowering =
167 arith::AttrConvertOverflowToLLVM>;
168 using UIToFPOpLowering =
181 matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
189 template <
typename OpTy,
typename ExtCastTy>
194 matchAndRewrite(OpTy op,
typename OpTy::Adaptor adaptor,
198 using IndexCastOpSILowering =
199 IndexCastOpLowering<arith::IndexCastOp, LLVM::SExtOp>;
200 using IndexCastOpUILowering =
201 IndexCastOpLowering<arith::IndexCastUIOp, LLVM::ZExtOp>;
203 struct AddUIExtendedOpLowering
208 matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor,
212 template <
typename ArithMulOp,
bool IsSigned>
217 matchAndRewrite(ArithMulOp op,
typename ArithMulOp::Adaptor adaptor,
221 using MulSIExtendedOpLowering =
222 MulIExtendedOpLowering<arith::MulSIExtendedOp, true>;
223 using MulUIExtendedOpLowering =
224 MulIExtendedOpLowering<arith::MulUIExtendedOp, false>;
230 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
238 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
249 ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
252 adaptor.getOperands(), op->getAttrs(),
253 *getTypeConverter(), rewriter);
260 template <
typename OpTy,
typename ExtCastTy>
261 LogicalResult IndexCastOpLowering<OpTy, ExtCastTy>::matchAndRewrite(
262 OpTy op,
typename OpTy::Adaptor adaptor,
264 Type resultType = op.getResult().getType();
265 Type targetElementType =
267 Type sourceElementType =
272 if (targetBits == sourceBits) {
278 Type operandType = adaptor.getIn().getType();
279 if (!isa<LLVM::LLVMArrayType>(operandType)) {
280 Type targetType = this->typeConverter->convertType(resultType);
281 if (targetBits < sourceBits)
289 if (!isa<VectorType>(resultType))
293 op.getOperation(), adaptor.getOperands(), *(this->getTypeConverter()),
295 typename OpTy::Adaptor adaptor(operands);
296 if (targetBits < sourceBits) {
297 return rewriter.create<LLVM::TruncOp>(op.getLoc(), llvm1DVectorTy,
300 return rewriter.
create<ExtCastTy>(op.getLoc(), llvm1DVectorTy,
310 LogicalResult AddUIExtendedOpLowering::matchAndRewrite(
311 arith::AddUIExtendedOp op, OpAdaptor adaptor,
313 Type operandType = adaptor.getLhs().getType();
314 Type sumResultType = op.getSum().getType();
315 Type overflowResultType = op.getOverflow().getType();
324 if (!isa<LLVM::LLVMArrayType>(operandType)) {
325 Type newOverflowType = typeConverter->convertType(overflowResultType);
327 LLVM::LLVMStructType::getLiteral(ctx, {sumResultType, newOverflowType});
328 Value addOverflow = rewriter.
create<LLVM::UAddWithOverflowOp>(
329 loc, structType, adaptor.getLhs(), adaptor.getRhs());
331 rewriter.
create<LLVM::ExtractValueOp>(loc, addOverflow, 0);
332 Value overflowExtracted =
333 rewriter.
create<LLVM::ExtractValueOp>(loc, addOverflow, 1);
334 rewriter.
replaceOp(op, {sumExtracted, overflowExtracted});
338 if (!isa<VectorType>(sumResultType))
342 "ND vector types are not supported yet");
349 template <
typename ArithMulOp,
bool IsSigned>
350 LogicalResult MulIExtendedOpLowering<ArithMulOp, IsSigned>::matchAndRewrite(
351 ArithMulOp op,
typename ArithMulOp::Adaptor adaptor,
353 Type resultType = adaptor.getLhs().getType();
364 if (!isa<LLVM::LLVMArrayType>(resultType)) {
366 TypedAttr shiftValAttr;
368 if (
auto intTy = dyn_cast<IntegerType>(resultType)) {
369 unsigned resultBitwidth = intTy.getWidth();
373 auto vecTy = cast<VectorType>(resultType);
374 unsigned resultBitwidth = vecTy.getElementTypeBitWidth();
378 attrTy, APInt(resultBitwidth * 2, resultBitwidth));
380 Type wideType = shiftValAttr.getType();
382 "LLVM dialect should support all signless integer types");
384 using LLVMExtOp = std::conditional_t<IsSigned, LLVM::SExtOp, LLVM::ZExtOp>;
385 Value lhsExt = rewriter.
create<LLVMExtOp>(loc, wideType, adaptor.getLhs());
386 Value rhsExt = rewriter.
create<LLVMExtOp>(loc, wideType, adaptor.getRhs());
387 Value mulExt = rewriter.
create<LLVM::MulOp>(loc, wideType, lhsExt, rhsExt);
390 Value low = rewriter.
create<LLVM::TruncOp>(loc, resultType, mulExt);
391 Value shiftVal = rewriter.
create<LLVM::ConstantOp>(loc, shiftValAttr);
392 Value highExt = rewriter.
create<LLVM::LShrOp>(loc, mulExt, shiftVal);
393 Value high = rewriter.
create<LLVM::TruncOp>(loc, resultType, highExt);
399 if (!isa<VectorType>(resultType))
403 "ND vector types are not supported yet");
412 template <
typename LLVMPredType,
typename PredType>
414 return static_cast<LLVMPredType
>(pred);
418 CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
420 Type operandType = adaptor.getLhs().getType();
421 Type resultType = op.getResult().getType();
424 if (!isa<LLVM::LLVMArrayType>(operandType)) {
426 op, typeConverter->convertType(resultType),
427 convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()),
428 adaptor.getLhs(), adaptor.getRhs());
432 if (!isa<VectorType>(resultType))
436 op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
438 OpAdaptor adaptor(operands);
439 return rewriter.
create<LLVM::ICmpOp>(
440 op.getLoc(), llvm1DVectorTy,
441 convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()),
442 adaptor.getLhs(), adaptor.getRhs());
452 CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
454 Type operandType = adaptor.getLhs().getType();
455 Type resultType = op.getResult().getType();
456 LLVM::FastmathFlags fmf =
460 if (!isa<LLVM::LLVMArrayType>(operandType)) {
462 op, typeConverter->convertType(resultType),
463 convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()),
464 adaptor.getLhs(), adaptor.getRhs(), fmf);
468 if (!isa<VectorType>(resultType))
472 op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
474 OpAdaptor adaptor(operands);
475 return rewriter.
create<LLVM::FCmpOp>(
476 op.getLoc(), llvm1DVectorTy,
477 convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()),
478 adaptor.getLhs(), adaptor.getRhs(), fmf);
488 struct ArithToLLVMConversionPass
489 :
public impl::ArithToLLVMConversionPassBase<ArithToLLVMConversionPass> {
492 void runOnOperation()
override {
498 options.overrideIndexBitwidth(indexBitwidth);
520 context->loadDialect<LLVM::LLVMDialect>();
537 dialect->addInterfaces<ArithToLLVMDialectInterface>();
558 AddUIExtendedOpLowering,
571 IndexCastOpSILowering,
572 IndexCastOpUILowering,
583 MulSIExtendedOpLowering,
584 MulUIExtendedOpLowering,
598 ConstrainedTruncFOpLowering,
static LLVMPredType convertCmpPredicate(PredType pred)
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static Value handleMultidimensionalVectors(ImplicitLocOpBuilder &builder, ValueRange operands, int64_t vectorWidth, llvm::function_ref< Value(ValueRange)> compute)
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Base class for dialect interfaces providing translation to LLVM IR.
virtual void populateConvertToLLVMConversionPatterns(ConversionTarget &target, LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) const =0
Hook for derived dialect interface to provide conversion patterns and mark dialect legal for the conv...
virtual void loadDependentDialects(MLIRContext *context) const
Hook for derived dialect interface to load the dialects they target.
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Basic lowering implementation to rewrite Ops with just one result to the LLVM Dialect.
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, IntegerOverflowFlags overflowFlags=IntegerOverflowFlags::none)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands.
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns)
Add patterns to expand Arith ceil/floor division ops.
void populateArithToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
void registerConvertArithToLLVMInterface(DialectRegistry ®istry)
LLVM::FastmathFlags convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF)
Maps arithmetic fastmath enum values to LLVM enum values.
Include the generated interface declarations.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.