21 #include "llvm/ADT/FloatingPointMode.h"
24 #define GEN_PASS_DEF_CONVERTMATHTOLLVMPASS
25 #include "mlir/Conversion/Passes.h.inc"
32 template <
typename SourceOp,
typename TargetOp>
33 using ConvertFastMath = arith::AttrConvertFastMathToLLVM<SourceOp, TargetOp>;
35 template <
typename SourceOp,
typename TargetOp>
36 using ConvertFMFMathToLLVMPattern =
39 using AbsFOpLowering = ConvertFMFMathToLLVMPattern<math::AbsFOp, LLVM::FAbsOp>;
40 using CeilOpLowering = ConvertFMFMathToLLVMPattern<math::CeilOp, LLVM::FCeilOp>;
41 using CopySignOpLowering =
42 ConvertFMFMathToLLVMPattern<math::CopySignOp, LLVM::CopySignOp>;
43 using CosOpLowering = ConvertFMFMathToLLVMPattern<math::CosOp, LLVM::CosOp>;
44 using CoshOpLowering = ConvertFMFMathToLLVMPattern<math::CoshOp, LLVM::CoshOp>;
45 using AcosOpLowering = ConvertFMFMathToLLVMPattern<math::AcosOp, LLVM::ACosOp>;
46 using CtPopFOpLowering =
48 using Exp2OpLowering = ConvertFMFMathToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
49 using ExpOpLowering = ConvertFMFMathToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
50 using FloorOpLowering =
51 ConvertFMFMathToLLVMPattern<math::FloorOp, LLVM::FFloorOp>;
52 using FmaOpLowering = ConvertFMFMathToLLVMPattern<math::FmaOp, LLVM::FMAOp>;
53 using Log10OpLowering =
54 ConvertFMFMathToLLVMPattern<math::Log10Op, LLVM::Log10Op>;
55 using Log2OpLowering = ConvertFMFMathToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
56 using LogOpLowering = ConvertFMFMathToLLVMPattern<math::LogOp, LLVM::LogOp>;
57 using PowFOpLowering = ConvertFMFMathToLLVMPattern<math::PowFOp, LLVM::PowOp>;
58 using FPowIOpLowering =
59 ConvertFMFMathToLLVMPattern<math::FPowIOp, LLVM::PowIOp>;
60 using RoundEvenOpLowering =
61 ConvertFMFMathToLLVMPattern<math::RoundEvenOp, LLVM::RoundEvenOp>;
62 using RoundOpLowering =
63 ConvertFMFMathToLLVMPattern<math::RoundOp, LLVM::RoundOp>;
64 using SinOpLowering = ConvertFMFMathToLLVMPattern<math::SinOp, LLVM::SinOp>;
65 using SinhOpLowering = ConvertFMFMathToLLVMPattern<math::SinhOp, LLVM::SinhOp>;
66 using ASinOpLowering = ConvertFMFMathToLLVMPattern<math::AsinOp, LLVM::ASinOp>;
67 using SqrtOpLowering = ConvertFMFMathToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
68 using FTruncOpLowering =
69 ConvertFMFMathToLLVMPattern<math::TruncOp, LLVM::FTruncOp>;
70 using TanOpLowering = ConvertFMFMathToLLVMPattern<math::TanOp, LLVM::TanOp>;
71 using TanhOpLowering = ConvertFMFMathToLLVMPattern<math::TanhOp, LLVM::TanhOp>;
72 using ATanOpLowering = ConvertFMFMathToLLVMPattern<math::AtanOp, LLVM::ATanOp>;
73 using ATan2OpLowering =
74 ConvertFMFMathToLLVMPattern<math::Atan2Op, LLVM::ATan2Op>;
78 template <
typename MathOp,
typename LLVMOp>
81 using Super = IntOpWithFlagLowering<MathOp, LLVMOp>;
84 matchAndRewrite(MathOp op,
typename MathOp::Adaptor adaptor,
86 const auto &typeConverter = *this->getTypeConverter();
87 auto operandType = adaptor.getOperand().getType();
88 auto llvmOperandType = typeConverter.convertType(operandType);
92 auto loc = op.getLoc();
93 auto resultType = op.getResult().getType();
94 auto llvmResultType = typeConverter.convertType(resultType);
98 if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
100 adaptor.getOperand(),
false);
104 if (!isa<VectorType>(llvmResultType))
108 op.getOperation(), adaptor.getOperands(), typeConverter,
110 return LLVMOp::create(rewriter, loc, llvm1DVectorTy, operands[0],
117 using CountLeadingZerosOpLowering =
118 IntOpWithFlagLowering<math::CountLeadingZerosOp, LLVM::CountLeadingZerosOp>;
119 using CountTrailingZerosOpLowering =
120 IntOpWithFlagLowering<math::CountTrailingZerosOp,
121 LLVM::CountTrailingZerosOp>;
122 using AbsIOpLowering = IntOpWithFlagLowering<math::AbsIOp, LLVM::AbsOp>;
129 matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor,
131 const auto &typeConverter = *this->getTypeConverter();
132 auto operandType = adaptor.getOperand().getType();
133 auto llvmOperandType = typeConverter.convertType(operandType);
134 if (!llvmOperandType)
137 auto loc = op.getLoc();
138 auto resultType = op.getResult().getType();
139 auto floatType = cast<FloatType>(
142 ConvertFastMath<math::ExpM1Op, LLVM::ExpOp> expAttrs(op);
143 ConvertFastMath<math::ExpM1Op, LLVM::FSubOp> subAttrs(op);
145 if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
146 LLVM::ConstantOp one;
148 one = LLVM::ConstantOp::create(
149 rewriter, loc, llvmOperandType,
154 LLVM::ConstantOp::create(rewriter, loc, llvmOperandType, floatOne);
156 auto exp = LLVM::ExpOp::create(rewriter, loc, adaptor.getOperand(),
157 expAttrs.getAttrs());
159 op, llvmOperandType,
ValueRange{exp, one}, subAttrs.getAttrs());
163 if (!isa<VectorType>(resultType))
167 op.getOperation(), adaptor.getOperands(), typeConverter,
169 auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
170 auto splatAttr = SplatElementsAttr::get(
171 mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
172 {numElements.isScalable()}),
174 auto one = LLVM::ConstantOp::create(rewriter, loc, llvm1DVectorTy,
176 auto exp = LLVM::ExpOp::create(rewriter, loc, llvm1DVectorTy,
177 operands[0], expAttrs.getAttrs());
178 return LLVM::FSubOp::create(rewriter, loc, llvm1DVectorTy,
180 subAttrs.getAttrs());
191 matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor,
193 const auto &typeConverter = *this->getTypeConverter();
194 auto operandType = adaptor.getOperand().getType();
195 auto llvmOperandType = typeConverter.convertType(operandType);
196 if (!llvmOperandType)
199 auto loc = op.getLoc();
200 auto resultType = op.getResult().getType();
201 auto floatType = cast<FloatType>(
204 ConvertFastMath<math::Log1pOp, LLVM::FAddOp> addAttrs(op);
205 ConvertFastMath<math::Log1pOp, LLVM::LogOp> logAttrs(op);
207 if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
208 LLVM::ConstantOp one =
209 isa<VectorType>(llvmOperandType)
210 ? LLVM::ConstantOp::create(
211 rewriter, loc, llvmOperandType,
214 : LLVM::ConstantOp::create(rewriter, loc, llvmOperandType,
217 auto add = LLVM::FAddOp::create(rewriter, loc, llvmOperandType,
219 addAttrs.getAttrs());
225 if (!isa<VectorType>(resultType))
229 op.getOperation(), adaptor.getOperands(), typeConverter,
231 auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
232 auto splatAttr = SplatElementsAttr::get(
233 mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
234 {numElements.isScalable()}),
236 auto one = LLVM::ConstantOp::create(rewriter, loc, llvm1DVectorTy,
238 auto add = LLVM::FAddOp::create(rewriter, loc, llvm1DVectorTy,
240 addAttrs.getAttrs());
241 return LLVM::LogOp::create(rewriter, loc, llvm1DVectorTy,
256 auto operandType = adaptor.getOperand().getType();
257 auto llvmOperandType = typeConverter.convertType(operandType);
258 if (!llvmOperandType)
261 auto loc = op.getLoc();
262 auto resultType = op.getResult().getType();
263 auto floatType = cast<FloatType>(
266 ConvertFastMath<math::RsqrtOp, LLVM::SqrtOp> sqrtAttrs(op);
267 ConvertFastMath<math::RsqrtOp, LLVM::FDivOp> divAttrs(op);
269 if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
270 LLVM::ConstantOp one;
271 if (isa<VectorType>(llvmOperandType)) {
272 one = LLVM::ConstantOp::create(
273 rewriter, loc, llvmOperandType,
278 LLVM::ConstantOp::create(rewriter, loc, llvmOperandType, floatOne);
280 auto sqrt = LLVM::SqrtOp::create(rewriter, loc, adaptor.getOperand(),
281 sqrtAttrs.getAttrs());
283 op, llvmOperandType,
ValueRange{one, sqrt}, divAttrs.getAttrs());
287 if (!isa<VectorType>(resultType))
291 op.getOperation(), adaptor.getOperands(), typeConverter,
293 auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
294 auto splatAttr = SplatElementsAttr::get(
295 mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
296 {numElements.isScalable()}),
298 auto one = LLVM::ConstantOp::create(rewriter, loc, llvm1DVectorTy,
300 auto sqrt = LLVM::SqrtOp::create(rewriter, loc, llvm1DVectorTy,
301 operands[0], sqrtAttrs.getAttrs());
302 return LLVM::FDivOp::create(rewriter, loc, llvm1DVectorTy,
304 divAttrs.getAttrs());
318 typeConverter.convertType(adaptor.getOperand().getType());
319 auto resultType = typeConverter.convertType(op.getResult().getType());
320 if (!operandType || !resultType)
324 op, resultType, adaptor.getOperand(), llvm::fcNan);
337 typeConverter.convertType(adaptor.getOperand().getType());
338 auto resultType = typeConverter.convertType(op.getResult().getType());
339 if (!operandType || !resultType)
343 op, resultType, adaptor.getOperand(), llvm::fcFinite);
348 struct ConvertMathToLLVMPass
349 :
public impl::ConvertMathToLLVMPassBase<ConvertMathToLLVMPass> {
352 void runOnOperation()
override {
367 if (approximateLog1p)
368 patterns.add<Log1pOpLowering>(converter, benefit);
380 CountLeadingZerosOpLowering,
381 CountTrailingZerosOpLowering,
405 >(converter, benefit);
418 context->loadDialect<LLVM::LLVMDialect>();
433 dialect->addInterfaces<MathToLLVMDialectInterface>();
static MLIRContext * getContext(OpFoldResult val)
FloatAttr getFloatAttr(Type type, double value)
This class implements a pattern rewriter for use with ConversionPatterns.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Wrappers around the RewritePattern methods that pass the derived op type.
Base class for dialect interfaces providing translation to LLVM IR.
virtual void populateConvertToLLVMConversionPatterns(ConversionTarget &target, LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) const =0
Hook for derived dialect interface to provide conversion patterns and mark dialect legal for the conv...
virtual void loadDependentDialects(MLIRContext *context) const
Hook for derived dialect interface to load the dialects they target.
ConvertToLLVMPatternInterface(Dialect *dialect)
const LLVMTypeConverter * getTypeConverter() const
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
MLIRContext is the top-level object for a collection of MLIR operations.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
Basic lowering implementation to rewrite Ops with just one result to the LLVM Dialect.
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect.
Include the generated interface declarations.
void populateMathToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool approximateLog1p=true, PatternBenefit benefit=1)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
void registerConvertMathToLLVMInterface(DialectRegistry ®istry)