MLIR  22.0.0git
MathToLLVM.cpp
Go to the documentation of this file.
1 //===- MathToLLVM.cpp - Math to LLVM dialect conversion -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
18 #include "mlir/IR/TypeUtilities.h"
19 #include "mlir/Pass/Pass.h"
20 
21 #include "llvm/ADT/FloatingPointMode.h"
22 
23 namespace mlir {
24 #define GEN_PASS_DEF_CONVERTMATHTOLLVMPASS
25 #include "mlir/Conversion/Passes.h.inc"
26 } // namespace mlir
27 
28 using namespace mlir;
29 
30 namespace {
31 
32 template <typename SourceOp, typename TargetOp>
33 using ConvertFastMath = arith::AttrConvertFastMathToLLVM<SourceOp, TargetOp>;
34 
35 template <typename SourceOp, typename TargetOp>
36 using ConvertFMFMathToLLVMPattern =
38 
39 using AbsFOpLowering = ConvertFMFMathToLLVMPattern<math::AbsFOp, LLVM::FAbsOp>;
40 using CeilOpLowering = ConvertFMFMathToLLVMPattern<math::CeilOp, LLVM::FCeilOp>;
41 using CopySignOpLowering =
42  ConvertFMFMathToLLVMPattern<math::CopySignOp, LLVM::CopySignOp>;
43 using CosOpLowering = ConvertFMFMathToLLVMPattern<math::CosOp, LLVM::CosOp>;
44 using CoshOpLowering = ConvertFMFMathToLLVMPattern<math::CoshOp, LLVM::CoshOp>;
45 using AcosOpLowering = ConvertFMFMathToLLVMPattern<math::AcosOp, LLVM::ACosOp>;
46 using CtPopFOpLowering =
48 using Exp2OpLowering = ConvertFMFMathToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
49 using ExpOpLowering = ConvertFMFMathToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
50 using FloorOpLowering =
51  ConvertFMFMathToLLVMPattern<math::FloorOp, LLVM::FFloorOp>;
52 using FmaOpLowering = ConvertFMFMathToLLVMPattern<math::FmaOp, LLVM::FMAOp>;
53 using Log10OpLowering =
54  ConvertFMFMathToLLVMPattern<math::Log10Op, LLVM::Log10Op>;
55 using Log2OpLowering = ConvertFMFMathToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
56 using LogOpLowering = ConvertFMFMathToLLVMPattern<math::LogOp, LLVM::LogOp>;
57 using PowFOpLowering = ConvertFMFMathToLLVMPattern<math::PowFOp, LLVM::PowOp>;
58 using FPowIOpLowering =
59  ConvertFMFMathToLLVMPattern<math::FPowIOp, LLVM::PowIOp>;
60 using RoundEvenOpLowering =
61  ConvertFMFMathToLLVMPattern<math::RoundEvenOp, LLVM::RoundEvenOp>;
62 using RoundOpLowering =
63  ConvertFMFMathToLLVMPattern<math::RoundOp, LLVM::RoundOp>;
64 using SinOpLowering = ConvertFMFMathToLLVMPattern<math::SinOp, LLVM::SinOp>;
65 using SinhOpLowering = ConvertFMFMathToLLVMPattern<math::SinhOp, LLVM::SinhOp>;
66 using ASinOpLowering = ConvertFMFMathToLLVMPattern<math::AsinOp, LLVM::ASinOp>;
67 using SqrtOpLowering = ConvertFMFMathToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
68 using FTruncOpLowering =
69  ConvertFMFMathToLLVMPattern<math::TruncOp, LLVM::FTruncOp>;
70 using TanOpLowering = ConvertFMFMathToLLVMPattern<math::TanOp, LLVM::TanOp>;
71 using TanhOpLowering = ConvertFMFMathToLLVMPattern<math::TanhOp, LLVM::TanhOp>;
72 using ATanOpLowering = ConvertFMFMathToLLVMPattern<math::AtanOp, LLVM::ATanOp>;
73 using ATan2OpLowering =
74  ConvertFMFMathToLLVMPattern<math::Atan2Op, LLVM::ATan2Op>;
75 // A `CtLz/CtTz/absi(a)` is converted into `CtLz/CtTz/absi(a, false)`.
76 // TODO: Result and operand types match for `absi` as opposed to `ct*z`, so it
77 // may be better to separate the patterns.
78 template <typename MathOp, typename LLVMOp>
79 struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern<MathOp> {
81  using Super = IntOpWithFlagLowering<MathOp, LLVMOp>;
82 
83  LogicalResult
84  matchAndRewrite(MathOp op, typename MathOp::Adaptor adaptor,
85  ConversionPatternRewriter &rewriter) const override {
86  const auto &typeConverter = *this->getTypeConverter();
87  auto operandType = adaptor.getOperand().getType();
88  auto llvmOperandType = typeConverter.convertType(operandType);
89  if (!llvmOperandType)
90  return failure();
91 
92  auto loc = op.getLoc();
93  auto resultType = op.getResult().getType();
94  auto llvmResultType = typeConverter.convertType(resultType);
95  if (!llvmResultType)
96  return failure();
97 
98  if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
99  rewriter.replaceOpWithNewOp<LLVMOp>(op, llvmResultType,
100  adaptor.getOperand(), false);
101  return success();
102  }
103 
104  if (!isa<VectorType>(llvmResultType))
105  return failure();
106 
108  op.getOperation(), adaptor.getOperands(), typeConverter,
109  [&](Type llvm1DVectorTy, ValueRange operands) {
110  return LLVMOp::create(rewriter, loc, llvm1DVectorTy, operands[0],
111  false);
112  },
113  rewriter);
114  }
115 };
116 
117 using CountLeadingZerosOpLowering =
118  IntOpWithFlagLowering<math::CountLeadingZerosOp, LLVM::CountLeadingZerosOp>;
119 using CountTrailingZerosOpLowering =
120  IntOpWithFlagLowering<math::CountTrailingZerosOp,
121  LLVM::CountTrailingZerosOp>;
122 using AbsIOpLowering = IntOpWithFlagLowering<math::AbsIOp, LLVM::AbsOp>;
123 
124 // A `expm1` is converted into `exp - 1`.
125 struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
127 
128  LogicalResult
129  matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor,
130  ConversionPatternRewriter &rewriter) const override {
131  const auto &typeConverter = *this->getTypeConverter();
132  auto operandType = adaptor.getOperand().getType();
133  auto llvmOperandType = typeConverter.convertType(operandType);
134  if (!llvmOperandType)
135  return failure();
136 
137  auto loc = op.getLoc();
138  auto resultType = op.getResult().getType();
139  auto floatType = cast<FloatType>(
140  typeConverter.convertType(getElementTypeOrSelf(resultType)));
141  auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
142  ConvertFastMath<math::ExpM1Op, LLVM::ExpOp> expAttrs(op);
143  ConvertFastMath<math::ExpM1Op, LLVM::FSubOp> subAttrs(op);
144 
145  if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
146  LLVM::ConstantOp one;
147  if (LLVM::isCompatibleVectorType(llvmOperandType)) {
148  one = LLVM::ConstantOp::create(
149  rewriter, loc, llvmOperandType,
150  SplatElementsAttr::get(cast<ShapedType>(llvmOperandType),
151  floatOne));
152  } else {
153  one =
154  LLVM::ConstantOp::create(rewriter, loc, llvmOperandType, floatOne);
155  }
156  auto exp = LLVM::ExpOp::create(rewriter, loc, adaptor.getOperand(),
157  expAttrs.getAttrs());
158  rewriter.replaceOpWithNewOp<LLVM::FSubOp>(
159  op, llvmOperandType, ValueRange{exp, one}, subAttrs.getAttrs());
160  return success();
161  }
162 
163  if (!isa<VectorType>(resultType))
164  return rewriter.notifyMatchFailure(op, "expected vector result type");
165 
167  op.getOperation(), adaptor.getOperands(), typeConverter,
168  [&](Type llvm1DVectorTy, ValueRange operands) {
169  auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
170  auto splatAttr = SplatElementsAttr::get(
171  mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
172  {numElements.isScalable()}),
173  floatOne);
174  auto one = LLVM::ConstantOp::create(rewriter, loc, llvm1DVectorTy,
175  splatAttr);
176  auto exp = LLVM::ExpOp::create(rewriter, loc, llvm1DVectorTy,
177  operands[0], expAttrs.getAttrs());
178  return LLVM::FSubOp::create(rewriter, loc, llvm1DVectorTy,
179  ValueRange{exp, one},
180  subAttrs.getAttrs());
181  },
182  rewriter);
183  }
184 };
185 
186 // A `log1p` is converted into `log(1 + ...)`.
187 struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
189 
190  LogicalResult
191  matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor,
192  ConversionPatternRewriter &rewriter) const override {
193  const auto &typeConverter = *this->getTypeConverter();
194  auto operandType = adaptor.getOperand().getType();
195  auto llvmOperandType = typeConverter.convertType(operandType);
196  if (!llvmOperandType)
197  return rewriter.notifyMatchFailure(op, "unsupported operand type");
198 
199  auto loc = op.getLoc();
200  auto resultType = op.getResult().getType();
201  auto floatType = cast<FloatType>(
202  typeConverter.convertType(getElementTypeOrSelf(resultType)));
203  auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
204  ConvertFastMath<math::Log1pOp, LLVM::FAddOp> addAttrs(op);
205  ConvertFastMath<math::Log1pOp, LLVM::LogOp> logAttrs(op);
206 
207  if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
208  LLVM::ConstantOp one =
209  isa<VectorType>(llvmOperandType)
210  ? LLVM::ConstantOp::create(
211  rewriter, loc, llvmOperandType,
212  SplatElementsAttr::get(cast<ShapedType>(llvmOperandType),
213  floatOne))
214  : LLVM::ConstantOp::create(rewriter, loc, llvmOperandType,
215  floatOne);
216 
217  auto add = LLVM::FAddOp::create(rewriter, loc, llvmOperandType,
218  ValueRange{one, adaptor.getOperand()},
219  addAttrs.getAttrs());
220  rewriter.replaceOpWithNewOp<LLVM::LogOp>(
221  op, llvmOperandType, ValueRange{add}, logAttrs.getAttrs());
222  return success();
223  }
224 
225  if (!isa<VectorType>(resultType))
226  return rewriter.notifyMatchFailure(op, "expected vector result type");
227 
229  op.getOperation(), adaptor.getOperands(), typeConverter,
230  [&](Type llvm1DVectorTy, ValueRange operands) {
231  auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
232  auto splatAttr = SplatElementsAttr::get(
233  mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
234  {numElements.isScalable()}),
235  floatOne);
236  auto one = LLVM::ConstantOp::create(rewriter, loc, llvm1DVectorTy,
237  splatAttr);
238  auto add = LLVM::FAddOp::create(rewriter, loc, llvm1DVectorTy,
239  ValueRange{one, operands[0]},
240  addAttrs.getAttrs());
241  return LLVM::LogOp::create(rewriter, loc, llvm1DVectorTy,
242  ValueRange{add}, logAttrs.getAttrs());
243  },
244  rewriter);
245  }
246 };
247 
248 // A `rsqrt` is converted into `1 / sqrt`.
249 struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
251 
252  LogicalResult
253  matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor,
254  ConversionPatternRewriter &rewriter) const override {
255  const auto &typeConverter = *this->getTypeConverter();
256  auto operandType = adaptor.getOperand().getType();
257  auto llvmOperandType = typeConverter.convertType(operandType);
258  if (!llvmOperandType)
259  return failure();
260 
261  auto loc = op.getLoc();
262  auto resultType = op.getResult().getType();
263  auto floatType = cast<FloatType>(
264  typeConverter.convertType(getElementTypeOrSelf(resultType)));
265  auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
266  ConvertFastMath<math::RsqrtOp, LLVM::SqrtOp> sqrtAttrs(op);
267  ConvertFastMath<math::RsqrtOp, LLVM::FDivOp> divAttrs(op);
268 
269  if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
270  LLVM::ConstantOp one;
271  if (isa<VectorType>(llvmOperandType)) {
272  one = LLVM::ConstantOp::create(
273  rewriter, loc, llvmOperandType,
274  SplatElementsAttr::get(cast<ShapedType>(llvmOperandType),
275  floatOne));
276  } else {
277  one =
278  LLVM::ConstantOp::create(rewriter, loc, llvmOperandType, floatOne);
279  }
280  auto sqrt = LLVM::SqrtOp::create(rewriter, loc, adaptor.getOperand(),
281  sqrtAttrs.getAttrs());
282  rewriter.replaceOpWithNewOp<LLVM::FDivOp>(
283  op, llvmOperandType, ValueRange{one, sqrt}, divAttrs.getAttrs());
284  return success();
285  }
286 
287  if (!isa<VectorType>(resultType))
288  return failure();
289 
291  op.getOperation(), adaptor.getOperands(), typeConverter,
292  [&](Type llvm1DVectorTy, ValueRange operands) {
293  auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
294  auto splatAttr = SplatElementsAttr::get(
295  mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
296  {numElements.isScalable()}),
297  floatOne);
298  auto one = LLVM::ConstantOp::create(rewriter, loc, llvm1DVectorTy,
299  splatAttr);
300  auto sqrt = LLVM::SqrtOp::create(rewriter, loc, llvm1DVectorTy,
301  operands[0], sqrtAttrs.getAttrs());
302  return LLVM::FDivOp::create(rewriter, loc, llvm1DVectorTy,
303  ValueRange{one, sqrt},
304  divAttrs.getAttrs());
305  },
306  rewriter);
307  }
308 };
309 
310 struct IsNaNOpLowering : public ConvertOpToLLVMPattern<math::IsNaNOp> {
312 
313  LogicalResult
314  matchAndRewrite(math::IsNaNOp op, OpAdaptor adaptor,
315  ConversionPatternRewriter &rewriter) const override {
316  const auto &typeConverter = *this->getTypeConverter();
317  auto operandType =
318  typeConverter.convertType(adaptor.getOperand().getType());
319  auto resultType = typeConverter.convertType(op.getResult().getType());
320  if (!operandType || !resultType)
321  return failure();
322 
323  rewriter.replaceOpWithNewOp<LLVM::IsFPClass>(
324  op, resultType, adaptor.getOperand(), llvm::fcNan);
325  return success();
326  }
327 };
328 
329 struct IsFiniteOpLowering : public ConvertOpToLLVMPattern<math::IsFiniteOp> {
331 
332  LogicalResult
333  matchAndRewrite(math::IsFiniteOp op, OpAdaptor adaptor,
334  ConversionPatternRewriter &rewriter) const override {
335  const auto &typeConverter = *this->getTypeConverter();
336  auto operandType =
337  typeConverter.convertType(adaptor.getOperand().getType());
338  auto resultType = typeConverter.convertType(op.getResult().getType());
339  if (!operandType || !resultType)
340  return failure();
341 
342  rewriter.replaceOpWithNewOp<LLVM::IsFPClass>(
343  op, resultType, adaptor.getOperand(), llvm::fcFinite);
344  return success();
345  }
346 };
347 
348 struct ConvertMathToLLVMPass
349  : public impl::ConvertMathToLLVMPassBase<ConvertMathToLLVMPass> {
350  using Base::Base;
351 
352  void runOnOperation() override {
354  LLVMTypeConverter converter(&getContext());
355  populateMathToLLVMConversionPatterns(converter, patterns, approximateLog1p);
357  if (failed(applyPartialConversion(getOperation(), target,
358  std::move(patterns))))
359  signalPassFailure();
360  }
361 };
362 } // namespace
363 
365  const LLVMTypeConverter &converter, RewritePatternSet &patterns,
366  bool approximateLog1p, PatternBenefit benefit) {
367  if (approximateLog1p)
368  patterns.add<Log1pOpLowering>(converter, benefit);
369  // clang-format off
370  patterns.add<
371  IsNaNOpLowering,
372  IsFiniteOpLowering,
373  AbsFOpLowering,
374  AbsIOpLowering,
375  CeilOpLowering,
376  CopySignOpLowering,
377  CosOpLowering,
378  CoshOpLowering,
379  AcosOpLowering,
380  CountLeadingZerosOpLowering,
381  CountTrailingZerosOpLowering,
382  CtPopFOpLowering,
383  Exp2OpLowering,
384  ExpM1OpLowering,
385  ExpOpLowering,
386  FPowIOpLowering,
387  FloorOpLowering,
388  FmaOpLowering,
389  Log10OpLowering,
390  Log2OpLowering,
391  LogOpLowering,
392  PowFOpLowering,
393  RoundEvenOpLowering,
394  RoundOpLowering,
395  RsqrtOpLowering,
396  SinOpLowering,
397  SinhOpLowering,
398  ASinOpLowering,
399  SqrtOpLowering,
400  FTruncOpLowering,
401  TanOpLowering,
402  TanhOpLowering,
403  ATanOpLowering,
404  ATan2OpLowering
405  >(converter, benefit);
406  // clang-format on
407 }
408 
409 //===----------------------------------------------------------------------===//
410 // ConvertToLLVMPatternInterface implementation
411 //===----------------------------------------------------------------------===//
412 
413 namespace {
414 /// Implement the interface to convert Math to LLVM.
415 struct MathToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
417  void loadDependentDialects(MLIRContext *context) const final {
418  context->loadDialect<LLVM::LLVMDialect>();
419  }
420 
421  /// Hook for derived dialect interface to provide conversion patterns
422  /// and mark dialect legal for the conversion target.
424  ConversionTarget &target, LLVMTypeConverter &typeConverter,
425  RewritePatternSet &patterns) const final {
427  }
428 };
429 } // namespace
430 
432  registry.addExtension(+[](MLIRContext *ctx, math::MathDialect *dialect) {
433  dialect->addInterfaces<MathToLLVMDialectInterface>();
434  });
435 }
static MLIRContext * getContext(OpFoldResult val)
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:249
This class implements a pattern rewriter for use with ConversionPatterns.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:209
LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Wrappers around the RewritePattern methods that pass the derived op type.
Definition: Pattern.h:223
Base class for dialect interfaces providing translation to LLVM IR.
virtual void populateConvertToLLVMConversionPatterns(ConversionTarget &target, LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) const =0
Hook for derived dialect interface to provide conversion patterns and mark dialect legal for the conv...
virtual void loadDependentDialects(MLIRContext *context) const
Hook for derived dialect interface to load the dialects they target.
ConvertToLLVMPatternInterface(Dialect *dialect)
const LLVMTypeConverter * getTypeConverter() const
Definition: Pattern.cpp:27
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
Basic lowering implementation to rewrite Ops with just one result to the LLVM Dialect.
Definition: VectorPattern.h:90
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:813
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
detail::LazyTextBuild add(const char *fmt, Ts &&...ts)
Create a Remark with llvm::formatv formatting.
Definition: Remarks.h:463
Include the generated interface declarations.
void populateMathToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool approximateLog1p=true, PatternBenefit benefit=1)
Definition: MathToLLVM.cpp:364
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
void registerConvertMathToLLVMInterface(DialectRegistry &registry)
Definition: MathToLLVM.cpp:431