24#define DEBUG_TYPE "affine-simplify-with-bounds"
37 operands.push_back(cast<Value>(basis));
47 size_t lhsTailConsumed,
49 size_t rhsTailConsumed,
62 for (
size_t k = 1; k + lhsTailConsumed <=
lhs.size(); ++k) {
95struct SimplifyDelinearizeOfLinearizeDisjoint final
99 LogicalResult matchAndRewrite(AffineDelinearizeIndexOp delinearizeOp,
100 PatternRewriter &rewriter)
const override {
102 delinearizeOp.getLinearIndex().getDefiningOp<AffineLinearizeIndexOp>();
105 "index doesn't come from linearize");
107 if (!linearizeOp.getDisjoint())
110 SmallVector<OpFoldResult> linBasis = linearizeOp.getMixedBasis();
111 SmallVector<OpFoldResult> delinBasis = delinearizeOp.getMixedBasis();
112 ValueRange linInputs = linearizeOp.getMultiIndex();
116 size_t linTailConsumed = 0;
117 size_t delinTailConsumed = 0;
122 SmallVector<std::pair<size_t, size_t>> matchedGroups;
124 while (linTailConsumed < linBasis.size() &&
125 delinTailConsumed < delinBasis.size()) {
128 linBasis, linTailConsumed, delinBasis, delinTailConsumed, ctx)) {
129 matchedGroups.emplace_back(*k, 1);
130 linTailConsumed += *k;
131 delinTailConsumed += 1;
136 delinBasis, delinTailConsumed, linBasis, linTailConsumed, ctx)) {
137 matchedGroups.emplace_back(1, *k);
138 delinTailConsumed += *k;
139 linTailConsumed += 1;
145 if (matchedGroups.empty())
147 "no trailing dimensions matched");
149 SmallVector<Value> results;
152 if (delinTailConsumed < delinBasis.size()) {
155 Value residualLinearize = AffineLinearizeIndexOp::create(
156 rewriter, linearizeOp.getLoc(), linInputs.drop_back(linTailConsumed),
157 ArrayRef(linBasis).drop_back(linTailConsumed),
158 linearizeOp.getDisjoint());
159 auto residualDelinearize = AffineDelinearizeIndexOp::create(
160 rewriter, delinearizeOp.getLoc(), residualLinearize,
161 ArrayRef(delinBasis).drop_back(delinTailConsumed),
162 delinearizeOp.hasOuterBound());
163 results.append(residualDelinearize.getResults().begin(),
164 residualDelinearize.getResults().end());
165 }
else if (!delinearizeOp.hasOuterBound()) {
168 ValueRange remainingInputs = linInputs.drop_back(linTailConsumed);
169 if (remainingInputs.empty()) {
172 rewriter, delinearizeOp.getLoc(), 0));
173 }
else if (remainingInputs.size() == 1) {
175 results.push_back(remainingInputs.front());
178 Value newLin = AffineLinearizeIndexOp::create(
179 rewriter, linearizeOp.getLoc(), remainingInputs,
180 ArrayRef(linBasis).drop_back(linTailConsumed),
181 linearizeOp.getDisjoint());
182 results.push_back(newLin);
187 size_t linInputOffset = linInputs.size() - linTailConsumed;
188 size_t linBasisOffset = linBasis.size() - linTailConsumed;
189 size_t delinBasisOffset = delinBasis.size() - delinTailConsumed;
190 for (
auto [linCount, delinCount] : llvm::reverse(matchedGroups)) {
191 if (linCount == 1 && delinCount == 1) {
193 results.push_back(linInputs[linInputOffset]);
194 }
else if (linCount > 1) {
196 Value newLin = AffineLinearizeIndexOp::create(
197 rewriter, linearizeOp.getLoc(),
198 linInputs.slice(linInputOffset, linCount),
199 ArrayRef(linBasis).slice(linBasisOffset, linCount),
201 results.push_back(newLin);
204 auto newDelin = AffineDelinearizeIndexOp::create(
205 rewriter, delinearizeOp.getLoc(), linInputs[linInputOffset],
206 ArrayRef(delinBasis).slice(delinBasisOffset, delinCount),
208 results.append(newDelin.getResults().begin(),
209 newDelin.getResults().end());
211 linInputOffset += linCount;
212 linBasisOffset += linCount;
213 delinBasisOffset += delinCount;
216 rewriter.
replaceOp(delinearizeOp, results);
225 patterns.
add<SimplifyDelinearizeOfLinearizeDisjoint>(patterns.
getContext());
234#define GEN_PASS_DEF_SIMPLIFYAFFINEWITHBOUNDS
235#include "mlir/Dialect/Affine/Transforms/Passes.h.inc"
240struct SimplifyAffineWithBoundsPass
241 : affine::impl::SimplifyAffineWithBoundsBase<SimplifyAffineWithBoundsPass> {
242 void runOnOperation()
override {
246 AffineDelinearizeIndexOp::getCanonicalizationPatterns(patterns,
248 AffineLinearizeIndexOp::getCanonicalizationPatterns(patterns,
252 return signalPassFailure();
static void buildProductExpr(OpFoldResult basis, AffineExpr &productExpr, SmallVectorImpl< Value > &operands, MLIRContext *ctx)
Accumulate a single basis element into the running product expression.
static std::optional< size_t > tryMatchProduct(ArrayRef< OpFoldResult > lhs, size_t lhsTailConsumed, ArrayRef< OpFoldResult > rhs, size_t rhsTailConsumed, MLIRContext *ctx)
Try to find k consecutive elements from lhs (starting from tail offset) whose product equals the sing...
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
MLIRContext * getContext() const
MLIRContext is the top-level object for a collection of MLIR operations.
This class represents a single result from folding an operation.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
A variable that can be added to the constraint set as a "column".
static FailureOr< bool > areEqual(const Variable &var1, const Variable &var2)
Compute whether the given variables are equal.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
void populateSimplifyAffineWithBoundsPatterns(RewritePatternSet &patterns)
Populate patterns that simplify affine.delinearize_index / affine.linearize_index pairs using value b...
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...