MLIR  22.0.0git
AlgebraicSimplification.cpp
Go to the documentation of this file.
1 //===- AlgebraicSimplification.cpp - Simplify algebraic expressions -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements rewrites based on the basic rules of algebra
10 // (Commutativity, associativity, etc...) and strength reductions for math
11 // operations.
12 //
13 //===----------------------------------------------------------------------===//
14 
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/Matchers.h"
21 #include "mlir/IR/TypeUtilities.h"
22 #include <climits>
23 
24 using namespace mlir;
25 
26 //----------------------------------------------------------------------------//
27 // PowFOp strength reduction.
28 //----------------------------------------------------------------------------//
29 
30 namespace {
31 struct PowFStrengthReduction : public OpRewritePattern<math::PowFOp> {
32 public:
34 
35  LogicalResult matchAndRewrite(math::PowFOp op,
36  PatternRewriter &rewriter) const final;
37 };
38 } // namespace
39 
40 LogicalResult
41 PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
42  PatternRewriter &rewriter) const {
43  Location loc = op.getLoc();
44  Value x = op.getLhs();
45 
46  FloatAttr scalarExponent;
47  DenseFPElementsAttr vectorExponent;
48 
49  bool isScalar = matchPattern(op.getRhs(), m_Constant(&scalarExponent));
50  bool isVector = matchPattern(op.getRhs(), m_Constant(&vectorExponent));
51 
52  // Returns true if exponent is a constant equal to `value`.
53  auto isExponentValue = [&](double value) -> bool {
54  if (isScalar)
55  return scalarExponent.getValue().isExactlyValue(value);
56 
57  if (isVector && vectorExponent.isSplat())
58  return vectorExponent.getSplatValue<FloatAttr>()
59  .getValue()
60  .isExactlyValue(value);
61 
62  return false;
63  };
64 
65  // Maybe broadcasts scalar value into vector type compatible with `op`.
66  auto bcast = [&](Value value) -> Value {
67  if (auto vec = dyn_cast<VectorType>(op.getType()))
68  return vector::BroadcastOp::create(rewriter, op.getLoc(), vec, value);
69  return value;
70  };
71 
72  // Replace `pow(x, 1.0)` with `x`.
73  if (isExponentValue(1.0)) {
74  rewriter.replaceOp(op, x);
75  return success();
76  }
77 
78  // Replace `pow(x, 2.0)` with `x * x`.
79  if (isExponentValue(2.0)) {
80  rewriter.replaceOpWithNewOp<arith::MulFOp>(op, ValueRange({x, x}));
81  return success();
82  }
83 
84  // Replace `pow(x, 3.0)` with `x * x * x`.
85  if (isExponentValue(3.0)) {
86  Value square =
87  arith::MulFOp::create(rewriter, op.getLoc(), ValueRange({x, x}));
88  rewriter.replaceOpWithNewOp<arith::MulFOp>(op, ValueRange({x, square}));
89  return success();
90  }
91 
92  // Replace `pow(x, -1.0)` with `1.0 / x`.
93  if (isExponentValue(-1.0)) {
94  Value one = arith::ConstantOp::create(
95  rewriter, loc,
96  rewriter.getFloatAttr(getElementTypeOrSelf(op.getType()), 1.0));
97  rewriter.replaceOpWithNewOp<arith::DivFOp>(op, ValueRange({bcast(one), x}));
98  return success();
99  }
100 
101  // Replace `pow(x, 0.5)` with `sqrt(x)`.
102  if (isExponentValue(0.5)) {
103  rewriter.replaceOpWithNewOp<math::SqrtOp>(op, x);
104  return success();
105  }
106 
107  // Replace `pow(x, -0.5)` with `rsqrt(x)`.
108  if (isExponentValue(-0.5)) {
109  rewriter.replaceOpWithNewOp<math::RsqrtOp>(op, x);
110  return success();
111  }
112 
113  // Replace `pow(x, 0.75)` with `sqrt(sqrt(x)) * sqrt(x)`.
114  if (isExponentValue(0.75)) {
115  Value powHalf = math::SqrtOp::create(rewriter, op.getLoc(), x);
116  Value powQuarter = math::SqrtOp::create(rewriter, op.getLoc(), powHalf);
117  rewriter.replaceOpWithNewOp<arith::MulFOp>(op,
118  ValueRange{powHalf, powQuarter});
119  return success();
120  }
121 
122  return failure();
123 }
124 
125 //----------------------------------------------------------------------------//
126 // FPowIOp/IPowIOp strength reduction.
127 //----------------------------------------------------------------------------//
128 
129 namespace {
130 template <typename PowIOpTy, typename DivOpTy, typename MulOpTy>
131 struct PowIStrengthReduction : public OpRewritePattern<PowIOpTy> {
132 
133  unsigned exponentThreshold;
134 
135 public:
136  PowIStrengthReduction(MLIRContext *context, unsigned exponentThreshold = 3,
137  PatternBenefit benefit = 1,
138  ArrayRef<StringRef> generatedNames = {})
139  : OpRewritePattern<PowIOpTy>(context, benefit, generatedNames),
140  exponentThreshold(exponentThreshold) {}
141 
142  LogicalResult matchAndRewrite(PowIOpTy op,
143  PatternRewriter &rewriter) const final;
144 };
145 } // namespace
146 
147 template <typename PowIOpTy, typename DivOpTy, typename MulOpTy>
148 LogicalResult
149 PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
150  PowIOpTy op, PatternRewriter &rewriter) const {
151  Location loc = op.getLoc();
152  Value base = op.getLhs();
153 
154  IntegerAttr scalarExponent;
155  DenseIntElementsAttr vectorExponent;
156 
157  bool isScalar = matchPattern(op.getRhs(), m_Constant(&scalarExponent));
158  bool isVector = matchPattern(op.getRhs(), m_Constant(&vectorExponent));
159 
160  // Simplify cases with known exponent value.
161  int64_t exponentValue = 0;
162  if (isScalar)
163  exponentValue = scalarExponent.getInt();
164  else if (isVector && vectorExponent.isSplat())
165  exponentValue = vectorExponent.getSplatValue<IntegerAttr>().getInt();
166  else
167  return failure();
168 
169  // Maybe broadcasts scalar value into vector type compatible with `op`.
170  auto bcast = [&loc, &op, &rewriter](Value value) -> Value {
171  if (auto vec = dyn_cast<VectorType>(op.getType()))
172  return vector::BroadcastOp::create(rewriter, loc, vec, value);
173  return value;
174  };
175 
176  Value one;
177  Type opType = getElementTypeOrSelf(op.getType());
178  if constexpr (std::is_same_v<PowIOpTy, math::FPowIOp>)
179  one = arith::ConstantOp::create(rewriter, loc,
180  rewriter.getFloatAttr(opType, 1.0));
181  else
182  one = arith::ConstantOp::create(rewriter, loc,
183  rewriter.getIntegerAttr(opType, 1));
184 
185  // Replace `[fi]powi(x, 0)` with `1`.
186  if (exponentValue == 0) {
187  rewriter.replaceOp(op, bcast(one));
188  return success();
189  }
190 
191  bool exponentIsNegative = false;
192  if (exponentValue < 0) {
193  exponentIsNegative = true;
194  exponentValue *= -1;
195  }
196 
197  // Bail out if `abs(exponent)` exceeds the threshold.
198  if (exponentValue > exponentThreshold)
199  return failure();
200 
201  Value result = base;
202  // Transform to naive sequence of multiplications:
203  // * For positive exponent case replace:
204  // `[fi]powi(x, positive_exponent)`
205  // with:
206  // x * x * x * ...
207  // * For negative exponent case replace:
208  // `[fi]powi(x, negative_exponent)`
209  // with:
210  // (1 / x) * (1 / x) * (1 / x) * ...
211  for (unsigned i = 1; i < exponentValue; ++i)
212  result = MulOpTy::create(rewriter, loc, result, base);
213 
214  // Inverse the base for negative exponent, i.e. for
215  // `[fi]powi(x, negative_exponent)` set `x` to `1 / x`.
216  if (exponentIsNegative)
217  result = DivOpTy::create(rewriter, loc, bcast(one), result);
218 
219  rewriter.replaceOp(op, result);
220  return success();
221 }
222 
223 //----------------------------------------------------------------------------//
224 
227  patterns
228  .add<PowFStrengthReduction,
229  PowIStrengthReduction<math::IPowIOp, arith::DivSIOp, arith::MulIOp>,
230  PowIStrengthReduction<math::FPowIOp, arith::DivFOp, arith::MulFOp>>(
231  patterns.getContext());
232 }
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:223
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:249
An attribute that represents a reference to a dense float vector or tensor object.
An attribute that represents a reference to a dense integer vector or tensor object.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319