MLIR  21.0.0git
SimplifyAffineMinMax.cpp
Go to the documentation of this file.
1 //===- SimplifyAffineMinMax.cpp - Simplify affine min/max ops -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a transform to simplify mix/max affine operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
17 #include "mlir/IR/PatternMatch.h"
21 #include "llvm/ADT/IntEqClasses.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/InterleavedRange.h"
24 
25 #define DEBUG_TYPE "affine-min-max"
26 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
27 
28 using namespace mlir;
29 using namespace mlir::affine;
30 
31 /// Simplifies an affine min/max operation by proving there's a lower or upper
32 /// bound.
33 template <typename AffineOp>
34 static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) {
35  using Variable = ValueBoundsConstraintSet::Variable;
36  using ComparisonOperator = ValueBoundsConstraintSet::ComparisonOperator;
37 
38  AffineMap affineMap = affineOp.getMap();
39  ValueRange operands = affineOp.getOperands();
40  static constexpr bool isMin = std::is_same_v<AffineOp, AffineMinOp>;
41 
42  LLVM_DEBUG({ DBGS() << "analyzing value: `" << affineOp << "`\n"; });
43 
44  // Create a `Variable` list with values corresponding to each of the results
45  // in the affine affineMap.
46  SmallVector<Variable> variables = llvm::map_to_vector(
47  llvm::iota_range<unsigned>(0u, affineMap.getNumResults(), false),
48  [&](unsigned i) {
49  return Variable(affineMap.getSliceMap(i, 1), operands);
50  });
51  LLVM_DEBUG({
52  DBGS() << "- constructed variables are: "
53  << llvm::interleaved_array(llvm::map_range(
54  variables, [](const Variable &v) { return v.getMap(); }))
55  << "`\n";
56  });
57 
58  // Get the comparison operation.
59  ComparisonOperator cmpOp =
60  isMin ? ComparisonOperator::LT : ComparisonOperator::GT;
61 
62  // Find disjoint sets bounded by a common value.
63  llvm::IntEqClasses boundedClasses(variables.size());
65  for (auto &&[i, v] : llvm::enumerate(variables)) {
66  unsigned eqClass = boundedClasses.findLeader(i);
67 
68  // If the class already has a bound continue.
69  if (bounds.contains(eqClass))
70  continue;
71 
72  // Initialize the bound.
73  Variable *bound = &v;
74 
75  LLVM_DEBUG({
76  DBGS() << "- inspecting variable: #" << i << ", with map: `" << v.getMap()
77  << "`\n";
78  });
79 
80  // Check against the other variables.
81  for (size_t j = i + 1; j < variables.size(); ++j) {
82  unsigned jEqClass = boundedClasses.findLeader(j);
83  // Skip if the class is the same.
84  if (jEqClass == eqClass)
85  continue;
86 
87  // Get the bound of the equivalence class or itself.
88  Variable *nv = bounds.lookup_or(jEqClass, &variables[j]);
89 
90  LLVM_DEBUG({
91  DBGS() << "- comparing with variable: #" << jEqClass
92  << ", with map: " << nv->getMap() << "\n";
93  });
94 
95  // Compare the variables.
96  FailureOr<bool> cmpResult =
97  ValueBoundsConstraintSet::strongCompare(*bound, cmpOp, *nv);
98 
99  // The variables cannot be compared.
100  if (failed(cmpResult)) {
101  LLVM_DEBUG({
102  DBGS() << "-- classes: #" << i << ", #" << jEqClass
103  << " cannot be merged\n";
104  });
105  continue;
106  }
107 
108  // Join the equivalent classes and update the bound if necessary.
109  LLVM_DEBUG({
110  DBGS() << "-- merging classes: #" << i << ", #" << jEqClass
111  << ", is cmp(lhs, rhs): " << *cmpResult << "`\n";
112  });
113  if (*cmpResult) {
114  boundedClasses.join(eqClass, jEqClass);
115  } else {
116  // In this case we have lhs > rhs if isMin == true, or lhs < rhs if
117  // isMin == false.
118  bound = nv;
119  boundedClasses.join(eqClass, jEqClass);
120  }
121  }
122  bounds[boundedClasses.findLeader(i)] = bound;
123  }
124 
125  // Return if there's no simplification.
126  if (bounds.size() >= affineMap.getNumResults()) {
127  LLVM_DEBUG(
128  { DBGS() << "- the affine operation couldn't get simplified\n"; });
129  return false;
130  }
131 
132  // Construct the new affine affineMap.
133  SmallVector<AffineExpr> results;
134  results.reserve(bounds.size());
135  for (auto [k, bound] : bounds)
136  results.push_back(bound->getMap().getResult(0));
137 
138  LLVM_DEBUG({
139  DBGS() << "- starting from map: " << affineMap << "\n";
140  DBGS() << "- creating new map with: \n";
141  DBGS() << "--- dims: " << affineMap.getNumDims() << "\n";
142  DBGS() << "--- syms: " << affineMap.getNumSymbols() << "\n";
143  DBGS() << "--- res: " << llvm::interleaved_array(results) << "\n";
144  });
145 
146  affineMap =
147  AffineMap::get(0, affineMap.getNumSymbols() + affineMap.getNumDims(),
148  results, rewriter.getContext());
149 
150  // Update the affine op.
151  rewriter.modifyOpInPlace(affineOp, [&]() { affineOp.setMap(affineMap); });
152  LLVM_DEBUG({ DBGS() << "- simplified affine op: `" << affineOp << "`\n"; });
153  return true;
154 }
155 
156 bool mlir::affine::simplifyAffineMinOp(RewriterBase &rewriter, AffineMinOp op) {
157  return simplifyAffineMinMaxOp(rewriter, op);
158 }
159 
160 bool mlir::affine::simplifyAffineMaxOp(RewriterBase &rewriter, AffineMaxOp op) {
161  return simplifyAffineMinMaxOp(rewriter, op);
162 }
163 
166  bool *modified) {
167  bool changed = false;
168  for (Operation *op : ops) {
169  if (auto minOp = dyn_cast<AffineMinOp>(op))
170  changed = simplifyAffineMinOp(rewriter, minOp) || changed;
171  else if (auto maxOp = cast<AffineMaxOp>(op))
172  changed = simplifyAffineMaxOp(rewriter, maxOp) || changed;
173  }
175  AffineMaxOp::getCanonicalizationPatterns(patterns, rewriter.getContext());
176  AffineMinOp::getCanonicalizationPatterns(patterns, rewriter.getContext());
177  FrozenRewritePatternSet frozenPatterns(std::move(patterns));
178  if (modified)
179  *modified = changed;
180  // Canonicalize to a fixpoint.
181  if (failed(applyOpPatternsGreedily(
182  ops, frozenPatterns,
184  .setListener(
185  static_cast<RewriterBase::Listener *>(rewriter.getListener()))
187  &changed))) {
188  return failure();
189  }
190  if (modified)
191  *modified = changed;
192  return success();
193 }
194 
195 namespace {
196 
197 struct SimplifyAffineMaxOp : public OpRewritePattern<AffineMaxOp> {
199 
200  LogicalResult matchAndRewrite(AffineMaxOp affineOp,
201  PatternRewriter &rewriter) const override {
202  return success(simplifyAffineMaxOp(rewriter, affineOp));
203  }
204 };
205 
206 struct SimplifyAffineMinOp : public OpRewritePattern<AffineMinOp> {
208 
209  LogicalResult matchAndRewrite(AffineMinOp affineOp,
210  PatternRewriter &rewriter) const override {
211  return success(simplifyAffineMinOp(rewriter, affineOp));
212  }
213 };
214 
215 struct SimplifyAffineApplyOp : public OpRewritePattern<AffineApplyOp> {
217 
218  LogicalResult matchAndRewrite(AffineApplyOp affineOp,
219  PatternRewriter &rewriter) const override {
220  AffineMap map = affineOp.getAffineMap();
221  SmallVector<Value> operands{affineOp->getOperands().begin(),
222  affineOp->getOperands().end()};
223  fullyComposeAffineMapAndOperands(&map, &operands,
224  /*composeAffineMin=*/true);
225 
226  // No change => failure to apply.
227  if (map == affineOp.getAffineMap())
228  return failure();
229 
230  rewriter.modifyOpInPlace(affineOp, [&]() {
231  affineOp.setMap(map);
232  affineOp->setOperands(operands);
233  });
234  return success();
235  }
236 };
237 
238 } // namespace
239 
240 namespace mlir {
241 namespace affine {
242 #define GEN_PASS_DEF_SIMPLIFYAFFINEMINMAXPASS
243 #include "mlir/Dialect/Affine/Passes.h.inc"
244 } // namespace affine
245 } // namespace mlir
246 
247 /// Creates a simplification pass for affine min/max/apply.
249  : public affine::impl::SimplifyAffineMinMaxPassBase<
250  SimplifyAffineMinMaxPass> {
251  void runOnOperation() override;
252 };
253 
255  FunctionOpInterface func = getOperation();
256  RewritePatternSet patterns(func.getContext());
257  AffineMaxOp::getCanonicalizationPatterns(patterns, func.getContext());
258  AffineMinOp::getCanonicalizationPatterns(patterns, func.getContext());
259  patterns.add<SimplifyAffineMaxOp, SimplifyAffineMinOp, SimplifyAffineApplyOp>(
260  func.getContext());
261  FrozenRewritePatternSet frozenPatterns(std::move(patterns));
262  if (failed(applyPatternsGreedily(func, std::move(frozenPatterns))))
263  return signalPassFailure();
264 }
#define DBGS()
static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp)
Simplifies an affine min/max operation by proving there's a lower or upper bound.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
Definition: AffineMap.cpp:394
unsigned getNumDims() const
Definition: AffineMap.cpp:390
unsigned getNumResults() const
Definition: AffineMap.cpp:398
MLIRContext * getContext() const
Definition: Builders.h:55
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
GreedyRewriteConfig & setStrictness(GreedyRewriteStrictness mode)
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:318
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:748
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:593
A variable that can be added to the constraint set as a "column".
static llvm::FailureOr< bool > strongCompare(const Variable &lhs, ComparisonOperator cmp, const Variable &rhs)
This function is similar to ValueBoundsConstraintSet::compare, except that it returns false if !...
ComparisonOperator
Comparison operator for ValueBoundsConstraintSet::compare.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
bool simplifyAffineMaxOp(RewriterBase &rewriter, AffineMaxOp op)
This transform tries to simplify the affine max operation op, by finding a common upper bound for a s...
void fullyComposeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands, bool composeAffineMin=false)
Given an affine map map and its input operands, this method composes into map, maps of AffineApplyOps...
Definition: AffineOps.cpp:1262
bool simplifyAffineMinOp(RewriterBase &rewriter, AffineMinOp op)
This transform tries to simplify the affine min operation op, by finding a common lower bound for a s...
LogicalResult simplifyAffineMinMaxOps(RewriterBase &rewriter, ArrayRef< Operation * > ops, bool *modified=nullptr)
This transform applies simplifyAffineMinOp and simplifyAffineMaxOp to all the affine....
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
const FrozenRewritePatternSet & patterns
@ ExistingAndNewOps
Only pre-existing and newly created ops are processed.
Creates a simplification pass for affine min/max/apply.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.