MLIR  17.0.0git
MathToLLVM.cpp
Go to the documentation of this file.
1 //===- MathToLLVM.cpp - Math to LLVM dialect conversion -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
17 #include "mlir/IR/TypeUtilities.h"
18 #include "mlir/Pass/Pass.h"
19 
20 namespace mlir {
21 #define GEN_PASS_DEF_CONVERTMATHTOLLVM
22 #include "mlir/Conversion/Passes.h.inc"
23 } // namespace mlir
24 
25 using namespace mlir;
26 
27 namespace {
28 
29 template <typename SourceOp, typename TargetOp>
30 using ConvertFastMath = arith::AttrConvertFastMathToLLVM<SourceOp, TargetOp>;
31 
32 template <typename SourceOp, typename TargetOp>
33 using ConvertFMFMathToLLVMPattern =
35 
36 using AbsFOpLowering = ConvertFMFMathToLLVMPattern<math::AbsFOp, LLVM::FAbsOp>;
37 using CeilOpLowering = ConvertFMFMathToLLVMPattern<math::CeilOp, LLVM::FCeilOp>;
38 using CopySignOpLowering =
39  ConvertFMFMathToLLVMPattern<math::CopySignOp, LLVM::CopySignOp>;
40 using CosOpLowering = ConvertFMFMathToLLVMPattern<math::CosOp, LLVM::CosOp>;
41 using CtPopFOpLowering =
43 using Exp2OpLowering = ConvertFMFMathToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
44 using ExpOpLowering = ConvertFMFMathToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
45 using FloorOpLowering =
46  ConvertFMFMathToLLVMPattern<math::FloorOp, LLVM::FFloorOp>;
47 using FmaOpLowering = ConvertFMFMathToLLVMPattern<math::FmaOp, LLVM::FMAOp>;
48 using Log10OpLowering =
49  ConvertFMFMathToLLVMPattern<math::Log10Op, LLVM::Log10Op>;
50 using Log2OpLowering = ConvertFMFMathToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
51 using LogOpLowering = ConvertFMFMathToLLVMPattern<math::LogOp, LLVM::LogOp>;
52 using PowFOpLowering = ConvertFMFMathToLLVMPattern<math::PowFOp, LLVM::PowOp>;
53 using FPowIOpLowering =
54  ConvertFMFMathToLLVMPattern<math::FPowIOp, LLVM::PowIOp>;
55 using RoundEvenOpLowering =
56  ConvertFMFMathToLLVMPattern<math::RoundEvenOp, LLVM::RoundEvenOp>;
57 using RoundOpLowering =
58  ConvertFMFMathToLLVMPattern<math::RoundOp, LLVM::RoundOp>;
59 using SinOpLowering = ConvertFMFMathToLLVMPattern<math::SinOp, LLVM::SinOp>;
60 using SqrtOpLowering = ConvertFMFMathToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
61 using FTruncOpLowering =
62  ConvertFMFMathToLLVMPattern<math::TruncOp, LLVM::FTruncOp>;
63 
64 // A `CtLz/CtTz/absi(a)` is converted into `CtLz/CtTz/absi(a, false)`.
65 template <typename MathOp, typename LLVMOp>
66 struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern<MathOp> {
68  using Super = IntOpWithFlagLowering<MathOp, LLVMOp>;
69 
71  matchAndRewrite(MathOp op, typename MathOp::Adaptor adaptor,
72  ConversionPatternRewriter &rewriter) const override {
73  auto operandType = adaptor.getOperand().getType();
74 
75  if (!operandType || !LLVM::isCompatibleType(operandType))
76  return failure();
77 
78  auto loc = op.getLoc();
79  auto resultType = op.getResult().getType();
80  auto boolZero = rewriter.getBoolAttr(false);
81 
82  if (!operandType.template isa<LLVM::LLVMArrayType>()) {
83  LLVM::ConstantOp zero = rewriter.create<LLVM::ConstantOp>(loc, boolZero);
84  rewriter.replaceOpWithNewOp<LLVMOp>(op, resultType, adaptor.getOperand(),
85  zero);
86  return success();
87  }
88 
89  auto vectorType = resultType.template dyn_cast<VectorType>();
90  if (!vectorType)
91  return failure();
92 
94  op.getOperation(), adaptor.getOperands(), *this->getTypeConverter(),
95  [&](Type llvm1DVectorTy, ValueRange operands) {
96  LLVM::ConstantOp zero =
97  rewriter.create<LLVM::ConstantOp>(loc, boolZero);
98  return rewriter.create<LLVMOp>(loc, llvm1DVectorTy, operands[0],
99  zero);
100  },
101  rewriter);
102  }
103 };
104 
105 using CountLeadingZerosOpLowering =
106  IntOpWithFlagLowering<math::CountLeadingZerosOp, LLVM::CountLeadingZerosOp>;
107 using CountTrailingZerosOpLowering =
108  IntOpWithFlagLowering<math::CountTrailingZerosOp, LLVM::CountTrailingZerosOp>;
109 using AbsIOpLowering = IntOpWithFlagLowering<math::AbsIOp, LLVM::AbsOp>;
110 
111 // A `expm1` is converted into `exp - 1`.
112 struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
114 
116  matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor,
117  ConversionPatternRewriter &rewriter) const override {
118  auto operandType = adaptor.getOperand().getType();
119 
120  if (!operandType || !LLVM::isCompatibleType(operandType))
121  return failure();
122 
123  auto loc = op.getLoc();
124  auto resultType = op.getResult().getType();
125  auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
126  auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
127  ConvertFastMath<math::ExpM1Op, LLVM::ExpOp> expAttrs(op);
128  ConvertFastMath<math::ExpM1Op, LLVM::FSubOp> subAttrs(op);
129 
130  if (!operandType.isa<LLVM::LLVMArrayType>()) {
131  LLVM::ConstantOp one;
132  if (LLVM::isCompatibleVectorType(operandType)) {
133  one = rewriter.create<LLVM::ConstantOp>(
134  loc, operandType,
135  SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne));
136  } else {
137  one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
138  }
139  auto exp = rewriter.create<LLVM::ExpOp>(loc, adaptor.getOperand(),
140  expAttrs.getAttrs());
141  rewriter.replaceOpWithNewOp<LLVM::FSubOp>(
142  op, operandType, ValueRange{exp, one}, subAttrs.getAttrs());
143  return success();
144  }
145 
146  auto vectorType = resultType.dyn_cast<VectorType>();
147  if (!vectorType)
148  return rewriter.notifyMatchFailure(op, "expected vector result type");
149 
151  op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
152  [&](Type llvm1DVectorTy, ValueRange operands) {
153  auto splatAttr = SplatElementsAttr::get(
154  mlir::VectorType::get(
155  {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
156  floatType),
157  floatOne);
158  auto one =
159  rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
160  auto exp = rewriter.create<LLVM::ExpOp>(
161  loc, llvm1DVectorTy, operands[0], expAttrs.getAttrs());
162  return rewriter.create<LLVM::FSubOp>(
163  loc, llvm1DVectorTy, ValueRange{exp, one}, subAttrs.getAttrs());
164  },
165  rewriter);
166  }
167 };
168 
169 // A `log1p` is converted into `log(1 + ...)`.
170 struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
172 
174  matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor,
175  ConversionPatternRewriter &rewriter) const override {
176  auto operandType = adaptor.getOperand().getType();
177 
178  if (!operandType || !LLVM::isCompatibleType(operandType))
179  return rewriter.notifyMatchFailure(op, "unsupported operand type");
180 
181  auto loc = op.getLoc();
182  auto resultType = op.getResult().getType();
183  auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
184  auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
185  ConvertFastMath<math::Log1pOp, LLVM::FAddOp> addAttrs(op);
186  ConvertFastMath<math::Log1pOp, LLVM::LogOp> logAttrs(op);
187 
188  if (!operandType.isa<LLVM::LLVMArrayType>()) {
189  LLVM::ConstantOp one =
190  LLVM::isCompatibleVectorType(operandType)
191  ? rewriter.create<LLVM::ConstantOp>(
192  loc, operandType,
193  SplatElementsAttr::get(resultType.cast<ShapedType>(),
194  floatOne))
195  : rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
196 
197  auto add = rewriter.create<LLVM::FAddOp>(
198  loc, operandType, ValueRange{one, adaptor.getOperand()},
199  addAttrs.getAttrs());
200  rewriter.replaceOpWithNewOp<LLVM::LogOp>(op, operandType, ValueRange{add},
201  logAttrs.getAttrs());
202  return success();
203  }
204 
205  auto vectorType = resultType.dyn_cast<VectorType>();
206  if (!vectorType)
207  return rewriter.notifyMatchFailure(op, "expected vector result type");
208 
210  op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
211  [&](Type llvm1DVectorTy, ValueRange operands) {
212  auto splatAttr = SplatElementsAttr::get(
213  mlir::VectorType::get(
214  {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
215  floatType),
216  floatOne);
217  auto one =
218  rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
219  auto add = rewriter.create<LLVM::FAddOp>(loc, llvm1DVectorTy,
220  ValueRange{one, operands[0]},
221  addAttrs.getAttrs());
222  return rewriter.create<LLVM::LogOp>(
223  loc, llvm1DVectorTy, ValueRange{add}, logAttrs.getAttrs());
224  },
225  rewriter);
226  }
227 };
228 
229 // A `rsqrt` is converted into `1 / sqrt`.
230 struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
232 
234  matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor,
235  ConversionPatternRewriter &rewriter) const override {
236  auto operandType = adaptor.getOperand().getType();
237 
238  if (!operandType || !LLVM::isCompatibleType(operandType))
239  return failure();
240 
241  auto loc = op.getLoc();
242  auto resultType = op.getResult().getType();
243  auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
244  auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
245  ConvertFastMath<math::RsqrtOp, LLVM::SqrtOp> sqrtAttrs(op);
246  ConvertFastMath<math::RsqrtOp, LLVM::FDivOp> divAttrs(op);
247 
248  if (!operandType.isa<LLVM::LLVMArrayType>()) {
249  LLVM::ConstantOp one;
250  if (LLVM::isCompatibleVectorType(operandType)) {
251  one = rewriter.create<LLVM::ConstantOp>(
252  loc, operandType,
253  SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne));
254  } else {
255  one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
256  }
257  auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, adaptor.getOperand(),
258  sqrtAttrs.getAttrs());
259  rewriter.replaceOpWithNewOp<LLVM::FDivOp>(
260  op, operandType, ValueRange{one, sqrt}, divAttrs.getAttrs());
261  return success();
262  }
263 
264  auto vectorType = resultType.dyn_cast<VectorType>();
265  if (!vectorType)
266  return failure();
267 
269  op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
270  [&](Type llvm1DVectorTy, ValueRange operands) {
271  auto splatAttr = SplatElementsAttr::get(
272  mlir::VectorType::get(
273  {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
274  floatType),
275  floatOne);
276  auto one =
277  rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
278  auto sqrt = rewriter.create<LLVM::SqrtOp>(
279  loc, llvm1DVectorTy, operands[0], sqrtAttrs.getAttrs());
280  return rewriter.create<LLVM::FDivOp>(
281  loc, llvm1DVectorTy, ValueRange{one, sqrt}, divAttrs.getAttrs());
282  },
283  rewriter);
284  }
285 };
286 
287 struct ConvertMathToLLVMPass
288  : public impl::ConvertMathToLLVMBase<ConvertMathToLLVMPass> {
289  ConvertMathToLLVMPass() = default;
290 
291  void runOnOperation() override {
292  RewritePatternSet patterns(&getContext());
293  LLVMTypeConverter converter(&getContext());
294  populateMathToLLVMConversionPatterns(converter, patterns);
295  LLVMConversionTarget target(getContext());
296  if (failed(applyPartialConversion(getOperation(), target,
297  std::move(patterns))))
298  signalPassFailure();
299  }
300 };
301 } // namespace
302 
304  RewritePatternSet &patterns) {
305  // clang-format off
306  patterns.add<
307  AbsFOpLowering,
308  AbsIOpLowering,
309  CeilOpLowering,
310  CopySignOpLowering,
311  CosOpLowering,
312  CountLeadingZerosOpLowering,
313  CountTrailingZerosOpLowering,
314  CtPopFOpLowering,
315  Exp2OpLowering,
316  ExpM1OpLowering,
317  ExpOpLowering,
318  FPowIOpLowering,
319  FloorOpLowering,
320  FmaOpLowering,
321  Log10OpLowering,
322  Log1pOpLowering,
323  Log2OpLowering,
324  LogOpLowering,
325  PowFOpLowering,
326  RoundEvenOpLowering,
327  RoundOpLowering,
328  RsqrtOpLowering,
329  SinOpLowering,
330  SqrtOpLowering,
331  FTruncOpLowering
332  >(converter);
333  // clang-format on
334 }
335 
336 std::unique_ptr<Pass> mlir::createConvertMathToLLVMPass() {
337  return std::make_unique<ConvertMathToLLVMPass>();
338 }
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:235
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:101
This class implements a pattern rewriter for use with ConversionPatterns.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:135
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:30
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:422
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
U cast() const
Definition: Types.h:280
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:350
Basic lowering implementation to rewrite Ops with just one result to the LLVM Dialect.
Definition: VectorPattern.h:87
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:855
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:837
llvm::ElementCount getVectorNumElements(Type type)
Returns the element count of any LLVM-compatible vector type.
Definition: LLVMTypes.cpp:880
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation * > *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Definition: MathToLLVM.cpp:303
std::unique_ptr< Pass > createConvertMathToLLVMPass()
Definition: MathToLLVM.cpp:336
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26