21#include "llvm/ADT/FloatingPointMode.h"
24#define GEN_PASS_DEF_CONVERTMATHTOLLVMPASS
25#include "mlir/Conversion/Passes.h.inc"
32template <
typename SourceOp,
typename TargetOp>
35template <
typename SourceOp,
typename TargetOp,
bool FailOnUnsupportedFP = true>
36using ConvertFMFMathToLLVMPattern =
40using AbsFOpLowering = ConvertFMFMathToLLVMPattern<math::AbsFOp, LLVM::FAbsOp>;
41using CeilOpLowering = ConvertFMFMathToLLVMPattern<math::CeilOp, LLVM::FCeilOp>;
42using CopySignOpLowering =
43 ConvertFMFMathToLLVMPattern<math::CopySignOp, LLVM::CopySignOp>;
44using CosOpLowering = ConvertFMFMathToLLVMPattern<math::CosOp, LLVM::CosOp>;
45using CoshOpLowering = ConvertFMFMathToLLVMPattern<math::CoshOp, LLVM::CoshOp>;
46using AcosOpLowering = ConvertFMFMathToLLVMPattern<math::AcosOp, LLVM::ACosOp>;
47using CtPopFOpLowering =
51using Exp2OpLowering = ConvertFMFMathToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
52using ExpOpLowering = ConvertFMFMathToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
53using FloorOpLowering =
54 ConvertFMFMathToLLVMPattern<math::FloorOp, LLVM::FFloorOp>;
55using FmaOpLowering = ConvertFMFMathToLLVMPattern<math::FmaOp, LLVM::FMAOp>;
56using Log10OpLowering =
57 ConvertFMFMathToLLVMPattern<math::Log10Op, LLVM::Log10Op>;
58using Log2OpLowering = ConvertFMFMathToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
59using LogOpLowering = ConvertFMFMathToLLVMPattern<math::LogOp, LLVM::LogOp>;
60using PowFOpLowering = ConvertFMFMathToLLVMPattern<math::PowFOp, LLVM::PowOp>;
61using FPowIOpLowering =
62 ConvertFMFMathToLLVMPattern<math::FPowIOp, LLVM::PowIOp>;
63using RoundEvenOpLowering =
64 ConvertFMFMathToLLVMPattern<math::RoundEvenOp, LLVM::RoundEvenOp>;
65using RoundOpLowering =
66 ConvertFMFMathToLLVMPattern<math::RoundOp, LLVM::RoundOp>;
67using SinOpLowering = ConvertFMFMathToLLVMPattern<math::SinOp, LLVM::SinOp>;
68using SinhOpLowering = ConvertFMFMathToLLVMPattern<math::SinhOp, LLVM::SinhOp>;
69using ASinOpLowering = ConvertFMFMathToLLVMPattern<math::AsinOp, LLVM::ASinOp>;
70using SqrtOpLowering = ConvertFMFMathToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
71using FTruncOpLowering =
72 ConvertFMFMathToLLVMPattern<math::TruncOp, LLVM::FTruncOp>;
73using TanOpLowering = ConvertFMFMathToLLVMPattern<math::TanOp, LLVM::TanOp>;
74using TanhOpLowering = ConvertFMFMathToLLVMPattern<math::TanhOp, LLVM::TanhOp>;
75using ATanOpLowering = ConvertFMFMathToLLVMPattern<math::AtanOp, LLVM::ATanOp>;
76using ATan2OpLowering =
77 ConvertFMFMathToLLVMPattern<math::Atan2Op, LLVM::ATan2Op>;
81template <
typename MathOp,
typename LLVMOp>
82struct IntOpWithFlagLowering
84 using ConvertOpToLLVMPattern<
85 MathOp,
true>::ConvertOpToLLVMPattern;
86 using Super = IntOpWithFlagLowering<MathOp, LLVMOp>;
89 matchAndRewrite(MathOp op,
typename MathOp::Adaptor adaptor,
90 ConversionPatternRewriter &rewriter)
const override {
91 const auto &typeConverter = *this->getTypeConverter();
92 auto operandType = adaptor.getOperand().getType();
93 auto llvmOperandType = typeConverter.convertType(operandType);
97 auto loc = op.getLoc();
98 auto resultType = op.getResult().getType();
99 auto llvmResultType = typeConverter.convertType(resultType);
103 if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
104 rewriter.replaceOpWithNewOp<LLVMOp>(op, llvmResultType,
105 adaptor.getOperand(),
false);
109 if (!isa<VectorType>(llvmResultType))
113 op.getOperation(), adaptor.getOperands(), typeConverter,
114 [&](Type llvm1DVectorTy,
ValueRange operands) {
115 return LLVMOp::create(rewriter, loc, llvm1DVectorTy, operands[0],
122using CountLeadingZerosOpLowering =
123 IntOpWithFlagLowering<math::CountLeadingZerosOp, LLVM::CountLeadingZerosOp>;
124using CountTrailingZerosOpLowering =
125 IntOpWithFlagLowering<math::CountTrailingZerosOp,
126 LLVM::CountTrailingZerosOp>;
127using AbsIOpLowering = IntOpWithFlagLowering<math::AbsIOp, LLVM::AbsOp>;
138 ConversionPatternRewriter &rewriter)
const override {
140 mlir::Location loc = op.getLoc();
141 mlir::Type operandType = adaptor.getOperand().getType();
142 mlir::Type llvmOperandType = typeConverter.convertType(operandType);
143 mlir::Type sinType = typeConverter.convertType(op.getSin().getType());
144 mlir::Type cosType = typeConverter.convertType(op.getCos().getType());
145 if (!llvmOperandType || !sinType || !cosType)
148 ConvertFastMath<math::SincosOp, LLVM::SincosOp> attrs(op);
150 auto structType = LLVM::LLVMStructType::getLiteral(
151 rewriter.getContext(), {llvmOperandType, llvmOperandType});
153 auto sincosOp = LLVM::SincosOp::create(
154 rewriter, loc, structType, adaptor.getOperand(), attrs.getAttrs());
156 auto sinValue = LLVM::ExtractValueOp::create(rewriter, loc, sincosOp, 0);
157 auto cosValue = LLVM::ExtractValueOp::create(rewriter, loc, sincosOp, 1);
159 rewriter.replaceOp(op, {sinValue, cosValue});
165struct ExpM1OpLowering
168 using ConvertOpToLLVMPattern<
169 math::ExpM1Op,
true>::ConvertOpToLLVMPattern;
172 matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor,
173 ConversionPatternRewriter &rewriter)
const override {
174 const auto &typeConverter = *this->getTypeConverter();
175 auto operandType = adaptor.getOperand().getType();
176 auto llvmOperandType = typeConverter.convertType(operandType);
177 if (!llvmOperandType)
180 auto loc = op.getLoc();
181 auto resultType = op.getResult().getType();
182 auto floatType = cast<FloatType>(
184 auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
185 ConvertFastMath<math::ExpM1Op, LLVM::ExpOp> expAttrs(op);
186 ConvertFastMath<math::ExpM1Op, LLVM::FSubOp> subAttrs(op);
188 if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
189 LLVM::ConstantOp one;
191 one = LLVM::ConstantOp::create(
192 rewriter, loc, llvmOperandType,
193 SplatElementsAttr::get(cast<ShapedType>(llvmOperandType),
197 LLVM::ConstantOp::create(rewriter, loc, llvmOperandType, floatOne);
199 auto exp = LLVM::ExpOp::create(rewriter, loc, adaptor.getOperand(),
200 expAttrs.getAttrs());
201 rewriter.replaceOpWithNewOp<LLVM::FSubOp>(
202 op, llvmOperandType,
ValueRange{exp, one}, subAttrs.getAttrs());
206 if (!isa<VectorType>(resultType))
207 return rewriter.notifyMatchFailure(op,
"expected vector result type");
210 op.getOperation(), adaptor.getOperands(), typeConverter,
211 [&](Type llvm1DVectorTy,
ValueRange operands) {
212 auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
213 auto splatAttr = SplatElementsAttr::get(
214 mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
215 {numElements.isScalable()}),
217 auto one = LLVM::ConstantOp::create(rewriter, loc, llvm1DVectorTy,
219 auto exp = LLVM::ExpOp::create(rewriter, loc, llvm1DVectorTy,
220 operands[0], expAttrs.getAttrs());
221 return LLVM::FSubOp::create(rewriter, loc, llvm1DVectorTy,
223 subAttrs.getAttrs());
230struct Log1pOpLowering
233 using ConvertOpToLLVMPattern<
234 math::Log1pOp,
true>::ConvertOpToLLVMPattern;
237 matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor,
238 ConversionPatternRewriter &rewriter)
const override {
239 const auto &typeConverter = *this->getTypeConverter();
240 auto operandType = adaptor.getOperand().getType();
241 auto llvmOperandType = typeConverter.convertType(operandType);
242 if (!llvmOperandType)
243 return rewriter.notifyMatchFailure(op,
"unsupported operand type");
245 auto loc = op.getLoc();
246 auto resultType = op.getResult().getType();
247 auto floatType = cast<FloatType>(
249 auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
250 ConvertFastMath<math::Log1pOp, LLVM::FAddOp> addAttrs(op);
251 ConvertFastMath<math::Log1pOp, LLVM::LogOp> logAttrs(op);
253 if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
254 LLVM::ConstantOp one =
255 isa<VectorType>(llvmOperandType)
256 ? LLVM::ConstantOp::create(
257 rewriter, loc, llvmOperandType,
258 SplatElementsAttr::get(cast<ShapedType>(llvmOperandType),
260 : LLVM::ConstantOp::create(rewriter, loc, llvmOperandType,
263 auto add = LLVM::FAddOp::create(rewriter, loc, llvmOperandType,
265 addAttrs.getAttrs());
266 rewriter.replaceOpWithNewOp<LLVM::LogOp>(
271 if (!isa<VectorType>(resultType))
272 return rewriter.notifyMatchFailure(op,
"expected vector result type");
275 op.getOperation(), adaptor.getOperands(), typeConverter,
276 [&](Type llvm1DVectorTy,
ValueRange operands) {
277 auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
278 auto splatAttr = SplatElementsAttr::get(
279 mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
280 {numElements.isScalable()}),
282 auto one = LLVM::ConstantOp::create(rewriter, loc, llvm1DVectorTy,
284 auto add = LLVM::FAddOp::create(rewriter, loc, llvm1DVectorTy,
286 addAttrs.getAttrs());
287 return LLVM::LogOp::create(rewriter, loc, llvm1DVectorTy,
295struct RsqrtOpLowering
298 using ConvertOpToLLVMPattern<
299 math::RsqrtOp,
true>::ConvertOpToLLVMPattern;
302 matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor,
303 ConversionPatternRewriter &rewriter)
const override {
304 const auto &typeConverter = *this->getTypeConverter();
305 auto operandType = adaptor.getOperand().getType();
306 auto llvmOperandType = typeConverter.convertType(operandType);
307 if (!llvmOperandType)
310 auto loc = op.getLoc();
311 auto resultType = op.getResult().getType();
312 auto floatType = cast<FloatType>(
314 auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
315 ConvertFastMath<math::RsqrtOp, LLVM::SqrtOp> sqrtAttrs(op);
316 ConvertFastMath<math::RsqrtOp, LLVM::FDivOp> divAttrs(op);
318 if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
319 LLVM::ConstantOp one;
320 if (isa<VectorType>(llvmOperandType)) {
321 one = LLVM::ConstantOp::create(
322 rewriter, loc, llvmOperandType,
323 SplatElementsAttr::get(cast<ShapedType>(llvmOperandType),
327 LLVM::ConstantOp::create(rewriter, loc, llvmOperandType, floatOne);
329 auto sqrt = LLVM::SqrtOp::create(rewriter, loc, adaptor.getOperand(),
330 sqrtAttrs.getAttrs());
331 rewriter.replaceOpWithNewOp<LLVM::FDivOp>(
332 op, llvmOperandType,
ValueRange{one, sqrt}, divAttrs.getAttrs());
336 if (!isa<VectorType>(resultType))
340 op.getOperation(), adaptor.getOperands(), typeConverter,
341 [&](Type llvm1DVectorTy,
ValueRange operands) {
342 auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
343 auto splatAttr = SplatElementsAttr::get(
344 mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
345 {numElements.isScalable()}),
347 auto one = LLVM::ConstantOp::create(rewriter, loc, llvm1DVectorTy,
349 auto sqrt = LLVM::SqrtOp::create(rewriter, loc, llvm1DVectorTy,
350 operands[0], sqrtAttrs.getAttrs());
351 return LLVM::FDivOp::create(rewriter, loc, llvm1DVectorTy,
353 divAttrs.getAttrs());
359struct IsNaNOpLowering
362 using ConvertOpToLLVMPattern<
363 math::IsNaNOp,
true>::ConvertOpToLLVMPattern;
366 matchAndRewrite(math::IsNaNOp op, OpAdaptor adaptor,
367 ConversionPatternRewriter &rewriter)
const override {
368 const auto &typeConverter = *this->getTypeConverter();
370 typeConverter.convertType(adaptor.getOperand().getType());
371 auto resultType = typeConverter.convertType(op.getResult().getType());
372 if (!operandType || !resultType)
375 rewriter.replaceOpWithNewOp<LLVM::IsFPClass>(
376 op, resultType, adaptor.getOperand(), llvm::fcNan);
381struct IsFiniteOpLowering
384 using ConvertOpToLLVMPattern<
385 math::IsFiniteOp,
true>::ConvertOpToLLVMPattern;
388 matchAndRewrite(math::IsFiniteOp op, OpAdaptor adaptor,
389 ConversionPatternRewriter &rewriter)
const override {
390 const auto &typeConverter = *this->getTypeConverter();
392 typeConverter.convertType(adaptor.getOperand().getType());
393 auto resultType = typeConverter.convertType(op.getResult().getType());
394 if (!operandType || !resultType)
397 rewriter.replaceOpWithNewOp<LLVM::IsFPClass>(
398 op, resultType, adaptor.getOperand(), llvm::fcFinite);
403struct ConvertMathToLLVMPass
407 void runOnOperation()
override {
412 if (
failed(applyPartialConversion(getOperation(),
target,
422 if (approximateLog1p)
423 patterns.add<Log1pOpLowering>(converter, benefit);
435 CountLeadingZerosOpLowering,
436 CountTrailingZerosOpLowering,
461 >(converter, benefit);
473 void loadDependentDialects(
MLIRContext *context)
const final {
474 context->loadDialect<LLVM::LLVMDialect>();
479 void populateConvertToLLVMConversionPatterns(
480 ConversionTarget &
target, LLVMTypeConverter &typeConverter,
481 RewritePatternSet &
patterns)
const final {
489 dialect->addInterfaces<MathToLLVMDialectInterface>();
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
typename math::SincosOp::Adaptor OpAdaptor
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
const LLVMTypeConverter * getTypeConverter() const
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
MLIRContext is the top-level object for a collection of MLIR operations.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Basic lowering implementation to rewrite Ops with just one result to the LLVM Dialect.
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect.
Include the generated interface declarations.
void populateMathToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool approximateLog1p=true, PatternBenefit benefit=1)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void registerConvertMathToLLVMInterface(DialectRegistry ®istry)
LogicalResult matchAndRewrite(math::SincosOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override