21 #include "llvm/ADT/FloatingPointMode.h"
24 #define GEN_PASS_DEF_CONVERTMATHTOLLVMPASS
25 #include "mlir/Conversion/Passes.h.inc"
32 template <
typename SourceOp,
typename TargetOp>
33 using ConvertFastMath = arith::AttrConvertFastMathToLLVM<SourceOp, TargetOp>;
35 template <
typename SourceOp,
typename TargetOp>
36 using ConvertFMFMathToLLVMPattern =
39 using AbsFOpLowering = ConvertFMFMathToLLVMPattern<math::AbsFOp, LLVM::FAbsOp>;
40 using CeilOpLowering = ConvertFMFMathToLLVMPattern<math::CeilOp, LLVM::FCeilOp>;
41 using CopySignOpLowering =
42 ConvertFMFMathToLLVMPattern<math::CopySignOp, LLVM::CopySignOp>;
43 using CosOpLowering = ConvertFMFMathToLLVMPattern<math::CosOp, LLVM::CosOp>;
44 using CoshOpLowering = ConvertFMFMathToLLVMPattern<math::CoshOp, LLVM::CoshOp>;
45 using CtPopFOpLowering =
47 using Exp2OpLowering = ConvertFMFMathToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
48 using ExpOpLowering = ConvertFMFMathToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
49 using FloorOpLowering =
50 ConvertFMFMathToLLVMPattern<math::FloorOp, LLVM::FFloorOp>;
51 using FmaOpLowering = ConvertFMFMathToLLVMPattern<math::FmaOp, LLVM::FMAOp>;
52 using Log10OpLowering =
53 ConvertFMFMathToLLVMPattern<math::Log10Op, LLVM::Log10Op>;
54 using Log2OpLowering = ConvertFMFMathToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
55 using LogOpLowering = ConvertFMFMathToLLVMPattern<math::LogOp, LLVM::LogOp>;
56 using PowFOpLowering = ConvertFMFMathToLLVMPattern<math::PowFOp, LLVM::PowOp>;
57 using FPowIOpLowering =
58 ConvertFMFMathToLLVMPattern<math::FPowIOp, LLVM::PowIOp>;
59 using RoundEvenOpLowering =
60 ConvertFMFMathToLLVMPattern<math::RoundEvenOp, LLVM::RoundEvenOp>;
61 using RoundOpLowering =
62 ConvertFMFMathToLLVMPattern<math::RoundOp, LLVM::RoundOp>;
63 using SinOpLowering = ConvertFMFMathToLLVMPattern<math::SinOp, LLVM::SinOp>;
64 using SinhOpLowering = ConvertFMFMathToLLVMPattern<math::SinhOp, LLVM::SinhOp>;
65 using SqrtOpLowering = ConvertFMFMathToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
66 using FTruncOpLowering =
67 ConvertFMFMathToLLVMPattern<math::TruncOp, LLVM::FTruncOp>;
68 using TanOpLowering = ConvertFMFMathToLLVMPattern<math::TanOp, LLVM::TanOp>;
69 using TanhOpLowering = ConvertFMFMathToLLVMPattern<math::TanhOp, LLVM::TanhOp>;
72 template <
typename MathOp,
typename LLVMOp>
75 using Super = IntOpWithFlagLowering<MathOp, LLVMOp>;
78 matchAndRewrite(MathOp op,
typename MathOp::Adaptor adaptor,
80 auto operandType = adaptor.getOperand().getType();
85 auto loc = op.getLoc();
86 auto resultType = op.getResult().getType();
88 if (!isa<LLVM::LLVMArrayType>(operandType)) {
94 auto vectorType = dyn_cast<VectorType>(resultType);
99 op.getOperation(), adaptor.getOperands(), *this->getTypeConverter(),
101 return rewriter.create<LLVMOp>(loc, llvm1DVectorTy, operands[0],
108 using CountLeadingZerosOpLowering =
109 IntOpWithFlagLowering<math::CountLeadingZerosOp, LLVM::CountLeadingZerosOp>;
110 using CountTrailingZerosOpLowering =
111 IntOpWithFlagLowering<math::CountTrailingZerosOp,
112 LLVM::CountTrailingZerosOp>;
113 using AbsIOpLowering = IntOpWithFlagLowering<math::AbsIOp, LLVM::AbsOp>;
120 matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor,
122 auto operandType = adaptor.getOperand().getType();
127 auto loc = op.getLoc();
128 auto resultType = op.getResult().getType();
131 ConvertFastMath<math::ExpM1Op, LLVM::ExpOp> expAttrs(op);
132 ConvertFastMath<math::ExpM1Op, LLVM::FSubOp> subAttrs(op);
134 if (!isa<LLVM::LLVMArrayType>(operandType)) {
135 LLVM::ConstantOp one;
137 one = rewriter.
create<LLVM::ConstantOp>(
141 one = rewriter.
create<LLVM::ConstantOp>(loc, operandType, floatOne);
143 auto exp = rewriter.
create<LLVM::ExpOp>(loc, adaptor.getOperand(),
144 expAttrs.getAttrs());
146 op, operandType,
ValueRange{exp, one}, subAttrs.getAttrs());
150 auto vectorType = dyn_cast<VectorType>(resultType);
155 op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
160 {numElements.isScalable()}),
163 rewriter.
create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
164 auto exp = rewriter.
create<LLVM::ExpOp>(
165 loc, llvm1DVectorTy, operands[0], expAttrs.getAttrs());
166 return rewriter.
create<LLVM::FSubOp>(
167 loc, llvm1DVectorTy,
ValueRange{exp, one}, subAttrs.getAttrs());
178 matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor,
180 auto operandType = adaptor.getOperand().getType();
185 auto loc = op.getLoc();
186 auto resultType = op.getResult().getType();
189 ConvertFastMath<math::Log1pOp, LLVM::FAddOp> addAttrs(op);
190 ConvertFastMath<math::Log1pOp, LLVM::LogOp> logAttrs(op);
192 if (!isa<LLVM::LLVMArrayType>(operandType)) {
193 LLVM::ConstantOp one =
195 ? rewriter.
create<LLVM::ConstantOp>(
199 : rewriter.
create<LLVM::ConstantOp>(loc, operandType, floatOne);
201 auto add = rewriter.
create<LLVM::FAddOp>(
202 loc, operandType,
ValueRange{one, adaptor.getOperand()},
203 addAttrs.getAttrs());
205 logAttrs.getAttrs());
209 auto vectorType = dyn_cast<VectorType>(resultType);
214 op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
219 {numElements.isScalable()}),
222 rewriter.
create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
223 auto add = rewriter.
create<LLVM::FAddOp>(loc, llvm1DVectorTy,
225 addAttrs.getAttrs());
226 return rewriter.
create<LLVM::LogOp>(
227 loc, llvm1DVectorTy,
ValueRange{add}, logAttrs.getAttrs());
238 matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor,
240 auto operandType = adaptor.getOperand().getType();
245 auto loc = op.getLoc();
246 auto resultType = op.getResult().getType();
249 ConvertFastMath<math::RsqrtOp, LLVM::SqrtOp> sqrtAttrs(op);
250 ConvertFastMath<math::RsqrtOp, LLVM::FDivOp> divAttrs(op);
252 if (!isa<LLVM::LLVMArrayType>(operandType)) {
253 LLVM::ConstantOp one;
255 one = rewriter.
create<LLVM::ConstantOp>(
259 one = rewriter.
create<LLVM::ConstantOp>(loc, operandType, floatOne);
261 auto sqrt = rewriter.
create<LLVM::SqrtOp>(loc, adaptor.getOperand(),
262 sqrtAttrs.getAttrs());
264 op, operandType,
ValueRange{one, sqrt}, divAttrs.getAttrs());
268 auto vectorType = dyn_cast<VectorType>(resultType);
273 op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
278 {numElements.isScalable()}),
281 rewriter.
create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
282 auto sqrt = rewriter.
create<LLVM::SqrtOp>(
283 loc, llvm1DVectorTy, operands[0], sqrtAttrs.getAttrs());
284 return rewriter.
create<LLVM::FDivOp>(
285 loc, llvm1DVectorTy,
ValueRange{one, sqrt}, divAttrs.getAttrs());
295 matchAndRewrite(math::IsNaNOp op, OpAdaptor adaptor,
297 auto operandType = adaptor.getOperand().getType();
303 op, op.getType(), adaptor.getOperand(), llvm::fcNan);
312 matchAndRewrite(math::IsFiniteOp op, OpAdaptor adaptor,
314 auto operandType = adaptor.getOperand().getType();
320 op, op.getType(), adaptor.getOperand(), llvm::fcFinite);
325 struct ConvertMathToLLVMPass
326 :
public impl::ConvertMathToLLVMPassBase<ConvertMathToLLVMPass> {
329 void runOnOperation()
override {
344 if (approximateLog1p)
345 patterns.add<Log1pOpLowering>(converter, benefit);
356 CountLeadingZerosOpLowering,
357 CountTrailingZerosOpLowering,
378 >(converter, benefit);
390 void loadDependentDialects(
MLIRContext *context)
const final {
391 context->loadDialect<LLVM::LLVMDialect>();
396 void populateConvertToLLVMConversionPatterns(
406 dialect->addInterfaces<MathToLLVMDialectInterface>();
static MLIRContext * getContext(OpFoldResult val)
FloatAttr getFloatAttr(Type type, double value)
This class implements a pattern rewriter for use with ConversionPatterns.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
Basic lowering implementation to rewrite Ops with just one result to the LLVM Dialect.
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect.
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
llvm::ElementCount getVectorNumElements(Type type)
Returns the element count of any LLVM-compatible vector type.
Include the generated interface declarations.
void populateMathToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool approximateLog1p=true, PatternBenefit benefit=1)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
void registerConvertMathToLLVMInterface(DialectRegistry ®istry)