21 #include "llvm/ADT/FloatingPointMode.h"
24 #define GEN_PASS_DEF_CONVERTMATHTOLLVMPASS
25 #include "mlir/Conversion/Passes.h.inc"
32 template <
typename SourceOp,
typename TargetOp>
33 using ConvertFastMath = arith::AttrConvertFastMathToLLVM<SourceOp, TargetOp>;
35 template <
typename SourceOp,
typename TargetOp>
36 using ConvertFMFMathToLLVMPattern =
39 using AbsFOpLowering = ConvertFMFMathToLLVMPattern<math::AbsFOp, LLVM::FAbsOp>;
40 using CeilOpLowering = ConvertFMFMathToLLVMPattern<math::CeilOp, LLVM::FCeilOp>;
41 using CopySignOpLowering =
42 ConvertFMFMathToLLVMPattern<math::CopySignOp, LLVM::CopySignOp>;
43 using CosOpLowering = ConvertFMFMathToLLVMPattern<math::CosOp, LLVM::CosOp>;
44 using CoshOpLowering = ConvertFMFMathToLLVMPattern<math::CoshOp, LLVM::CoshOp>;
45 using AcosOpLowering = ConvertFMFMathToLLVMPattern<math::AcosOp, LLVM::ACosOp>;
46 using CtPopFOpLowering =
48 using Exp2OpLowering = ConvertFMFMathToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
49 using ExpOpLowering = ConvertFMFMathToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
50 using FloorOpLowering =
51 ConvertFMFMathToLLVMPattern<math::FloorOp, LLVM::FFloorOp>;
52 using FmaOpLowering = ConvertFMFMathToLLVMPattern<math::FmaOp, LLVM::FMAOp>;
53 using Log10OpLowering =
54 ConvertFMFMathToLLVMPattern<math::Log10Op, LLVM::Log10Op>;
55 using Log2OpLowering = ConvertFMFMathToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
56 using LogOpLowering = ConvertFMFMathToLLVMPattern<math::LogOp, LLVM::LogOp>;
57 using PowFOpLowering = ConvertFMFMathToLLVMPattern<math::PowFOp, LLVM::PowOp>;
58 using FPowIOpLowering =
59 ConvertFMFMathToLLVMPattern<math::FPowIOp, LLVM::PowIOp>;
60 using RoundEvenOpLowering =
61 ConvertFMFMathToLLVMPattern<math::RoundEvenOp, LLVM::RoundEvenOp>;
62 using RoundOpLowering =
63 ConvertFMFMathToLLVMPattern<math::RoundOp, LLVM::RoundOp>;
64 using SinOpLowering = ConvertFMFMathToLLVMPattern<math::SinOp, LLVM::SinOp>;
65 using SinhOpLowering = ConvertFMFMathToLLVMPattern<math::SinhOp, LLVM::SinhOp>;
66 using ASinOpLowering = ConvertFMFMathToLLVMPattern<math::AsinOp, LLVM::ASinOp>;
67 using SqrtOpLowering = ConvertFMFMathToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
68 using FTruncOpLowering =
69 ConvertFMFMathToLLVMPattern<math::TruncOp, LLVM::FTruncOp>;
70 using TanOpLowering = ConvertFMFMathToLLVMPattern<math::TanOp, LLVM::TanOp>;
71 using TanhOpLowering = ConvertFMFMathToLLVMPattern<math::TanhOp, LLVM::TanhOp>;
72 using ATanOpLowering = ConvertFMFMathToLLVMPattern<math::AtanOp, LLVM::ATanOp>;
73 using ATan2OpLowering =
74 ConvertFMFMathToLLVMPattern<math::Atan2Op, LLVM::ATan2Op>;
78 template <
typename MathOp,
typename LLVMOp>
81 using Super = IntOpWithFlagLowering<MathOp, LLVMOp>;
84 matchAndRewrite(MathOp op,
typename MathOp::Adaptor adaptor,
86 const auto &typeConverter = *this->getTypeConverter();
87 auto operandType = adaptor.getOperand().getType();
88 auto llvmOperandType = typeConverter.convertType(operandType);
92 auto loc = op.getLoc();
93 auto resultType = op.getResult().getType();
94 auto llvmResultType = typeConverter.convertType(resultType);
98 if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
100 adaptor.getOperand(),
false);
104 if (!isa<VectorType>(llvmResultType))
108 op.getOperation(), adaptor.getOperands(), typeConverter,
110 return LLVMOp::create(rewriter, loc, llvm1DVectorTy, operands[0],
117 using CountLeadingZerosOpLowering =
118 IntOpWithFlagLowering<math::CountLeadingZerosOp, LLVM::CountLeadingZerosOp>;
119 using CountTrailingZerosOpLowering =
120 IntOpWithFlagLowering<math::CountTrailingZerosOp,
121 LLVM::CountTrailingZerosOp>;
122 using AbsIOpLowering = IntOpWithFlagLowering<math::AbsIOp, LLVM::AbsOp>;
129 matchAndRewrite(math::SincosOp op, OpAdaptor adaptor,
133 mlir::Type operandType = adaptor.getOperand().getType();
137 if (!llvmOperandType || !sinType || !cosType)
140 ConvertFastMath<math::SincosOp, LLVM::SincosOp> attrs(op);
142 auto structType = LLVM::LLVMStructType::getLiteral(
143 rewriter.
getContext(), {llvmOperandType, llvmOperandType});
145 auto sincosOp = rewriter.
create<LLVM::SincosOp>(
146 loc, structType, adaptor.getOperand(), attrs.getAttrs());
148 auto sinValue = LLVM::ExtractValueOp::create(rewriter, loc, sincosOp, 0);
149 auto cosValue = LLVM::ExtractValueOp::create(rewriter, loc, sincosOp, 1);
151 rewriter.
replaceOp(op, {sinValue, cosValue});
161 matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor,
163 const auto &typeConverter = *this->getTypeConverter();
164 auto operandType = adaptor.getOperand().getType();
165 auto llvmOperandType = typeConverter.
convertType(operandType);
166 if (!llvmOperandType)
169 auto loc = op.getLoc();
170 auto resultType = op.getResult().getType();
171 auto floatType = cast<FloatType>(
174 ConvertFastMath<math::ExpM1Op, LLVM::ExpOp> expAttrs(op);
175 ConvertFastMath<math::ExpM1Op, LLVM::FSubOp> subAttrs(op);
177 if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
178 LLVM::ConstantOp one;
180 one = LLVM::ConstantOp::create(
181 rewriter, loc, llvmOperandType,
186 LLVM::ConstantOp::create(rewriter, loc, llvmOperandType, floatOne);
188 auto exp = LLVM::ExpOp::create(rewriter, loc, adaptor.getOperand(),
189 expAttrs.getAttrs());
191 op, llvmOperandType,
ValueRange{exp, one}, subAttrs.getAttrs());
195 if (!isa<VectorType>(resultType))
199 op.getOperation(), adaptor.getOperands(), typeConverter,
201 auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
202 auto splatAttr = SplatElementsAttr::get(
203 mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
204 {numElements.isScalable()}),
206 auto one = LLVM::ConstantOp::create(rewriter, loc, llvm1DVectorTy,
208 auto exp = LLVM::ExpOp::create(rewriter, loc, llvm1DVectorTy,
209 operands[0], expAttrs.getAttrs());
210 return LLVM::FSubOp::create(rewriter, loc, llvm1DVectorTy,
212 subAttrs.getAttrs());
223 matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor,
225 const auto &typeConverter = *this->getTypeConverter();
226 auto operandType = adaptor.getOperand().getType();
227 auto llvmOperandType = typeConverter.
convertType(operandType);
228 if (!llvmOperandType)
231 auto loc = op.getLoc();
232 auto resultType = op.getResult().getType();
233 auto floatType = cast<FloatType>(
236 ConvertFastMath<math::Log1pOp, LLVM::FAddOp> addAttrs(op);
237 ConvertFastMath<math::Log1pOp, LLVM::LogOp> logAttrs(op);
239 if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
240 LLVM::ConstantOp one =
241 isa<VectorType>(llvmOperandType)
242 ? LLVM::ConstantOp::create(
243 rewriter, loc, llvmOperandType,
246 : LLVM::ConstantOp::create(rewriter, loc, llvmOperandType,
249 auto add = LLVM::FAddOp::create(rewriter, loc, llvmOperandType,
251 addAttrs.getAttrs());
257 if (!isa<VectorType>(resultType))
261 op.getOperation(), adaptor.getOperands(), typeConverter,
263 auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
264 auto splatAttr = SplatElementsAttr::get(
265 mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
266 {numElements.isScalable()}),
268 auto one = LLVM::ConstantOp::create(rewriter, loc, llvm1DVectorTy,
270 auto add = LLVM::FAddOp::create(rewriter, loc, llvm1DVectorTy,
272 addAttrs.getAttrs());
273 return LLVM::LogOp::create(rewriter, loc, llvm1DVectorTy,
288 auto operandType = adaptor.getOperand().getType();
289 auto llvmOperandType = typeConverter.
convertType(operandType);
290 if (!llvmOperandType)
293 auto loc = op.getLoc();
294 auto resultType = op.getResult().getType();
295 auto floatType = cast<FloatType>(
298 ConvertFastMath<math::RsqrtOp, LLVM::SqrtOp> sqrtAttrs(op);
299 ConvertFastMath<math::RsqrtOp, LLVM::FDivOp> divAttrs(op);
301 if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
302 LLVM::ConstantOp one;
303 if (isa<VectorType>(llvmOperandType)) {
304 one = LLVM::ConstantOp::create(
305 rewriter, loc, llvmOperandType,
310 LLVM::ConstantOp::create(rewriter, loc, llvmOperandType, floatOne);
312 auto sqrt = LLVM::SqrtOp::create(rewriter, loc, adaptor.getOperand(),
313 sqrtAttrs.getAttrs());
315 op, llvmOperandType,
ValueRange{one, sqrt}, divAttrs.getAttrs());
319 if (!isa<VectorType>(resultType))
323 op.getOperation(), adaptor.getOperands(), typeConverter,
325 auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
326 auto splatAttr = SplatElementsAttr::get(
327 mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
328 {numElements.isScalable()}),
330 auto one = LLVM::ConstantOp::create(rewriter, loc, llvm1DVectorTy,
332 auto sqrt = LLVM::SqrtOp::create(rewriter, loc, llvm1DVectorTy,
333 operands[0], sqrtAttrs.getAttrs());
334 return LLVM::FDivOp::create(rewriter, loc, llvm1DVectorTy,
336 divAttrs.getAttrs());
350 typeConverter.
convertType(adaptor.getOperand().getType());
351 auto resultType = typeConverter.
convertType(op.getResult().getType());
352 if (!operandType || !resultType)
356 op, resultType, adaptor.getOperand(), llvm::fcNan);
369 typeConverter.
convertType(adaptor.getOperand().getType());
370 auto resultType = typeConverter.
convertType(op.getResult().getType());
371 if (!operandType || !resultType)
375 op, resultType, adaptor.getOperand(), llvm::fcFinite);
380 struct ConvertMathToLLVMPass
381 :
public impl::ConvertMathToLLVMPassBase<ConvertMathToLLVMPass> {
384 void runOnOperation()
override {
399 if (approximateLog1p)
400 patterns.add<Log1pOpLowering>(converter, benefit);
412 CountLeadingZerosOpLowering,
413 CountTrailingZerosOpLowering,
438 >(converter, benefit);
451 context->loadDialect<LLVM::LLVMDialect>();
466 dialect->addInterfaces<MathToLLVMDialectInterface>();
static MLIRContext * getContext(OpFoldResult val)
FloatAttr getFloatAttr(Type type, double value)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Wrappers around the RewritePattern methods that pass the derived op type.
Base class for dialect interfaces providing translation to LLVM IR.
virtual void populateConvertToLLVMConversionPatterns(ConversionTarget &target, LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) const =0
Hook for derived dialect interface to provide conversion patterns and mark dialect legal for the conv...
virtual void loadDependentDialects(MLIRContext *context) const
Hook for derived dialect interface to load the dialects they target.
ConvertToLLVMPatternInterface(Dialect *dialect)
const LLVMTypeConverter * getTypeConverter() const
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
Basic lowering implementation to rewrite Ops with just one result to the LLVM Dialect.
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect.
Include the generated interface declarations.
void populateMathToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool approximateLog1p=true, PatternBenefit benefit=1)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
void registerConvertMathToLLVMInterface(DialectRegistry ®istry)