MLIR 22.0.0git
SimplifyAffineMinMax.cpp
Go to the documentation of this file.
1//===- SimplifyAffineMinMax.cpp - Simplify affine min/max ops -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a transform to simplify mix/max affine operations.
10//
11//===----------------------------------------------------------------------===//
12
14
21#include "llvm/ADT/IntEqClasses.h"
22#include "llvm/Support/DebugLog.h"
23#include "llvm/Support/InterleavedRange.h"
24
25#define DEBUG_TYPE "affine-min-max"
26
27using namespace mlir;
28using namespace mlir::affine;
29
30/// Simplifies an affine min/max operation by proving there's a lower or upper
31/// bound.
32template <typename AffineOp>
33static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) {
35 using ComparisonOperator = ValueBoundsConstraintSet::ComparisonOperator;
36
37 AffineMap affineMap = affineOp.getMap();
38 ValueRange operands = affineOp.getOperands();
39 static constexpr bool isMin = std::is_same_v<AffineOp, AffineMinOp>;
40
41 LDBG() << "analyzing value: `" << affineOp;
42
43 // Create a `Variable` list with values corresponding to each of the results
44 // in the affine affineMap.
45 SmallVector<Variable> variables = llvm::map_to_vector(
46 llvm::iota_range<unsigned>(0u, affineMap.getNumResults(), false),
47 [&](unsigned i) {
48 return Variable(affineMap.getSliceMap(i, 1), operands);
49 });
50 LDBG() << "- constructed variables are: "
51 << llvm::interleaved_array(llvm::map_range(
52 variables, [](const Variable &v) { return v.getMap(); }));
53
54 // Get the comparison operation.
55 ComparisonOperator cmpOp =
56 isMin ? ComparisonOperator::LT : ComparisonOperator::GT;
57
58 // Find disjoint sets bounded by a common value.
59 llvm::IntEqClasses boundedClasses(variables.size());
61 for (auto &&[i, v] : llvm::enumerate(variables)) {
62 unsigned eqClass = boundedClasses.findLeader(i);
63
64 // If the class already has a bound continue.
65 if (bounds.contains(eqClass))
66 continue;
67
68 // Initialize the bound.
69 Variable *bound = &v;
70
71 LDBG() << "- inspecting variable: #" << i << ", with map: `" << v.getMap()
72 << "`\n";
73
74 // Check against the other variables.
75 for (size_t j = i + 1; j < variables.size(); ++j) {
76 unsigned jEqClass = boundedClasses.findLeader(j);
77 // Skip if the class is the same.
78 if (jEqClass == eqClass)
79 continue;
80
81 // Get the bound of the equivalence class or itself.
82 Variable *nv = bounds.lookup_or(jEqClass, &variables[j]);
83
84 LDBG() << "- comparing with variable: #" << jEqClass
85 << ", with map: " << nv->getMap();
86
87 // Compare the variables.
88 FailureOr<bool> cmpResult =
90
91 // The variables cannot be compared.
92 if (failed(cmpResult)) {
93 LDBG() << "-- classes: #" << i << ", #" << jEqClass
94 << " cannot be merged";
95 continue;
96 }
97
98 // Join the equivalent classes and update the bound if necessary.
99 LDBG() << "-- merging classes: #" << i << ", #" << jEqClass
100 << ", is cmp(lhs, rhs): " << *cmpResult << "`";
101 if (*cmpResult) {
102 boundedClasses.join(eqClass, jEqClass);
103 } else {
104 // In this case we have lhs > rhs if isMin == true, or lhs < rhs if
105 // isMin == false.
106 bound = nv;
107 boundedClasses.join(eqClass, jEqClass);
108 }
109 }
110 bounds[boundedClasses.findLeader(i)] = bound;
111 }
112
113 // Return if there's no simplification.
114 if (bounds.size() >= affineMap.getNumResults()) {
115 LDBG() << "- the affine operation couldn't get simplified";
116 return false;
117 }
118
119 // Construct the new affine affineMap.
121 results.reserve(bounds.size());
122 for (auto [k, bound] : bounds)
123 results.push_back(bound->getMap().getResult(0));
124
125 LDBG() << "- starting from map: " << affineMap;
126 LDBG() << "- creating new map with:";
127 LDBG() << "--- dims: " << affineMap.getNumDims();
128 LDBG() << "--- syms: " << affineMap.getNumSymbols();
129 LDBG() << "--- res: " << llvm::interleaved_array(results);
130
131 affineMap =
132 AffineMap::get(0, affineMap.getNumSymbols() + affineMap.getNumDims(),
133 results, rewriter.getContext());
134
135 // Update the affine op.
136 rewriter.modifyOpInPlace(affineOp, [&]() { affineOp.setMap(affineMap); });
137 LDBG() << "- simplified affine op: `" << affineOp << "`";
138 return true;
139}
140
141bool mlir::affine::simplifyAffineMinOp(RewriterBase &rewriter, AffineMinOp op) {
142 return simplifyAffineMinMaxOp(rewriter, op);
143}
144
145bool mlir::affine::simplifyAffineMaxOp(RewriterBase &rewriter, AffineMaxOp op) {
146 return simplifyAffineMinMaxOp(rewriter, op);
147}
148
151 bool *modified) {
152 bool changed = false;
153 for (Operation *op : ops) {
154 if (auto minOp = dyn_cast<AffineMinOp>(op)) {
155 changed = simplifyAffineMinOp(rewriter, minOp) || changed;
156 continue;
157 }
158 auto maxOp = cast<AffineMaxOp>(op);
159 changed = simplifyAffineMaxOp(rewriter, maxOp) || changed;
160 }
162 AffineMaxOp::getCanonicalizationPatterns(patterns, rewriter.getContext());
163 AffineMinOp::getCanonicalizationPatterns(patterns, rewriter.getContext());
164 FrozenRewritePatternSet frozenPatterns(std::move(patterns));
165 if (modified)
166 *modified = changed;
167 // Canonicalize to a fixpoint.
168 if (failed(applyOpPatternsGreedily(
169 ops, frozenPatterns,
171 .setListener(
172 static_cast<RewriterBase::Listener *>(rewriter.getListener()))
174 &changed))) {
175 return failure();
176 }
177 if (modified)
178 *modified = changed;
179 return success();
180}
181
182namespace {
183
184struct SimplifyAffineMaxOp : public OpRewritePattern<AffineMaxOp> {
185 using OpRewritePattern<AffineMaxOp>::OpRewritePattern;
186
187 LogicalResult matchAndRewrite(AffineMaxOp affineOp,
188 PatternRewriter &rewriter) const override {
189 return success(simplifyAffineMaxOp(rewriter, affineOp));
190 }
191};
192
193struct SimplifyAffineMinOp : public OpRewritePattern<AffineMinOp> {
194 using OpRewritePattern<AffineMinOp>::OpRewritePattern;
195
196 LogicalResult matchAndRewrite(AffineMinOp affineOp,
197 PatternRewriter &rewriter) const override {
198 return success(simplifyAffineMinOp(rewriter, affineOp));
199 }
200};
201
202struct SimplifyAffineApplyOp : public OpRewritePattern<AffineApplyOp> {
203 using OpRewritePattern<AffineApplyOp>::OpRewritePattern;
204
205 LogicalResult matchAndRewrite(AffineApplyOp affineOp,
206 PatternRewriter &rewriter) const override {
207 AffineMap map = affineOp.getAffineMap();
208 SmallVector<Value> operands{affineOp->getOperands().begin(),
209 affineOp->getOperands().end()};
210 fullyComposeAffineMapAndOperands(&map, &operands,
211 /*composeAffineMin=*/true);
212
213 // No change => failure to apply.
214 if (map == affineOp.getAffineMap())
215 return failure();
216
217 rewriter.modifyOpInPlace(affineOp, [&]() {
218 affineOp.setMap(map);
219 affineOp->setOperands(operands);
220 });
221 return success();
222 }
223};
224
225} // namespace
226
227namespace mlir {
228namespace affine {
229#define GEN_PASS_DEF_SIMPLIFYAFFINEMINMAXPASS
230#include "mlir/Dialect/Affine/Passes.h.inc"
231} // namespace affine
232} // namespace mlir
233
234/// Creates a simplification pass for affine min/max/apply.
237 SimplifyAffineMinMaxPass> {
238 void runOnOperation() override;
239};
240
242 FunctionOpInterface func = getOperation();
243 RewritePatternSet patterns(func.getContext());
244 AffineMaxOp::getCanonicalizationPatterns(patterns, func.getContext());
245 AffineMinOp::getCanonicalizationPatterns(patterns, func.getContext());
246 patterns.add<SimplifyAffineMaxOp, SimplifyAffineMinOp, SimplifyAffineApplyOp>(
247 func.getContext());
248 FrozenRewritePatternSet frozenPatterns(std::move(patterns));
249 if (failed(applyPatternsGreedily(func, frozenPatterns)))
250 return signalPassFailure();
251}
return success()
static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp)
Simplifies an affine min/max operation by proving there's a lower or upper bound.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
unsigned getNumDims() const
unsigned getNumResults() const
MLIRContext * getContext() const
Definition Builders.h:56
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
GreedyRewriteConfig & setStrictness(GreedyRewriteStrictness mode)
FunctionOpInterface getOperation()
Definition Pass.h:444
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition Builders.h:320
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
void signalPassFailure()
Signal that some invariant was broken when running.
Definition Pass.h:218
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
A variable that can be added to the constraint set as a "column".
static llvm::FailureOr< bool > strongCompare(const Variable &lhs, ComparisonOperator cmp, const Variable &rhs)
This function is similar to ValueBoundsConstraintSet::compare, except that it returns false if !...
ComparisonOperator
Comparison operator for ValueBoundsConstraintSet::compare.
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
bool simplifyAffineMaxOp(RewriterBase &rewriter, AffineMaxOp op)
This transform tries to simplify the affine max operation op, by finding a common upper bound for a s...
void fullyComposeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands, bool composeAffineMin=false)
Given an affine map map and its input operands, this method composes into map, maps of AffineApplyOps...
bool simplifyAffineMinOp(RewriterBase &rewriter, AffineMinOp op)
This transform tries to simplify the affine min operation op, by finding a common lower bound for a s...
LogicalResult simplifyAffineMinMaxOps(RewriterBase &rewriter, ArrayRef< Operation * > ops, bool *modified=nullptr)
This transform applies simplifyAffineMinOp and simplifyAffineMaxOp to all the affine....
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
const FrozenRewritePatternSet & patterns
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
@ ExistingAndNewOps
Only pre-existing and newly created ops are processed.
Creates a simplification pass for affine min/max/apply.
void runOnOperation() override
The polymorphic API that runs the pass over the currently held operation.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.