MLIR  21.0.0git
MathToLLVM.cpp
Go to the documentation of this file.
1 //===- MathToLLVM.cpp - Math to LLVM dialect conversion -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
18 #include "mlir/IR/TypeUtilities.h"
19 #include "mlir/Pass/Pass.h"
20 
21 #include "llvm/ADT/FloatingPointMode.h"
22 
23 namespace mlir {
24 #define GEN_PASS_DEF_CONVERTMATHTOLLVMPASS
25 #include "mlir/Conversion/Passes.h.inc"
26 } // namespace mlir
27 
28 using namespace mlir;
29 
30 namespace {
31 
32 template <typename SourceOp, typename TargetOp>
33 using ConvertFastMath = arith::AttrConvertFastMathToLLVM<SourceOp, TargetOp>;
34 
35 template <typename SourceOp, typename TargetOp>
36 using ConvertFMFMathToLLVMPattern =
38 
39 using AbsFOpLowering = ConvertFMFMathToLLVMPattern<math::AbsFOp, LLVM::FAbsOp>;
40 using CeilOpLowering = ConvertFMFMathToLLVMPattern<math::CeilOp, LLVM::FCeilOp>;
41 using CopySignOpLowering =
42  ConvertFMFMathToLLVMPattern<math::CopySignOp, LLVM::CopySignOp>;
43 using CosOpLowering = ConvertFMFMathToLLVMPattern<math::CosOp, LLVM::CosOp>;
44 using CoshOpLowering = ConvertFMFMathToLLVMPattern<math::CoshOp, LLVM::CoshOp>;
45 using AcosOpLowering = ConvertFMFMathToLLVMPattern<math::AcosOp, LLVM::ACosOp>;
46 using CtPopFOpLowering =
48 using Exp2OpLowering = ConvertFMFMathToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
49 using ExpOpLowering = ConvertFMFMathToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
50 using FloorOpLowering =
51  ConvertFMFMathToLLVMPattern<math::FloorOp, LLVM::FFloorOp>;
52 using FmaOpLowering = ConvertFMFMathToLLVMPattern<math::FmaOp, LLVM::FMAOp>;
53 using Log10OpLowering =
54  ConvertFMFMathToLLVMPattern<math::Log10Op, LLVM::Log10Op>;
55 using Log2OpLowering = ConvertFMFMathToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
56 using LogOpLowering = ConvertFMFMathToLLVMPattern<math::LogOp, LLVM::LogOp>;
57 using PowFOpLowering = ConvertFMFMathToLLVMPattern<math::PowFOp, LLVM::PowOp>;
58 using FPowIOpLowering =
59  ConvertFMFMathToLLVMPattern<math::FPowIOp, LLVM::PowIOp>;
60 using RoundEvenOpLowering =
61  ConvertFMFMathToLLVMPattern<math::RoundEvenOp, LLVM::RoundEvenOp>;
62 using RoundOpLowering =
63  ConvertFMFMathToLLVMPattern<math::RoundOp, LLVM::RoundOp>;
64 using SinOpLowering = ConvertFMFMathToLLVMPattern<math::SinOp, LLVM::SinOp>;
65 using SinhOpLowering = ConvertFMFMathToLLVMPattern<math::SinhOp, LLVM::SinhOp>;
66 using ASinOpLowering = ConvertFMFMathToLLVMPattern<math::AsinOp, LLVM::ASinOp>;
67 using SqrtOpLowering = ConvertFMFMathToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
68 using FTruncOpLowering =
69  ConvertFMFMathToLLVMPattern<math::TruncOp, LLVM::FTruncOp>;
70 using TanOpLowering = ConvertFMFMathToLLVMPattern<math::TanOp, LLVM::TanOp>;
71 using TanhOpLowering = ConvertFMFMathToLLVMPattern<math::TanhOp, LLVM::TanhOp>;
72 using ATanOpLowering = ConvertFMFMathToLLVMPattern<math::AtanOp, LLVM::ATanOp>;
73 using ATan2OpLowering =
74  ConvertFMFMathToLLVMPattern<math::Atan2Op, LLVM::ATan2Op>;
75 // A `CtLz/CtTz/absi(a)` is converted into `CtLz/CtTz/absi(a, false)`.
76 // TODO: Result and operand types match for `absi` as opposed to `ct*z`, so it
77 // may be better to separate the patterns.
78 template <typename MathOp, typename LLVMOp>
79 struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern<MathOp> {
81  using Super = IntOpWithFlagLowering<MathOp, LLVMOp>;
82 
83  LogicalResult
84  matchAndRewrite(MathOp op, typename MathOp::Adaptor adaptor,
85  ConversionPatternRewriter &rewriter) const override {
86  const auto &typeConverter = *this->getTypeConverter();
87  auto operandType = adaptor.getOperand().getType();
88  auto llvmOperandType = typeConverter.convertType(operandType);
89  if (!llvmOperandType)
90  return failure();
91 
92  auto loc = op.getLoc();
93  auto resultType = op.getResult().getType();
94  auto llvmResultType = typeConverter.convertType(resultType);
95  if (!llvmResultType)
96  return failure();
97 
98  if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
99  rewriter.replaceOpWithNewOp<LLVMOp>(op, llvmResultType,
100  adaptor.getOperand(), false);
101  return success();
102  }
103 
104  if (!isa<VectorType>(llvmResultType))
105  return failure();
106 
108  op.getOperation(), adaptor.getOperands(), typeConverter,
109  [&](Type llvm1DVectorTy, ValueRange operands) {
110  return rewriter.create<LLVMOp>(loc, llvm1DVectorTy, operands[0],
111  false);
112  },
113  rewriter);
114  }
115 };
116 
117 using CountLeadingZerosOpLowering =
118  IntOpWithFlagLowering<math::CountLeadingZerosOp, LLVM::CountLeadingZerosOp>;
119 using CountTrailingZerosOpLowering =
120  IntOpWithFlagLowering<math::CountTrailingZerosOp,
121  LLVM::CountTrailingZerosOp>;
122 using AbsIOpLowering = IntOpWithFlagLowering<math::AbsIOp, LLVM::AbsOp>;
123 
124 // A `expm1` is converted into `exp - 1`.
125 struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
127 
128  LogicalResult
129  matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor,
130  ConversionPatternRewriter &rewriter) const override {
131  const auto &typeConverter = *this->getTypeConverter();
132  auto operandType = adaptor.getOperand().getType();
133  auto llvmOperandType = typeConverter.convertType(operandType);
134  if (!llvmOperandType)
135  return failure();
136 
137  auto loc = op.getLoc();
138  auto resultType = op.getResult().getType();
139  auto floatType = cast<FloatType>(
140  typeConverter.convertType(getElementTypeOrSelf(resultType)));
141  auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
142  ConvertFastMath<math::ExpM1Op, LLVM::ExpOp> expAttrs(op);
143  ConvertFastMath<math::ExpM1Op, LLVM::FSubOp> subAttrs(op);
144 
145  if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
146  LLVM::ConstantOp one;
147  if (LLVM::isCompatibleVectorType(llvmOperandType)) {
148  one = rewriter.create<LLVM::ConstantOp>(
149  loc, llvmOperandType,
150  SplatElementsAttr::get(cast<ShapedType>(llvmOperandType),
151  floatOne));
152  } else {
153  one = rewriter.create<LLVM::ConstantOp>(loc, llvmOperandType, floatOne);
154  }
155  auto exp = rewriter.create<LLVM::ExpOp>(loc, adaptor.getOperand(),
156  expAttrs.getAttrs());
157  rewriter.replaceOpWithNewOp<LLVM::FSubOp>(
158  op, llvmOperandType, ValueRange{exp, one}, subAttrs.getAttrs());
159  return success();
160  }
161 
162  if (!isa<VectorType>(resultType))
163  return rewriter.notifyMatchFailure(op, "expected vector result type");
164 
166  op.getOperation(), adaptor.getOperands(), typeConverter,
167  [&](Type llvm1DVectorTy, ValueRange operands) {
168  auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
169  auto splatAttr = SplatElementsAttr::get(
170  mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
171  {numElements.isScalable()}),
172  floatOne);
173  auto one =
174  rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
175  auto exp = rewriter.create<LLVM::ExpOp>(
176  loc, llvm1DVectorTy, operands[0], expAttrs.getAttrs());
177  return rewriter.create<LLVM::FSubOp>(
178  loc, llvm1DVectorTy, ValueRange{exp, one}, subAttrs.getAttrs());
179  },
180  rewriter);
181  }
182 };
183 
184 // A `log1p` is converted into `log(1 + ...)`.
185 struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
187 
188  LogicalResult
189  matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor,
190  ConversionPatternRewriter &rewriter) const override {
191  const auto &typeConverter = *this->getTypeConverter();
192  auto operandType = adaptor.getOperand().getType();
193  auto llvmOperandType = typeConverter.convertType(operandType);
194  if (!llvmOperandType)
195  return rewriter.notifyMatchFailure(op, "unsupported operand type");
196 
197  auto loc = op.getLoc();
198  auto resultType = op.getResult().getType();
199  auto floatType = cast<FloatType>(
200  typeConverter.convertType(getElementTypeOrSelf(resultType)));
201  auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
202  ConvertFastMath<math::Log1pOp, LLVM::FAddOp> addAttrs(op);
203  ConvertFastMath<math::Log1pOp, LLVM::LogOp> logAttrs(op);
204 
205  if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
206  LLVM::ConstantOp one =
207  isa<VectorType>(llvmOperandType)
208  ? rewriter.create<LLVM::ConstantOp>(
209  loc, llvmOperandType,
210  SplatElementsAttr::get(cast<ShapedType>(llvmOperandType),
211  floatOne))
212  : rewriter.create<LLVM::ConstantOp>(loc, llvmOperandType,
213  floatOne);
214 
215  auto add = rewriter.create<LLVM::FAddOp>(
216  loc, llvmOperandType, ValueRange{one, adaptor.getOperand()},
217  addAttrs.getAttrs());
218  rewriter.replaceOpWithNewOp<LLVM::LogOp>(
219  op, llvmOperandType, ValueRange{add}, logAttrs.getAttrs());
220  return success();
221  }
222 
223  if (!isa<VectorType>(resultType))
224  return rewriter.notifyMatchFailure(op, "expected vector result type");
225 
227  op.getOperation(), adaptor.getOperands(), typeConverter,
228  [&](Type llvm1DVectorTy, ValueRange operands) {
229  auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
230  auto splatAttr = SplatElementsAttr::get(
231  mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
232  {numElements.isScalable()}),
233  floatOne);
234  auto one =
235  rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
236  auto add = rewriter.create<LLVM::FAddOp>(loc, llvm1DVectorTy,
237  ValueRange{one, operands[0]},
238  addAttrs.getAttrs());
239  return rewriter.create<LLVM::LogOp>(
240  loc, llvm1DVectorTy, ValueRange{add}, logAttrs.getAttrs());
241  },
242  rewriter);
243  }
244 };
245 
246 // A `rsqrt` is converted into `1 / sqrt`.
247 struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
249 
250  LogicalResult
251  matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor,
252  ConversionPatternRewriter &rewriter) const override {
253  const auto &typeConverter = *this->getTypeConverter();
254  auto operandType = adaptor.getOperand().getType();
255  auto llvmOperandType = typeConverter.convertType(operandType);
256  if (!llvmOperandType)
257  return failure();
258 
259  auto loc = op.getLoc();
260  auto resultType = op.getResult().getType();
261  auto floatType = cast<FloatType>(
262  typeConverter.convertType(getElementTypeOrSelf(resultType)));
263  auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
264  ConvertFastMath<math::RsqrtOp, LLVM::SqrtOp> sqrtAttrs(op);
265  ConvertFastMath<math::RsqrtOp, LLVM::FDivOp> divAttrs(op);
266 
267  if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
268  LLVM::ConstantOp one;
269  if (isa<VectorType>(llvmOperandType)) {
270  one = rewriter.create<LLVM::ConstantOp>(
271  loc, llvmOperandType,
272  SplatElementsAttr::get(cast<ShapedType>(llvmOperandType),
273  floatOne));
274  } else {
275  one = rewriter.create<LLVM::ConstantOp>(loc, llvmOperandType, floatOne);
276  }
277  auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, adaptor.getOperand(),
278  sqrtAttrs.getAttrs());
279  rewriter.replaceOpWithNewOp<LLVM::FDivOp>(
280  op, llvmOperandType, ValueRange{one, sqrt}, divAttrs.getAttrs());
281  return success();
282  }
283 
284  if (!isa<VectorType>(resultType))
285  return failure();
286 
288  op.getOperation(), adaptor.getOperands(), typeConverter,
289  [&](Type llvm1DVectorTy, ValueRange operands) {
290  auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
291  auto splatAttr = SplatElementsAttr::get(
292  mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
293  {numElements.isScalable()}),
294  floatOne);
295  auto one =
296  rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
297  auto sqrt = rewriter.create<LLVM::SqrtOp>(
298  loc, llvm1DVectorTy, operands[0], sqrtAttrs.getAttrs());
299  return rewriter.create<LLVM::FDivOp>(
300  loc, llvm1DVectorTy, ValueRange{one, sqrt}, divAttrs.getAttrs());
301  },
302  rewriter);
303  }
304 };
305 
306 struct IsNaNOpLowering : public ConvertOpToLLVMPattern<math::IsNaNOp> {
308 
309  LogicalResult
310  matchAndRewrite(math::IsNaNOp op, OpAdaptor adaptor,
311  ConversionPatternRewriter &rewriter) const override {
312  const auto &typeConverter = *this->getTypeConverter();
313  auto operandType =
314  typeConverter.convertType(adaptor.getOperand().getType());
315  auto resultType = typeConverter.convertType(op.getResult().getType());
316  if (!operandType || !resultType)
317  return failure();
318 
319  rewriter.replaceOpWithNewOp<LLVM::IsFPClass>(
320  op, resultType, adaptor.getOperand(), llvm::fcNan);
321  return success();
322  }
323 };
324 
325 struct IsFiniteOpLowering : public ConvertOpToLLVMPattern<math::IsFiniteOp> {
327 
328  LogicalResult
329  matchAndRewrite(math::IsFiniteOp op, OpAdaptor adaptor,
330  ConversionPatternRewriter &rewriter) const override {
331  const auto &typeConverter = *this->getTypeConverter();
332  auto operandType =
333  typeConverter.convertType(adaptor.getOperand().getType());
334  auto resultType = typeConverter.convertType(op.getResult().getType());
335  if (!operandType || !resultType)
336  return failure();
337 
338  rewriter.replaceOpWithNewOp<LLVM::IsFPClass>(
339  op, resultType, adaptor.getOperand(), llvm::fcFinite);
340  return success();
341  }
342 };
343 
344 struct ConvertMathToLLVMPass
345  : public impl::ConvertMathToLLVMPassBase<ConvertMathToLLVMPass> {
346  using Base::Base;
347 
348  void runOnOperation() override {
350  LLVMTypeConverter converter(&getContext());
351  populateMathToLLVMConversionPatterns(converter, patterns, approximateLog1p);
353  if (failed(applyPartialConversion(getOperation(), target,
354  std::move(patterns))))
355  signalPassFailure();
356  }
357 };
358 } // namespace
359 
361  const LLVMTypeConverter &converter, RewritePatternSet &patterns,
362  bool approximateLog1p, PatternBenefit benefit) {
363  if (approximateLog1p)
364  patterns.add<Log1pOpLowering>(converter, benefit);
365  // clang-format off
366  patterns.add<
367  IsNaNOpLowering,
368  IsFiniteOpLowering,
369  AbsFOpLowering,
370  AbsIOpLowering,
371  CeilOpLowering,
372  CopySignOpLowering,
373  CosOpLowering,
374  CoshOpLowering,
375  AcosOpLowering,
376  CountLeadingZerosOpLowering,
377  CountTrailingZerosOpLowering,
378  CtPopFOpLowering,
379  Exp2OpLowering,
380  ExpM1OpLowering,
381  ExpOpLowering,
382  FPowIOpLowering,
383  FloorOpLowering,
384  FmaOpLowering,
385  Log10OpLowering,
386  Log2OpLowering,
387  LogOpLowering,
388  PowFOpLowering,
389  RoundEvenOpLowering,
390  RoundOpLowering,
391  RsqrtOpLowering,
392  SinOpLowering,
393  SinhOpLowering,
394  ASinOpLowering,
395  SqrtOpLowering,
396  FTruncOpLowering,
397  TanOpLowering,
398  TanhOpLowering,
399  ATanOpLowering,
400  ATan2OpLowering
401  >(converter, benefit);
402  // clang-format on
403 }
404 
405 //===----------------------------------------------------------------------===//
406 // ConvertToLLVMPatternInterface implementation
407 //===----------------------------------------------------------------------===//
408 
409 namespace {
410 /// Implement the interface to convert Math to LLVM.
411 struct MathToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
413  void loadDependentDialects(MLIRContext *context) const final {
414  context->loadDialect<LLVM::LLVMDialect>();
415  }
416 
417  /// Hook for derived dialect interface to provide conversion patterns
418  /// and mark dialect legal for the conversion target.
420  ConversionTarget &target, LLVMTypeConverter &typeConverter,
421  RewritePatternSet &patterns) const final {
423  }
424 };
425 } // namespace
426 
428  registry.addExtension(+[](MLIRContext *ctx, math::MathDialect *dialect) {
429  dialect->addInterfaces<MathToLLVMDialectInterface>();
430  });
431 }
static MLIRContext * getContext(OpFoldResult val)
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:252
This class implements a pattern rewriter for use with ConversionPatterns.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:195
LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const final
Wrappers around the RewritePattern methods that pass the derived op type.
Definition: Pattern.h:209
Base class for dialect interfaces providing translation to LLVM IR.
virtual void populateConvertToLLVMConversionPatterns(ConversionTarget &target, LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) const =0
Hook for derived dialect interface to provide conversion patterns and mark dialect legal for the conv...
virtual void loadDependentDialects(MLIRContext *context) const
Hook for derived dialect interface to load the dialects they target.
ConvertToLLVMPatternInterface(Dialect *dialect)
const LLVMTypeConverter * getTypeConverter() const
Definition: Pattern.cpp:27
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:455
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
Basic lowering implementation to rewrite Ops with just one result to the LLVM Dialect.
Definition: VectorPattern.h:90
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:814
Include the generated interface declarations.
void populateMathToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool approximateLog1p=true, PatternBenefit benefit=1)
Definition: MathToLLVM.cpp:360
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
void registerConvertMathToLLVMInterface(DialectRegistry &registry)
Definition: MathToLLVM.cpp:427