21 #include "llvm/ADT/FloatingPointMode.h"
24 #define GEN_PASS_DEF_CONVERTMATHTOLLVMPASS
25 #include "mlir/Conversion/Passes.h.inc"
32 template <
typename SourceOp,
typename TargetOp>
33 using ConvertFastMath = arith::AttrConvertFastMathToLLVM<SourceOp, TargetOp>;
35 template <
typename SourceOp,
typename TargetOp>
36 using ConvertFMFMathToLLVMPattern =
39 using AbsFOpLowering = ConvertFMFMathToLLVMPattern<math::AbsFOp, LLVM::FAbsOp>;
40 using CeilOpLowering = ConvertFMFMathToLLVMPattern<math::CeilOp, LLVM::FCeilOp>;
41 using CopySignOpLowering =
42 ConvertFMFMathToLLVMPattern<math::CopySignOp, LLVM::CopySignOp>;
43 using CosOpLowering = ConvertFMFMathToLLVMPattern<math::CosOp, LLVM::CosOp>;
44 using CoshOpLowering = ConvertFMFMathToLLVMPattern<math::CoshOp, LLVM::CoshOp>;
45 using AcosOpLowering = ConvertFMFMathToLLVMPattern<math::AcosOp, LLVM::ACosOp>;
46 using CtPopFOpLowering =
48 using Exp2OpLowering = ConvertFMFMathToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
49 using ExpOpLowering = ConvertFMFMathToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
50 using FloorOpLowering =
51 ConvertFMFMathToLLVMPattern<math::FloorOp, LLVM::FFloorOp>;
52 using FmaOpLowering = ConvertFMFMathToLLVMPattern<math::FmaOp, LLVM::FMAOp>;
53 using Log10OpLowering =
54 ConvertFMFMathToLLVMPattern<math::Log10Op, LLVM::Log10Op>;
55 using Log2OpLowering = ConvertFMFMathToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
56 using LogOpLowering = ConvertFMFMathToLLVMPattern<math::LogOp, LLVM::LogOp>;
57 using PowFOpLowering = ConvertFMFMathToLLVMPattern<math::PowFOp, LLVM::PowOp>;
58 using FPowIOpLowering =
59 ConvertFMFMathToLLVMPattern<math::FPowIOp, LLVM::PowIOp>;
60 using RoundEvenOpLowering =
61 ConvertFMFMathToLLVMPattern<math::RoundEvenOp, LLVM::RoundEvenOp>;
62 using RoundOpLowering =
63 ConvertFMFMathToLLVMPattern<math::RoundOp, LLVM::RoundOp>;
64 using SinOpLowering = ConvertFMFMathToLLVMPattern<math::SinOp, LLVM::SinOp>;
65 using SinhOpLowering = ConvertFMFMathToLLVMPattern<math::SinhOp, LLVM::SinhOp>;
66 using ASinOpLowering = ConvertFMFMathToLLVMPattern<math::AsinOp, LLVM::ASinOp>;
67 using SqrtOpLowering = ConvertFMFMathToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
68 using FTruncOpLowering =
69 ConvertFMFMathToLLVMPattern<math::TruncOp, LLVM::FTruncOp>;
70 using TanOpLowering = ConvertFMFMathToLLVMPattern<math::TanOp, LLVM::TanOp>;
71 using TanhOpLowering = ConvertFMFMathToLLVMPattern<math::TanhOp, LLVM::TanhOp>;
72 using ATanOpLowering = ConvertFMFMathToLLVMPattern<math::AtanOp, LLVM::ATanOp>;
73 using ATan2OpLowering =
74 ConvertFMFMathToLLVMPattern<math::Atan2Op, LLVM::ATan2Op>;
78 template <
typename MathOp,
typename LLVMOp>
81 using Super = IntOpWithFlagLowering<MathOp, LLVMOp>;
84 matchAndRewrite(MathOp op,
typename MathOp::Adaptor adaptor,
86 const auto &typeConverter = *this->getTypeConverter();
87 auto operandType = adaptor.getOperand().getType();
88 auto llvmOperandType = typeConverter.convertType(operandType);
92 auto loc = op.getLoc();
93 auto resultType = op.getResult().getType();
94 auto llvmResultType = typeConverter.convertType(resultType);
98 if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
100 adaptor.getOperand(),
false);
104 if (!isa<VectorType>(llvmResultType))
108 op.getOperation(), adaptor.getOperands(), typeConverter,
110 return rewriter.create<LLVMOp>(loc, llvm1DVectorTy, operands[0],
117 using CountLeadingZerosOpLowering =
118 IntOpWithFlagLowering<math::CountLeadingZerosOp, LLVM::CountLeadingZerosOp>;
119 using CountTrailingZerosOpLowering =
120 IntOpWithFlagLowering<math::CountTrailingZerosOp,
121 LLVM::CountTrailingZerosOp>;
122 using AbsIOpLowering = IntOpWithFlagLowering<math::AbsIOp, LLVM::AbsOp>;
129 matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor,
131 const auto &typeConverter = *this->getTypeConverter();
132 auto operandType = adaptor.getOperand().getType();
133 auto llvmOperandType = typeConverter.convertType(operandType);
134 if (!llvmOperandType)
137 auto loc = op.getLoc();
138 auto resultType = op.getResult().getType();
139 auto floatType = cast<FloatType>(
142 ConvertFastMath<math::ExpM1Op, LLVM::ExpOp> expAttrs(op);
143 ConvertFastMath<math::ExpM1Op, LLVM::FSubOp> subAttrs(op);
145 if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
146 LLVM::ConstantOp one;
148 one = rewriter.
create<LLVM::ConstantOp>(
149 loc, llvmOperandType,
153 one = rewriter.
create<LLVM::ConstantOp>(loc, llvmOperandType, floatOne);
155 auto exp = rewriter.
create<LLVM::ExpOp>(loc, adaptor.getOperand(),
156 expAttrs.getAttrs());
158 op, llvmOperandType,
ValueRange{exp, one}, subAttrs.getAttrs());
162 if (!isa<VectorType>(resultType))
166 op.getOperation(), adaptor.getOperands(), typeConverter,
168 auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
169 auto splatAttr = SplatElementsAttr::get(
170 mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
171 {numElements.isScalable()}),
174 rewriter.
create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
175 auto exp = rewriter.
create<LLVM::ExpOp>(
176 loc, llvm1DVectorTy, operands[0], expAttrs.getAttrs());
177 return rewriter.
create<LLVM::FSubOp>(
178 loc, llvm1DVectorTy,
ValueRange{exp, one}, subAttrs.getAttrs());
189 matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor,
191 const auto &typeConverter = *this->getTypeConverter();
192 auto operandType = adaptor.getOperand().getType();
193 auto llvmOperandType = typeConverter.convertType(operandType);
194 if (!llvmOperandType)
197 auto loc = op.getLoc();
198 auto resultType = op.getResult().getType();
199 auto floatType = cast<FloatType>(
202 ConvertFastMath<math::Log1pOp, LLVM::FAddOp> addAttrs(op);
203 ConvertFastMath<math::Log1pOp, LLVM::LogOp> logAttrs(op);
205 if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
206 LLVM::ConstantOp one =
207 isa<VectorType>(llvmOperandType)
208 ? rewriter.
create<LLVM::ConstantOp>(
209 loc, llvmOperandType,
212 : rewriter.
create<LLVM::ConstantOp>(loc, llvmOperandType,
215 auto add = rewriter.
create<LLVM::FAddOp>(
216 loc, llvmOperandType,
ValueRange{one, adaptor.getOperand()},
217 addAttrs.getAttrs());
219 op, llvmOperandType,
ValueRange{add}, logAttrs.getAttrs());
223 if (!isa<VectorType>(resultType))
227 op.getOperation(), adaptor.getOperands(), typeConverter,
229 auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
230 auto splatAttr = SplatElementsAttr::get(
231 mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
232 {numElements.isScalable()}),
235 rewriter.
create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
236 auto add = rewriter.
create<LLVM::FAddOp>(loc, llvm1DVectorTy,
238 addAttrs.getAttrs());
239 return rewriter.
create<LLVM::LogOp>(
240 loc, llvm1DVectorTy,
ValueRange{add}, logAttrs.getAttrs());
254 auto operandType = adaptor.getOperand().getType();
255 auto llvmOperandType = typeConverter.convertType(operandType);
256 if (!llvmOperandType)
259 auto loc = op.getLoc();
260 auto resultType = op.getResult().getType();
261 auto floatType = cast<FloatType>(
264 ConvertFastMath<math::RsqrtOp, LLVM::SqrtOp> sqrtAttrs(op);
265 ConvertFastMath<math::RsqrtOp, LLVM::FDivOp> divAttrs(op);
267 if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
268 LLVM::ConstantOp one;
269 if (isa<VectorType>(llvmOperandType)) {
270 one = rewriter.
create<LLVM::ConstantOp>(
271 loc, llvmOperandType,
275 one = rewriter.
create<LLVM::ConstantOp>(loc, llvmOperandType, floatOne);
277 auto sqrt = rewriter.
create<LLVM::SqrtOp>(loc, adaptor.getOperand(),
278 sqrtAttrs.getAttrs());
280 op, llvmOperandType,
ValueRange{one, sqrt}, divAttrs.getAttrs());
284 if (!isa<VectorType>(resultType))
288 op.getOperation(), adaptor.getOperands(), typeConverter,
290 auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
291 auto splatAttr = SplatElementsAttr::get(
292 mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
293 {numElements.isScalable()}),
296 rewriter.
create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
297 auto sqrt = rewriter.
create<LLVM::SqrtOp>(
298 loc, llvm1DVectorTy, operands[0], sqrtAttrs.getAttrs());
299 return rewriter.
create<LLVM::FDivOp>(
300 loc, llvm1DVectorTy,
ValueRange{one, sqrt}, divAttrs.getAttrs());
314 typeConverter.convertType(adaptor.getOperand().getType());
315 auto resultType = typeConverter.convertType(op.getResult().getType());
316 if (!operandType || !resultType)
320 op, resultType, adaptor.getOperand(), llvm::fcNan);
333 typeConverter.convertType(adaptor.getOperand().getType());
334 auto resultType = typeConverter.convertType(op.getResult().getType());
335 if (!operandType || !resultType)
339 op, resultType, adaptor.getOperand(), llvm::fcFinite);
344 struct ConvertMathToLLVMPass
345 :
public impl::ConvertMathToLLVMPassBase<ConvertMathToLLVMPass> {
348 void runOnOperation()
override {
363 if (approximateLog1p)
364 patterns.add<Log1pOpLowering>(converter, benefit);
376 CountLeadingZerosOpLowering,
377 CountTrailingZerosOpLowering,
401 >(converter, benefit);
414 context->loadDialect<LLVM::LLVMDialect>();
429 dialect->addInterfaces<MathToLLVMDialectInterface>();
static MLIRContext * getContext(OpFoldResult val)
FloatAttr getFloatAttr(Type type, double value)
This class implements a pattern rewriter for use with ConversionPatterns.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Wrappers around the RewritePattern methods that pass the derived op type.
Base class for dialect interfaces providing translation to LLVM IR.
virtual void populateConvertToLLVMConversionPatterns(ConversionTarget &target, LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) const =0
Hook for derived dialect interface to provide conversion patterns and mark dialect legal for the conv...
virtual void loadDependentDialects(MLIRContext *context) const
Hook for derived dialect interface to load the dialects they target.
ConvertToLLVMPatternInterface(Dialect *dialect)
const LLVMTypeConverter * getTypeConverter() const
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
Basic lowering implementation to rewrite Ops with just one result to the LLVM Dialect.
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect.
Include the generated interface declarations.
void populateMathToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool approximateLog1p=true, PatternBenefit benefit=1)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
void registerConvertMathToLLVMInterface(DialectRegistry ®istry)