MLIR  20.0.0git
MathToLibm.cpp
Go to the documentation of this file.
1 //===-- MathToLibm.cpp - conversion from Math to libm calls ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
17 #include "mlir/IR/BuiltinDialect.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/Pass/Pass.h"
21 
22 namespace mlir {
23 #define GEN_PASS_DEF_CONVERTMATHTOLIBM
24 #include "mlir/Conversion/Passes.h.inc"
25 } // namespace mlir
26 
27 using namespace mlir;
28 
29 namespace {
30 // Pattern to convert vector operations to scalar operations. This is needed as
31 // libm calls require scalars.
32 template <typename Op>
33 struct VecOpToScalarOp : public OpRewritePattern<Op> {
34 public:
36 
37  LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
38 };
39 // Pattern to promote an op of a smaller floating point type to F32.
40 template <typename Op>
41 struct PromoteOpToF32 : public OpRewritePattern<Op> {
42 public:
44 
45  LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
46 };
47 // Pattern to convert scalar math operations to calls to libm functions.
48 // Additionally the libm function signatures are declared.
49 template <typename Op>
50 struct ScalarOpToLibmCall : public OpRewritePattern<Op> {
51 public:
53  ScalarOpToLibmCall(MLIRContext *context, StringRef floatFunc,
54  StringRef doubleFunc)
55  : OpRewritePattern<Op>(context), floatFunc(floatFunc),
56  doubleFunc(doubleFunc){};
57 
58  LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
59 
60 private:
61  std::string floatFunc, doubleFunc;
62 };
63 
64 template <typename OpTy>
65 void populatePatternsForOp(RewritePatternSet &patterns, MLIRContext *ctx,
66  StringRef floatFunc, StringRef doubleFunc) {
67  patterns.add<VecOpToScalarOp<OpTy>, PromoteOpToF32<OpTy>>(ctx);
68  patterns.add<ScalarOpToLibmCall<OpTy>>(ctx, floatFunc, doubleFunc);
69 }
70 
71 } // namespace
72 
73 template <typename Op>
74 LogicalResult
75 VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
76  auto opType = op.getType();
77  auto loc = op.getLoc();
78  auto vecType = dyn_cast<VectorType>(opType);
79 
80  if (!vecType)
81  return failure();
82  if (!vecType.hasRank())
83  return failure();
84  auto shape = vecType.getShape();
85  int64_t numElements = vecType.getNumElements();
86 
87  Value result = rewriter.create<arith::ConstantOp>(
89  vecType, FloatAttr::get(vecType.getElementType(), 0.0)));
90  SmallVector<int64_t> strides = computeStrides(shape);
91  for (auto linearIndex = 0; linearIndex < numElements; ++linearIndex) {
92  SmallVector<int64_t> positions = delinearize(linearIndex, strides);
93  SmallVector<Value> operands;
94  for (auto input : op->getOperands())
95  operands.push_back(
96  rewriter.create<vector::ExtractOp>(loc, input, positions));
97  Value scalarOp =
98  rewriter.create<Op>(loc, vecType.getElementType(), operands);
99  result =
100  rewriter.create<vector::InsertOp>(loc, scalarOp, result, positions);
101  }
102  rewriter.replaceOp(op, {result});
103  return success();
104 }
105 
106 template <typename Op>
107 LogicalResult
108 PromoteOpToF32<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
109  auto opType = op.getType();
110  if (!isa<Float16Type, BFloat16Type>(opType))
111  return failure();
112 
113  auto loc = op.getLoc();
114  auto f32 = rewriter.getF32Type();
115  auto extendedOperands = llvm::to_vector(
116  llvm::map_range(op->getOperands(), [&](Value operand) -> Value {
117  return rewriter.create<arith::ExtFOp>(loc, f32, operand);
118  }));
119  auto newOp = rewriter.create<Op>(loc, f32, extendedOperands);
120  rewriter.replaceOpWithNewOp<arith::TruncFOp>(op, opType, newOp);
121  return success();
122 }
123 
124 template <typename Op>
125 LogicalResult
126 ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
127  PatternRewriter &rewriter) const {
128  auto module = SymbolTable::getNearestSymbolTable(op);
129  auto type = op.getType();
130  if (!isa<Float32Type, Float64Type>(type))
131  return failure();
132 
133  auto name = type.getIntOrFloatBitWidth() == 64 ? doubleFunc : floatFunc;
134  auto opFunc = dyn_cast_or_null<SymbolOpInterface>(
135  SymbolTable::lookupSymbolIn(module, name));
136  // Forward declare function if it hasn't already been
137  if (!opFunc) {
138  OpBuilder::InsertionGuard guard(rewriter);
139  rewriter.setInsertionPointToStart(&module->getRegion(0).front());
140  auto opFunctionTy = FunctionType::get(
141  rewriter.getContext(), op->getOperandTypes(), op->getResultTypes());
142  opFunc = rewriter.create<func::FuncOp>(rewriter.getUnknownLoc(), name,
143  opFunctionTy);
144  opFunc.setPrivate();
145 
146  // By definition Math dialect operations imply LLVM's "readnone"
147  // function attribute, so we can set it here to provide more
148  // optimization opportunities (e.g. LICM) for backends targeting LLVM IR.
149  // This will have to be changed, when strict FP behavior is supported
150  // by Math dialect.
151  opFunc->setAttr(LLVM::LLVMDialect::getReadnoneAttrName(),
152  UnitAttr::get(rewriter.getContext()));
153  }
154  assert(isa<FunctionOpInterface>(SymbolTable::lookupSymbolIn(module, name)));
155 
156  rewriter.replaceOpWithNewOp<func::CallOp>(op, name, op.getType(),
157  op->getOperands());
158 
159  return success();
160 }
161 
163  MLIRContext *ctx = patterns.getContext();
164 
165  populatePatternsForOp<math::AbsFOp>(patterns, ctx, "fabsf", "fabs");
166  populatePatternsForOp<math::AcosOp>(patterns, ctx, "acosf", "acos");
167  populatePatternsForOp<math::AcoshOp>(patterns, ctx, "acoshf", "acosh");
168  populatePatternsForOp<math::AsinOp>(patterns, ctx, "asinf", "asin");
169  populatePatternsForOp<math::AsinhOp>(patterns, ctx, "asinhf", "asinh");
170  populatePatternsForOp<math::Atan2Op>(patterns, ctx, "atan2f", "atan2");
171  populatePatternsForOp<math::AtanOp>(patterns, ctx, "atanf", "atan");
172  populatePatternsForOp<math::AtanhOp>(patterns, ctx, "atanhf", "atanh");
173  populatePatternsForOp<math::CbrtOp>(patterns, ctx, "cbrtf", "cbrt");
174  populatePatternsForOp<math::CeilOp>(patterns, ctx, "ceilf", "ceil");
175  populatePatternsForOp<math::CosOp>(patterns, ctx, "cosf", "cos");
176  populatePatternsForOp<math::CoshOp>(patterns, ctx, "coshf", "cosh");
177  populatePatternsForOp<math::ErfOp>(patterns, ctx, "erff", "erf");
178  populatePatternsForOp<math::ExpOp>(patterns, ctx, "expf", "exp");
179  populatePatternsForOp<math::Exp2Op>(patterns, ctx, "exp2f", "exp2");
180  populatePatternsForOp<math::ExpM1Op>(patterns, ctx, "expm1f", "expm1");
181  populatePatternsForOp<math::FloorOp>(patterns, ctx, "floorf", "floor");
182  populatePatternsForOp<math::FmaOp>(patterns, ctx, "fmaf", "fma");
183  populatePatternsForOp<math::LogOp>(patterns, ctx, "logf", "log");
184  populatePatternsForOp<math::Log2Op>(patterns, ctx, "log2f", "log2");
185  populatePatternsForOp<math::Log10Op>(patterns, ctx, "log10f", "log10");
186  populatePatternsForOp<math::Log1pOp>(patterns, ctx, "log1pf", "log1p");
187  populatePatternsForOp<math::PowFOp>(patterns, ctx, "powf", "pow");
188  populatePatternsForOp<math::RoundEvenOp>(patterns, ctx, "roundevenf",
189  "roundeven");
190  populatePatternsForOp<math::RoundOp>(patterns, ctx, "roundf", "round");
191  populatePatternsForOp<math::SinOp>(patterns, ctx, "sinf", "sin");
192  populatePatternsForOp<math::SinhOp>(patterns, ctx, "sinhf", "sinh");
193  populatePatternsForOp<math::SqrtOp>(patterns, ctx, "sqrtf", "sqrt");
194  populatePatternsForOp<math::RsqrtOp>(patterns, ctx, "rsqrtf", "rsqrt");
195  populatePatternsForOp<math::TanOp>(patterns, ctx, "tanf", "tan");
196  populatePatternsForOp<math::TanhOp>(patterns, ctx, "tanhf", "tanh");
197  populatePatternsForOp<math::TruncOp>(patterns, ctx, "truncf", "trunc");
198 }
199 
200 namespace {
201 struct ConvertMathToLibmPass
202  : public impl::ConvertMathToLibmBase<ConvertMathToLibmPass> {
203  void runOnOperation() override;
204 };
205 } // namespace
206 
207 void ConvertMathToLibmPass::runOnOperation() {
208  auto module = getOperation();
209 
212 
213  ConversionTarget target(getContext());
214  target.addLegalDialect<arith::ArithDialect, BuiltinDialect, func::FuncDialect,
215  vector::VectorDialect>();
216  target.addIllegalDialect<math::MathDialect>();
217  if (failed(applyPartialConversion(module, target, std::move(patterns))))
218  signalPassFailure();
219 }
220 
221 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertMathToLibmPass() {
222  return std::make_unique<ConvertMathToLibmPass>();
223 }
static MLIRContext * getContext(OpFoldResult val)
FloatType getF32Type()
Definition: Builders.cpp:87
MLIRContext * getContext() const
Definition: Builders.h:56
Location getUnknownLoc()
Definition: Builders.cpp:27
This class describes a specific conversion target.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:357
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:440
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
Location getLoc()
The source location the operation was defined or derived from.
Definition: OpDefinition.h:125
This provides public APIs that all operations should have.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:582
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Include the generated interface declarations.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
Definition: IndexingUtils.h:47
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
void populateMathToLibmConversionPatterns(RewritePatternSet &patterns)
Populate the given list with patterns that convert from Math to Libm calls.
Definition: MathToLibm.cpp:162
const FrozenRewritePatternSet & patterns
std::unique_ptr< OperationPass< ModuleOp > > createConvertMathToLibmPass()
Create a pass to convert Math operations to libm calls.
Definition: MathToLibm.cpp:221
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358