21#include "llvm/ADT/FloatingPointMode.h"
24#define GEN_PASS_DEF_CONVERTMATHTOLLVMPASS
25#include "mlir/Conversion/Passes.h.inc"
32template <
typename SourceOp,
typename TargetOp>
35template <
typename SourceOp,
typename TargetOp>
36using ConvertFMFMathToLLVMPattern =
39using AbsFOpLowering = ConvertFMFMathToLLVMPattern<math::AbsFOp, LLVM::FAbsOp>;
40using CeilOpLowering = ConvertFMFMathToLLVMPattern<math::CeilOp, LLVM::FCeilOp>;
41using CopySignOpLowering =
42 ConvertFMFMathToLLVMPattern<math::CopySignOp, LLVM::CopySignOp>;
43using CosOpLowering = ConvertFMFMathToLLVMPattern<math::CosOp, LLVM::CosOp>;
44using CoshOpLowering = ConvertFMFMathToLLVMPattern<math::CoshOp, LLVM::CoshOp>;
45using AcosOpLowering = ConvertFMFMathToLLVMPattern<math::AcosOp, LLVM::ACosOp>;
46using CtPopFOpLowering =
48using Exp2OpLowering = ConvertFMFMathToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
49using ExpOpLowering = ConvertFMFMathToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
50using FloorOpLowering =
51 ConvertFMFMathToLLVMPattern<math::FloorOp, LLVM::FFloorOp>;
52using FmaOpLowering = ConvertFMFMathToLLVMPattern<math::FmaOp, LLVM::FMAOp>;
53using Log10OpLowering =
54 ConvertFMFMathToLLVMPattern<math::Log10Op, LLVM::Log10Op>;
55using Log2OpLowering = ConvertFMFMathToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
56using LogOpLowering = ConvertFMFMathToLLVMPattern<math::LogOp, LLVM::LogOp>;
57using PowFOpLowering = ConvertFMFMathToLLVMPattern<math::PowFOp, LLVM::PowOp>;
58using FPowIOpLowering =
59 ConvertFMFMathToLLVMPattern<math::FPowIOp, LLVM::PowIOp>;
60using RoundEvenOpLowering =
61 ConvertFMFMathToLLVMPattern<math::RoundEvenOp, LLVM::RoundEvenOp>;
62using RoundOpLowering =
63 ConvertFMFMathToLLVMPattern<math::RoundOp, LLVM::RoundOp>;
64using SinOpLowering = ConvertFMFMathToLLVMPattern<math::SinOp, LLVM::SinOp>;
65using SinhOpLowering = ConvertFMFMathToLLVMPattern<math::SinhOp, LLVM::SinhOp>;
66using ASinOpLowering = ConvertFMFMathToLLVMPattern<math::AsinOp, LLVM::ASinOp>;
67using SqrtOpLowering = ConvertFMFMathToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
68using FTruncOpLowering =
69 ConvertFMFMathToLLVMPattern<math::TruncOp, LLVM::FTruncOp>;
70using TanOpLowering = ConvertFMFMathToLLVMPattern<math::TanOp, LLVM::TanOp>;
71using TanhOpLowering = ConvertFMFMathToLLVMPattern<math::TanhOp, LLVM::TanhOp>;
72using ATanOpLowering = ConvertFMFMathToLLVMPattern<math::AtanOp, LLVM::ATanOp>;
73using ATan2OpLowering =
74 ConvertFMFMathToLLVMPattern<math::Atan2Op, LLVM::ATan2Op>;
78template <
typename MathOp,
typename LLVMOp>
80 using ConvertOpToLLVMPattern<MathOp>::ConvertOpToLLVMPattern;
81 using Super = IntOpWithFlagLowering<MathOp, LLVMOp>;
84 matchAndRewrite(MathOp op,
typename MathOp::Adaptor adaptor,
85 ConversionPatternRewriter &rewriter)
const override {
86 const auto &typeConverter = *this->getTypeConverter();
87 auto operandType = adaptor.getOperand().getType();
88 auto llvmOperandType = typeConverter.convertType(operandType);
92 auto loc = op.getLoc();
93 auto resultType = op.getResult().getType();
94 auto llvmResultType = typeConverter.convertType(resultType);
98 if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
99 rewriter.replaceOpWithNewOp<LLVMOp>(op, llvmResultType,
100 adaptor.getOperand(),
false);
104 if (!isa<VectorType>(llvmResultType))
108 op.getOperation(), adaptor.getOperands(), typeConverter,
109 [&](Type llvm1DVectorTy,
ValueRange operands) {
110 return LLVMOp::create(rewriter, loc, llvm1DVectorTy, operands[0],
117using CountLeadingZerosOpLowering =
118 IntOpWithFlagLowering<math::CountLeadingZerosOp, LLVM::CountLeadingZerosOp>;
119using CountTrailingZerosOpLowering =
120 IntOpWithFlagLowering<math::CountTrailingZerosOp,
121 LLVM::CountTrailingZerosOp>;
122using AbsIOpLowering = IntOpWithFlagLowering<math::AbsIOp, LLVM::AbsOp>;
130 ConversionPatternRewriter &rewriter)
const override {
132 mlir::Location loc = op.getLoc();
133 mlir::Type operandType = adaptor.getOperand().getType();
134 mlir::Type llvmOperandType = typeConverter.convertType(operandType);
135 mlir::Type sinType = typeConverter.convertType(op.getSin().getType());
136 mlir::Type cosType = typeConverter.convertType(op.getCos().getType());
137 if (!llvmOperandType || !sinType || !cosType)
140 ConvertFastMath<math::SincosOp, LLVM::SincosOp> attrs(op);
142 auto structType = LLVM::LLVMStructType::getLiteral(
143 rewriter.getContext(), {llvmOperandType, llvmOperandType});
145 auto sincosOp = LLVM::SincosOp::create(
146 rewriter, loc, structType, adaptor.getOperand(), attrs.getAttrs());
148 auto sinValue = LLVM::ExtractValueOp::create(rewriter, loc, sincosOp, 0);
149 auto cosValue = LLVM::ExtractValueOp::create(rewriter, loc, sincosOp, 1);
151 rewriter.replaceOp(op, {sinValue, cosValue});
158 using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern;
161 matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor,
162 ConversionPatternRewriter &rewriter)
const override {
163 const auto &typeConverter = *this->getTypeConverter();
164 auto operandType = adaptor.getOperand().getType();
165 auto llvmOperandType = typeConverter.convertType(operandType);
166 if (!llvmOperandType)
169 auto loc = op.getLoc();
170 auto resultType = op.getResult().getType();
171 auto floatType = cast<FloatType>(
173 auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
174 ConvertFastMath<math::ExpM1Op, LLVM::ExpOp> expAttrs(op);
175 ConvertFastMath<math::ExpM1Op, LLVM::FSubOp> subAttrs(op);
177 if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
178 LLVM::ConstantOp one;
180 one = LLVM::ConstantOp::create(
181 rewriter, loc, llvmOperandType,
182 SplatElementsAttr::get(cast<ShapedType>(llvmOperandType),
186 LLVM::ConstantOp::create(rewriter, loc, llvmOperandType, floatOne);
188 auto exp = LLVM::ExpOp::create(rewriter, loc, adaptor.getOperand(),
189 expAttrs.getAttrs());
190 rewriter.replaceOpWithNewOp<LLVM::FSubOp>(
191 op, llvmOperandType,
ValueRange{exp, one}, subAttrs.getAttrs());
195 if (!isa<VectorType>(resultType))
196 return rewriter.notifyMatchFailure(op,
"expected vector result type");
199 op.getOperation(), adaptor.getOperands(), typeConverter,
200 [&](Type llvm1DVectorTy,
ValueRange operands) {
201 auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
202 auto splatAttr = SplatElementsAttr::get(
203 mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
204 {numElements.isScalable()}),
206 auto one = LLVM::ConstantOp::create(rewriter, loc, llvm1DVectorTy,
208 auto exp = LLVM::ExpOp::create(rewriter, loc, llvm1DVectorTy,
209 operands[0], expAttrs.getAttrs());
210 return LLVM::FSubOp::create(rewriter, loc, llvm1DVectorTy,
212 subAttrs.getAttrs());
220 using ConvertOpToLLVMPattern<math::Log1pOp>::ConvertOpToLLVMPattern;
223 matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor,
224 ConversionPatternRewriter &rewriter)
const override {
225 const auto &typeConverter = *this->getTypeConverter();
226 auto operandType = adaptor.getOperand().getType();
227 auto llvmOperandType = typeConverter.convertType(operandType);
228 if (!llvmOperandType)
229 return rewriter.notifyMatchFailure(op,
"unsupported operand type");
231 auto loc = op.getLoc();
232 auto resultType = op.getResult().getType();
233 auto floatType = cast<FloatType>(
235 auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
236 ConvertFastMath<math::Log1pOp, LLVM::FAddOp> addAttrs(op);
237 ConvertFastMath<math::Log1pOp, LLVM::LogOp> logAttrs(op);
239 if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
240 LLVM::ConstantOp one =
241 isa<VectorType>(llvmOperandType)
242 ? LLVM::ConstantOp::create(
243 rewriter, loc, llvmOperandType,
244 SplatElementsAttr::get(cast<ShapedType>(llvmOperandType),
246 : LLVM::ConstantOp::create(rewriter, loc, llvmOperandType,
249 auto add = LLVM::FAddOp::create(rewriter, loc, llvmOperandType,
251 addAttrs.getAttrs());
252 rewriter.replaceOpWithNewOp<LLVM::LogOp>(
257 if (!isa<VectorType>(resultType))
258 return rewriter.notifyMatchFailure(op,
"expected vector result type");
261 op.getOperation(), adaptor.getOperands(), typeConverter,
262 [&](Type llvm1DVectorTy,
ValueRange operands) {
263 auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
264 auto splatAttr = SplatElementsAttr::get(
265 mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
266 {numElements.isScalable()}),
268 auto one = LLVM::ConstantOp::create(rewriter, loc, llvm1DVectorTy,
270 auto add = LLVM::FAddOp::create(rewriter, loc, llvm1DVectorTy,
272 addAttrs.getAttrs());
273 return LLVM::LogOp::create(rewriter, loc, llvm1DVectorTy,
282 using ConvertOpToLLVMPattern<math::RsqrtOp>::ConvertOpToLLVMPattern;
285 matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor,
286 ConversionPatternRewriter &rewriter)
const override {
287 const auto &typeConverter = *this->getTypeConverter();
288 auto operandType = adaptor.getOperand().getType();
289 auto llvmOperandType = typeConverter.convertType(operandType);
290 if (!llvmOperandType)
293 auto loc = op.getLoc();
294 auto resultType = op.getResult().getType();
295 auto floatType = cast<FloatType>(
297 auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
298 ConvertFastMath<math::RsqrtOp, LLVM::SqrtOp> sqrtAttrs(op);
299 ConvertFastMath<math::RsqrtOp, LLVM::FDivOp> divAttrs(op);
301 if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
302 LLVM::ConstantOp one;
303 if (isa<VectorType>(llvmOperandType)) {
304 one = LLVM::ConstantOp::create(
305 rewriter, loc, llvmOperandType,
306 SplatElementsAttr::get(cast<ShapedType>(llvmOperandType),
310 LLVM::ConstantOp::create(rewriter, loc, llvmOperandType, floatOne);
312 auto sqrt = LLVM::SqrtOp::create(rewriter, loc, adaptor.getOperand(),
313 sqrtAttrs.getAttrs());
314 rewriter.replaceOpWithNewOp<LLVM::FDivOp>(
315 op, llvmOperandType,
ValueRange{one, sqrt}, divAttrs.getAttrs());
319 if (!isa<VectorType>(resultType))
323 op.getOperation(), adaptor.getOperands(), typeConverter,
324 [&](Type llvm1DVectorTy,
ValueRange operands) {
325 auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
326 auto splatAttr = SplatElementsAttr::get(
327 mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
328 {numElements.isScalable()}),
330 auto one = LLVM::ConstantOp::create(rewriter, loc, llvm1DVectorTy,
332 auto sqrt = LLVM::SqrtOp::create(rewriter, loc, llvm1DVectorTy,
333 operands[0], sqrtAttrs.getAttrs());
334 return LLVM::FDivOp::create(rewriter, loc, llvm1DVectorTy,
336 divAttrs.getAttrs());
343 using ConvertOpToLLVMPattern<math::IsNaNOp>::ConvertOpToLLVMPattern;
346 matchAndRewrite(math::IsNaNOp op, OpAdaptor adaptor,
347 ConversionPatternRewriter &rewriter)
const override {
348 const auto &typeConverter = *this->getTypeConverter();
350 typeConverter.convertType(adaptor.getOperand().getType());
351 auto resultType = typeConverter.convertType(op.getResult().getType());
352 if (!operandType || !resultType)
355 rewriter.replaceOpWithNewOp<LLVM::IsFPClass>(
356 op, resultType, adaptor.getOperand(), llvm::fcNan);
362 using ConvertOpToLLVMPattern<math::IsFiniteOp>::ConvertOpToLLVMPattern;
365 matchAndRewrite(math::IsFiniteOp op, OpAdaptor adaptor,
366 ConversionPatternRewriter &rewriter)
const override {
367 const auto &typeConverter = *this->getTypeConverter();
369 typeConverter.convertType(adaptor.getOperand().getType());
370 auto resultType = typeConverter.convertType(op.getResult().getType());
371 if (!operandType || !resultType)
374 rewriter.replaceOpWithNewOp<LLVM::IsFPClass>(
375 op, resultType, adaptor.getOperand(), llvm::fcFinite);
380struct ConvertMathToLLVMPass
384 void runOnOperation()
override {
389 if (
failed(applyPartialConversion(getOperation(),
target,
399 if (approximateLog1p)
400 patterns.add<Log1pOpLowering>(converter, benefit);
412 CountLeadingZerosOpLowering,
413 CountTrailingZerosOpLowering,
438 >(converter, benefit);
450 void loadDependentDialects(
MLIRContext *context)
const final {
451 context->loadDialect<LLVM::LLVMDialect>();
456 void populateConvertToLLVMConversionPatterns(
457 ConversionTarget &
target, LLVMTypeConverter &typeConverter,
458 RewritePatternSet &
patterns)
const final {
466 dialect->addInterfaces<MathToLLVMDialectInterface>();
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
typename math::SincosOp::Adaptor OpAdaptor
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
const LLVMTypeConverter * getTypeConverter() const
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
MLIRContext is the top-level object for a collection of MLIR operations.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Basic lowering implementation to rewrite Ops with just one result to the LLVM Dialect.
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect.
Include the generated interface declarations.
void populateMathToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool approximateLog1p=true, PatternBenefit benefit=1)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void registerConvertMathToLLVMInterface(DialectRegistry ®istry)
LogicalResult matchAndRewrite(math::SincosOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override