MLIR 23.0.0git
SimplifyAffineWithBounds.cpp
Go to the documentation of this file.
1//===- SimplifyAffineWithBounds.cpp ---------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements simplification patterns for affine.delinearize_index /
10// affine.linearize_index pairs using value bounds analysis.
11//
12//===----------------------------------------------------------------------===//
13
15
23
24#define DEBUG_TYPE "affine-simplify-with-bounds"
25
26using namespace mlir;
27using namespace mlir::affine;
28
29/// Accumulate a single basis element into the running product expression.
30/// Static values become affine constants, and dynamic values become symbols.
31static void buildProductExpr(OpFoldResult basis, AffineExpr &productExpr,
32 SmallVectorImpl<Value> &operands,
33 MLIRContext *ctx) {
34 if (auto val = getConstantIntValue(basis)) {
35 productExpr = productExpr * getAffineConstantExpr(*val, ctx);
36 } else {
37 operands.push_back(cast<Value>(basis));
38 productExpr = productExpr * getAffineSymbolExpr(operands.size() - 1, ctx);
39 }
40}
41
42/// Try to find k consecutive elements from `lhs` (starting from tail offset)
43/// whose product equals the single next element from `rhs`.
44/// The product is accumulated incrementally to avoid redundant computation.
45/// Returns the number of matched elements k, or std::nullopt if no match.
46static std::optional<size_t> tryMatchProduct(ArrayRef<OpFoldResult> lhs,
47 size_t lhsTailConsumed,
49 size_t rhsTailConsumed,
50 MLIRContext *ctx) {
51 // Build a Variable for the single rhs element.
52 AffineExpr rhsExpr = getAffineConstantExpr(1, ctx);
53 SmallVector<Value> rhsOperands;
54 buildProductExpr(rhs[rhs.size() - rhsTailConsumed - 1], rhsExpr, rhsOperands,
55 ctx);
57 AffineMap::get(0, rhsOperands.size(), rhsExpr, ctx), rhsOperands);
58
59 // Incrementally accumulate lhs product and check for equality.
60 AffineExpr lhsExpr = getAffineConstantExpr(1, ctx);
61 SmallVector<Value> lhsOperands;
62 for (size_t k = 1; k + lhsTailConsumed <= lhs.size(); ++k) {
63 buildProductExpr(lhs[lhs.size() - lhsTailConsumed - k], lhsExpr,
64 lhsOperands, ctx);
65 AffineMap lhsMap = AffineMap::get(0, lhsOperands.size(), lhsExpr, ctx);
66 ValueBoundsConstraintSet::Variable lhsVar(lhsMap, lhsOperands);
67 FailureOr<bool> result = ValueBoundsConstraintSet::areEqual(lhsVar, rhsVar);
68 if (succeeded(result) && *result)
69 return k;
70 }
71 return std::nullopt;
72}
73
74namespace {
75
76/// Simplify delinearize(linearize) pairs from the tail by matching groups of
77/// dimensions whose basis products are equal via ValueBounds analysis.
78///
79/// For each step from the tail, tries:
80/// 1. Many-to-one: k linearize dims -> 1 delinearize dim
81/// 2. One-to-many: 1 linearize dim -> k delinearize dims
82///
83/// Matched trailing dimensions are peeled off. Unmatched prefix dimensions
84/// are left as residual linearize/delinearize operations.
85///
86/// Example (many-to-one, D*E == Z):
87/// %lin = affine.linearize_index disjoint [%a, %b, %c, %d, %e]
88/// by (A, B, C, D, E)
89/// %result:3 = affine.delinearize_index %lin into (X, Y, Z)
90/// ->
91/// %prefix_lin = affine.linearize_index disjoint [%a, %b, %c] by (A, B, C)
92/// %prefix:2 = affine.delinearize_index %prefix_lin into (X, Y)
93/// %tail = affine.linearize_index disjoint [%d, %e] by (D, E)
94/// %result = [%prefix#0, %prefix#1, %tail]
95struct SimplifyDelinearizeOfLinearizeDisjoint final
96 : OpRewritePattern<AffineDelinearizeIndexOp> {
97 using Base::Base;
98
99 LogicalResult matchAndRewrite(AffineDelinearizeIndexOp delinearizeOp,
100 PatternRewriter &rewriter) const override {
101 auto linearizeOp =
102 delinearizeOp.getLinearIndex().getDefiningOp<AffineLinearizeIndexOp>();
103 if (!linearizeOp)
104 return rewriter.notifyMatchFailure(delinearizeOp,
105 "index doesn't come from linearize");
106
107 if (!linearizeOp.getDisjoint())
108 return rewriter.notifyMatchFailure(linearizeOp, "not disjoint");
109
110 SmallVector<OpFoldResult> linBasis = linearizeOp.getMixedBasis();
111 SmallVector<OpFoldResult> delinBasis = delinearizeOp.getMixedBasis();
112 ValueRange linInputs = linearizeOp.getMultiIndex();
113 MLIRContext *ctx = rewriter.getContext();
114
115 // Track how many elements consumed from each tail.
116 size_t linTailConsumed = 0;
117 size_t delinTailConsumed = 0;
118
119 // For each matched group (innermost first), record the number of
120 // linearize and delinearize dimensions it spans. Many-to-one groups
121 // have linCount > 1, one-to-many groups have delinCount > 1.
122 SmallVector<std::pair<size_t, size_t>> matchedGroups;
123
124 while (linTailConsumed < linBasis.size() &&
125 delinTailConsumed < delinBasis.size()) {
126 // Try many-to-one: k lin dims -> 1 delin dim.
127 if (std::optional<size_t> k = tryMatchProduct(
128 linBasis, linTailConsumed, delinBasis, delinTailConsumed, ctx)) {
129 matchedGroups.emplace_back(*k, 1);
130 linTailConsumed += *k;
131 delinTailConsumed += 1;
132 continue;
133 }
134 // Try one-to-many: 1 lin dim -> k delin dims.
135 if (std::optional<size_t> k = tryMatchProduct(
136 delinBasis, delinTailConsumed, linBasis, linTailConsumed, ctx)) {
137 matchedGroups.emplace_back(1, *k);
138 delinTailConsumed += *k;
139 linTailConsumed += 1;
140 continue;
141 }
142 break;
143 }
144
145 if (matchedGroups.empty())
146 return rewriter.notifyMatchFailure(delinearizeOp,
147 "no trailing dimensions matched");
148
149 SmallVector<Value> results;
150
151 // Build residual prefix ops for unmatched dimensions.
152 if (delinTailConsumed < delinBasis.size()) {
153 // Partial match: create residual linearize + delinearize for the
154 // unmatched prefix.
155 Value residualLinearize = AffineLinearizeIndexOp::create(
156 rewriter, linearizeOp.getLoc(), linInputs.drop_back(linTailConsumed),
157 ArrayRef(linBasis).drop_back(linTailConsumed),
158 linearizeOp.getDisjoint());
159 auto residualDelinearize = AffineDelinearizeIndexOp::create(
160 rewriter, delinearizeOp.getLoc(), residualLinearize,
161 ArrayRef(delinBasis).drop_back(delinTailConsumed),
162 delinearizeOp.hasOuterBound());
163 results.append(residualDelinearize.getResults().begin(),
164 residualDelinearize.getResults().end());
165 } else if (!delinearizeOp.hasOuterBound()) {
166 // All basis elements consumed, but the original delinearize has no outer
167 // bound which requires special handling.
168 ValueRange remainingInputs = linInputs.drop_back(linTailConsumed);
169 if (remainingInputs.empty()) {
170 // The outermost delinearize result is guaranteed to be zero.
171 results.push_back(arith::ConstantIndexOp::create(
172 rewriter, delinearizeOp.getLoc(), 0));
173 } else if (remainingInputs.size() == 1) {
174 // Pass through the single remaining input.
175 results.push_back(remainingInputs.front());
176 } else {
177 // Re-linearize the remaining inputs to produce the outermost result.
178 Value newLin = AffineLinearizeIndexOp::create(
179 rewriter, linearizeOp.getLoc(), remainingInputs,
180 ArrayRef(linBasis).drop_back(linTailConsumed),
181 linearizeOp.getDisjoint());
182 results.push_back(newLin);
183 }
184 }
185
186 // Build results for each matched group.
187 size_t linInputOffset = linInputs.size() - linTailConsumed;
188 size_t linBasisOffset = linBasis.size() - linTailConsumed;
189 size_t delinBasisOffset = delinBasis.size() - delinTailConsumed;
190 for (auto [linCount, delinCount] : llvm::reverse(matchedGroups)) {
191 if (linCount == 1 && delinCount == 1) {
192 // Exact 1:1 match: pass through directly.
193 results.push_back(linInputs[linInputOffset]);
194 } else if (linCount > 1) {
195 // Many-to-one: re-linearize the group's lin inputs.
196 Value newLin = AffineLinearizeIndexOp::create(
197 rewriter, linearizeOp.getLoc(),
198 linInputs.slice(linInputOffset, linCount),
199 ArrayRef(linBasis).slice(linBasisOffset, linCount),
200 /*disjoint=*/true);
201 results.push_back(newLin);
202 } else {
203 // One-to-many: delinearize the single lin input.
204 auto newDelin = AffineDelinearizeIndexOp::create(
205 rewriter, delinearizeOp.getLoc(), linInputs[linInputOffset],
206 ArrayRef(delinBasis).slice(delinBasisOffset, delinCount),
207 /*hasOuterBound=*/true);
208 results.append(newDelin.getResults().begin(),
209 newDelin.getResults().end());
210 }
211 linInputOffset += linCount;
212 linBasisOffset += linCount;
213 delinBasisOffset += delinCount;
214 }
215
216 rewriter.replaceOp(delinearizeOp, results);
217 return success();
218 }
219};
220
221/// Simplifies the affine map results by eliminating redundant expressions.
222///
223/// This function performs a pairwise comparison of all expressions in the map
224/// using the analysis from `ValueBoundsConstraintSet`. If an expression `a` is
225/// statically proven to be strictly bounded or covered by another expression
226/// `b` (based on the given comparison operator `cmp`), `a` is considered
227/// redundant and is safely pruned from the results.
229simplifyRedundantMapResults(AffineMap map, ValueRange operands,
231 llvm::BitVector preservedExprs(map.getNumResults(), true);
232 for (size_t i = 0, e = map.getNumResults(); i < e; ++i) {
233 AffineMap mapA = map.getSubMap(i);
234 ValueBoundsConstraintSet::Variable varA(mapA, operands);
235
236 for (size_t j = 0; j < e; ++j) {
237 if (i == j || !preservedExprs[j])
238 continue;
239
240 AffineMap mapB = map.getSubMap(j);
241 ValueBoundsConstraintSet::Variable varB(mapB, operands);
242
243 if (ValueBoundsConstraintSet::compare(varB, cmp, varA)) {
244 preservedExprs[i] = false;
245 break;
246 }
247 }
248 }
249
250 SmallVector<AffineExpr> mapResults;
251 for (size_t i = 0, e = map.getNumResults(); i < e; ++i)
252 if (preservedExprs[i])
253 mapResults.push_back(map.getResult(i));
254 return mapResults;
255}
256
257/// A pattern that simplifies multi-result lower and upper bounds of
258/// `affine.for` loops by pruning redundant expressions leveraging
259/// `ValueBoundsConstraintSet`.
260struct SimplifyAffineLoopBoundMap final : OpRewritePattern<AffineForOp> {
261 using Base::Base;
262 LogicalResult matchAndRewrite(AffineForOp forOp,
263 PatternRewriter &rewriter) const override {
264 AffineMap lowerBoundMap = forOp.getLowerBoundMap();
265 auto lowerBoundOperands = forOp.getLowerBoundOperands();
266 AffineMap upperBoundMap = forOp.getUpperBoundMap();
267 auto upperBoundOperands = forOp.getUpperBoundOperands();
268 if (lowerBoundMap.getNumResults() < 2 &&
269 forOp.getUpperBoundMap().getNumResults() < 2)
270 return failure();
271
272 SmallVector<AffineExpr> lowerMapExprs = simplifyRedundantMapResults(
273 lowerBoundMap, lowerBoundOperands, ValueBoundsConstraintSet::GT);
274 SmallVector<AffineExpr> upperMapExprs = simplifyRedundantMapResults(
275 upperBoundMap, upperBoundOperands, ValueBoundsConstraintSet::LT);
276
277 bool lowerBoundUpdate =
278 lowerMapExprs.size() < lowerBoundMap.getNumResults();
279 bool upperBoundUpdate =
280 upperMapExprs.size() < upperBoundMap.getNumResults();
281 if (!(lowerBoundUpdate || upperBoundUpdate))
282 return failure();
283
284 MLIRContext *context = forOp->getContext();
285 if (lowerBoundUpdate) {
286 rewriter.modifyOpInPlace(forOp, [&]() {
287 forOp.setLowerBound(forOp.getLowerBoundOperands(),
288 AffineMap::get(lowerBoundMap.getNumDims(),
289 lowerBoundMap.getNumSymbols(),
290 lowerMapExprs, context));
291 });
292 }
293 if (upperBoundUpdate) {
294 rewriter.modifyOpInPlace(forOp, [&]() {
295 forOp.setUpperBound(forOp.getUpperBoundOperands(),
296 AffineMap::get(upperBoundMap.getNumDims(),
297 upperBoundMap.getNumSymbols(),
298 upperMapExprs, context));
299 });
300 }
301 return success();
302 }
303};
304} // namespace
305
307 RewritePatternSet &patterns) {
308 patterns
309 .add<SimplifyDelinearizeOfLinearizeDisjoint, SimplifyAffineLoopBoundMap>(
310 patterns.getContext());
311}
312
313//===----------------------------------------------------------------------===//
314// Pass definition
315//===----------------------------------------------------------------------===//
316
317namespace mlir {
318namespace affine {
319#define GEN_PASS_DEF_SIMPLIFYAFFINEWITHBOUNDS
320#include "mlir/Dialect/Affine/Transforms/Passes.h.inc"
321} // namespace affine
322} // namespace mlir
323
324namespace {
325struct SimplifyAffineWithBoundsPass
326 : affine::impl::SimplifyAffineWithBoundsBase<SimplifyAffineWithBoundsPass> {
327 void runOnOperation() override {
328 RewritePatternSet patterns(&getContext());
329 // Add canonicalization patterns first so cheap exact-match cases are
330 // handled without invoking value bounds analysis.
331 AffineDelinearizeIndexOp::getCanonicalizationPatterns(patterns,
332 &getContext());
333 AffineLinearizeIndexOp::getCanonicalizationPatterns(patterns,
334 &getContext());
336 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
337 return signalPassFailure();
338 }
339};
340} // namespace
return success()
lhs
b getContext())
static void buildProductExpr(OpFoldResult basis, AffineExpr &productExpr, SmallVectorImpl< Value > &operands, MLIRContext *ctx)
Accumulate a single basis element into the running product expression.
static std::optional< size_t > tryMatchProduct(ArrayRef< OpFoldResult > lhs, size_t lhsTailConsumed, ArrayRef< OpFoldResult > rhs, size_t rhsTailConsumed, MLIRContext *ctx)
Try to find k consecutive elements from lhs (starting from tail offset) whose product equals the sing...
Base type for affine expression.
Definition AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
unsigned getNumDims() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
MLIRContext * getContext() const
Definition Builders.h:56
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class represents a single result from folding an operation.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
A variable that can be added to the constraint set as a "column".
static bool compare(const Variable &lhs, ComparisonOperator cmp, const Variable &rhs)
Return "true" if "lhs cmp rhs" was proven to hold.
static FailureOr< bool > areEqual(const Variable &var1, const Variable &var2)
Compute whether the given variables are equal.
ComparisonOperator
Comparison operator for ValueBoundsConstraintSet::compare.
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:384
void populateSimplifyAffineWithBoundsPatterns(RewritePatternSet &patterns)
Populate patterns that simplify affine.delinearize_index / affine.linearize_index pairs using value b...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.