MLIR  20.0.0git
MathToLLVM.cpp
Go to the documentation of this file.
1 //===- MathToLLVM.cpp - Math to LLVM dialect conversion -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
18 #include "mlir/IR/TypeUtilities.h"
19 #include "mlir/Pass/Pass.h"
20 
21 namespace mlir {
22 #define GEN_PASS_DEF_CONVERTMATHTOLLVMPASS
23 #include "mlir/Conversion/Passes.h.inc"
24 } // namespace mlir
25 
26 using namespace mlir;
27 
28 namespace {
29 
30 template <typename SourceOp, typename TargetOp>
31 using ConvertFastMath = arith::AttrConvertFastMathToLLVM<SourceOp, TargetOp>;
32 
33 template <typename SourceOp, typename TargetOp>
34 using ConvertFMFMathToLLVMPattern =
36 
37 using AbsFOpLowering = ConvertFMFMathToLLVMPattern<math::AbsFOp, LLVM::FAbsOp>;
38 using CeilOpLowering = ConvertFMFMathToLLVMPattern<math::CeilOp, LLVM::FCeilOp>;
39 using CopySignOpLowering =
40  ConvertFMFMathToLLVMPattern<math::CopySignOp, LLVM::CopySignOp>;
41 using CosOpLowering = ConvertFMFMathToLLVMPattern<math::CosOp, LLVM::CosOp>;
42 using CtPopFOpLowering =
44 using Exp2OpLowering = ConvertFMFMathToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
45 using ExpOpLowering = ConvertFMFMathToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
46 using FloorOpLowering =
47  ConvertFMFMathToLLVMPattern<math::FloorOp, LLVM::FFloorOp>;
48 using FmaOpLowering = ConvertFMFMathToLLVMPattern<math::FmaOp, LLVM::FMAOp>;
49 using Log10OpLowering =
50  ConvertFMFMathToLLVMPattern<math::Log10Op, LLVM::Log10Op>;
51 using Log2OpLowering = ConvertFMFMathToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
52 using LogOpLowering = ConvertFMFMathToLLVMPattern<math::LogOp, LLVM::LogOp>;
53 using PowFOpLowering = ConvertFMFMathToLLVMPattern<math::PowFOp, LLVM::PowOp>;
54 using FPowIOpLowering =
55  ConvertFMFMathToLLVMPattern<math::FPowIOp, LLVM::PowIOp>;
56 using RoundEvenOpLowering =
57  ConvertFMFMathToLLVMPattern<math::RoundEvenOp, LLVM::RoundEvenOp>;
58 using RoundOpLowering =
59  ConvertFMFMathToLLVMPattern<math::RoundOp, LLVM::RoundOp>;
60 using SinOpLowering = ConvertFMFMathToLLVMPattern<math::SinOp, LLVM::SinOp>;
61 using SqrtOpLowering = ConvertFMFMathToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
62 using FTruncOpLowering =
63  ConvertFMFMathToLLVMPattern<math::TruncOp, LLVM::FTruncOp>;
64 
65 // A `CtLz/CtTz/absi(a)` is converted into `CtLz/CtTz/absi(a, false)`.
66 template <typename MathOp, typename LLVMOp>
67 struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern<MathOp> {
69  using Super = IntOpWithFlagLowering<MathOp, LLVMOp>;
70 
71  LogicalResult
72  matchAndRewrite(MathOp op, typename MathOp::Adaptor adaptor,
73  ConversionPatternRewriter &rewriter) const override {
74  auto operandType = adaptor.getOperand().getType();
75 
76  if (!operandType || !LLVM::isCompatibleType(operandType))
77  return failure();
78 
79  auto loc = op.getLoc();
80  auto resultType = op.getResult().getType();
81 
82  if (!isa<LLVM::LLVMArrayType>(operandType)) {
83  rewriter.replaceOpWithNewOp<LLVMOp>(op, resultType, adaptor.getOperand(),
84  false);
85  return success();
86  }
87 
88  auto vectorType = dyn_cast<VectorType>(resultType);
89  if (!vectorType)
90  return failure();
91 
93  op.getOperation(), adaptor.getOperands(), *this->getTypeConverter(),
94  [&](Type llvm1DVectorTy, ValueRange operands) {
95  return rewriter.create<LLVMOp>(loc, llvm1DVectorTy, operands[0],
96  false);
97  },
98  rewriter);
99  }
100 };
101 
102 using CountLeadingZerosOpLowering =
103  IntOpWithFlagLowering<math::CountLeadingZerosOp, LLVM::CountLeadingZerosOp>;
104 using CountTrailingZerosOpLowering =
105  IntOpWithFlagLowering<math::CountTrailingZerosOp,
106  LLVM::CountTrailingZerosOp>;
107 using AbsIOpLowering = IntOpWithFlagLowering<math::AbsIOp, LLVM::AbsOp>;
108 
109 // A `expm1` is converted into `exp - 1`.
110 struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
112 
113  LogicalResult
114  matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor,
115  ConversionPatternRewriter &rewriter) const override {
116  auto operandType = adaptor.getOperand().getType();
117 
118  if (!operandType || !LLVM::isCompatibleType(operandType))
119  return failure();
120 
121  auto loc = op.getLoc();
122  auto resultType = op.getResult().getType();
123  auto floatType = cast<FloatType>(getElementTypeOrSelf(resultType));
124  auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
125  ConvertFastMath<math::ExpM1Op, LLVM::ExpOp> expAttrs(op);
126  ConvertFastMath<math::ExpM1Op, LLVM::FSubOp> subAttrs(op);
127 
128  if (!isa<LLVM::LLVMArrayType>(operandType)) {
129  LLVM::ConstantOp one;
130  if (LLVM::isCompatibleVectorType(operandType)) {
131  one = rewriter.create<LLVM::ConstantOp>(
132  loc, operandType,
133  SplatElementsAttr::get(cast<ShapedType>(resultType), floatOne));
134  } else {
135  one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
136  }
137  auto exp = rewriter.create<LLVM::ExpOp>(loc, adaptor.getOperand(),
138  expAttrs.getAttrs());
139  rewriter.replaceOpWithNewOp<LLVM::FSubOp>(
140  op, operandType, ValueRange{exp, one}, subAttrs.getAttrs());
141  return success();
142  }
143 
144  auto vectorType = dyn_cast<VectorType>(resultType);
145  if (!vectorType)
146  return rewriter.notifyMatchFailure(op, "expected vector result type");
147 
149  op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
150  [&](Type llvm1DVectorTy, ValueRange operands) {
151  auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
152  auto splatAttr = SplatElementsAttr::get(
153  mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
154  {numElements.isScalable()}),
155  floatOne);
156  auto one =
157  rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
158  auto exp = rewriter.create<LLVM::ExpOp>(
159  loc, llvm1DVectorTy, operands[0], expAttrs.getAttrs());
160  return rewriter.create<LLVM::FSubOp>(
161  loc, llvm1DVectorTy, ValueRange{exp, one}, subAttrs.getAttrs());
162  },
163  rewriter);
164  }
165 };
166 
167 // A `log1p` is converted into `log(1 + ...)`.
168 struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
170 
171  LogicalResult
172  matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor,
173  ConversionPatternRewriter &rewriter) const override {
174  auto operandType = adaptor.getOperand().getType();
175 
176  if (!operandType || !LLVM::isCompatibleType(operandType))
177  return rewriter.notifyMatchFailure(op, "unsupported operand type");
178 
179  auto loc = op.getLoc();
180  auto resultType = op.getResult().getType();
181  auto floatType = cast<FloatType>(getElementTypeOrSelf(resultType));
182  auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
183  ConvertFastMath<math::Log1pOp, LLVM::FAddOp> addAttrs(op);
184  ConvertFastMath<math::Log1pOp, LLVM::LogOp> logAttrs(op);
185 
186  if (!isa<LLVM::LLVMArrayType>(operandType)) {
187  LLVM::ConstantOp one =
188  LLVM::isCompatibleVectorType(operandType)
189  ? rewriter.create<LLVM::ConstantOp>(
190  loc, operandType,
191  SplatElementsAttr::get(cast<ShapedType>(resultType),
192  floatOne))
193  : rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
194 
195  auto add = rewriter.create<LLVM::FAddOp>(
196  loc, operandType, ValueRange{one, adaptor.getOperand()},
197  addAttrs.getAttrs());
198  rewriter.replaceOpWithNewOp<LLVM::LogOp>(op, operandType, ValueRange{add},
199  logAttrs.getAttrs());
200  return success();
201  }
202 
203  auto vectorType = dyn_cast<VectorType>(resultType);
204  if (!vectorType)
205  return rewriter.notifyMatchFailure(op, "expected vector result type");
206 
208  op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
209  [&](Type llvm1DVectorTy, ValueRange operands) {
210  auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
211  auto splatAttr = SplatElementsAttr::get(
212  mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
213  {numElements.isScalable()}),
214  floatOne);
215  auto one =
216  rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
217  auto add = rewriter.create<LLVM::FAddOp>(loc, llvm1DVectorTy,
218  ValueRange{one, operands[0]},
219  addAttrs.getAttrs());
220  return rewriter.create<LLVM::LogOp>(
221  loc, llvm1DVectorTy, ValueRange{add}, logAttrs.getAttrs());
222  },
223  rewriter);
224  }
225 };
226 
227 // A `rsqrt` is converted into `1 / sqrt`.
228 struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
230 
231  LogicalResult
232  matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor,
233  ConversionPatternRewriter &rewriter) const override {
234  auto operandType = adaptor.getOperand().getType();
235 
236  if (!operandType || !LLVM::isCompatibleType(operandType))
237  return failure();
238 
239  auto loc = op.getLoc();
240  auto resultType = op.getResult().getType();
241  auto floatType = cast<FloatType>(getElementTypeOrSelf(resultType));
242  auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
243  ConvertFastMath<math::RsqrtOp, LLVM::SqrtOp> sqrtAttrs(op);
244  ConvertFastMath<math::RsqrtOp, LLVM::FDivOp> divAttrs(op);
245 
246  if (!isa<LLVM::LLVMArrayType>(operandType)) {
247  LLVM::ConstantOp one;
248  if (LLVM::isCompatibleVectorType(operandType)) {
249  one = rewriter.create<LLVM::ConstantOp>(
250  loc, operandType,
251  SplatElementsAttr::get(cast<ShapedType>(resultType), floatOne));
252  } else {
253  one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
254  }
255  auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, adaptor.getOperand(),
256  sqrtAttrs.getAttrs());
257  rewriter.replaceOpWithNewOp<LLVM::FDivOp>(
258  op, operandType, ValueRange{one, sqrt}, divAttrs.getAttrs());
259  return success();
260  }
261 
262  auto vectorType = dyn_cast<VectorType>(resultType);
263  if (!vectorType)
264  return failure();
265 
267  op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
268  [&](Type llvm1DVectorTy, ValueRange operands) {
269  auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
270  auto splatAttr = SplatElementsAttr::get(
271  mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
272  {numElements.isScalable()}),
273  floatOne);
274  auto one =
275  rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
276  auto sqrt = rewriter.create<LLVM::SqrtOp>(
277  loc, llvm1DVectorTy, operands[0], sqrtAttrs.getAttrs());
278  return rewriter.create<LLVM::FDivOp>(
279  loc, llvm1DVectorTy, ValueRange{one, sqrt}, divAttrs.getAttrs());
280  },
281  rewriter);
282  }
283 };
284 
285 struct ConvertMathToLLVMPass
286  : public impl::ConvertMathToLLVMPassBase<ConvertMathToLLVMPass> {
287  using Base::Base;
288 
289  void runOnOperation() override {
290  RewritePatternSet patterns(&getContext());
291  LLVMTypeConverter converter(&getContext());
292  populateMathToLLVMConversionPatterns(converter, patterns, approximateLog1p);
294  if (failed(applyPartialConversion(getOperation(), target,
295  std::move(patterns))))
296  signalPassFailure();
297  }
298 };
299 } // namespace
300 
302  const LLVMTypeConverter &converter, RewritePatternSet &patterns,
303  bool approximateLog1p) {
304  if (approximateLog1p)
305  patterns.add<Log1pOpLowering>(converter);
306  // clang-format off
307  patterns.add<
308  AbsFOpLowering,
309  AbsIOpLowering,
310  CeilOpLowering,
311  CopySignOpLowering,
312  CosOpLowering,
313  CountLeadingZerosOpLowering,
314  CountTrailingZerosOpLowering,
315  CtPopFOpLowering,
316  Exp2OpLowering,
317  ExpM1OpLowering,
318  ExpOpLowering,
319  FPowIOpLowering,
320  FloorOpLowering,
321  FmaOpLowering,
322  Log10OpLowering,
323  Log2OpLowering,
324  LogOpLowering,
325  PowFOpLowering,
326  RoundEvenOpLowering,
327  RoundOpLowering,
328  RsqrtOpLowering,
329  SinOpLowering,
330  SqrtOpLowering,
331  FTruncOpLowering
332  >(converter);
333  // clang-format on
334 }
335 
336 //===----------------------------------------------------------------------===//
337 // ConvertToLLVMPatternInterface implementation
338 //===----------------------------------------------------------------------===//
339 
340 namespace {
341 /// Implement the interface to convert Math to LLVM.
342 struct MathToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
344  void loadDependentDialects(MLIRContext *context) const final {
345  context->loadDialect<LLVM::LLVMDialect>();
346  }
347 
348  /// Hook for derived dialect interface to provide conversion patterns
349  /// and mark dialect legal for the conversion target.
350  void populateConvertToLLVMConversionPatterns(
351  ConversionTarget &target, LLVMTypeConverter &typeConverter,
352  RewritePatternSet &patterns) const final {
353  populateMathToLLVMConversionPatterns(typeConverter, patterns);
354  }
355 };
356 } // namespace
357 
359  registry.addExtension(+[](MLIRContext *ctx, math::MathDialect *dialect) {
360  dialect->addInterfaces<MathToLLVMDialectInterface>();
361  });
362 }
static MLIRContext * getContext(OpFoldResult val)
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:294
This class implements a pattern rewriter for use with ConversionPatterns.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:143
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
Basic lowering implementation to rewrite Ops with just one result to the LLVM Dialect.
Definition: VectorPattern.h:90
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:876
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:858
llvm::ElementCount getVectorNumElements(Type type)
Returns the element count of any LLVM-compatible vector type.
Definition: LLVMTypes.cpp:901
Include the generated interface declarations.
void populateMathToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool approximateLog1p=true)
Definition: MathToLLVM.cpp:301
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
void registerConvertMathToLLVMInterface(DialectRegistry &registry)
Definition: MathToLLVM.cpp:358