MLIR  19.0.0git
MathToLibm.cpp
Go to the documentation of this file.
1 //===-- MathToLibm.cpp - conversion from Math to libm calls ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
17 #include "mlir/IR/BuiltinDialect.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/Pass/Pass.h"
21 
22 namespace mlir {
23 #define GEN_PASS_DEF_CONVERTMATHTOLIBM
24 #include "mlir/Conversion/Passes.h.inc"
25 } // namespace mlir
26 
27 using namespace mlir;
28 
29 namespace {
30 // Pattern to convert vector operations to scalar operations. This is needed as
31 // libm calls require scalars.
32 template <typename Op>
33 struct VecOpToScalarOp : public OpRewritePattern<Op> {
34 public:
36 
37  LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
38 };
39 // Pattern to promote an op of a smaller floating point type to F32.
40 template <typename Op>
41 struct PromoteOpToF32 : public OpRewritePattern<Op> {
42 public:
44 
45  LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
46 };
47 // Pattern to convert scalar math operations to calls to libm functions.
48 // Additionally the libm function signatures are declared.
49 template <typename Op>
50 struct ScalarOpToLibmCall : public OpRewritePattern<Op> {
51 public:
53  ScalarOpToLibmCall(MLIRContext *context, StringRef floatFunc,
54  StringRef doubleFunc)
55  : OpRewritePattern<Op>(context), floatFunc(floatFunc),
56  doubleFunc(doubleFunc){};
57 
58  LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
59 
60 private:
61  std::string floatFunc, doubleFunc;
62 };
63 
64 template <typename OpTy>
65 void populatePatternsForOp(RewritePatternSet &patterns, MLIRContext *ctx,
66  StringRef floatFunc, StringRef doubleFunc) {
67  patterns.add<VecOpToScalarOp<OpTy>, PromoteOpToF32<OpTy>>(ctx);
68  patterns.add<ScalarOpToLibmCall<OpTy>>(ctx, floatFunc, doubleFunc);
69 }
70 
71 } // namespace
72 
73 template <typename Op>
75 VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
76  auto opType = op.getType();
77  auto loc = op.getLoc();
78  auto vecType = dyn_cast<VectorType>(opType);
79 
80  if (!vecType)
81  return failure();
82  if (!vecType.hasRank())
83  return failure();
84  auto shape = vecType.getShape();
85  int64_t numElements = vecType.getNumElements();
86 
87  Value result = rewriter.create<arith::ConstantOp>(
89  vecType, FloatAttr::get(vecType.getElementType(), 0.0)));
90  SmallVector<int64_t> strides = computeStrides(shape);
91  for (auto linearIndex = 0; linearIndex < numElements; ++linearIndex) {
92  SmallVector<int64_t> positions = delinearize(linearIndex, strides);
93  SmallVector<Value> operands;
94  for (auto input : op->getOperands())
95  operands.push_back(
96  rewriter.create<vector::ExtractOp>(loc, input, positions));
97  Value scalarOp =
98  rewriter.create<Op>(loc, vecType.getElementType(), operands);
99  result =
100  rewriter.create<vector::InsertOp>(loc, scalarOp, result, positions);
101  }
102  rewriter.replaceOp(op, {result});
103  return success();
104 }
105 
106 template <typename Op>
108 PromoteOpToF32<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
109  auto opType = op.getType();
110  if (!isa<Float16Type, BFloat16Type>(opType))
111  return failure();
112 
113  auto loc = op.getLoc();
114  auto f32 = rewriter.getF32Type();
115  auto extendedOperands = llvm::to_vector(
116  llvm::map_range(op->getOperands(), [&](Value operand) -> Value {
117  return rewriter.create<arith::ExtFOp>(loc, f32, operand);
118  }));
119  auto newOp = rewriter.create<Op>(loc, f32, extendedOperands);
120  rewriter.replaceOpWithNewOp<arith::TruncFOp>(op, opType, newOp);
121  return success();
122 }
123 
124 template <typename Op>
126 ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
127  PatternRewriter &rewriter) const {
128  auto module = SymbolTable::getNearestSymbolTable(op);
129  auto type = op.getType();
130  if (!isa<Float32Type, Float64Type>(type))
131  return failure();
132 
133  auto name = type.getIntOrFloatBitWidth() == 64 ? doubleFunc : floatFunc;
134  auto opFunc = dyn_cast_or_null<SymbolOpInterface>(
135  SymbolTable::lookupSymbolIn(module, name));
136  // Forward declare function if it hasn't already been
137  if (!opFunc) {
138  OpBuilder::InsertionGuard guard(rewriter);
139  rewriter.setInsertionPointToStart(&module->getRegion(0).front());
140  auto opFunctionTy = FunctionType::get(
141  rewriter.getContext(), op->getOperandTypes(), op->getResultTypes());
142  opFunc = rewriter.create<func::FuncOp>(rewriter.getUnknownLoc(), name,
143  opFunctionTy);
144  opFunc.setPrivate();
145 
146  // By definition Math dialect operations imply LLVM's "readnone"
147  // function attribute, so we can set it here to provide more
148  // optimization opportunities (e.g. LICM) for backends targeting LLVM IR.
149  // This will have to be changed, when strict FP behavior is supported
150  // by Math dialect.
151  opFunc->setAttr(LLVM::LLVMDialect::getReadnoneAttrName(),
152  UnitAttr::get(rewriter.getContext()));
153  }
154  assert(isa<FunctionOpInterface>(SymbolTable::lookupSymbolIn(module, name)));
155 
156  rewriter.replaceOpWithNewOp<func::CallOp>(op, name, op.getType(),
157  op->getOperands());
158 
159  return success();
160 }
161 
163  MLIRContext *ctx = patterns.getContext();
164 
165  populatePatternsForOp<math::AbsFOp>(patterns, ctx, "fabsf", "fabs");
166  populatePatternsForOp<math::AcosOp>(patterns, ctx, "acosf", "acos");
167  populatePatternsForOp<math::AcoshOp>(patterns, ctx, "acoshf", "acosh");
168  populatePatternsForOp<math::AsinOp>(patterns, ctx, "asinf", "asin");
169  populatePatternsForOp<math::AsinhOp>(patterns, ctx, "asinhf", "asinh");
170  populatePatternsForOp<math::Atan2Op>(patterns, ctx, "atan2f", "atan2");
171  populatePatternsForOp<math::AtanOp>(patterns, ctx, "atanf", "atan");
172  populatePatternsForOp<math::AtanhOp>(patterns, ctx, "atanhf", "atanh");
173  populatePatternsForOp<math::CbrtOp>(patterns, ctx, "cbrtf", "cbrt");
174  populatePatternsForOp<math::CeilOp>(patterns, ctx, "ceilf", "ceil");
175  populatePatternsForOp<math::CosOp>(patterns, ctx, "cosf", "cos");
176  populatePatternsForOp<math::CoshOp>(patterns, ctx, "coshf", "cosh");
177  populatePatternsForOp<math::ErfOp>(patterns, ctx, "erff", "erf");
178  populatePatternsForOp<math::ExpOp>(patterns, ctx, "expf", "exp");
179  populatePatternsForOp<math::Exp2Op>(patterns, ctx, "exp2f", "exp2");
180  populatePatternsForOp<math::ExpM1Op>(patterns, ctx, "expm1f", "expm1");
181  populatePatternsForOp<math::FloorOp>(patterns, ctx, "floorf", "floor");
182  populatePatternsForOp<math::FmaOp>(patterns, ctx, "fmaf", "fma");
183  populatePatternsForOp<math::LogOp>(patterns, ctx, "logf", "log");
184  populatePatternsForOp<math::Log2Op>(patterns, ctx, "log2f", "log2");
185  populatePatternsForOp<math::Log10Op>(patterns, ctx, "log10f", "log10");
186  populatePatternsForOp<math::Log1pOp>(patterns, ctx, "log1pf", "log1p");
187  populatePatternsForOp<math::PowFOp>(patterns, ctx, "powf", "pow");
188  populatePatternsForOp<math::RoundEvenOp>(patterns, ctx, "roundevenf",
189  "roundeven");
190  populatePatternsForOp<math::RoundOp>(patterns, ctx, "roundf", "round");
191  populatePatternsForOp<math::SinOp>(patterns, ctx, "sinf", "sin");
192  populatePatternsForOp<math::SinhOp>(patterns, ctx, "sinhf", "sinh");
193  populatePatternsForOp<math::SqrtOp>(patterns, ctx, "sqrtf", "sqrt");
194  populatePatternsForOp<math::TanOp>(patterns, ctx, "tanf", "tan");
195  populatePatternsForOp<math::TanhOp>(patterns, ctx, "tanhf", "tanh");
196  populatePatternsForOp<math::TruncOp>(patterns, ctx, "truncf", "trunc");
197 }
198 
199 namespace {
200 struct ConvertMathToLibmPass
201  : public impl::ConvertMathToLibmBase<ConvertMathToLibmPass> {
202  void runOnOperation() override;
203 };
204 } // namespace
205 
206 void ConvertMathToLibmPass::runOnOperation() {
207  auto module = getOperation();
208 
209  RewritePatternSet patterns(&getContext());
211 
212  ConversionTarget target(getContext());
213  target.addLegalDialect<arith::ArithDialect, BuiltinDialect, func::FuncDialect,
214  vector::VectorDialect>();
215  target.addIllegalDialect<math::MathDialect>();
216  if (failed(applyPartialConversion(module, target, std::move(patterns))))
217  signalPassFailure();
218 }
219 
220 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertMathToLibmPass() {
221  return std::make_unique<ConvertMathToLibmPass>();
222 }
static MLIRContext * getContext(OpFoldResult val)
FloatType getF32Type()
Definition: Builders.cpp:63
MLIRContext * getContext() const
Definition: Builders.h:55
Location getUnknownLoc()
Definition: Builders.cpp:27
This class describes a specific conversion target.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:433
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
This provides public APIs that all operations should have.
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:577
operand_type_range getOperandTypes()
Definition: Operation.h:392
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
Definition: IndexingUtils.h:47
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
void populateMathToLibmConversionPatterns(RewritePatternSet &patterns)
Populate the given list with patterns that convert from Math to Libm calls.
Definition: MathToLibm.cpp:162
std::unique_ptr< OperationPass< ModuleOp > > createConvertMathToLibmPass()
Create a pass to convert Math operations to libm calls.
Definition: MathToLibm.cpp:220
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358