21#include "llvm/ADT/FloatingPointMode.h"
24#define GEN_PASS_DEF_CONVERTMATHTOLLVMPASS
25#include "mlir/Conversion/Passes.h.inc"
32template <
typename SourceOp,
typename TargetOp>
35template <
typename SourceOp,
typename TargetOp,
bool FailOnUnsupportedFP = true>
36using ConvertFMFMathToLLVMPattern =
46template <
typename SourceOp,
typename TargetOp,
bool HasRoundingMode,
47 template <
typename,
typename>
typename AttrConvert =
49 bool FailOnUnsupportedFP =
true>
50struct ConstrainedVectorConvertToLLVMPattern
52 FailOnUnsupportedFP> {
53 using VectorConvertToLLVMPattern<
54 SourceOp, TargetOp, AttrConvert,
55 FailOnUnsupportedFP>::VectorConvertToLLVMPattern;
58 matchAndRewrite(SourceOp op,
typename SourceOp::Adaptor adaptor,
59 ConversionPatternRewriter &rewriter)
const override {
60 if (HasRoundingMode !=
static_cast<bool>(op.getRoundingModeAttr()))
62 return VectorConvertToLLVMPattern<
63 SourceOp, TargetOp, AttrConvert,
64 FailOnUnsupportedFP>::matchAndRewrite(op, adaptor, rewriter);
69 ConvertFMFMathToLLVMPattern<math::AbsFOp, LLVM::FAbsOp,
71using CeilOpLowering = ConvertFMFMathToLLVMPattern<math::CeilOp, LLVM::FCeilOp>;
72using CopySignOpLowering =
73 ConvertFMFMathToLLVMPattern<math::CopySignOp, LLVM::CopySignOp>;
74using CosOpLowering = ConvertFMFMathToLLVMPattern<math::CosOp, LLVM::CosOp>;
75using CoshOpLowering = ConvertFMFMathToLLVMPattern<math::CoshOp, LLVM::CoshOp>;
76using AcosOpLowering = ConvertFMFMathToLLVMPattern<math::AcosOp, LLVM::ACosOp>;
77using CtPopFOpLowering =
81using Exp2OpLowering = ConvertFMFMathToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
82using ExpOpLowering = ConvertFMFMathToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
83using FloorOpLowering =
84 ConvertFMFMathToLLVMPattern<math::FloorOp, LLVM::FFloorOp>;
86 ConstrainedVectorConvertToLLVMPattern<math::FmaOp, LLVM::FMAOp,
90using ConstrainedFmaOpLowering = ConstrainedVectorConvertToLLVMPattern<
91 math::FmaOp, LLVM::ConstrainedFMAIntr,
true,
93using Log10OpLowering =
94 ConvertFMFMathToLLVMPattern<math::Log10Op, LLVM::Log10Op>;
95using Log2OpLowering = ConvertFMFMathToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
96using LogOpLowering = ConvertFMFMathToLLVMPattern<math::LogOp, LLVM::LogOp>;
97using PowFOpLowering = ConvertFMFMathToLLVMPattern<math::PowFOp, LLVM::PowOp>;
98using FPowIOpLowering =
99 ConvertFMFMathToLLVMPattern<math::FPowIOp, LLVM::PowIOp>;
100using RoundEvenOpLowering =
101 ConvertFMFMathToLLVMPattern<math::RoundEvenOp, LLVM::RoundEvenOp>;
102using RoundOpLowering =
103 ConvertFMFMathToLLVMPattern<math::RoundOp, LLVM::RoundOp>;
104using SinOpLowering = ConvertFMFMathToLLVMPattern<math::SinOp, LLVM::SinOp>;
105using SinhOpLowering = ConvertFMFMathToLLVMPattern<math::SinhOp, LLVM::SinhOp>;
106using ASinOpLowering = ConvertFMFMathToLLVMPattern<math::AsinOp, LLVM::ASinOp>;
107using SqrtOpLowering = ConvertFMFMathToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
108using FTruncOpLowering =
109 ConvertFMFMathToLLVMPattern<math::TruncOp, LLVM::FTruncOp>;
110using TanOpLowering = ConvertFMFMathToLLVMPattern<math::TanOp, LLVM::TanOp>;
111using TanhOpLowering = ConvertFMFMathToLLVMPattern<math::TanhOp, LLVM::TanhOp>;
112using ATanOpLowering = ConvertFMFMathToLLVMPattern<math::AtanOp, LLVM::ATanOp>;
113using ATan2OpLowering =
114 ConvertFMFMathToLLVMPattern<math::Atan2Op, LLVM::ATan2Op>;
118template <
typename MathOp,
typename LLVMOp>
119struct IntOpWithFlagLowering
121 using ConvertOpToLLVMPattern<
122 MathOp,
true>::ConvertOpToLLVMPattern;
123 using Super = IntOpWithFlagLowering<MathOp, LLVMOp>;
126 matchAndRewrite(MathOp op,
typename MathOp::Adaptor adaptor,
127 ConversionPatternRewriter &rewriter)
const override {
128 const auto &typeConverter = *this->getTypeConverter();
129 auto operandType = adaptor.getOperand().getType();
130 auto llvmOperandType = typeConverter.convertType(operandType);
131 if (!llvmOperandType)
134 auto loc = op.getLoc();
135 auto resultType = op.getResult().getType();
136 auto llvmResultType = typeConverter.convertType(resultType);
140 if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
141 rewriter.replaceOpWithNewOp<LLVMOp>(op, llvmResultType,
142 adaptor.getOperand(),
false);
146 if (!isa<VectorType>(resultType))
150 op.getOperation(), adaptor.getOperands(), typeConverter,
151 [&](Type llvm1DVectorTy,
ValueRange operands) {
152 return LLVMOp::create(rewriter, loc, llvm1DVectorTy, operands[0],
159using CountLeadingZerosOpLowering =
160 IntOpWithFlagLowering<math::CountLeadingZerosOp, LLVM::CountLeadingZerosOp>;
161using CountTrailingZerosOpLowering =
162 IntOpWithFlagLowering<math::CountTrailingZerosOp,
163 LLVM::CountTrailingZerosOp>;
164using AbsIOpLowering = IntOpWithFlagLowering<math::AbsIOp, LLVM::AbsOp>;
167struct SincosOpLowering
171 math::SincosOp,
true>::ConvertOpToLLVMPattern;
175 ConversionPatternRewriter &rewriter)
const override {
177 mlir::Location loc = op.getLoc();
178 mlir::Type operandType = adaptor.getOperand().getType();
179 mlir::Type llvmOperandType = typeConverter.convertType(operandType);
180 mlir::Type sinType = typeConverter.convertType(op.getSin().getType());
181 mlir::Type cosType = typeConverter.convertType(op.getCos().getType());
182 if (!llvmOperandType || !sinType || !cosType)
185 ConvertFastMath<math::SincosOp, LLVM::SincosOp> attrs(op);
187 auto structType = LLVM::LLVMStructType::getLiteral(
188 rewriter.getContext(), {llvmOperandType, llvmOperandType});
190 auto sincosOp = LLVM::SincosOp::create(
191 rewriter, loc, structType, adaptor.getOperand(), attrs.getAttrs());
193 auto sinValue = LLVM::ExtractValueOp::create(rewriter, loc, sincosOp, 0);
194 auto cosValue = LLVM::ExtractValueOp::create(rewriter, loc, sincosOp, 1);
196 rewriter.replaceOp(op, {sinValue, cosValue});
202struct ExpM1OpLowering
205 using ConvertOpToLLVMPattern<
206 math::ExpM1Op,
true>::ConvertOpToLLVMPattern;
209 matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor,
210 ConversionPatternRewriter &rewriter)
const override {
211 const auto &typeConverter = *this->getTypeConverter();
212 auto operandType = adaptor.getOperand().getType();
213 auto llvmOperandType = typeConverter.convertType(operandType);
214 if (!llvmOperandType)
217 auto loc = op.getLoc();
218 auto resultType = op.getResult().getType();
219 auto floatType = cast<FloatType>(
221 auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
222 ConvertFastMath<math::ExpM1Op, LLVM::ExpOp> expAttrs(op);
223 ConvertFastMath<math::ExpM1Op, LLVM::FSubOp> subAttrs(op);
225 if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
226 LLVM::ConstantOp one;
228 one = LLVM::ConstantOp::create(
229 rewriter, loc, llvmOperandType,
230 SplatElementsAttr::get(cast<ShapedType>(llvmOperandType),
234 LLVM::ConstantOp::create(rewriter, loc, llvmOperandType, floatOne);
236 auto exp = LLVM::ExpOp::create(rewriter, loc, adaptor.getOperand(),
237 expAttrs.getAttrs());
238 rewriter.replaceOpWithNewOp<LLVM::FSubOp>(
239 op, llvmOperandType,
ValueRange{exp, one}, subAttrs.getAttrs());
243 if (!isa<VectorType>(resultType))
244 return rewriter.notifyMatchFailure(op,
"expected vector result type");
247 op.getOperation(), adaptor.getOperands(), typeConverter,
248 [&](Type llvm1DVectorTy,
ValueRange operands) {
249 auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
250 auto splatAttr = SplatElementsAttr::get(
251 mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
252 {numElements.isScalable()}),
254 auto one = LLVM::ConstantOp::create(rewriter, loc, llvm1DVectorTy,
256 auto exp = LLVM::ExpOp::create(rewriter, loc, llvm1DVectorTy,
257 operands[0], expAttrs.getAttrs());
258 return LLVM::FSubOp::create(rewriter, loc, llvm1DVectorTy,
260 subAttrs.getAttrs());
267struct Log1pOpLowering
270 using ConvertOpToLLVMPattern<
271 math::Log1pOp,
true>::ConvertOpToLLVMPattern;
274 matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor,
275 ConversionPatternRewriter &rewriter)
const override {
276 const auto &typeConverter = *this->getTypeConverter();
277 auto operandType = adaptor.getOperand().getType();
278 auto llvmOperandType = typeConverter.convertType(operandType);
279 if (!llvmOperandType)
280 return rewriter.notifyMatchFailure(op,
"unsupported operand type");
282 auto loc = op.getLoc();
283 auto resultType = op.getResult().getType();
284 auto floatType = cast<FloatType>(
286 auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
287 ConvertFastMath<math::Log1pOp, LLVM::FAddOp> addAttrs(op);
288 ConvertFastMath<math::Log1pOp, LLVM::LogOp> logAttrs(op);
290 if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
291 LLVM::ConstantOp one =
292 isa<VectorType>(llvmOperandType)
293 ? LLVM::ConstantOp::create(
294 rewriter, loc, llvmOperandType,
295 SplatElementsAttr::get(cast<ShapedType>(llvmOperandType),
297 : LLVM::ConstantOp::create(rewriter, loc, llvmOperandType,
300 auto add = LLVM::FAddOp::create(rewriter, loc, llvmOperandType,
302 addAttrs.getAttrs());
303 rewriter.replaceOpWithNewOp<LLVM::LogOp>(
308 if (!isa<VectorType>(resultType))
309 return rewriter.notifyMatchFailure(op,
"expected vector result type");
312 op.getOperation(), adaptor.getOperands(), typeConverter,
313 [&](Type llvm1DVectorTy,
ValueRange operands) {
314 auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
315 auto splatAttr = SplatElementsAttr::get(
316 mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
317 {numElements.isScalable()}),
319 auto one = LLVM::ConstantOp::create(rewriter, loc, llvm1DVectorTy,
321 auto add = LLVM::FAddOp::create(rewriter, loc, llvm1DVectorTy,
323 addAttrs.getAttrs());
324 return LLVM::LogOp::create(rewriter, loc, llvm1DVectorTy,
332struct RsqrtOpLowering
335 using ConvertOpToLLVMPattern<
336 math::RsqrtOp,
true>::ConvertOpToLLVMPattern;
339 matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor,
340 ConversionPatternRewriter &rewriter)
const override {
341 const auto &typeConverter = *this->getTypeConverter();
342 auto operandType = adaptor.getOperand().getType();
343 auto llvmOperandType = typeConverter.convertType(operandType);
344 if (!llvmOperandType)
347 auto loc = op.getLoc();
348 auto resultType = op.getResult().getType();
349 auto floatType = cast<FloatType>(
351 auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
352 ConvertFastMath<math::RsqrtOp, LLVM::SqrtOp> sqrtAttrs(op);
353 ConvertFastMath<math::RsqrtOp, LLVM::FDivOp> divAttrs(op);
355 if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
356 LLVM::ConstantOp one;
357 if (isa<VectorType>(llvmOperandType)) {
358 one = LLVM::ConstantOp::create(
359 rewriter, loc, llvmOperandType,
360 SplatElementsAttr::get(cast<ShapedType>(llvmOperandType),
364 LLVM::ConstantOp::create(rewriter, loc, llvmOperandType, floatOne);
366 auto sqrt = LLVM::SqrtOp::create(rewriter, loc, adaptor.getOperand(),
367 sqrtAttrs.getAttrs());
368 rewriter.replaceOpWithNewOp<LLVM::FDivOp>(
369 op, llvmOperandType,
ValueRange{one, sqrt}, divAttrs.getAttrs());
373 if (!isa<VectorType>(resultType))
377 op.getOperation(), adaptor.getOperands(), typeConverter,
378 [&](Type llvm1DVectorTy,
ValueRange operands) {
379 auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
380 auto splatAttr = SplatElementsAttr::get(
381 mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
382 {numElements.isScalable()}),
384 auto one = LLVM::ConstantOp::create(rewriter, loc, llvm1DVectorTy,
386 auto sqrt = LLVM::SqrtOp::create(rewriter, loc, llvm1DVectorTy,
387 operands[0], sqrtAttrs.getAttrs());
388 return LLVM::FDivOp::create(rewriter, loc, llvm1DVectorTy,
390 divAttrs.getAttrs());
396struct IsNaNOpLowering
399 using ConvertOpToLLVMPattern<
400 math::IsNaNOp,
true>::ConvertOpToLLVMPattern;
403 matchAndRewrite(math::IsNaNOp op, OpAdaptor adaptor,
404 ConversionPatternRewriter &rewriter)
const override {
405 const auto &typeConverter = *this->getTypeConverter();
407 typeConverter.convertType(adaptor.getOperand().getType());
408 auto resultType = typeConverter.convertType(op.getResult().getType());
409 if (!operandType || !resultType)
412 rewriter.replaceOpWithNewOp<LLVM::IsFPClass>(
413 op, resultType, adaptor.getOperand(), llvm::fcNan);
418struct IsFiniteOpLowering
421 using ConvertOpToLLVMPattern<
422 math::IsFiniteOp,
true>::ConvertOpToLLVMPattern;
425 matchAndRewrite(math::IsFiniteOp op, OpAdaptor adaptor,
426 ConversionPatternRewriter &rewriter)
const override {
427 const auto &typeConverter = *this->getTypeConverter();
429 typeConverter.convertType(adaptor.getOperand().getType());
430 auto resultType = typeConverter.convertType(op.getResult().getType());
431 if (!operandType || !resultType)
434 rewriter.replaceOpWithNewOp<LLVM::IsFPClass>(
435 op, resultType, adaptor.getOperand(), llvm::fcFinite);
440struct ConvertMathToLLVMPass
444 void runOnOperation()
override {
449 if (
failed(applyPartialConversion(getOperation(),
target,
450 std::move(patterns))))
459 if (approximateLog1p)
460 patterns.
add<Log1pOpLowering>(converter, benefit);
472 CountLeadingZerosOpLowering,
473 CountTrailingZerosOpLowering,
481 ConstrainedFmaOpLowering,
499 >(converter, benefit);
509struct MathToLLVMDialectInterface :
public ConvertToLLVMPatternInterface {
510 MathToLLVMDialectInterface(
Dialect *dialect)
511 : ConvertToLLVMPatternInterface(dialect) {}
513 void loadDependentDialects(MLIRContext *context)
const final {
514 context->loadDialect<LLVM::LLVMDialect>();
519 void populateConvertToLLVMConversionPatterns(
520 ConversionTarget &
target, LLVMTypeConverter &typeConverter,
521 RewritePatternSet &patterns)
const final {
529 dialect->addInterfaces<MathToLLVMDialectInterface>();
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
const LLVMTypeConverter * getTypeConverter() const
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Conversion from types to the LLVM IR dialect.
MLIRContext is the top-level object for a collection of MLIR operations.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Basic lowering implementation to rewrite Ops with just one result to the LLVM Dialect.
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect.
Include the generated interface declarations.
void populateMathToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool approximateLog1p=true, PatternBenefit benefit=1)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void registerConvertMathToLLVMInterface(DialectRegistry ®istry)
LogicalResult matchAndRewrite(math::SincosOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override