MLIR  22.0.0git
AlgebraicSimplification.cpp
Go to the documentation of this file.
1 //===- AlgebraicSimplification.cpp - Simplify algebraic expressions -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements rewrites based on the basic rules of algebra
10 // (Commutativity, associativity, etc...) and strength reductions for math
11 // operations.
12 //
13 //===----------------------------------------------------------------------===//
14 
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/TypeUtilities.h"
23 #include <climits>
24 
25 using namespace mlir;
26 
27 //----------------------------------------------------------------------------//
28 // PowFOp strength reduction.
29 //----------------------------------------------------------------------------//
30 
31 namespace {
32 struct PowFStrengthReduction : public OpRewritePattern<math::PowFOp> {
33 public:
35 
36  LogicalResult matchAndRewrite(math::PowFOp op,
37  PatternRewriter &rewriter) const final;
38 };
39 } // namespace
40 
41 LogicalResult
42 PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
43  PatternRewriter &rewriter) const {
44  Location loc = op.getLoc();
45  Value x = op.getLhs();
46 
47  FloatAttr scalarExponent;
48  DenseFPElementsAttr vectorExponent;
49 
50  bool isScalar = matchPattern(op.getRhs(), m_Constant(&scalarExponent));
51  bool isVector = matchPattern(op.getRhs(), m_Constant(&vectorExponent));
52 
53  // Returns true if exponent is a constant equal to `value`.
54  auto isExponentValue = [&](double value) -> bool {
55  if (isScalar)
56  return scalarExponent.getValue().isExactlyValue(value);
57 
58  if (isVector && vectorExponent.isSplat())
59  return vectorExponent.getSplatValue<FloatAttr>()
60  .getValue()
61  .isExactlyValue(value);
62 
63  return false;
64  };
65 
66  // Maybe broadcasts scalar value into vector type compatible with `op`.
67  auto bcast = [&](Value value) -> Value {
68  if (auto vec = dyn_cast<VectorType>(op.getType()))
69  return vector::BroadcastOp::create(rewriter, op.getLoc(), vec, value);
70  return value;
71  };
72 
73  // Replace `pow(x, 1.0)` with `x`.
74  if (isExponentValue(1.0)) {
75  rewriter.replaceOp(op, x);
76  return success();
77  }
78 
79  // Replace `pow(x, 2.0)` with `x * x`.
80  if (isExponentValue(2.0)) {
81  rewriter.replaceOpWithNewOp<arith::MulFOp>(op, ValueRange({x, x}));
82  return success();
83  }
84 
85  // Replace `pow(x, 3.0)` with `x * x * x`.
86  if (isExponentValue(3.0)) {
87  Value square =
88  arith::MulFOp::create(rewriter, op.getLoc(), ValueRange({x, x}));
89  rewriter.replaceOpWithNewOp<arith::MulFOp>(op, ValueRange({x, square}));
90  return success();
91  }
92 
93  // Replace `pow(x, -1.0)` with `1.0 / x`.
94  if (isExponentValue(-1.0)) {
95  Value one = arith::ConstantOp::create(
96  rewriter, loc,
97  rewriter.getFloatAttr(getElementTypeOrSelf(op.getType()), 1.0));
98  rewriter.replaceOpWithNewOp<arith::DivFOp>(op, ValueRange({bcast(one), x}));
99  return success();
100  }
101 
102  // Replace `pow(x, 0.5)` with `sqrt(x)`.
103  if (isExponentValue(0.5)) {
104  rewriter.replaceOpWithNewOp<math::SqrtOp>(op, x);
105  return success();
106  }
107 
108  // Replace `pow(x, -0.5)` with `rsqrt(x)`.
109  if (isExponentValue(-0.5)) {
110  rewriter.replaceOpWithNewOp<math::RsqrtOp>(op, x);
111  return success();
112  }
113 
114  // Replace `pow(x, 0.75)` with `sqrt(sqrt(x)) * sqrt(x)`.
115  if (isExponentValue(0.75)) {
116  Value powHalf = math::SqrtOp::create(rewriter, op.getLoc(), x);
117  Value powQuarter = math::SqrtOp::create(rewriter, op.getLoc(), powHalf);
118  rewriter.replaceOpWithNewOp<arith::MulFOp>(op,
119  ValueRange{powHalf, powQuarter});
120  return success();
121  }
122 
123  return failure();
124 }
125 
126 //----------------------------------------------------------------------------//
127 // FPowIOp/IPowIOp strength reduction.
128 //----------------------------------------------------------------------------//
129 
130 namespace {
131 template <typename PowIOpTy, typename DivOpTy, typename MulOpTy>
132 struct PowIStrengthReduction : public OpRewritePattern<PowIOpTy> {
133 
134  unsigned exponentThreshold;
135 
136 public:
137  PowIStrengthReduction(MLIRContext *context, unsigned exponentThreshold = 3,
138  PatternBenefit benefit = 1,
139  ArrayRef<StringRef> generatedNames = {})
140  : OpRewritePattern<PowIOpTy>(context, benefit, generatedNames),
141  exponentThreshold(exponentThreshold) {}
142 
143  LogicalResult matchAndRewrite(PowIOpTy op,
144  PatternRewriter &rewriter) const final;
145 };
146 } // namespace
147 
148 template <typename PowIOpTy, typename DivOpTy, typename MulOpTy>
149 LogicalResult
150 PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
151  PowIOpTy op, PatternRewriter &rewriter) const {
152  Location loc = op.getLoc();
153  Value base = op.getLhs();
154 
155  IntegerAttr scalarExponent;
156  DenseIntElementsAttr vectorExponent;
157 
158  bool isScalar = matchPattern(op.getRhs(), m_Constant(&scalarExponent));
159  bool isVector = matchPattern(op.getRhs(), m_Constant(&vectorExponent));
160 
161  // Simplify cases with known exponent value.
162  int64_t exponentValue = 0;
163  if (isScalar)
164  exponentValue = scalarExponent.getInt();
165  else if (isVector && vectorExponent.isSplat())
166  exponentValue = vectorExponent.getSplatValue<IntegerAttr>().getInt();
167  else
168  return failure();
169 
170  // Maybe broadcasts scalar value into vector type compatible with `op`.
171  auto bcast = [&loc, &op, &rewriter](Value value) -> Value {
172  if (auto vec = dyn_cast<VectorType>(op.getType()))
173  return vector::BroadcastOp::create(rewriter, loc, vec, value);
174  return value;
175  };
176 
177  Value one;
178  Type opType = getElementTypeOrSelf(op.getType());
179  if constexpr (std::is_same_v<PowIOpTy, math::FPowIOp>) {
180  one = arith::ConstantOp::create(rewriter, loc,
181  rewriter.getFloatAttr(opType, 1.0));
182  } else if constexpr (std::is_same_v<PowIOpTy, complex::PowiOp>) {
183  auto complexTy = cast<ComplexType>(opType);
184  Type elementType = complexTy.getElementType();
185  auto realPart = rewriter.getFloatAttr(elementType, 1.0);
186  auto imagPart = rewriter.getFloatAttr(elementType, 0.0);
187  one = complex::ConstantOp::create(
188  rewriter, loc, complexTy, rewriter.getArrayAttr({realPart, imagPart}));
189  } else {
190  one = arith::ConstantOp::create(rewriter, loc,
191  rewriter.getIntegerAttr(opType, 1));
192  }
193 
194  // Replace `[fi]powi(x, 0)` with `1`.
195  if (exponentValue == 0) {
196  rewriter.replaceOp(op, bcast(one));
197  return success();
198  }
199 
200  bool exponentIsNegative = false;
201  if (exponentValue < 0) {
202  exponentIsNegative = true;
203  exponentValue *= -1;
204  }
205 
206  // Bail out if `abs(exponent)` exceeds the threshold.
207  if (exponentValue > exponentThreshold)
208  return failure();
209 
210  Value result = base;
211  // Transform to naive sequence of multiplications:
212  // * For positive exponent case replace:
213  // `[fi]powi(x, positive_exponent)`
214  // with:
215  // x * x * x * ...
216  // * For negative exponent case replace:
217  // `[fi]powi(x, negative_exponent)`
218  // with:
219  // (1 / x) * (1 / x) * (1 / x) * ...
220  auto buildMul = [&](Value lhs, Value rhs) {
221  if constexpr (std::is_same_v<PowIOpTy, complex::PowiOp>)
222  return MulOpTy::create(rewriter, loc, op.getType(), lhs, rhs,
223  op.getFastmathAttr());
224  else
225  return MulOpTy::create(rewriter, loc, lhs, rhs);
226  };
227  for (unsigned i = 1; i < exponentValue; ++i)
228  result = buildMul(result, base);
229 
230  // Inverse the base for negative exponent, i.e. for
231  // `[fi]powi(x, negative_exponent)` set `x` to `1 / x`.
232  if (exponentIsNegative) {
233  if constexpr (std::is_same_v<PowIOpTy, complex::PowiOp>)
234  result = DivOpTy::create(rewriter, loc, op.getType(), bcast(one), result,
235  op.getFastmathAttr());
236  else
237  result = DivOpTy::create(rewriter, loc, bcast(one), result);
238  }
239 
240  rewriter.replaceOp(op, result);
241  return success();
242 }
243 
244 //----------------------------------------------------------------------------//
245 
248  patterns.add<
249  PowFStrengthReduction,
250  PowIStrengthReduction<math::IPowIOp, arith::DivSIOp, arith::MulIOp>,
251  PowIStrengthReduction<math::FPowIOp, arith::DivFOp, arith::MulFOp>,
252  PowIStrengthReduction<complex::PowiOp, complex::DivOp, complex::MulOp>>(
253  patterns.getContext(), /*exponentThreshold=*/8);
254 }
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:227
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:253
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:265
An attribute that represents a reference to a dense float vector or tensor object.
An attribute that represents a reference to a dense integer vector or tensor object.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:529
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:322