22 #define GEN_PASS_DEF_CONVERTMATHTOLLVMPASS
23 #include "mlir/Conversion/Passes.h.inc"
30 template <
typename SourceOp,
typename TargetOp>
31 using ConvertFastMath = arith::AttrConvertFastMathToLLVM<SourceOp, TargetOp>;
33 template <
typename SourceOp,
typename TargetOp>
34 using ConvertFMFMathToLLVMPattern =
37 using AbsFOpLowering = ConvertFMFMathToLLVMPattern<math::AbsFOp, LLVM::FAbsOp>;
38 using CeilOpLowering = ConvertFMFMathToLLVMPattern<math::CeilOp, LLVM::FCeilOp>;
39 using CopySignOpLowering =
40 ConvertFMFMathToLLVMPattern<math::CopySignOp, LLVM::CopySignOp>;
41 using CosOpLowering = ConvertFMFMathToLLVMPattern<math::CosOp, LLVM::CosOp>;
42 using CtPopFOpLowering =
44 using Exp2OpLowering = ConvertFMFMathToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
45 using ExpOpLowering = ConvertFMFMathToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
46 using FloorOpLowering =
47 ConvertFMFMathToLLVMPattern<math::FloorOp, LLVM::FFloorOp>;
48 using FmaOpLowering = ConvertFMFMathToLLVMPattern<math::FmaOp, LLVM::FMAOp>;
49 using Log10OpLowering =
50 ConvertFMFMathToLLVMPattern<math::Log10Op, LLVM::Log10Op>;
51 using Log2OpLowering = ConvertFMFMathToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
52 using LogOpLowering = ConvertFMFMathToLLVMPattern<math::LogOp, LLVM::LogOp>;
53 using PowFOpLowering = ConvertFMFMathToLLVMPattern<math::PowFOp, LLVM::PowOp>;
54 using FPowIOpLowering =
55 ConvertFMFMathToLLVMPattern<math::FPowIOp, LLVM::PowIOp>;
56 using RoundEvenOpLowering =
57 ConvertFMFMathToLLVMPattern<math::RoundEvenOp, LLVM::RoundEvenOp>;
58 using RoundOpLowering =
59 ConvertFMFMathToLLVMPattern<math::RoundOp, LLVM::RoundOp>;
60 using SinOpLowering = ConvertFMFMathToLLVMPattern<math::SinOp, LLVM::SinOp>;
61 using SqrtOpLowering = ConvertFMFMathToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
62 using FTruncOpLowering =
63 ConvertFMFMathToLLVMPattern<math::TruncOp, LLVM::FTruncOp>;
66 template <
typename MathOp,
typename LLVMOp>
69 using Super = IntOpWithFlagLowering<MathOp, LLVMOp>;
72 matchAndRewrite(MathOp op,
typename MathOp::Adaptor adaptor,
74 auto operandType = adaptor.getOperand().getType();
82 if (!isa<LLVM::LLVMArrayType>(operandType)) {
88 auto vectorType = dyn_cast<VectorType>(resultType);
93 op.getOperation(), adaptor.
getOperands(), *this->getTypeConverter(),
95 return rewriter.create<LLVMOp>(loc, llvm1DVectorTy, operands[0],
102 using CountLeadingZerosOpLowering =
103 IntOpWithFlagLowering<math::CountLeadingZerosOp, LLVM::CountLeadingZerosOp>;
104 using CountTrailingZerosOpLowering =
105 IntOpWithFlagLowering<math::CountTrailingZerosOp,
106 LLVM::CountTrailingZerosOp>;
107 using AbsIOpLowering = IntOpWithFlagLowering<math::AbsIOp, LLVM::AbsOp>;
114 matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor,
116 auto operandType = adaptor.getOperand().getType();
125 ConvertFastMath<math::ExpM1Op, LLVM::ExpOp> expAttrs(op);
126 ConvertFastMath<math::ExpM1Op, LLVM::FSubOp> subAttrs(op);
128 if (!isa<LLVM::LLVMArrayType>(operandType)) {
129 LLVM::ConstantOp one;
131 one = rewriter.
create<LLVM::ConstantOp>(
135 one = rewriter.
create<LLVM::ConstantOp>(loc, operandType, floatOne);
137 auto exp = rewriter.
create<LLVM::ExpOp>(loc, adaptor.getOperand(),
138 expAttrs.getAttrs());
140 op, operandType,
ValueRange{exp, one}, subAttrs.getAttrs());
144 auto vectorType = dyn_cast<VectorType>(resultType);
149 op.getOperation(), adaptor.
getOperands(), *getTypeConverter(),
157 rewriter.
create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
158 auto exp = rewriter.
create<LLVM::ExpOp>(
159 loc, llvm1DVectorTy, operands[0], expAttrs.getAttrs());
160 return rewriter.
create<LLVM::FSubOp>(
161 loc, llvm1DVectorTy,
ValueRange{exp, one}, subAttrs.getAttrs());
172 matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor,
174 auto operandType = adaptor.getOperand().getType();
183 ConvertFastMath<math::Log1pOp, LLVM::FAddOp> addAttrs(op);
184 ConvertFastMath<math::Log1pOp, LLVM::LogOp> logAttrs(op);
186 if (!isa<LLVM::LLVMArrayType>(operandType)) {
187 LLVM::ConstantOp one =
189 ? rewriter.
create<LLVM::ConstantOp>(
193 : rewriter.
create<LLVM::ConstantOp>(loc, operandType, floatOne);
195 auto add = rewriter.
create<LLVM::FAddOp>(
196 loc, operandType,
ValueRange{one, adaptor.getOperand()},
197 addAttrs.getAttrs());
199 logAttrs.getAttrs());
203 auto vectorType = dyn_cast<VectorType>(resultType);
208 op.getOperation(), adaptor.
getOperands(), *getTypeConverter(),
216 rewriter.
create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
217 auto add = rewriter.
create<LLVM::FAddOp>(loc, llvm1DVectorTy,
219 addAttrs.getAttrs());
220 return rewriter.
create<LLVM::LogOp>(
221 loc, llvm1DVectorTy,
ValueRange{add}, logAttrs.getAttrs());
232 matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor,
234 auto operandType = adaptor.getOperand().getType();
243 ConvertFastMath<math::RsqrtOp, LLVM::SqrtOp> sqrtAttrs(op);
244 ConvertFastMath<math::RsqrtOp, LLVM::FDivOp> divAttrs(op);
246 if (!isa<LLVM::LLVMArrayType>(operandType)) {
247 LLVM::ConstantOp one;
249 one = rewriter.
create<LLVM::ConstantOp>(
253 one = rewriter.
create<LLVM::ConstantOp>(loc, operandType, floatOne);
255 auto sqrt = rewriter.
create<LLVM::SqrtOp>(loc, adaptor.getOperand(),
256 sqrtAttrs.getAttrs());
258 op, operandType,
ValueRange{one, sqrt}, divAttrs.getAttrs());
262 auto vectorType = dyn_cast<VectorType>(resultType);
267 op.getOperation(), adaptor.
getOperands(), *getTypeConverter(),
275 rewriter.
create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
276 auto sqrt = rewriter.
create<LLVM::SqrtOp>(
277 loc, llvm1DVectorTy, operands[0], sqrtAttrs.getAttrs());
278 return rewriter.
create<LLVM::FDivOp>(
279 loc, llvm1DVectorTy,
ValueRange{one, sqrt}, divAttrs.getAttrs());
285 struct ConvertMathToLLVMPass
286 :
public impl::ConvertMathToLLVMPassBase<ConvertMathToLLVMPass> {
289 void runOnOperation()
override {
295 std::move(patterns))))
303 bool approximateLog1p) {
304 if (approximateLog1p)
305 patterns.
add<Log1pOpLowering>(converter);
313 CountLeadingZerosOpLowering,
314 CountTrailingZerosOpLowering,
344 void loadDependentDialects(
MLIRContext *context)
const final {
345 context->loadDialect<LLVM::LLVMDialect>();
350 void populateConvertToLLVMConversionPatterns(
360 dialect->addInterfaces<MathToLLVMDialectInterface>();
static MLIRContext * getContext(OpFoldResult val)
FloatAttr getFloatAttr(Type type, double value)
This class implements a pattern rewriter for use with ConversionPatterns.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Value getOperand(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
operand_range getOperands()
Returns an iterator on the underlying Value's.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
Type getType() const
Return the type of this value.
Basic lowering implementation to rewrite Ops with just one result to the LLVM Dialect.
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect.
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
llvm::ElementCount getVectorNumElements(Type type)
Returns the element count of any LLVM-compatible vector type.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation * > *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, bool approximateLog1p=true)
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
void registerConvertMathToLLVMInterface(DialectRegistry ®istry)
This class represents an efficient way to signal success or failure.