MLIR 22.0.0git
AlgebraicSimplification.cpp
Go to the documentation of this file.
1//===- AlgebraicSimplification.cpp - Simplify algebraic expressions -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements rewrites based on the basic rules of algebra
10// (Commutativity, associativity, etc...) and strength reductions for math
11// operations.
12//
13//===----------------------------------------------------------------------===//
14
20#include "mlir/IR/Builders.h"
21#include "mlir/IR/Matchers.h"
23#include <climits>
24
25using namespace mlir;
26
27//----------------------------------------------------------------------------//
28// PowFOp strength reduction.
29//----------------------------------------------------------------------------//
30
31namespace {
32struct PowFStrengthReduction : public OpRewritePattern<math::PowFOp> {
33public:
35
36 LogicalResult matchAndRewrite(math::PowFOp op,
37 PatternRewriter &rewriter) const final;
38};
39} // namespace
40
41LogicalResult
42PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
43 PatternRewriter &rewriter) const {
44 Location loc = op.getLoc();
45 Value x = op.getLhs();
46 arith::FastMathFlags fmf = op.getFastmathAttr().getValue();
47
48 FloatAttr scalarExponent;
49 DenseFPElementsAttr vectorExponent;
50
51 bool isScalar = matchPattern(op.getRhs(), m_Constant(&scalarExponent));
52 bool isVector = matchPattern(op.getRhs(), m_Constant(&vectorExponent));
53
54 // Returns true if exponent is a constant equal to `value`.
55 auto isExponentValue = [&](double value) -> bool {
56 if (isScalar)
57 return scalarExponent.getValue().isExactlyValue(value);
58
59 if (isVector && vectorExponent.isSplat())
60 return vectorExponent.getSplatValue<FloatAttr>()
61 .getValue()
62 .isExactlyValue(value);
63
64 return false;
65 };
66
67 // Maybe broadcasts scalar value into vector type compatible with `op`.
68 auto bcast = [&](Value value) -> Value {
69 if (auto vec = dyn_cast<VectorType>(op.getType()))
70 return vector::BroadcastOp::create(rewriter, loc, vec, value);
71 return value;
72 };
73
74 // Replace `pow(x, 1.0)` with `x`.
75 if (isExponentValue(1.0)) {
76 rewriter.replaceOp(op, x);
77 return success();
78 }
79
80 // Replace `pow(x, 2.0)` with `x * x`.
81 if (isExponentValue(2.0)) {
82 rewriter.replaceOpWithNewOp<arith::MulFOp>(op, x, x, fmf);
83 return success();
84 }
85
86 // Replace `pow(x, 3.0)` with `x * x * x`.
87 if (isExponentValue(3.0)) {
88 Value square = arith::MulFOp::create(rewriter, loc, x, x, fmf);
89 rewriter.replaceOpWithNewOp<arith::MulFOp>(op, x, square, fmf);
90 return success();
91 }
92
93 // Replace `pow(x, -1.0)` with `1.0 / x`.
94 if (isExponentValue(-1.0)) {
95 Value one = arith::ConstantOp::create(
96 rewriter, loc,
97 rewriter.getFloatAttr(getElementTypeOrSelf(op.getType()), 1.0));
98 rewriter.replaceOpWithNewOp<arith::DivFOp>(op, bcast(one), x, fmf);
99 return success();
100 }
101
102 // Replace `pow(x, 0.5)` with `sqrt(x)`.
103 if (isExponentValue(0.5)) {
104 rewriter.replaceOpWithNewOp<math::SqrtOp>(op, x, fmf);
105 return success();
106 }
107
108 // Replace `pow(x, -0.5)` with `rsqrt(x)`.
109 if (isExponentValue(-0.5)) {
110 rewriter.replaceOpWithNewOp<math::RsqrtOp>(op, x, fmf);
111 return success();
112 }
113
114 // Replace `pow(x, 0.75)` with `sqrt(sqrt(x)) * sqrt(x)`.
115 if (isExponentValue(0.75)) {
116 Value powHalf = math::SqrtOp::create(rewriter, loc, x, fmf);
117 Value powQuarter = math::SqrtOp::create(rewriter, loc, powHalf, fmf);
118 rewriter.replaceOpWithNewOp<arith::MulFOp>(op, powHalf, powQuarter, fmf);
119 return success();
120 }
121
122 return failure();
123}
124
125//----------------------------------------------------------------------------//
126// FPowIOp/IPowIOp strength reduction.
127//----------------------------------------------------------------------------//
128
129namespace {
130template <typename PowIOpTy, typename DivOpTy, typename MulOpTy>
131struct PowIStrengthReduction : public OpRewritePattern<PowIOpTy> {
132
133 unsigned exponentThreshold;
134
135public:
136 PowIStrengthReduction(MLIRContext *context, unsigned exponentThreshold = 3,
137 PatternBenefit benefit = 1,
138 ArrayRef<StringRef> generatedNames = {})
139 : OpRewritePattern<PowIOpTy>(context, benefit, generatedNames),
140 exponentThreshold(exponentThreshold) {}
141
142 LogicalResult matchAndRewrite(PowIOpTy op,
143 PatternRewriter &rewriter) const final;
144};
145} // namespace
146
147template <typename PowIOpTy, typename DivOpTy, typename MulOpTy>
148LogicalResult
149PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
150 PowIOpTy op, PatternRewriter &rewriter) const {
151 Location loc = op.getLoc();
152 Value base = op.getLhs();
153
154 IntegerAttr scalarExponent;
155 DenseIntElementsAttr vectorExponent;
156
157 bool isScalar = matchPattern(op.getRhs(), m_Constant(&scalarExponent));
158 bool isVector = matchPattern(op.getRhs(), m_Constant(&vectorExponent));
159
160 // Simplify cases with known exponent value.
161 int64_t exponentValue = 0;
162 if (isScalar)
163 exponentValue = scalarExponent.getInt();
164 else if (isVector && vectorExponent.isSplat())
165 exponentValue = vectorExponent.getSplatValue<IntegerAttr>().getInt();
166 else
167 return failure();
168
169 // Maybe broadcasts scalar value into vector type compatible with `op`.
170 auto bcast = [&loc, &op, &rewriter](Value value) -> Value {
171 if (auto vec = dyn_cast<VectorType>(op.getType()))
172 return vector::BroadcastOp::create(rewriter, loc, vec, value);
173 return value;
174 };
175
176 Value one;
177 Type opType = getElementTypeOrSelf(op.getType());
178 if constexpr (std::is_same_v<PowIOpTy, math::FPowIOp>) {
179 one = arith::ConstantOp::create(rewriter, loc,
180 rewriter.getFloatAttr(opType, 1.0));
181 } else if constexpr (std::is_same_v<PowIOpTy, complex::PowiOp>) {
182 auto complexTy = cast<ComplexType>(opType);
183 Type elementType = complexTy.getElementType();
184 auto realPart = rewriter.getFloatAttr(elementType, 1.0);
185 auto imagPart = rewriter.getFloatAttr(elementType, 0.0);
186 one = complex::ConstantOp::create(
187 rewriter, loc, complexTy, rewriter.getArrayAttr({realPart, imagPart}));
188 } else {
189 one = arith::ConstantOp::create(rewriter, loc,
190 rewriter.getIntegerAttr(opType, 1));
191 }
192
193 // Replace `[fi]powi(x, 0)` with `1`.
194 if (exponentValue == 0) {
195 rewriter.replaceOp(op, bcast(one));
196 return success();
197 }
198
199 bool exponentIsNegative = false;
200 if (exponentValue < 0) {
201 exponentIsNegative = true;
202 exponentValue *= -1;
203 }
204
205 // Bail out if `abs(exponent)` exceeds the threshold.
206 if (exponentValue > exponentThreshold)
207 return failure();
208
209 Value result = base;
210 // Transform to naive sequence of multiplications:
211 // * For positive exponent case replace:
212 // `[fi]powi(x, positive_exponent)`
213 // with:
214 // x * x * x * ...
215 // * For negative exponent case replace:
216 // `[fi]powi(x, negative_exponent)`
217 // with:
218 // (1 / x) * (1 / x) * (1 / x) * ...
219 auto buildMul = [&](Value lhs, Value rhs) {
220 if constexpr (std::is_same_v<PowIOpTy, complex::PowiOp>)
221 return MulOpTy::create(rewriter, loc, op.getType(), lhs, rhs,
222 op.getFastmathAttr());
223 else
224 return MulOpTy::create(rewriter, loc, lhs, rhs);
225 };
226 for (unsigned i = 1; i < exponentValue; ++i)
227 result = buildMul(result, base);
228
229 // Inverse the base for negative exponent, i.e. for
230 // `[fi]powi(x, negative_exponent)` set `x` to `1 / x`.
231 if (exponentIsNegative) {
232 if constexpr (std::is_same_v<PowIOpTy, complex::PowiOp>)
233 result = DivOpTy::create(rewriter, loc, op.getType(), bcast(one), result,
234 op.getFastmathAttr());
235 else
236 result = DivOpTy::create(rewriter, loc, bcast(one), result);
237 }
238
239 rewriter.replaceOp(op, result);
240 return success();
241}
242
243//----------------------------------------------------------------------------//
244
247 patterns.add<
248 PowFStrengthReduction,
249 PowIStrengthReduction<math::IPowIOp, arith::DivSIOp, arith::MulIOp>,
250 PowIStrengthReduction<math::FPowIOp, arith::DivFOp, arith::MulFOp>,
251 PowIStrengthReduction<complex::PowiOp, complex::DivOp, complex::MulOp>>(
252 patterns.getContext(), /*exponentThreshold=*/8);
253}
return success()
lhs
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:228
FloatAttr getFloatAttr(Type type, double value)
Definition Builders.cpp:254
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:266
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...