21#include "llvm/ADT/FloatingPointMode.h"
24#define GEN_PASS_DEF_CONVERTMATHTOLLVMPASS
25#include "mlir/Conversion/Passes.h.inc"
32template <
typename SourceOp,
typename TargetOp>
35template <
typename SourceOp,
typename TargetOp,
bool FailOnUnsupportedFP = true>
36using ConvertFMFMathToLLVMPattern =
41 ConvertFMFMathToLLVMPattern<math::AbsFOp, LLVM::FAbsOp,
43using CeilOpLowering = ConvertFMFMathToLLVMPattern<math::CeilOp, LLVM::FCeilOp>;
44using CopySignOpLowering =
45 ConvertFMFMathToLLVMPattern<math::CopySignOp, LLVM::CopySignOp>;
46using CosOpLowering = ConvertFMFMathToLLVMPattern<math::CosOp, LLVM::CosOp>;
47using CoshOpLowering = ConvertFMFMathToLLVMPattern<math::CoshOp, LLVM::CoshOp>;
48using AcosOpLowering = ConvertFMFMathToLLVMPattern<math::AcosOp, LLVM::ACosOp>;
49using CtPopFOpLowering =
53using Exp2OpLowering = ConvertFMFMathToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
54using ExpOpLowering = ConvertFMFMathToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
55using FloorOpLowering =
56 ConvertFMFMathToLLVMPattern<math::FloorOp, LLVM::FFloorOp>;
57using FmaOpLowering = ConvertFMFMathToLLVMPattern<math::FmaOp, LLVM::FMAOp,
59using Log10OpLowering =
60 ConvertFMFMathToLLVMPattern<math::Log10Op, LLVM::Log10Op>;
61using Log2OpLowering = ConvertFMFMathToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
62using LogOpLowering = ConvertFMFMathToLLVMPattern<math::LogOp, LLVM::LogOp>;
63using PowFOpLowering = ConvertFMFMathToLLVMPattern<math::PowFOp, LLVM::PowOp>;
64using FPowIOpLowering =
65 ConvertFMFMathToLLVMPattern<math::FPowIOp, LLVM::PowIOp>;
66using RoundEvenOpLowering =
67 ConvertFMFMathToLLVMPattern<math::RoundEvenOp, LLVM::RoundEvenOp>;
68using RoundOpLowering =
69 ConvertFMFMathToLLVMPattern<math::RoundOp, LLVM::RoundOp>;
70using SinOpLowering = ConvertFMFMathToLLVMPattern<math::SinOp, LLVM::SinOp>;
71using SinhOpLowering = ConvertFMFMathToLLVMPattern<math::SinhOp, LLVM::SinhOp>;
72using ASinOpLowering = ConvertFMFMathToLLVMPattern<math::AsinOp, LLVM::ASinOp>;
73using SqrtOpLowering = ConvertFMFMathToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
74using FTruncOpLowering =
75 ConvertFMFMathToLLVMPattern<math::TruncOp, LLVM::FTruncOp>;
76using TanOpLowering = ConvertFMFMathToLLVMPattern<math::TanOp, LLVM::TanOp>;
77using TanhOpLowering = ConvertFMFMathToLLVMPattern<math::TanhOp, LLVM::TanhOp>;
78using ATanOpLowering = ConvertFMFMathToLLVMPattern<math::AtanOp, LLVM::ATanOp>;
79using ATan2OpLowering =
80 ConvertFMFMathToLLVMPattern<math::Atan2Op, LLVM::ATan2Op>;
84template <
typename MathOp,
typename LLVMOp>
85struct IntOpWithFlagLowering
87 using ConvertOpToLLVMPattern<
88 MathOp,
true>::ConvertOpToLLVMPattern;
89 using Super = IntOpWithFlagLowering<MathOp, LLVMOp>;
92 matchAndRewrite(MathOp op,
typename MathOp::Adaptor adaptor,
93 ConversionPatternRewriter &rewriter)
const override {
94 const auto &typeConverter = *this->getTypeConverter();
95 auto operandType = adaptor.getOperand().getType();
96 auto llvmOperandType = typeConverter.convertType(operandType);
100 auto loc = op.getLoc();
101 auto resultType = op.getResult().getType();
102 auto llvmResultType = typeConverter.convertType(resultType);
106 if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
107 rewriter.replaceOpWithNewOp<LLVMOp>(op, llvmResultType,
108 adaptor.getOperand(),
false);
112 if (!isa<VectorType>(llvmResultType))
116 op.getOperation(), adaptor.getOperands(), typeConverter,
117 [&](Type llvm1DVectorTy,
ValueRange operands) {
118 return LLVMOp::create(rewriter, loc, llvm1DVectorTy, operands[0],
125using CountLeadingZerosOpLowering =
126 IntOpWithFlagLowering<math::CountLeadingZerosOp, LLVM::CountLeadingZerosOp>;
127using CountTrailingZerosOpLowering =
128 IntOpWithFlagLowering<math::CountTrailingZerosOp,
129 LLVM::CountTrailingZerosOp>;
130using AbsIOpLowering = IntOpWithFlagLowering<math::AbsIOp, LLVM::AbsOp>;
133struct SincosOpLowering
137 math::SincosOp,
true>::ConvertOpToLLVMPattern;
141 ConversionPatternRewriter &rewriter)
const override {
143 mlir::Location loc = op.getLoc();
144 mlir::Type operandType = adaptor.getOperand().getType();
145 mlir::Type llvmOperandType = typeConverter.convertType(operandType);
146 mlir::Type sinType = typeConverter.convertType(op.getSin().getType());
147 mlir::Type cosType = typeConverter.convertType(op.getCos().getType());
148 if (!llvmOperandType || !sinType || !cosType)
151 ConvertFastMath<math::SincosOp, LLVM::SincosOp> attrs(op);
153 auto structType = LLVM::LLVMStructType::getLiteral(
154 rewriter.getContext(), {llvmOperandType, llvmOperandType});
156 auto sincosOp = LLVM::SincosOp::create(
157 rewriter, loc, structType, adaptor.getOperand(), attrs.getAttrs());
159 auto sinValue = LLVM::ExtractValueOp::create(rewriter, loc, sincosOp, 0);
160 auto cosValue = LLVM::ExtractValueOp::create(rewriter, loc, sincosOp, 1);
162 rewriter.replaceOp(op, {sinValue, cosValue});
168struct ExpM1OpLowering
171 using ConvertOpToLLVMPattern<
172 math::ExpM1Op,
true>::ConvertOpToLLVMPattern;
175 matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor,
176 ConversionPatternRewriter &rewriter)
const override {
177 const auto &typeConverter = *this->getTypeConverter();
178 auto operandType = adaptor.getOperand().getType();
179 auto llvmOperandType = typeConverter.convertType(operandType);
180 if (!llvmOperandType)
183 auto loc = op.getLoc();
184 auto resultType = op.getResult().getType();
185 auto floatType = cast<FloatType>(
187 auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
188 ConvertFastMath<math::ExpM1Op, LLVM::ExpOp> expAttrs(op);
189 ConvertFastMath<math::ExpM1Op, LLVM::FSubOp> subAttrs(op);
191 if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
192 LLVM::ConstantOp one;
194 one = LLVM::ConstantOp::create(
195 rewriter, loc, llvmOperandType,
196 SplatElementsAttr::get(cast<ShapedType>(llvmOperandType),
200 LLVM::ConstantOp::create(rewriter, loc, llvmOperandType, floatOne);
202 auto exp = LLVM::ExpOp::create(rewriter, loc, adaptor.getOperand(),
203 expAttrs.getAttrs());
204 rewriter.replaceOpWithNewOp<LLVM::FSubOp>(
205 op, llvmOperandType,
ValueRange{exp, one}, subAttrs.getAttrs());
209 if (!isa<VectorType>(resultType))
210 return rewriter.notifyMatchFailure(op,
"expected vector result type");
213 op.getOperation(), adaptor.getOperands(), typeConverter,
214 [&](Type llvm1DVectorTy,
ValueRange operands) {
215 auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
216 auto splatAttr = SplatElementsAttr::get(
217 mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
218 {numElements.isScalable()}),
220 auto one = LLVM::ConstantOp::create(rewriter, loc, llvm1DVectorTy,
222 auto exp = LLVM::ExpOp::create(rewriter, loc, llvm1DVectorTy,
223 operands[0], expAttrs.getAttrs());
224 return LLVM::FSubOp::create(rewriter, loc, llvm1DVectorTy,
226 subAttrs.getAttrs());
233struct Log1pOpLowering
236 using ConvertOpToLLVMPattern<
237 math::Log1pOp,
true>::ConvertOpToLLVMPattern;
240 matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor,
241 ConversionPatternRewriter &rewriter)
const override {
242 const auto &typeConverter = *this->getTypeConverter();
243 auto operandType = adaptor.getOperand().getType();
244 auto llvmOperandType = typeConverter.convertType(operandType);
245 if (!llvmOperandType)
246 return rewriter.notifyMatchFailure(op,
"unsupported operand type");
248 auto loc = op.getLoc();
249 auto resultType = op.getResult().getType();
250 auto floatType = cast<FloatType>(
252 auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
253 ConvertFastMath<math::Log1pOp, LLVM::FAddOp> addAttrs(op);
254 ConvertFastMath<math::Log1pOp, LLVM::LogOp> logAttrs(op);
256 if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
257 LLVM::ConstantOp one =
258 isa<VectorType>(llvmOperandType)
259 ? LLVM::ConstantOp::create(
260 rewriter, loc, llvmOperandType,
261 SplatElementsAttr::get(cast<ShapedType>(llvmOperandType),
263 : LLVM::ConstantOp::create(rewriter, loc, llvmOperandType,
266 auto add = LLVM::FAddOp::create(rewriter, loc, llvmOperandType,
268 addAttrs.getAttrs());
269 rewriter.replaceOpWithNewOp<LLVM::LogOp>(
274 if (!isa<VectorType>(resultType))
275 return rewriter.notifyMatchFailure(op,
"expected vector result type");
278 op.getOperation(), adaptor.getOperands(), typeConverter,
279 [&](Type llvm1DVectorTy,
ValueRange operands) {
280 auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
281 auto splatAttr = SplatElementsAttr::get(
282 mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
283 {numElements.isScalable()}),
285 auto one = LLVM::ConstantOp::create(rewriter, loc, llvm1DVectorTy,
287 auto add = LLVM::FAddOp::create(rewriter, loc, llvm1DVectorTy,
289 addAttrs.getAttrs());
290 return LLVM::LogOp::create(rewriter, loc, llvm1DVectorTy,
298struct RsqrtOpLowering
301 using ConvertOpToLLVMPattern<
302 math::RsqrtOp,
true>::ConvertOpToLLVMPattern;
305 matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor,
306 ConversionPatternRewriter &rewriter)
const override {
307 const auto &typeConverter = *this->getTypeConverter();
308 auto operandType = adaptor.getOperand().getType();
309 auto llvmOperandType = typeConverter.convertType(operandType);
310 if (!llvmOperandType)
313 auto loc = op.getLoc();
314 auto resultType = op.getResult().getType();
315 auto floatType = cast<FloatType>(
317 auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
318 ConvertFastMath<math::RsqrtOp, LLVM::SqrtOp> sqrtAttrs(op);
319 ConvertFastMath<math::RsqrtOp, LLVM::FDivOp> divAttrs(op);
321 if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
322 LLVM::ConstantOp one;
323 if (isa<VectorType>(llvmOperandType)) {
324 one = LLVM::ConstantOp::create(
325 rewriter, loc, llvmOperandType,
326 SplatElementsAttr::get(cast<ShapedType>(llvmOperandType),
330 LLVM::ConstantOp::create(rewriter, loc, llvmOperandType, floatOne);
332 auto sqrt = LLVM::SqrtOp::create(rewriter, loc, adaptor.getOperand(),
333 sqrtAttrs.getAttrs());
334 rewriter.replaceOpWithNewOp<LLVM::FDivOp>(
335 op, llvmOperandType,
ValueRange{one, sqrt}, divAttrs.getAttrs());
339 if (!isa<VectorType>(resultType))
343 op.getOperation(), adaptor.getOperands(), typeConverter,
344 [&](Type llvm1DVectorTy,
ValueRange operands) {
345 auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
346 auto splatAttr = SplatElementsAttr::get(
347 mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
348 {numElements.isScalable()}),
350 auto one = LLVM::ConstantOp::create(rewriter, loc, llvm1DVectorTy,
352 auto sqrt = LLVM::SqrtOp::create(rewriter, loc, llvm1DVectorTy,
353 operands[0], sqrtAttrs.getAttrs());
354 return LLVM::FDivOp::create(rewriter, loc, llvm1DVectorTy,
356 divAttrs.getAttrs());
362struct IsNaNOpLowering
365 using ConvertOpToLLVMPattern<
366 math::IsNaNOp,
true>::ConvertOpToLLVMPattern;
369 matchAndRewrite(math::IsNaNOp op, OpAdaptor adaptor,
370 ConversionPatternRewriter &rewriter)
const override {
371 const auto &typeConverter = *this->getTypeConverter();
373 typeConverter.convertType(adaptor.getOperand().getType());
374 auto resultType = typeConverter.convertType(op.getResult().getType());
375 if (!operandType || !resultType)
378 rewriter.replaceOpWithNewOp<LLVM::IsFPClass>(
379 op, resultType, adaptor.getOperand(), llvm::fcNan);
384struct IsFiniteOpLowering
387 using ConvertOpToLLVMPattern<
388 math::IsFiniteOp,
true>::ConvertOpToLLVMPattern;
391 matchAndRewrite(math::IsFiniteOp op, OpAdaptor adaptor,
392 ConversionPatternRewriter &rewriter)
const override {
393 const auto &typeConverter = *this->getTypeConverter();
395 typeConverter.convertType(adaptor.getOperand().getType());
396 auto resultType = typeConverter.convertType(op.getResult().getType());
397 if (!operandType || !resultType)
400 rewriter.replaceOpWithNewOp<LLVM::IsFPClass>(
401 op, resultType, adaptor.getOperand(), llvm::fcFinite);
406struct ConvertMathToLLVMPass
410 void runOnOperation()
override {
415 if (
failed(applyPartialConversion(getOperation(),
target,
416 std::move(patterns))))
425 if (approximateLog1p)
426 patterns.
add<Log1pOpLowering>(converter, benefit);
438 CountLeadingZerosOpLowering,
439 CountTrailingZerosOpLowering,
464 >(converter, benefit);
476 void loadDependentDialects(
MLIRContext *context)
const final {
477 context->loadDialect<LLVM::LLVMDialect>();
482 void populateConvertToLLVMConversionPatterns(
483 ConversionTarget &
target, LLVMTypeConverter &typeConverter,
484 RewritePatternSet &patterns)
const final {
492 dialect->addInterfaces<MathToLLVMDialectInterface>();
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
const LLVMTypeConverter * getTypeConverter() const
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
MLIRContext is the top-level object for a collection of MLIR operations.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Basic lowering implementation to rewrite Ops with just one result to the LLVM Dialect.
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect.
Include the generated interface declarations.
void populateMathToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool approximateLog1p=true, PatternBenefit benefit=1)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void registerConvertMathToLLVMInterface(DialectRegistry ®istry)
LogicalResult matchAndRewrite(math::SincosOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override