21#include "llvm/ADT/FloatingPointMode.h"
24#define GEN_PASS_DEF_CONVERTMATHTOLLVMPASS
25#include "mlir/Conversion/Passes.h.inc"
32template <
typename SourceOp,
typename TargetOp>
35template <
typename SourceOp,
typename TargetOp,
bool FailOnUnsupportedFP = true>
36using ConvertFMFMathToLLVMPattern =
41 ConvertFMFMathToLLVMPattern<math::AbsFOp, LLVM::FAbsOp,
43using CeilOpLowering = ConvertFMFMathToLLVMPattern<math::CeilOp, LLVM::FCeilOp>;
44using CopySignOpLowering =
45 ConvertFMFMathToLLVMPattern<math::CopySignOp, LLVM::CopySignOp>;
46using CosOpLowering = ConvertFMFMathToLLVMPattern<math::CosOp, LLVM::CosOp>;
47using CoshOpLowering = ConvertFMFMathToLLVMPattern<math::CoshOp, LLVM::CoshOp>;
48using AcosOpLowering = ConvertFMFMathToLLVMPattern<math::AcosOp, LLVM::ACosOp>;
49using CtPopFOpLowering =
53using Exp2OpLowering = ConvertFMFMathToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
54using ExpOpLowering = ConvertFMFMathToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
55using FloorOpLowering =
56 ConvertFMFMathToLLVMPattern<math::FloorOp, LLVM::FFloorOp>;
57using FmaOpLowering = ConvertFMFMathToLLVMPattern<math::FmaOp, LLVM::FMAOp,
59using Log10OpLowering =
60 ConvertFMFMathToLLVMPattern<math::Log10Op, LLVM::Log10Op>;
61using Log2OpLowering = ConvertFMFMathToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
62using LogOpLowering = ConvertFMFMathToLLVMPattern<math::LogOp, LLVM::LogOp>;
63using PowFOpLowering = ConvertFMFMathToLLVMPattern<math::PowFOp, LLVM::PowOp>;
64using FPowIOpLowering =
65 ConvertFMFMathToLLVMPattern<math::FPowIOp, LLVM::PowIOp>;
66using RoundEvenOpLowering =
67 ConvertFMFMathToLLVMPattern<math::RoundEvenOp, LLVM::RoundEvenOp>;
68using RoundOpLowering =
69 ConvertFMFMathToLLVMPattern<math::RoundOp, LLVM::RoundOp>;
70using SinOpLowering = ConvertFMFMathToLLVMPattern<math::SinOp, LLVM::SinOp>;
71using SinhOpLowering = ConvertFMFMathToLLVMPattern<math::SinhOp, LLVM::SinhOp>;
72using ASinOpLowering = ConvertFMFMathToLLVMPattern<math::AsinOp, LLVM::ASinOp>;
73using SqrtOpLowering = ConvertFMFMathToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
74using FTruncOpLowering =
75 ConvertFMFMathToLLVMPattern<math::TruncOp, LLVM::FTruncOp>;
76using TanOpLowering = ConvertFMFMathToLLVMPattern<math::TanOp, LLVM::TanOp>;
77using TanhOpLowering = ConvertFMFMathToLLVMPattern<math::TanhOp, LLVM::TanhOp>;
78using ATanOpLowering = ConvertFMFMathToLLVMPattern<math::AtanOp, LLVM::ATanOp>;
79using ATan2OpLowering =
80 ConvertFMFMathToLLVMPattern<math::Atan2Op, LLVM::ATan2Op>;
84template <
typename MathOp,
typename LLVMOp>
85struct IntOpWithFlagLowering
87 using ConvertOpToLLVMPattern<
88 MathOp,
true>::ConvertOpToLLVMPattern;
89 using Super = IntOpWithFlagLowering<MathOp, LLVMOp>;
92 matchAndRewrite(MathOp op,
typename MathOp::Adaptor adaptor,
93 ConversionPatternRewriter &rewriter)
const override {
94 const auto &typeConverter = *this->getTypeConverter();
95 auto operandType = adaptor.getOperand().getType();
96 auto llvmOperandType = typeConverter.convertType(operandType);
100 auto loc = op.getLoc();
101 auto resultType = op.getResult().getType();
102 auto llvmResultType = typeConverter.convertType(resultType);
106 if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
107 rewriter.replaceOpWithNewOp<LLVMOp>(op, llvmResultType,
108 adaptor.getOperand(),
false);
112 if (!isa<VectorType>(llvmResultType))
116 op.getOperation(), adaptor.getOperands(), typeConverter,
117 [&](Type llvm1DVectorTy,
ValueRange operands) {
118 return LLVMOp::create(rewriter, loc, llvm1DVectorTy, operands[0],
125using CountLeadingZerosOpLowering =
126 IntOpWithFlagLowering<math::CountLeadingZerosOp, LLVM::CountLeadingZerosOp>;
127using CountTrailingZerosOpLowering =
128 IntOpWithFlagLowering<math::CountTrailingZerosOp,
129 LLVM::CountTrailingZerosOp>;
130using AbsIOpLowering = IntOpWithFlagLowering<math::AbsIOp, LLVM::AbsOp>;
141 ConversionPatternRewriter &rewriter)
const override {
143 mlir::Location loc = op.getLoc();
144 mlir::Type operandType = adaptor.getOperand().getType();
145 mlir::Type llvmOperandType = typeConverter.convertType(operandType);
146 mlir::Type sinType = typeConverter.convertType(op.getSin().getType());
147 mlir::Type cosType = typeConverter.convertType(op.getCos().getType());
148 if (!llvmOperandType || !sinType || !cosType)
151 ConvertFastMath<math::SincosOp, LLVM::SincosOp> attrs(op);
153 auto structType = LLVM::LLVMStructType::getLiteral(
154 rewriter.getContext(), {llvmOperandType, llvmOperandType});
156 auto sincosOp = LLVM::SincosOp::create(
157 rewriter, loc, structType, adaptor.getOperand(), attrs.getAttrs());
159 auto sinValue = LLVM::ExtractValueOp::create(rewriter, loc, sincosOp, 0);
160 auto cosValue = LLVM::ExtractValueOp::create(rewriter, loc, sincosOp, 1);
162 rewriter.replaceOp(op, {sinValue, cosValue});
168struct ExpM1OpLowering
171 using ConvertOpToLLVMPattern<
172 math::ExpM1Op,
true>::ConvertOpToLLVMPattern;
175 matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor,
176 ConversionPatternRewriter &rewriter)
const override {
177 const auto &typeConverter = *this->getTypeConverter();
178 auto operandType = adaptor.getOperand().getType();
179 auto llvmOperandType = typeConverter.convertType(operandType);
180 if (!llvmOperandType)
183 auto loc = op.getLoc();
184 auto resultType = op.getResult().getType();
185 auto floatType = cast<FloatType>(
187 auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
188 ConvertFastMath<math::ExpM1Op, LLVM::ExpOp> expAttrs(op);
189 ConvertFastMath<math::ExpM1Op, LLVM::FSubOp> subAttrs(op);
191 if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
192 LLVM::ConstantOp one;
193 if (LLVM::isCompatibleVectorType(llvmOperandType)) {
194 one = LLVM::ConstantOp::create(
195 rewriter, loc, llvmOperandType,
196 SplatElementsAttr::get(cast<ShapedType>(llvmOperandType),
200 LLVM::ConstantOp::create(rewriter, loc, llvmOperandType, floatOne);
202 auto exp = LLVM::ExpOp::create(rewriter, loc, adaptor.getOperand(),
203 expAttrs.getAttrs());
204 rewriter.replaceOpWithNewOp<LLVM::FSubOp>(
205 op, llvmOperandType,
ValueRange{exp, one}, subAttrs.getAttrs());
209 if (!isa<VectorType>(resultType))
210 return rewriter.notifyMatchFailure(op,
"expected vector result type");
213 op.getOperation(), adaptor.getOperands(), typeConverter,
214 [&](Type llvm1DVectorTy,
ValueRange operands) {
215 auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
216 auto splatAttr = SplatElementsAttr::get(
217 mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
218 {numElements.isScalable()}),
220 auto one = LLVM::ConstantOp::create(rewriter, loc, llvm1DVectorTy,
222 auto exp = LLVM::ExpOp::create(rewriter, loc, llvm1DVectorTy,
223 operands[0], expAttrs.getAttrs());
224 return LLVM::FSubOp::create(rewriter, loc, llvm1DVectorTy,
226 subAttrs.getAttrs());
233struct Log1pOpLowering
236 using ConvertOpToLLVMPattern<
237 math::Log1pOp,
true>::ConvertOpToLLVMPattern;
240 matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor,
241 ConversionPatternRewriter &rewriter)
const override {
242 const auto &typeConverter = *this->getTypeConverter();
243 auto operandType = adaptor.getOperand().getType();
244 auto llvmOperandType = typeConverter.convertType(operandType);
245 if (!llvmOperandType)
246 return rewriter.notifyMatchFailure(op,
"unsupported operand type");
248 auto loc = op.getLoc();
249 auto resultType = op.getResult().getType();
250 auto floatType = cast<FloatType>(
252 auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
253 ConvertFastMath<math::Log1pOp, LLVM::FAddOp> addAttrs(op);
254 ConvertFastMath<math::Log1pOp, LLVM::LogOp> logAttrs(op);
256 if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
257 LLVM::ConstantOp one =
258 isa<VectorType>(llvmOperandType)
259 ? LLVM::ConstantOp::create(
260 rewriter, loc, llvmOperandType,
261 SplatElementsAttr::get(cast<ShapedType>(llvmOperandType),
263 : LLVM::ConstantOp::create(rewriter, loc, llvmOperandType,
266 auto add = LLVM::FAddOp::create(rewriter, loc, llvmOperandType,
268 addAttrs.getAttrs());
269 rewriter.replaceOpWithNewOp<LLVM::LogOp>(
274 if (!isa<VectorType>(resultType))
275 return rewriter.notifyMatchFailure(op,
"expected vector result type");
278 op.getOperation(), adaptor.getOperands(), typeConverter,
279 [&](Type llvm1DVectorTy,
ValueRange operands) {
280 auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
281 auto splatAttr = SplatElementsAttr::get(
282 mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
283 {numElements.isScalable()}),
285 auto one = LLVM::ConstantOp::create(rewriter, loc, llvm1DVectorTy,
287 auto add = LLVM::FAddOp::create(rewriter, loc, llvm1DVectorTy,
289 addAttrs.getAttrs());
290 return LLVM::LogOp::create(rewriter, loc, llvm1DVectorTy,
298struct RsqrtOpLowering
301 using ConvertOpToLLVMPattern<
302 math::RsqrtOp,
true>::ConvertOpToLLVMPattern;
305 matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor,
306 ConversionPatternRewriter &rewriter)
const override {
307 const auto &typeConverter = *this->getTypeConverter();
308 auto operandType = adaptor.getOperand().getType();
309 auto llvmOperandType = typeConverter.convertType(operandType);
310 if (!llvmOperandType)
313 auto loc = op.getLoc();
314 auto resultType = op.getResult().getType();
315 auto floatType = cast<FloatType>(
317 auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
318 ConvertFastMath<math::RsqrtOp, LLVM::SqrtOp> sqrtAttrs(op);
319 ConvertFastMath<math::RsqrtOp, LLVM::FDivOp> divAttrs(op);
321 if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
322 LLVM::ConstantOp one;
323 if (isa<VectorType>(llvmOperandType)) {
324 one = LLVM::ConstantOp::create(
325 rewriter, loc, llvmOperandType,
326 SplatElementsAttr::get(cast<ShapedType>(llvmOperandType),
330 LLVM::ConstantOp::create(rewriter, loc, llvmOperandType, floatOne);
332 auto sqrt = LLVM::SqrtOp::create(rewriter, loc, adaptor.getOperand(),
333 sqrtAttrs.getAttrs());
334 rewriter.replaceOpWithNewOp<LLVM::FDivOp>(
335 op, llvmOperandType,
ValueRange{one, sqrt}, divAttrs.getAttrs());
339 if (!isa<VectorType>(resultType))
343 op.getOperation(), adaptor.getOperands(), typeConverter,
344 [&](Type llvm1DVectorTy,
ValueRange operands) {
345 auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
346 auto splatAttr = SplatElementsAttr::get(
347 mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
348 {numElements.isScalable()}),
350 auto one = LLVM::ConstantOp::create(rewriter, loc, llvm1DVectorTy,
352 auto sqrt = LLVM::SqrtOp::create(rewriter, loc, llvm1DVectorTy,
353 operands[0], sqrtAttrs.getAttrs());
354 return LLVM::FDivOp::create(rewriter, loc, llvm1DVectorTy,
356 divAttrs.getAttrs());
362struct IsNaNOpLowering
365 using ConvertOpToLLVMPattern<
366 math::IsNaNOp,
true>::ConvertOpToLLVMPattern;
369 matchAndRewrite(math::IsNaNOp op, OpAdaptor adaptor,
370 ConversionPatternRewriter &rewriter)
const override {
371 const auto &typeConverter = *this->getTypeConverter();
373 typeConverter.convertType(adaptor.getOperand().getType());
374 auto resultType = typeConverter.convertType(op.getResult().getType());
375 if (!operandType || !resultType)
378 rewriter.replaceOpWithNewOp<LLVM::IsFPClass>(
379 op, resultType, adaptor.getOperand(), llvm::fcNan);
384struct IsFiniteOpLowering
387 using ConvertOpToLLVMPattern<
388 math::IsFiniteOp,
true>::ConvertOpToLLVMPattern;
391 matchAndRewrite(math::IsFiniteOp op, OpAdaptor adaptor,
392 ConversionPatternRewriter &rewriter)
const override {
393 const auto &typeConverter = *this->getTypeConverter();
395 typeConverter.convertType(adaptor.getOperand().getType());
396 auto resultType = typeConverter.convertType(op.getResult().getType());
397 if (!operandType || !resultType)
400 rewriter.replaceOpWithNewOp<LLVM::IsFPClass>(
401 op, resultType, adaptor.getOperand(), llvm::fcFinite);
406struct ConvertMathToLLVMPass
407 :
public impl::ConvertMathToLLVMPassBase<ConvertMathToLLVMPass> {
410 void runOnOperation()
override {
415 if (
failed(applyPartialConversion(getOperation(),
target,
425 if (approximateLog1p)
426 patterns.add<Log1pOpLowering>(converter, benefit);
438 CountLeadingZerosOpLowering,
439 CountTrailingZerosOpLowering,
464 >(converter, benefit);
476 void loadDependentDialects(
MLIRContext *context)
const final {
477 context->loadDialect<LLVM::LLVMDialect>();
482 void populateConvertToLLVMConversionPatterns(
483 ConversionTarget &
target, LLVMTypeConverter &typeConverter,
484 RewritePatternSet &
patterns)
const final {
492 dialect->addInterfaces<MathToLLVMDialectInterface>();
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
typename math::SincosOp::Adaptor OpAdaptor
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
const LLVMTypeConverter * getTypeConverter() const
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
MLIRContext is the top-level object for a collection of MLIR operations.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Basic lowering implementation to rewrite Ops with just one result to the LLVM Dialect.
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
Include the generated interface declarations.
void populateMathToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool approximateLog1p=true, PatternBenefit benefit=1)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void registerConvertMathToLLVMInterface(DialectRegistry ®istry)
LogicalResult matchAndRewrite(math::SincosOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override