MLIR 23.0.0git
SimplifyAffineWithBounds.cpp
Go to the documentation of this file.
1//===- SimplifyAffineWithBounds.cpp ---------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements simplification patterns for affine.delinearize_index /
10// affine.linearize_index pairs using value bounds analysis.
11//
12//===----------------------------------------------------------------------===//
13
15
23
24#define DEBUG_TYPE "affine-simplify-with-bounds"
25
26using namespace mlir;
27using namespace mlir::affine;
28
29/// Accumulate a single basis element into the running product expression.
30/// Static values become affine constants, and dynamic values become symbols.
31static void buildProductExpr(OpFoldResult basis, AffineExpr &productExpr,
32 SmallVectorImpl<Value> &operands,
33 MLIRContext *ctx) {
34 if (auto val = getConstantIntValue(basis)) {
35 productExpr = productExpr * getAffineConstantExpr(*val, ctx);
36 } else {
37 operands.push_back(cast<Value>(basis));
38 productExpr = productExpr * getAffineSymbolExpr(operands.size() - 1, ctx);
39 }
40}
41
42/// Try to find k consecutive elements from `lhs` (starting from tail offset)
43/// whose product equals the single next element from `rhs`.
44/// The product is accumulated incrementally to avoid redundant computation.
45/// Returns the number of matched elements k, or std::nullopt if no match.
46static std::optional<size_t> tryMatchProduct(ArrayRef<OpFoldResult> lhs,
47 size_t lhsTailConsumed,
49 size_t rhsTailConsumed,
50 MLIRContext *ctx) {
51 // Build a Variable for the single rhs element.
52 AffineExpr rhsExpr = getAffineConstantExpr(1, ctx);
53 SmallVector<Value> rhsOperands;
54 buildProductExpr(rhs[rhs.size() - rhsTailConsumed - 1], rhsExpr, rhsOperands,
55 ctx);
57 AffineMap::get(0, rhsOperands.size(), rhsExpr, ctx), rhsOperands);
58
59 // Incrementally accumulate lhs product and check for equality.
60 AffineExpr lhsExpr = getAffineConstantExpr(1, ctx);
61 SmallVector<Value> lhsOperands;
62 for (size_t k = 1; k + lhsTailConsumed <= lhs.size(); ++k) {
63 buildProductExpr(lhs[lhs.size() - lhsTailConsumed - k], lhsExpr,
64 lhsOperands, ctx);
65 AffineMap lhsMap = AffineMap::get(0, lhsOperands.size(), lhsExpr, ctx);
66 ValueBoundsConstraintSet::Variable lhsVar(lhsMap, lhsOperands);
67 FailureOr<bool> result = ValueBoundsConstraintSet::areEqual(lhsVar, rhsVar);
68 if (succeeded(result) && *result)
69 return k;
70 }
71 return std::nullopt;
72}
73
74namespace {
75
76/// Simplify delinearize(linearize) pairs from the tail by matching groups of
77/// dimensions whose basis products are equal via ValueBounds analysis.
78///
79/// For each step from the tail, tries:
80/// 1. Many-to-one: k linearize dims -> 1 delinearize dim
81/// 2. One-to-many: 1 linearize dim -> k delinearize dims
82///
83/// Matched trailing dimensions are peeled off. Unmatched prefix dimensions
84/// are left as residual linearize/delinearize operations.
85///
86/// Example (many-to-one, D*E == Z):
87/// %lin = affine.linearize_index disjoint [%a, %b, %c, %d, %e]
88/// by (A, B, C, D, E)
89/// %result:3 = affine.delinearize_index %lin into (X, Y, Z)
90/// ->
91/// %prefix_lin = affine.linearize_index disjoint [%a, %b, %c] by (A, B, C)
92/// %prefix:2 = affine.delinearize_index %prefix_lin into (X, Y)
93/// %tail = affine.linearize_index disjoint [%d, %e] by (D, E)
94/// %result = [%prefix#0, %prefix#1, %tail]
95struct SimplifyDelinearizeOfLinearizeDisjoint final
96 : OpRewritePattern<AffineDelinearizeIndexOp> {
97 using Base::Base;
98
99 LogicalResult matchAndRewrite(AffineDelinearizeIndexOp delinearizeOp,
100 PatternRewriter &rewriter) const override {
101 auto linearizeOp =
102 delinearizeOp.getLinearIndex().getDefiningOp<AffineLinearizeIndexOp>();
103 if (!linearizeOp)
104 return rewriter.notifyMatchFailure(delinearizeOp,
105 "index doesn't come from linearize");
106
107 if (!linearizeOp.getDisjoint())
108 return rewriter.notifyMatchFailure(linearizeOp, "not disjoint");
109
110 SmallVector<OpFoldResult> linBasis = linearizeOp.getMixedBasis();
111 SmallVector<OpFoldResult> delinBasis = delinearizeOp.getMixedBasis();
112 ValueRange linInputs = linearizeOp.getMultiIndex();
113 MLIRContext *ctx = rewriter.getContext();
114
115 // Track how many elements consumed from each tail.
116 size_t linTailConsumed = 0;
117 size_t delinTailConsumed = 0;
118
119 // For each matched group (innermost first), record the number of
120 // linearize and delinearize dimensions it spans. Many-to-one groups
121 // have linCount > 1, one-to-many groups have delinCount > 1.
122 SmallVector<std::pair<size_t, size_t>> matchedGroups;
123
124 while (linTailConsumed < linBasis.size() &&
125 delinTailConsumed < delinBasis.size()) {
126 // Try many-to-one: k lin dims -> 1 delin dim.
127 if (std::optional<size_t> k = tryMatchProduct(
128 linBasis, linTailConsumed, delinBasis, delinTailConsumed, ctx)) {
129 matchedGroups.emplace_back(*k, 1);
130 linTailConsumed += *k;
131 delinTailConsumed += 1;
132 continue;
133 }
134 // Try one-to-many: 1 lin dim -> k delin dims.
135 if (std::optional<size_t> k = tryMatchProduct(
136 delinBasis, delinTailConsumed, linBasis, linTailConsumed, ctx)) {
137 matchedGroups.emplace_back(1, *k);
138 delinTailConsumed += *k;
139 linTailConsumed += 1;
140 continue;
141 }
142 break;
143 }
144
145 if (matchedGroups.empty())
146 return rewriter.notifyMatchFailure(delinearizeOp,
147 "no trailing dimensions matched");
148
149 SmallVector<Value> results;
150
151 // Build residual prefix ops for unmatched dimensions.
152 if (delinTailConsumed < delinBasis.size()) {
153 // Partial match: create residual linearize + delinearize for the
154 // unmatched prefix.
155 Value residualLinearize = AffineLinearizeIndexOp::create(
156 rewriter, linearizeOp.getLoc(), linInputs.drop_back(linTailConsumed),
157 ArrayRef(linBasis).drop_back(linTailConsumed),
158 linearizeOp.getDisjoint());
159 auto residualDelinearize = AffineDelinearizeIndexOp::create(
160 rewriter, delinearizeOp.getLoc(), residualLinearize,
161 ArrayRef(delinBasis).drop_back(delinTailConsumed),
162 delinearizeOp.hasOuterBound());
163 results.append(residualDelinearize.getResults().begin(),
164 residualDelinearize.getResults().end());
165 } else if (!delinearizeOp.hasOuterBound()) {
166 // All basis elements consumed, but the original delinearize has no outer
167 // bound which requires special handling.
168 ValueRange remainingInputs = linInputs.drop_back(linTailConsumed);
169 if (remainingInputs.empty()) {
170 // The outermost delinearize result is guaranteed to be zero.
171 results.push_back(arith::ConstantIndexOp::create(
172 rewriter, delinearizeOp.getLoc(), 0));
173 } else if (remainingInputs.size() == 1) {
174 // Pass through the single remaining input.
175 results.push_back(remainingInputs.front());
176 } else {
177 // Re-linearize the remaining inputs to produce the outermost result.
178 Value newLin = AffineLinearizeIndexOp::create(
179 rewriter, linearizeOp.getLoc(), remainingInputs,
180 ArrayRef(linBasis).drop_back(linTailConsumed),
181 linearizeOp.getDisjoint());
182 results.push_back(newLin);
183 }
184 }
185
186 // Build results for each matched group.
187 size_t linInputOffset = linInputs.size() - linTailConsumed;
188 size_t linBasisOffset = linBasis.size() - linTailConsumed;
189 size_t delinBasisOffset = delinBasis.size() - delinTailConsumed;
190 for (auto [linCount, delinCount] : llvm::reverse(matchedGroups)) {
191 if (linCount == 1 && delinCount == 1) {
192 // Exact 1:1 match: pass through directly.
193 results.push_back(linInputs[linInputOffset]);
194 } else if (linCount > 1) {
195 // Many-to-one: re-linearize the group's lin inputs.
196 Value newLin = AffineLinearizeIndexOp::create(
197 rewriter, linearizeOp.getLoc(),
198 linInputs.slice(linInputOffset, linCount),
199 ArrayRef(linBasis).slice(linBasisOffset, linCount),
200 /*disjoint=*/true);
201 results.push_back(newLin);
202 } else {
203 // One-to-many: delinearize the single lin input.
204 auto newDelin = AffineDelinearizeIndexOp::create(
205 rewriter, delinearizeOp.getLoc(), linInputs[linInputOffset],
206 ArrayRef(delinBasis).slice(delinBasisOffset, delinCount),
207 /*hasOuterBound=*/true);
208 results.append(newDelin.getResults().begin(),
209 newDelin.getResults().end());
210 }
211 linInputOffset += linCount;
212 linBasisOffset += linCount;
213 delinBasisOffset += delinCount;
214 }
215
216 rewriter.replaceOp(delinearizeOp, results);
217 return success();
218 }
219};
220
221} // namespace
222
224 RewritePatternSet &patterns) {
225 patterns.add<SimplifyDelinearizeOfLinearizeDisjoint>(patterns.getContext());
226}
227
228//===----------------------------------------------------------------------===//
229// Pass definition
230//===----------------------------------------------------------------------===//
231
232namespace mlir {
233namespace affine {
234#define GEN_PASS_DEF_SIMPLIFYAFFINEWITHBOUNDS
235#include "mlir/Dialect/Affine/Transforms/Passes.h.inc"
236} // namespace affine
237} // namespace mlir
238
239namespace {
240struct SimplifyAffineWithBoundsPass
241 : affine::impl::SimplifyAffineWithBoundsBase<SimplifyAffineWithBoundsPass> {
242 void runOnOperation() override {
243 RewritePatternSet patterns(&getContext());
244 // Add canonicalization patterns first so cheap exact-match cases are
245 // handled without invoking value bounds analysis.
246 AffineDelinearizeIndexOp::getCanonicalizationPatterns(patterns,
247 &getContext());
248 AffineLinearizeIndexOp::getCanonicalizationPatterns(patterns,
249 &getContext());
251 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
252 return signalPassFailure();
253 }
254};
255} // namespace
return success()
lhs
b getContext())
static void buildProductExpr(OpFoldResult basis, AffineExpr &productExpr, SmallVectorImpl< Value > &operands, MLIRContext *ctx)
Accumulate a single basis element into the running product expression.
static std::optional< size_t > tryMatchProduct(ArrayRef< OpFoldResult > lhs, size_t lhsTailConsumed, ArrayRef< OpFoldResult > rhs, size_t rhsTailConsumed, MLIRContext *ctx)
Try to find k consecutive elements from lhs (starting from tail offset) whose product equals the sing...
Base type for affine expression.
Definition AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
MLIRContext * getContext() const
Definition Builders.h:56
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class represents a single result from folding an operation.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
A variable that can be added to the constraint set as a "column".
static FailureOr< bool > areEqual(const Variable &var1, const Variable &var2)
Compute whether the given variables are equal.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
void populateSimplifyAffineWithBoundsPatterns(RewritePatternSet &patterns)
Populate patterns that simplify affine.delinearize_index / affine.linearize_index pairs using value b...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...