MLIR  16.0.0git
MathToLLVM.cpp
Go to the documentation of this file.
1 //===- MathToLLVM.cpp - Math to LLVM dialect conversion -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
16 #include "mlir/IR/TypeUtilities.h"
17 #include "mlir/Pass/Pass.h"
18 
19 namespace mlir {
20 #define GEN_PASS_DEF_CONVERTMATHTOLLVM
21 #include "mlir/Conversion/Passes.h.inc"
22 } // namespace mlir
23 
24 using namespace mlir;
25 
26 namespace {
29 using CopySignOpLowering =
32 using CtPopFOpLowering =
36 using FloorOpLowering =
39 using Log10OpLowering =
44 using RoundEvenOpLowering =
46 using RoundOpLowering =
50 using FTruncOpLowering =
52 
53 // A `CtLz/CtTz/absi(a)` is converted into `CtLz/CtTz/absi(a, false)`.
54 template <typename MathOp, typename LLVMOp>
55 struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern<MathOp> {
57  using Super = IntOpWithFlagLowering<MathOp, LLVMOp>;
58 
60  matchAndRewrite(MathOp op, typename MathOp::Adaptor adaptor,
61  ConversionPatternRewriter &rewriter) const override {
62  auto operandType = adaptor.getOperand().getType();
63 
64  if (!operandType || !LLVM::isCompatibleType(operandType))
65  return failure();
66 
67  auto loc = op.getLoc();
68  auto resultType = op.getResult().getType();
69  auto boolZero = rewriter.getBoolAttr(false);
70 
71  if (!operandType.template isa<LLVM::LLVMArrayType>()) {
72  LLVM::ConstantOp zero = rewriter.create<LLVM::ConstantOp>(loc, boolZero);
73  rewriter.replaceOpWithNewOp<LLVMOp>(op, resultType, adaptor.getOperand(),
74  zero);
75  return success();
76  }
77 
78  auto vectorType = resultType.template dyn_cast<VectorType>();
79  if (!vectorType)
80  return failure();
81 
83  op.getOperation(), adaptor.getOperands(), *this->getTypeConverter(),
84  [&](Type llvm1DVectorTy, ValueRange operands) {
85  LLVM::ConstantOp zero =
86  rewriter.create<LLVM::ConstantOp>(loc, boolZero);
87  return rewriter.create<LLVMOp>(loc, llvm1DVectorTy, operands[0],
88  zero);
89  },
90  rewriter);
91  }
92 };
93 
94 using CountLeadingZerosOpLowering =
95  IntOpWithFlagLowering<math::CountLeadingZerosOp, LLVM::CountLeadingZerosOp>;
96 using CountTrailingZerosOpLowering =
97  IntOpWithFlagLowering<math::CountTrailingZerosOp, LLVM::CountTrailingZerosOp>;
98 using AbsIOpLowering = IntOpWithFlagLowering<math::AbsIOp, LLVM::AbsOp>;
99 
100 // A `expm1` is converted into `exp - 1`.
101 struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
103 
105  matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor,
106  ConversionPatternRewriter &rewriter) const override {
107  auto operandType = adaptor.getOperand().getType();
108 
109  if (!operandType || !LLVM::isCompatibleType(operandType))
110  return failure();
111 
112  auto loc = op.getLoc();
113  auto resultType = op.getResult().getType();
114  auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
115  auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
116 
117  if (!operandType.isa<LLVM::LLVMArrayType>()) {
118  LLVM::ConstantOp one;
119  if (LLVM::isCompatibleVectorType(operandType)) {
120  one = rewriter.create<LLVM::ConstantOp>(
121  loc, operandType,
122  SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne));
123  } else {
124  one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
125  }
126  auto exp = rewriter.create<LLVM::ExpOp>(loc, adaptor.getOperand());
127  rewriter.replaceOpWithNewOp<LLVM::FSubOp>(op, operandType, exp, one);
128  return success();
129  }
130 
131  auto vectorType = resultType.dyn_cast<VectorType>();
132  if (!vectorType)
133  return rewriter.notifyMatchFailure(op, "expected vector result type");
134 
136  op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
137  [&](Type llvm1DVectorTy, ValueRange operands) {
138  auto splatAttr = SplatElementsAttr::get(
139  mlir::VectorType::get(
140  {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
141  floatType),
142  floatOne);
143  auto one =
144  rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
145  auto exp =
146  rewriter.create<LLVM::ExpOp>(loc, llvm1DVectorTy, operands[0]);
147  return rewriter.create<LLVM::FSubOp>(loc, llvm1DVectorTy, exp, one);
148  },
149  rewriter);
150  }
151 };
152 
153 // A `log1p` is converted into `log(1 + ...)`.
154 struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
156 
158  matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor,
159  ConversionPatternRewriter &rewriter) const override {
160  auto operandType = adaptor.getOperand().getType();
161 
162  if (!operandType || !LLVM::isCompatibleType(operandType))
163  return rewriter.notifyMatchFailure(op, "unsupported operand type");
164 
165  auto loc = op.getLoc();
166  auto resultType = op.getResult().getType();
167  auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
168  auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
169 
170  if (!operandType.isa<LLVM::LLVMArrayType>()) {
171  LLVM::ConstantOp one =
172  LLVM::isCompatibleVectorType(operandType)
173  ? rewriter.create<LLVM::ConstantOp>(
174  loc, operandType,
175  SplatElementsAttr::get(resultType.cast<ShapedType>(),
176  floatOne))
177  : rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
178 
179  auto add = rewriter.create<LLVM::FAddOp>(loc, operandType, one,
180  adaptor.getOperand());
181  rewriter.replaceOpWithNewOp<LLVM::LogOp>(op, operandType, add);
182  return success();
183  }
184 
185  auto vectorType = resultType.dyn_cast<VectorType>();
186  if (!vectorType)
187  return rewriter.notifyMatchFailure(op, "expected vector result type");
188 
190  op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
191  [&](Type llvm1DVectorTy, ValueRange operands) {
192  auto splatAttr = SplatElementsAttr::get(
193  mlir::VectorType::get(
194  {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
195  floatType),
196  floatOne);
197  auto one =
198  rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
199  auto add = rewriter.create<LLVM::FAddOp>(loc, llvm1DVectorTy, one,
200  operands[0]);
201  return rewriter.create<LLVM::LogOp>(loc, llvm1DVectorTy, add);
202  },
203  rewriter);
204  }
205 };
206 
207 // A `rsqrt` is converted into `1 / sqrt`.
208 struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
210 
212  matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor,
213  ConversionPatternRewriter &rewriter) const override {
214  auto operandType = adaptor.getOperand().getType();
215 
216  if (!operandType || !LLVM::isCompatibleType(operandType))
217  return failure();
218 
219  auto loc = op.getLoc();
220  auto resultType = op.getResult().getType();
221  auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
222  auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
223 
224  if (!operandType.isa<LLVM::LLVMArrayType>()) {
225  LLVM::ConstantOp one;
226  if (LLVM::isCompatibleVectorType(operandType)) {
227  one = rewriter.create<LLVM::ConstantOp>(
228  loc, operandType,
229  SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne));
230  } else {
231  one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
232  }
233  auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, adaptor.getOperand());
234  rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, operandType, one, sqrt);
235  return success();
236  }
237 
238  auto vectorType = resultType.dyn_cast<VectorType>();
239  if (!vectorType)
240  return failure();
241 
243  op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
244  [&](Type llvm1DVectorTy, ValueRange operands) {
245  auto splatAttr = SplatElementsAttr::get(
246  mlir::VectorType::get(
247  {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
248  floatType),
249  floatOne);
250  auto one =
251  rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
252  auto sqrt =
253  rewriter.create<LLVM::SqrtOp>(loc, llvm1DVectorTy, operands[0]);
254  return rewriter.create<LLVM::FDivOp>(loc, llvm1DVectorTy, one, sqrt);
255  },
256  rewriter);
257  }
258 };
259 
260 struct ConvertMathToLLVMPass
261  : public impl::ConvertMathToLLVMBase<ConvertMathToLLVMPass> {
262  ConvertMathToLLVMPass() = default;
263 
264  void runOnOperation() override {
265  RewritePatternSet patterns(&getContext());
266  LLVMTypeConverter converter(&getContext());
267  populateMathToLLVMConversionPatterns(converter, patterns);
268  LLVMConversionTarget target(getContext());
269  if (failed(applyPartialConversion(getOperation(), target,
270  std::move(patterns))))
271  signalPassFailure();
272  }
273 };
274 } // namespace
275 
277  RewritePatternSet &patterns) {
278  // clang-format off
279  patterns.add<
280  AbsFOpLowering,
281  AbsIOpLowering,
282  CeilOpLowering,
283  CopySignOpLowering,
284  CosOpLowering,
285  CountLeadingZerosOpLowering,
286  CountTrailingZerosOpLowering,
287  CtPopFOpLowering,
288  Exp2OpLowering,
289  ExpM1OpLowering,
290  ExpOpLowering,
291  FloorOpLowering,
292  FmaOpLowering,
293  Log10OpLowering,
294  Log1pOpLowering,
295  Log2OpLowering,
296  LogOpLowering,
297  PowFOpLowering,
298  RoundEvenOpLowering,
299  RoundOpLowering,
300  RsqrtOpLowering,
301  SinOpLowering,
302  SqrtOpLowering,
303  FTruncOpLowering
304  >(converter);
305  // clang-format on
306 }
307 
308 std::unique_ptr<Pass> mlir::createConvertMathToLLVMPass() {
309  return std::make_unique<ConvertMathToLLVMPass>();
310 }
Include the generated interface declarations.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:133
llvm::ElementCount getVectorNumElements(Type type)
Returns the element count of any LLVM-compatible vector type.
Definition: LLVMTypes.cpp:894
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
LogicalResult applyPartialConversion(ArrayRef< Operation *> ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation *> *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
Basic lowering implementation to rewrite Ops with just one result to the LLVM Dialect.
Definition: VectorPattern.h:67
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
std::unique_ptr< Pass > createConvertMathToLLVMPass()
Definition: MathToLLVM.cpp:308
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
Derived class that automatically populates legalization information for different LLVM ops...
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:231
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:418
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:851
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
LLVM dialect array type.
Definition: LLVMTypes.h:75
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect. ...
Definition: LLVMTypes.cpp:869
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:30
static VectorType vectorType(CodeGen &codegen, Type etp)
Constructs vector type.
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:97
void populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Definition: MathToLLVM.cpp:276
This class implements a pattern rewriter for use with ConversionPatterns.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:345
U cast() const
Definition: Types.h:279