MLIR  16.0.0git
MathToLibm.cpp
Go to the documentation of this file.
1 //===-- MathToLibm.cpp - conversion from Math to libm calls ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
18 #include "mlir/IR/BuiltinDialect.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/Pass/Pass.h"
21 
22 namespace mlir {
23 #define GEN_PASS_DEF_CONVERTMATHTOLIBM
24 #include "mlir/Conversion/Passes.h.inc"
25 } // namespace mlir
26 
27 using namespace mlir;
28 
29 namespace {
30 // Pattern to convert vector operations to scalar operations. This is needed as
31 // libm calls require scalars.
32 template <typename Op>
33 struct VecOpToScalarOp : public OpRewritePattern<Op> {
34 public:
36 
37  LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
38 };
39 // Pattern to promote an op of a smaller floating point type to F32.
40 template <typename Op>
41 struct PromoteOpToF32 : public OpRewritePattern<Op> {
42 public:
44 
45  LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
46 };
47 // Pattern to convert scalar math operations to calls to libm functions.
48 // Additionally the libm function signatures are declared.
49 template <typename Op>
50 struct ScalarOpToLibmCall : public OpRewritePattern<Op> {
51 public:
53  ScalarOpToLibmCall<Op>(MLIRContext *context, StringRef floatFunc,
54  StringRef doubleFunc, PatternBenefit benefit)
55  : OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc),
56  doubleFunc(doubleFunc){};
57 
58  LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
59 
60 private:
61  std::string floatFunc, doubleFunc;
62 };
63 } // namespace
64 
65 template <typename Op>
67 VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
68  auto opType = op.getType();
69  auto loc = op.getLoc();
70  auto vecType = opType.template dyn_cast<VectorType>();
71 
72  if (!vecType)
73  return failure();
74  if (!vecType.hasRank())
75  return failure();
76  auto shape = vecType.getShape();
77  int64_t numElements = vecType.getNumElements();
78 
79  Value result = rewriter.create<arith::ConstantOp>(
81  vecType, FloatAttr::get(vecType.getElementType(), 0.0)));
82  SmallVector<int64_t> ones(shape.size(), 1);
83  SmallVector<int64_t> strides = computeStrides(shape, ones);
84  for (auto linearIndex = 0; linearIndex < numElements; ++linearIndex) {
85  SmallVector<int64_t> positions = delinearize(strides, linearIndex);
86  SmallVector<Value> operands;
87  for (auto input : op->getOperands())
88  operands.push_back(
89  rewriter.create<vector::ExtractOp>(loc, input, positions));
90  Value scalarOp =
91  rewriter.create<Op>(loc, vecType.getElementType(), operands);
92  result =
93  rewriter.create<vector::InsertOp>(loc, scalarOp, result, positions);
94  }
95  rewriter.replaceOp(op, {result});
96  return success();
97 }
98 
99 template <typename Op>
101 PromoteOpToF32<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
102  auto opType = op.getType();
103  if (!opType.template isa<Float16Type, BFloat16Type>())
104  return failure();
105 
106  auto loc = op.getLoc();
107  auto f32 = rewriter.getF32Type();
108  auto extendedOperands = llvm::to_vector(
109  llvm::map_range(op->getOperands(), [&](Value operand) -> Value {
110  return rewriter.create<arith::ExtFOp>(loc, f32, operand);
111  }));
112  auto newOp = rewriter.create<Op>(loc, f32, extendedOperands);
113  rewriter.replaceOpWithNewOp<arith::TruncFOp>(op, opType, newOp);
114  return success();
115 }
116 
117 template <typename Op>
119 ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
120  PatternRewriter &rewriter) const {
121  auto module = SymbolTable::getNearestSymbolTable(op);
122  auto type = op.getType();
123  if (!type.template isa<Float32Type, Float64Type>())
124  return failure();
125 
126  auto name = type.getIntOrFloatBitWidth() == 64 ? doubleFunc : floatFunc;
127  auto opFunc = dyn_cast_or_null<SymbolOpInterface>(
128  SymbolTable::lookupSymbolIn(module, name));
129  // Forward declare function if it hasn't already been
130  if (!opFunc) {
131  OpBuilder::InsertionGuard guard(rewriter);
132  rewriter.setInsertionPointToStart(&module->getRegion(0).front());
133  auto opFunctionTy = FunctionType::get(
134  rewriter.getContext(), op->getOperandTypes(), op->getResultTypes());
135  opFunc = rewriter.create<func::FuncOp>(rewriter.getUnknownLoc(), name,
136  opFunctionTy);
137  opFunc.setPrivate();
138 
139  // By definition Math dialect operations imply LLVM's "readnone"
140  // function attribute, so we can set it here to provide more
141  // optimization opportunities (e.g. LICM) for backends targeting LLVM IR.
142  // This will have to be changed, when strict FP behavior is supported
143  // by Math dialect.
144  opFunc->setAttr(LLVM::LLVMDialect::getReadnoneAttrName(),
145  UnitAttr::get(rewriter.getContext()));
146  }
147  assert(isa<FunctionOpInterface>(SymbolTable::lookupSymbolIn(module, name)));
148 
149  rewriter.replaceOpWithNewOp<func::CallOp>(op, name, op.getType(),
150  op->getOperands());
151 
152  return success();
153 }
154 
156  RewritePatternSet &patterns, PatternBenefit benefit,
157  llvm::Optional<PatternBenefit> log1pBenefit) {
158  patterns.add<VecOpToScalarOp<math::Atan2Op>, VecOpToScalarOp<math::ExpM1Op>,
159  VecOpToScalarOp<math::TanhOp>, VecOpToScalarOp<math::CosOp>,
160  VecOpToScalarOp<math::SinOp>, VecOpToScalarOp<math::ErfOp>,
161  VecOpToScalarOp<math::RoundEvenOp>,
162  VecOpToScalarOp<math::RoundOp>, VecOpToScalarOp<math::AtanOp>,
163  VecOpToScalarOp<math::TanOp>, VecOpToScalarOp<math::TruncOp>>(
164  patterns.getContext(), benefit);
165  patterns.add<PromoteOpToF32<math::Atan2Op>, PromoteOpToF32<math::ExpM1Op>,
166  PromoteOpToF32<math::TanhOp>, PromoteOpToF32<math::CosOp>,
167  PromoteOpToF32<math::SinOp>, PromoteOpToF32<math::ErfOp>,
168  PromoteOpToF32<math::RoundEvenOp>, PromoteOpToF32<math::RoundOp>,
169  PromoteOpToF32<math::AtanOp>, PromoteOpToF32<math::TanOp>,
170  PromoteOpToF32<math::TruncOp>>(patterns.getContext(), benefit);
171  patterns.add<ScalarOpToLibmCall<math::AtanOp>>(patterns.getContext(), "atanf",
172  "atan", benefit);
173  patterns.add<ScalarOpToLibmCall<math::Atan2Op>>(patterns.getContext(),
174  "atan2f", "atan2", benefit);
175  patterns.add<ScalarOpToLibmCall<math::ErfOp>>(patterns.getContext(), "erff",
176  "erf", benefit);
177  patterns.add<ScalarOpToLibmCall<math::ExpM1Op>>(patterns.getContext(),
178  "expm1f", "expm1", benefit);
179  patterns.add<ScalarOpToLibmCall<math::TanOp>>(patterns.getContext(), "tanf",
180  "tan", benefit);
181  patterns.add<ScalarOpToLibmCall<math::TanhOp>>(patterns.getContext(), "tanhf",
182  "tanh", benefit);
183  patterns.add<ScalarOpToLibmCall<math::RoundEvenOp>>(
184  patterns.getContext(), "roundevenf", "roundeven", benefit);
185  patterns.add<ScalarOpToLibmCall<math::RoundOp>>(patterns.getContext(),
186  "roundf", "round", benefit);
187  patterns.add<ScalarOpToLibmCall<math::CosOp>>(patterns.getContext(), "cosf",
188  "cos", benefit);
189  patterns.add<ScalarOpToLibmCall<math::SinOp>>(patterns.getContext(), "sinf",
190  "sin", benefit);
191  patterns.add<ScalarOpToLibmCall<math::Log1pOp>>(
192  patterns.getContext(), "log1pf", "log1p", log1pBenefit.value_or(benefit));
193  patterns.add<ScalarOpToLibmCall<math::FloorOp>>(patterns.getContext(),
194  "floorf", "floor", benefit);
195  patterns.add<ScalarOpToLibmCall<math::CeilOp>>(patterns.getContext(), "ceilf",
196  "ceil", benefit);
197  patterns.add<ScalarOpToLibmCall<math::TruncOp>>(patterns.getContext(),
198  "truncf", "trunc", benefit);
199 }
200 
201 namespace {
202 struct ConvertMathToLibmPass
203  : public impl::ConvertMathToLibmBase<ConvertMathToLibmPass> {
204  void runOnOperation() override;
205 };
206 } // namespace
207 
208 void ConvertMathToLibmPass::runOnOperation() {
209  auto module = getOperation();
210 
211  RewritePatternSet patterns(&getContext());
212  populateMathToLibmConversionPatterns(patterns, /*benefit=*/1);
213 
214  ConversionTarget target(getContext());
215  target.addLegalDialect<arith::ArithDialect, BuiltinDialect, func::FuncDialect,
216  vector::VectorDialect>();
217  target.addIllegalDialect<math::MathDialect>();
218  if (failed(applyPartialConversion(module, target, std::move(patterns))))
219  signalPassFailure();
220 }
221 
222 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertMathToLibmPass() {
223  return std::make_unique<ConvertMathToLibmPass>();
224 }
Location getUnknownLoc()
Definition: Builders.cpp:26
Include the generated interface declarations.
SmallVector< int64_t, 4 > computeStrides(ArrayRef< int64_t > shape, ArrayRef< int64_t > sizes)
Given the shape and sizes of a vector, returns the corresponding strides for each dimension...
Definition: VectorUtils.cpp:54
MLIRContext * getContext() const
Definition: Builders.h:54
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:600
LogicalResult applyPartialConversion(ArrayRef< Operation *> ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation *> *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of &#39;symbolTableOp&#39;.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
FloatType getF32Type()
Definition: Builders.cpp:44
void populateMathToLibmConversionPatterns(RewritePatternSet &patterns, PatternBenefit benefit, llvm::Optional< PatternBenefit > log1pBenefit=llvm::None)
Populate the given list with patterns that convert from Math to Libm calls.
Definition: MathToLibm.cpp:155
std::unique_ptr< OperationPass< ModuleOp > > createConvertMathToLibmPass()
Create a pass to convert Math operations to libm calls.
Definition: MathToLibm.cpp:222
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:418
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:32
SmallVector< int64_t, 4 > delinearize(ArrayRef< int64_t > strides, int64_t linearIndex)
Given the strides together with a linear index in the dimension space, returns the vector-space offse...
void addLegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as legal.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:355
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:382
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:299
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:395
void addIllegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as illegal, i.e.
Location getLoc()
The source location the operation was defined or derived from.
Definition: OpDefinition.h:114
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:56
This provides public APIs that all operations should have.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
This class describes a specific conversion target.
MLIRContext * getContext() const