MLIR  21.0.0git
MathToLLVM.cpp
Go to the documentation of this file.
1 //===- MathToLLVM.cpp - Math to LLVM dialect conversion -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
18 #include "mlir/IR/TypeUtilities.h"
19 #include "mlir/Pass/Pass.h"
20 
21 #include "llvm/ADT/FloatingPointMode.h"
22 
23 namespace mlir {
24 #define GEN_PASS_DEF_CONVERTMATHTOLLVMPASS
25 #include "mlir/Conversion/Passes.h.inc"
26 } // namespace mlir
27 
28 using namespace mlir;
29 
30 namespace {
31 
32 template <typename SourceOp, typename TargetOp>
33 using ConvertFastMath = arith::AttrConvertFastMathToLLVM<SourceOp, TargetOp>;
34 
35 template <typename SourceOp, typename TargetOp>
36 using ConvertFMFMathToLLVMPattern =
38 
39 using AbsFOpLowering = ConvertFMFMathToLLVMPattern<math::AbsFOp, LLVM::FAbsOp>;
40 using CeilOpLowering = ConvertFMFMathToLLVMPattern<math::CeilOp, LLVM::FCeilOp>;
41 using CopySignOpLowering =
42  ConvertFMFMathToLLVMPattern<math::CopySignOp, LLVM::CopySignOp>;
43 using CosOpLowering = ConvertFMFMathToLLVMPattern<math::CosOp, LLVM::CosOp>;
44 using CoshOpLowering = ConvertFMFMathToLLVMPattern<math::CoshOp, LLVM::CoshOp>;
45 using CtPopFOpLowering =
47 using Exp2OpLowering = ConvertFMFMathToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
48 using ExpOpLowering = ConvertFMFMathToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
49 using FloorOpLowering =
50  ConvertFMFMathToLLVMPattern<math::FloorOp, LLVM::FFloorOp>;
51 using FmaOpLowering = ConvertFMFMathToLLVMPattern<math::FmaOp, LLVM::FMAOp>;
52 using Log10OpLowering =
53  ConvertFMFMathToLLVMPattern<math::Log10Op, LLVM::Log10Op>;
54 using Log2OpLowering = ConvertFMFMathToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
55 using LogOpLowering = ConvertFMFMathToLLVMPattern<math::LogOp, LLVM::LogOp>;
56 using PowFOpLowering = ConvertFMFMathToLLVMPattern<math::PowFOp, LLVM::PowOp>;
57 using FPowIOpLowering =
58  ConvertFMFMathToLLVMPattern<math::FPowIOp, LLVM::PowIOp>;
59 using RoundEvenOpLowering =
60  ConvertFMFMathToLLVMPattern<math::RoundEvenOp, LLVM::RoundEvenOp>;
61 using RoundOpLowering =
62  ConvertFMFMathToLLVMPattern<math::RoundOp, LLVM::RoundOp>;
63 using SinOpLowering = ConvertFMFMathToLLVMPattern<math::SinOp, LLVM::SinOp>;
64 using SinhOpLowering = ConvertFMFMathToLLVMPattern<math::SinhOp, LLVM::SinhOp>;
65 using SqrtOpLowering = ConvertFMFMathToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
66 using FTruncOpLowering =
67  ConvertFMFMathToLLVMPattern<math::TruncOp, LLVM::FTruncOp>;
68 using TanOpLowering = ConvertFMFMathToLLVMPattern<math::TanOp, LLVM::TanOp>;
69 using TanhOpLowering = ConvertFMFMathToLLVMPattern<math::TanhOp, LLVM::TanhOp>;
70 
71 // A `CtLz/CtTz/absi(a)` is converted into `CtLz/CtTz/absi(a, false)`.
72 template <typename MathOp, typename LLVMOp>
73 struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern<MathOp> {
75  using Super = IntOpWithFlagLowering<MathOp, LLVMOp>;
76 
77  LogicalResult
78  matchAndRewrite(MathOp op, typename MathOp::Adaptor adaptor,
79  ConversionPatternRewriter &rewriter) const override {
80  auto operandType = adaptor.getOperand().getType();
81 
82  if (!operandType || !LLVM::isCompatibleType(operandType))
83  return failure();
84 
85  auto loc = op.getLoc();
86  auto resultType = op.getResult().getType();
87 
88  if (!isa<LLVM::LLVMArrayType>(operandType)) {
89  rewriter.replaceOpWithNewOp<LLVMOp>(op, resultType, adaptor.getOperand(),
90  false);
91  return success();
92  }
93 
94  auto vectorType = dyn_cast<VectorType>(resultType);
95  if (!vectorType)
96  return failure();
97 
99  op.getOperation(), adaptor.getOperands(), *this->getTypeConverter(),
100  [&](Type llvm1DVectorTy, ValueRange operands) {
101  return rewriter.create<LLVMOp>(loc, llvm1DVectorTy, operands[0],
102  false);
103  },
104  rewriter);
105  }
106 };
107 
108 using CountLeadingZerosOpLowering =
109  IntOpWithFlagLowering<math::CountLeadingZerosOp, LLVM::CountLeadingZerosOp>;
110 using CountTrailingZerosOpLowering =
111  IntOpWithFlagLowering<math::CountTrailingZerosOp,
112  LLVM::CountTrailingZerosOp>;
113 using AbsIOpLowering = IntOpWithFlagLowering<math::AbsIOp, LLVM::AbsOp>;
114 
115 // A `expm1` is converted into `exp - 1`.
116 struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
118 
119  LogicalResult
120  matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor,
121  ConversionPatternRewriter &rewriter) const override {
122  auto operandType = adaptor.getOperand().getType();
123 
124  if (!operandType || !LLVM::isCompatibleType(operandType))
125  return failure();
126 
127  auto loc = op.getLoc();
128  auto resultType = op.getResult().getType();
129  auto floatType = cast<FloatType>(getElementTypeOrSelf(resultType));
130  auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
131  ConvertFastMath<math::ExpM1Op, LLVM::ExpOp> expAttrs(op);
132  ConvertFastMath<math::ExpM1Op, LLVM::FSubOp> subAttrs(op);
133 
134  if (!isa<LLVM::LLVMArrayType>(operandType)) {
135  LLVM::ConstantOp one;
136  if (LLVM::isCompatibleVectorType(operandType)) {
137  one = rewriter.create<LLVM::ConstantOp>(
138  loc, operandType,
139  SplatElementsAttr::get(cast<ShapedType>(resultType), floatOne));
140  } else {
141  one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
142  }
143  auto exp = rewriter.create<LLVM::ExpOp>(loc, adaptor.getOperand(),
144  expAttrs.getAttrs());
145  rewriter.replaceOpWithNewOp<LLVM::FSubOp>(
146  op, operandType, ValueRange{exp, one}, subAttrs.getAttrs());
147  return success();
148  }
149 
150  auto vectorType = dyn_cast<VectorType>(resultType);
151  if (!vectorType)
152  return rewriter.notifyMatchFailure(op, "expected vector result type");
153 
155  op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
156  [&](Type llvm1DVectorTy, ValueRange operands) {
157  auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
158  auto splatAttr = SplatElementsAttr::get(
159  mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
160  {numElements.isScalable()}),
161  floatOne);
162  auto one =
163  rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
164  auto exp = rewriter.create<LLVM::ExpOp>(
165  loc, llvm1DVectorTy, operands[0], expAttrs.getAttrs());
166  return rewriter.create<LLVM::FSubOp>(
167  loc, llvm1DVectorTy, ValueRange{exp, one}, subAttrs.getAttrs());
168  },
169  rewriter);
170  }
171 };
172 
173 // A `log1p` is converted into `log(1 + ...)`.
174 struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
176 
177  LogicalResult
178  matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor,
179  ConversionPatternRewriter &rewriter) const override {
180  auto operandType = adaptor.getOperand().getType();
181 
182  if (!operandType || !LLVM::isCompatibleType(operandType))
183  return rewriter.notifyMatchFailure(op, "unsupported operand type");
184 
185  auto loc = op.getLoc();
186  auto resultType = op.getResult().getType();
187  auto floatType = cast<FloatType>(getElementTypeOrSelf(resultType));
188  auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
189  ConvertFastMath<math::Log1pOp, LLVM::FAddOp> addAttrs(op);
190  ConvertFastMath<math::Log1pOp, LLVM::LogOp> logAttrs(op);
191 
192  if (!isa<LLVM::LLVMArrayType>(operandType)) {
193  LLVM::ConstantOp one =
194  LLVM::isCompatibleVectorType(operandType)
195  ? rewriter.create<LLVM::ConstantOp>(
196  loc, operandType,
197  SplatElementsAttr::get(cast<ShapedType>(resultType),
198  floatOne))
199  : rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
200 
201  auto add = rewriter.create<LLVM::FAddOp>(
202  loc, operandType, ValueRange{one, adaptor.getOperand()},
203  addAttrs.getAttrs());
204  rewriter.replaceOpWithNewOp<LLVM::LogOp>(op, operandType, ValueRange{add},
205  logAttrs.getAttrs());
206  return success();
207  }
208 
209  auto vectorType = dyn_cast<VectorType>(resultType);
210  if (!vectorType)
211  return rewriter.notifyMatchFailure(op, "expected vector result type");
212 
214  op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
215  [&](Type llvm1DVectorTy, ValueRange operands) {
216  auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
217  auto splatAttr = SplatElementsAttr::get(
218  mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
219  {numElements.isScalable()}),
220  floatOne);
221  auto one =
222  rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
223  auto add = rewriter.create<LLVM::FAddOp>(loc, llvm1DVectorTy,
224  ValueRange{one, operands[0]},
225  addAttrs.getAttrs());
226  return rewriter.create<LLVM::LogOp>(
227  loc, llvm1DVectorTy, ValueRange{add}, logAttrs.getAttrs());
228  },
229  rewriter);
230  }
231 };
232 
233 // A `rsqrt` is converted into `1 / sqrt`.
234 struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
236 
237  LogicalResult
238  matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor,
239  ConversionPatternRewriter &rewriter) const override {
240  auto operandType = adaptor.getOperand().getType();
241 
242  if (!operandType || !LLVM::isCompatibleType(operandType))
243  return failure();
244 
245  auto loc = op.getLoc();
246  auto resultType = op.getResult().getType();
247  auto floatType = cast<FloatType>(getElementTypeOrSelf(resultType));
248  auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
249  ConvertFastMath<math::RsqrtOp, LLVM::SqrtOp> sqrtAttrs(op);
250  ConvertFastMath<math::RsqrtOp, LLVM::FDivOp> divAttrs(op);
251 
252  if (!isa<LLVM::LLVMArrayType>(operandType)) {
253  LLVM::ConstantOp one;
254  if (LLVM::isCompatibleVectorType(operandType)) {
255  one = rewriter.create<LLVM::ConstantOp>(
256  loc, operandType,
257  SplatElementsAttr::get(cast<ShapedType>(resultType), floatOne));
258  } else {
259  one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
260  }
261  auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, adaptor.getOperand(),
262  sqrtAttrs.getAttrs());
263  rewriter.replaceOpWithNewOp<LLVM::FDivOp>(
264  op, operandType, ValueRange{one, sqrt}, divAttrs.getAttrs());
265  return success();
266  }
267 
268  auto vectorType = dyn_cast<VectorType>(resultType);
269  if (!vectorType)
270  return failure();
271 
273  op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
274  [&](Type llvm1DVectorTy, ValueRange operands) {
275  auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
276  auto splatAttr = SplatElementsAttr::get(
277  mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
278  {numElements.isScalable()}),
279  floatOne);
280  auto one =
281  rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
282  auto sqrt = rewriter.create<LLVM::SqrtOp>(
283  loc, llvm1DVectorTy, operands[0], sqrtAttrs.getAttrs());
284  return rewriter.create<LLVM::FDivOp>(
285  loc, llvm1DVectorTy, ValueRange{one, sqrt}, divAttrs.getAttrs());
286  },
287  rewriter);
288  }
289 };
290 
291 struct IsNaNOpLowering : public ConvertOpToLLVMPattern<math::IsNaNOp> {
293 
294  LogicalResult
295  matchAndRewrite(math::IsNaNOp op, OpAdaptor adaptor,
296  ConversionPatternRewriter &rewriter) const override {
297  auto operandType = adaptor.getOperand().getType();
298 
299  if (!operandType || !LLVM::isCompatibleType(operandType))
300  return failure();
301 
302  rewriter.replaceOpWithNewOp<LLVM::IsFPClass>(
303  op, op.getType(), adaptor.getOperand(), llvm::fcNan);
304  return success();
305  }
306 };
307 
308 struct IsFiniteOpLowering : public ConvertOpToLLVMPattern<math::IsFiniteOp> {
310 
311  LogicalResult
312  matchAndRewrite(math::IsFiniteOp op, OpAdaptor adaptor,
313  ConversionPatternRewriter &rewriter) const override {
314  auto operandType = adaptor.getOperand().getType();
315 
316  if (!operandType || !LLVM::isCompatibleType(operandType))
317  return failure();
318 
319  rewriter.replaceOpWithNewOp<LLVM::IsFPClass>(
320  op, op.getType(), adaptor.getOperand(), llvm::fcFinite);
321  return success();
322  }
323 };
324 
325 struct ConvertMathToLLVMPass
326  : public impl::ConvertMathToLLVMPassBase<ConvertMathToLLVMPass> {
327  using Base::Base;
328 
329  void runOnOperation() override {
331  LLVMTypeConverter converter(&getContext());
332  populateMathToLLVMConversionPatterns(converter, patterns, approximateLog1p);
334  if (failed(applyPartialConversion(getOperation(), target,
335  std::move(patterns))))
336  signalPassFailure();
337  }
338 };
339 } // namespace
340 
342  const LLVMTypeConverter &converter, RewritePatternSet &patterns,
343  bool approximateLog1p, PatternBenefit benefit) {
344  if (approximateLog1p)
345  patterns.add<Log1pOpLowering>(converter, benefit);
346  // clang-format off
347  patterns.add<
348  IsNaNOpLowering,
349  IsFiniteOpLowering,
350  AbsFOpLowering,
351  AbsIOpLowering,
352  CeilOpLowering,
353  CopySignOpLowering,
354  CosOpLowering,
355  CoshOpLowering,
356  CountLeadingZerosOpLowering,
357  CountTrailingZerosOpLowering,
358  CtPopFOpLowering,
359  Exp2OpLowering,
360  ExpM1OpLowering,
361  ExpOpLowering,
362  FPowIOpLowering,
363  FloorOpLowering,
364  FmaOpLowering,
365  Log10OpLowering,
366  Log2OpLowering,
367  LogOpLowering,
368  PowFOpLowering,
369  RoundEvenOpLowering,
370  RoundOpLowering,
371  RsqrtOpLowering,
372  SinOpLowering,
373  SinhOpLowering,
374  SqrtOpLowering,
375  FTruncOpLowering,
376  TanOpLowering,
377  TanhOpLowering
378  >(converter, benefit);
379  // clang-format on
380 }
381 
382 //===----------------------------------------------------------------------===//
383 // ConvertToLLVMPatternInterface implementation
384 //===----------------------------------------------------------------------===//
385 
386 namespace {
387 /// Implement the interface to convert Math to LLVM.
388 struct MathToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
390  void loadDependentDialects(MLIRContext *context) const final {
391  context->loadDialect<LLVM::LLVMDialect>();
392  }
393 
394  /// Hook for derived dialect interface to provide conversion patterns
395  /// and mark dialect legal for the conversion target.
396  void populateConvertToLLVMConversionPatterns(
397  ConversionTarget &target, LLVMTypeConverter &typeConverter,
398  RewritePatternSet &patterns) const final {
400  }
401 };
402 } // namespace
403 
405  registry.addExtension(+[](MLIRContext *ctx, math::MathDialect *dialect) {
406  dialect->addInterfaces<MathToLLVMDialectInterface>();
407  });
408 }
static MLIRContext * getContext(OpFoldResult val)
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:250
This class implements a pattern rewriter for use with ConversionPatterns.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:148
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:736
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:554
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
Basic lowering implementation to rewrite Ops with just one result to the LLVM Dialect.
Definition: VectorPattern.h:90
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:882
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:864
llvm::ElementCount getVectorNumElements(Type type)
Returns the element count of any LLVM-compatible vector type.
Definition: LLVMTypes.cpp:907
Include the generated interface declarations.
void populateMathToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool approximateLog1p=true, PatternBenefit benefit=1)
Definition: MathToLLVM.cpp:341
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
void registerConvertMathToLLVMInterface(DialectRegistry &registry)
Definition: MathToLLVM.cpp:404