MLIR  21.0.0git
MathToLibm.cpp
Go to the documentation of this file.
1 //===-- MathToLibm.cpp - conversion from Math to libm calls ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
17 #include "mlir/IR/BuiltinDialect.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/Pass/Pass.h"
21 
22 namespace mlir {
23 #define GEN_PASS_DEF_CONVERTMATHTOLIBMPASS
24 #include "mlir/Conversion/Passes.h.inc"
25 } // namespace mlir
26 
27 using namespace mlir;
28 
29 namespace {
30 // Pattern to convert vector operations to scalar operations. This is needed as
31 // libm calls require scalars.
32 template <typename Op>
33 struct VecOpToScalarOp : public OpRewritePattern<Op> {
34 public:
36 
37  LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
38 };
39 // Pattern to promote an op of a smaller floating point type to F32.
40 template <typename Op>
41 struct PromoteOpToF32 : public OpRewritePattern<Op> {
42 public:
44 
45  LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
46 };
47 // Pattern to convert scalar math operations to calls to libm functions.
48 // Additionally the libm function signatures are declared.
49 template <typename Op>
50 struct ScalarOpToLibmCall : public OpRewritePattern<Op> {
51 public:
53  ScalarOpToLibmCall(MLIRContext *context, PatternBenefit benefit,
54  StringRef floatFunc, StringRef doubleFunc)
55  : OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc),
56  doubleFunc(doubleFunc) {};
57 
58  LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
59 
60 private:
61  std::string floatFunc, doubleFunc;
62 };
63 
64 template <typename OpTy>
65 void populatePatternsForOp(RewritePatternSet &patterns, PatternBenefit benefit,
66  MLIRContext *ctx, StringRef floatFunc,
67  StringRef doubleFunc) {
68  patterns.add<VecOpToScalarOp<OpTy>, PromoteOpToF32<OpTy>>(ctx, benefit);
69  patterns.add<ScalarOpToLibmCall<OpTy>>(ctx, benefit, floatFunc, doubleFunc);
70 }
71 
72 } // namespace
73 
74 template <typename Op>
75 LogicalResult
76 VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
77  auto opType = op.getType();
78  auto loc = op.getLoc();
79  auto vecType = dyn_cast<VectorType>(opType);
80 
81  if (!vecType)
82  return failure();
83  if (!vecType.hasRank())
84  return failure();
85  auto shape = vecType.getShape();
86  int64_t numElements = vecType.getNumElements();
87 
88  Value result = rewriter.create<arith::ConstantOp>(
90  vecType, FloatAttr::get(vecType.getElementType(), 0.0)));
91  SmallVector<int64_t> strides = computeStrides(shape);
92  for (auto linearIndex = 0; linearIndex < numElements; ++linearIndex) {
93  SmallVector<int64_t> positions = delinearize(linearIndex, strides);
94  SmallVector<Value> operands;
95  for (auto input : op->getOperands())
96  operands.push_back(
97  rewriter.create<vector::ExtractOp>(loc, input, positions));
98  Value scalarOp =
99  rewriter.create<Op>(loc, vecType.getElementType(), operands);
100  result =
101  rewriter.create<vector::InsertOp>(loc, scalarOp, result, positions);
102  }
103  rewriter.replaceOp(op, {result});
104  return success();
105 }
106 
107 template <typename Op>
108 LogicalResult
109 PromoteOpToF32<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
110  auto opType = op.getType();
111  if (!isa<Float16Type, BFloat16Type>(opType))
112  return failure();
113 
114  auto loc = op.getLoc();
115  auto f32 = rewriter.getF32Type();
116  auto extendedOperands = llvm::to_vector(
117  llvm::map_range(op->getOperands(), [&](Value operand) -> Value {
118  return rewriter.create<arith::ExtFOp>(loc, f32, operand);
119  }));
120  auto newOp = rewriter.create<Op>(loc, f32, extendedOperands);
121  rewriter.replaceOpWithNewOp<arith::TruncFOp>(op, opType, newOp);
122  return success();
123 }
124 
125 template <typename Op>
126 LogicalResult
127 ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
128  PatternRewriter &rewriter) const {
129  auto module = SymbolTable::getNearestSymbolTable(op);
130  auto type = op.getType();
131  if (!isa<Float32Type, Float64Type>(type))
132  return failure();
133 
134  auto name = type.getIntOrFloatBitWidth() == 64 ? doubleFunc : floatFunc;
135  auto opFunc = dyn_cast_or_null<SymbolOpInterface>(
136  SymbolTable::lookupSymbolIn(module, name));
137  // Forward declare function if it hasn't already been
138  if (!opFunc) {
139  OpBuilder::InsertionGuard guard(rewriter);
140  rewriter.setInsertionPointToStart(&module->getRegion(0).front());
141  auto opFunctionTy = FunctionType::get(
142  rewriter.getContext(), op->getOperandTypes(), op->getResultTypes());
143  opFunc = rewriter.create<func::FuncOp>(rewriter.getUnknownLoc(), name,
144  opFunctionTy);
145  opFunc.setPrivate();
146 
147  // By definition Math dialect operations imply LLVM's "readnone"
148  // function attribute, so we can set it here to provide more
149  // optimization opportunities (e.g. LICM) for backends targeting LLVM IR.
150  // This will have to be changed, when strict FP behavior is supported
151  // by Math dialect.
152  opFunc->setAttr(LLVM::LLVMDialect::getReadnoneAttrName(),
153  UnitAttr::get(rewriter.getContext()));
154  }
155  assert(isa<FunctionOpInterface>(SymbolTable::lookupSymbolIn(module, name)));
156 
157  rewriter.replaceOpWithNewOp<func::CallOp>(op, name, op.getType(),
158  op->getOperands());
159 
160  return success();
161 }
162 
164  PatternBenefit benefit) {
165  MLIRContext *ctx = patterns.getContext();
166 
167  populatePatternsForOp<math::AbsFOp>(patterns, benefit, ctx, "fabsf", "fabs");
168  populatePatternsForOp<math::AcosOp>(patterns, benefit, ctx, "acosf", "acos");
169  populatePatternsForOp<math::AcoshOp>(patterns, benefit, ctx, "acoshf",
170  "acosh");
171  populatePatternsForOp<math::AsinOp>(patterns, benefit, ctx, "asinf", "asin");
172  populatePatternsForOp<math::AsinhOp>(patterns, benefit, ctx, "asinhf",
173  "asinh");
174  populatePatternsForOp<math::Atan2Op>(patterns, benefit, ctx, "atan2f",
175  "atan2");
176  populatePatternsForOp<math::AtanOp>(patterns, benefit, ctx, "atanf", "atan");
177  populatePatternsForOp<math::AtanhOp>(patterns, benefit, ctx, "atanhf",
178  "atanh");
179  populatePatternsForOp<math::CbrtOp>(patterns, benefit, ctx, "cbrtf", "cbrt");
180  populatePatternsForOp<math::CeilOp>(patterns, benefit, ctx, "ceilf", "ceil");
181  populatePatternsForOp<math::CosOp>(patterns, benefit, ctx, "cosf", "cos");
182  populatePatternsForOp<math::CoshOp>(patterns, benefit, ctx, "coshf", "cosh");
183  populatePatternsForOp<math::ErfOp>(patterns, benefit, ctx, "erff", "erf");
184  populatePatternsForOp<math::ErfcOp>(patterns, benefit, ctx, "erfcf", "erfc");
185  populatePatternsForOp<math::ExpOp>(patterns, benefit, ctx, "expf", "exp");
186  populatePatternsForOp<math::Exp2Op>(patterns, benefit, ctx, "exp2f", "exp2");
187  populatePatternsForOp<math::ExpM1Op>(patterns, benefit, ctx, "expm1f",
188  "expm1");
189  populatePatternsForOp<math::FloorOp>(patterns, benefit, ctx, "floorf",
190  "floor");
191  populatePatternsForOp<math::FmaOp>(patterns, benefit, ctx, "fmaf", "fma");
192  populatePatternsForOp<math::LogOp>(patterns, benefit, ctx, "logf", "log");
193  populatePatternsForOp<math::Log2Op>(patterns, benefit, ctx, "log2f", "log2");
194  populatePatternsForOp<math::Log10Op>(patterns, benefit, ctx, "log10f",
195  "log10");
196  populatePatternsForOp<math::Log1pOp>(patterns, benefit, ctx, "log1pf",
197  "log1p");
198  populatePatternsForOp<math::PowFOp>(patterns, benefit, ctx, "powf", "pow");
199  populatePatternsForOp<math::RoundEvenOp>(patterns, benefit, ctx, "roundevenf",
200  "roundeven");
201  populatePatternsForOp<math::RoundOp>(patterns, benefit, ctx, "roundf",
202  "round");
203  populatePatternsForOp<math::SinOp>(patterns, benefit, ctx, "sinf", "sin");
204  populatePatternsForOp<math::SinhOp>(patterns, benefit, ctx, "sinhf", "sinh");
205  populatePatternsForOp<math::SqrtOp>(patterns, benefit, ctx, "sqrtf", "sqrt");
206  populatePatternsForOp<math::RsqrtOp>(patterns, benefit, ctx, "rsqrtf",
207  "rsqrt");
208  populatePatternsForOp<math::TanOp>(patterns, benefit, ctx, "tanf", "tan");
209  populatePatternsForOp<math::TanhOp>(patterns, benefit, ctx, "tanhf", "tanh");
210  populatePatternsForOp<math::TruncOp>(patterns, benefit, ctx, "truncf",
211  "trunc");
212 }
213 
214 namespace {
215 struct ConvertMathToLibmPass
216  : public impl::ConvertMathToLibmPassBase<ConvertMathToLibmPass> {
217  void runOnOperation() override;
218 };
219 } // namespace
220 
221 void ConvertMathToLibmPass::runOnOperation() {
222  auto module = getOperation();
223 
226 
227  ConversionTarget target(getContext());
228  target.addLegalDialect<arith::ArithDialect, BuiltinDialect, func::FuncDialect,
229  vector::VectorDialect>();
230  target.addIllegalDialect<math::MathDialect>();
231  if (failed(applyPartialConversion(module, target, std::move(patterns))))
232  signalPassFailure();
233 }
static MLIRContext * getContext(OpFoldResult val)
FloatType getF32Type()
Definition: Builders.cpp:43
MLIRContext * getContext() const
Definition: Builders.h:56
Location getUnknownLoc()
Definition: Builders.cpp:27
This class describes a specific conversion target.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
Location getLoc()
The source location the operation was defined or derived from.
Definition: OpDefinition.h:128
This provides public APIs that all operations should have.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:582
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:803
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:554
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Include the generated interface declarations.
void populateMathToLibmConversionPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the given list with patterns that convert from Math to Libm calls.
Definition: MathToLibm.cpp:163
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
Definition: IndexingUtils.h:47
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358