21 #define GEN_PASS_DEF_CONVERTMATHTOLLVM
22 #include "mlir/Conversion/Passes.h.inc"
29 template <
typename SourceOp,
typename TargetOp>
30 using ConvertFastMath = arith::AttrConvertFastMathToLLVM<SourceOp, TargetOp>;
32 template <
typename SourceOp,
typename TargetOp>
33 using ConvertFMFMathToLLVMPattern =
36 using AbsFOpLowering = ConvertFMFMathToLLVMPattern<math::AbsFOp, LLVM::FAbsOp>;
37 using CeilOpLowering = ConvertFMFMathToLLVMPattern<math::CeilOp, LLVM::FCeilOp>;
38 using CopySignOpLowering =
39 ConvertFMFMathToLLVMPattern<math::CopySignOp, LLVM::CopySignOp>;
40 using CosOpLowering = ConvertFMFMathToLLVMPattern<math::CosOp, LLVM::CosOp>;
41 using CtPopFOpLowering =
43 using Exp2OpLowering = ConvertFMFMathToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
44 using ExpOpLowering = ConvertFMFMathToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
45 using FloorOpLowering =
46 ConvertFMFMathToLLVMPattern<math::FloorOp, LLVM::FFloorOp>;
47 using FmaOpLowering = ConvertFMFMathToLLVMPattern<math::FmaOp, LLVM::FMAOp>;
48 using Log10OpLowering =
49 ConvertFMFMathToLLVMPattern<math::Log10Op, LLVM::Log10Op>;
50 using Log2OpLowering = ConvertFMFMathToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
51 using LogOpLowering = ConvertFMFMathToLLVMPattern<math::LogOp, LLVM::LogOp>;
52 using PowFOpLowering = ConvertFMFMathToLLVMPattern<math::PowFOp, LLVM::PowOp>;
53 using FPowIOpLowering =
54 ConvertFMFMathToLLVMPattern<math::FPowIOp, LLVM::PowIOp>;
55 using RoundEvenOpLowering =
56 ConvertFMFMathToLLVMPattern<math::RoundEvenOp, LLVM::RoundEvenOp>;
57 using RoundOpLowering =
58 ConvertFMFMathToLLVMPattern<math::RoundOp, LLVM::RoundOp>;
59 using SinOpLowering = ConvertFMFMathToLLVMPattern<math::SinOp, LLVM::SinOp>;
60 using SqrtOpLowering = ConvertFMFMathToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
61 using FTruncOpLowering =
62 ConvertFMFMathToLLVMPattern<math::TruncOp, LLVM::FTruncOp>;
65 template <
typename MathOp,
typename LLVMOp>
68 using Super = IntOpWithFlagLowering<MathOp, LLVMOp>;
71 matchAndRewrite(MathOp op,
typename MathOp::Adaptor adaptor,
73 auto operandType = adaptor.getOperand().getType();
78 auto loc = op.getLoc();
79 auto resultType = op.getResult().getType();
82 if (!operandType.template isa<LLVM::LLVMArrayType>()) {
83 LLVM::ConstantOp zero = rewriter.
create<LLVM::ConstantOp>(loc, boolZero);
89 auto vectorType = resultType.template dyn_cast<VectorType>();
94 op.getOperation(), adaptor.getOperands(), *this->getTypeConverter(),
96 LLVM::ConstantOp zero =
97 rewriter.create<LLVM::ConstantOp>(loc, boolZero);
98 return rewriter.create<LLVMOp>(loc, llvm1DVectorTy, operands[0],
105 using CountLeadingZerosOpLowering =
106 IntOpWithFlagLowering<math::CountLeadingZerosOp, LLVM::CountLeadingZerosOp>;
107 using CountTrailingZerosOpLowering =
108 IntOpWithFlagLowering<math::CountTrailingZerosOp, LLVM::CountTrailingZerosOp>;
109 using AbsIOpLowering = IntOpWithFlagLowering<math::AbsIOp, LLVM::AbsOp>;
116 matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor,
118 auto operandType = adaptor.getOperand().getType();
123 auto loc = op.getLoc();
124 auto resultType = op.getResult().getType();
127 ConvertFastMath<math::ExpM1Op, LLVM::ExpOp> expAttrs(op);
128 ConvertFastMath<math::ExpM1Op, LLVM::FSubOp> subAttrs(op);
130 if (!operandType.isa<LLVM::LLVMArrayType>()) {
131 LLVM::ConstantOp one;
133 one = rewriter.
create<LLVM::ConstantOp>(
137 one = rewriter.
create<LLVM::ConstantOp>(loc, operandType, floatOne);
139 auto exp = rewriter.
create<LLVM::ExpOp>(loc, adaptor.getOperand(),
140 expAttrs.getAttrs());
142 op, operandType,
ValueRange{exp, one}, subAttrs.getAttrs());
146 auto vectorType = resultType.dyn_cast<VectorType>();
151 op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
154 mlir::VectorType::get(
159 rewriter.
create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
160 auto exp = rewriter.
create<LLVM::ExpOp>(
161 loc, llvm1DVectorTy, operands[0], expAttrs.getAttrs());
162 return rewriter.
create<LLVM::FSubOp>(
163 loc, llvm1DVectorTy,
ValueRange{exp, one}, subAttrs.getAttrs());
174 matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor,
176 auto operandType = adaptor.getOperand().getType();
181 auto loc = op.getLoc();
182 auto resultType = op.getResult().getType();
185 ConvertFastMath<math::Log1pOp, LLVM::FAddOp> addAttrs(op);
186 ConvertFastMath<math::Log1pOp, LLVM::LogOp> logAttrs(op);
188 if (!operandType.isa<LLVM::LLVMArrayType>()) {
189 LLVM::ConstantOp one =
191 ? rewriter.
create<LLVM::ConstantOp>(
195 : rewriter.
create<LLVM::ConstantOp>(loc, operandType, floatOne);
197 auto add = rewriter.
create<LLVM::FAddOp>(
198 loc, operandType,
ValueRange{one, adaptor.getOperand()},
199 addAttrs.getAttrs());
201 logAttrs.getAttrs());
205 auto vectorType = resultType.dyn_cast<VectorType>();
210 op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
213 mlir::VectorType::get(
218 rewriter.
create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
219 auto add = rewriter.
create<LLVM::FAddOp>(loc, llvm1DVectorTy,
221 addAttrs.getAttrs());
222 return rewriter.
create<LLVM::LogOp>(
223 loc, llvm1DVectorTy,
ValueRange{add}, logAttrs.getAttrs());
234 matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor,
236 auto operandType = adaptor.getOperand().getType();
241 auto loc = op.getLoc();
242 auto resultType = op.getResult().getType();
245 ConvertFastMath<math::RsqrtOp, LLVM::SqrtOp> sqrtAttrs(op);
246 ConvertFastMath<math::RsqrtOp, LLVM::FDivOp> divAttrs(op);
248 if (!operandType.isa<LLVM::LLVMArrayType>()) {
249 LLVM::ConstantOp one;
251 one = rewriter.
create<LLVM::ConstantOp>(
255 one = rewriter.
create<LLVM::ConstantOp>(loc, operandType, floatOne);
257 auto sqrt = rewriter.
create<LLVM::SqrtOp>(loc, adaptor.getOperand(),
258 sqrtAttrs.getAttrs());
260 op, operandType,
ValueRange{one, sqrt}, divAttrs.getAttrs());
264 auto vectorType = resultType.dyn_cast<VectorType>();
269 op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
272 mlir::VectorType::get(
277 rewriter.
create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
278 auto sqrt = rewriter.
create<LLVM::SqrtOp>(
279 loc, llvm1DVectorTy, operands[0], sqrtAttrs.getAttrs());
280 return rewriter.
create<LLVM::FDivOp>(
281 loc, llvm1DVectorTy,
ValueRange{one, sqrt}, divAttrs.getAttrs());
287 struct ConvertMathToLLVMPass
288 :
public impl::ConvertMathToLLVMBase<ConvertMathToLLVMPass> {
289 ConvertMathToLLVMPass() =
default;
291 void runOnOperation()
override {
297 std::move(patterns))))
312 CountLeadingZerosOpLowering,
313 CountTrailingZerosOpLowering,
337 return std::make_unique<ConvertMathToLLVMPass>();
FloatAttr getFloatAttr(Type type, double value)
BoolAttr getBoolAttr(bool value)
This class implements a pattern rewriter for use with ConversionPatterns.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
Basic lowering implementation to rewrite Ops with just one result to the LLVM Dialect.
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect.
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
llvm::ElementCount getVectorNumElements(Type type)
Returns the element count of any LLVM-compatible vector type.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation * > *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
std::unique_ptr< Pass > createConvertMathToLLVMPass()
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.