MLIR  22.0.0git
SimplifyAffineMinMax.cpp
Go to the documentation of this file.
1 //===- SimplifyAffineMinMax.cpp - Simplify affine min/max ops -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a transform to simplify mix/max affine operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
17 #include "mlir/IR/PatternMatch.h"
21 #include "llvm/ADT/IntEqClasses.h"
22 #include "llvm/Support/DebugLog.h"
23 #include "llvm/Support/InterleavedRange.h"
24 
25 #define DEBUG_TYPE "affine-min-max"
26 
27 using namespace mlir;
28 using namespace mlir::affine;
29 
30 /// Simplifies an affine min/max operation by proving there's a lower or upper
31 /// bound.
32 template <typename AffineOp>
33 static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) {
34  using Variable = ValueBoundsConstraintSet::Variable;
35  using ComparisonOperator = ValueBoundsConstraintSet::ComparisonOperator;
36 
37  AffineMap affineMap = affineOp.getMap();
38  ValueRange operands = affineOp.getOperands();
39  static constexpr bool isMin = std::is_same_v<AffineOp, AffineMinOp>;
40 
41  LDBG() << "analyzing value: `" << affineOp;
42 
43  // Create a `Variable` list with values corresponding to each of the results
44  // in the affine affineMap.
45  SmallVector<Variable> variables = llvm::map_to_vector(
46  llvm::iota_range<unsigned>(0u, affineMap.getNumResults(), false),
47  [&](unsigned i) {
48  return Variable(affineMap.getSliceMap(i, 1), operands);
49  });
50  LDBG() << "- constructed variables are: "
51  << llvm::interleaved_array(llvm::map_range(
52  variables, [](const Variable &v) { return v.getMap(); }));
53 
54  // Get the comparison operation.
55  ComparisonOperator cmpOp =
56  isMin ? ComparisonOperator::LT : ComparisonOperator::GT;
57 
58  // Find disjoint sets bounded by a common value.
59  llvm::IntEqClasses boundedClasses(variables.size());
61  for (auto &&[i, v] : llvm::enumerate(variables)) {
62  unsigned eqClass = boundedClasses.findLeader(i);
63 
64  // If the class already has a bound continue.
65  if (bounds.contains(eqClass))
66  continue;
67 
68  // Initialize the bound.
69  Variable *bound = &v;
70 
71  LDBG() << "- inspecting variable: #" << i << ", with map: `" << v.getMap()
72  << "`\n";
73 
74  // Check against the other variables.
75  for (size_t j = i + 1; j < variables.size(); ++j) {
76  unsigned jEqClass = boundedClasses.findLeader(j);
77  // Skip if the class is the same.
78  if (jEqClass == eqClass)
79  continue;
80 
81  // Get the bound of the equivalence class or itself.
82  Variable *nv = bounds.lookup_or(jEqClass, &variables[j]);
83 
84  LDBG() << "- comparing with variable: #" << jEqClass
85  << ", with map: " << nv->getMap();
86 
87  // Compare the variables.
88  FailureOr<bool> cmpResult =
89  ValueBoundsConstraintSet::strongCompare(*bound, cmpOp, *nv);
90 
91  // The variables cannot be compared.
92  if (failed(cmpResult)) {
93  LDBG() << "-- classes: #" << i << ", #" << jEqClass
94  << " cannot be merged";
95  continue;
96  }
97 
98  // Join the equivalent classes and update the bound if necessary.
99  LDBG() << "-- merging classes: #" << i << ", #" << jEqClass
100  << ", is cmp(lhs, rhs): " << *cmpResult << "`";
101  if (*cmpResult) {
102  boundedClasses.join(eqClass, jEqClass);
103  } else {
104  // In this case we have lhs > rhs if isMin == true, or lhs < rhs if
105  // isMin == false.
106  bound = nv;
107  boundedClasses.join(eqClass, jEqClass);
108  }
109  }
110  bounds[boundedClasses.findLeader(i)] = bound;
111  }
112 
113  // Return if there's no simplification.
114  if (bounds.size() >= affineMap.getNumResults()) {
115  LDBG() << "- the affine operation couldn't get simplified";
116  return false;
117  }
118 
119  // Construct the new affine affineMap.
120  SmallVector<AffineExpr> results;
121  results.reserve(bounds.size());
122  for (auto [k, bound] : bounds)
123  results.push_back(bound->getMap().getResult(0));
124 
125  LDBG() << "- starting from map: " << affineMap;
126  LDBG() << "- creating new map with:";
127  LDBG() << "--- dims: " << affineMap.getNumDims();
128  LDBG() << "--- syms: " << affineMap.getNumSymbols();
129  LDBG() << "--- res: " << llvm::interleaved_array(results);
130 
131  affineMap =
132  AffineMap::get(0, affineMap.getNumSymbols() + affineMap.getNumDims(),
133  results, rewriter.getContext());
134 
135  // Update the affine op.
136  rewriter.modifyOpInPlace(affineOp, [&]() { affineOp.setMap(affineMap); });
137  LDBG() << "- simplified affine op: `" << affineOp << "`";
138  return true;
139 }
140 
141 bool mlir::affine::simplifyAffineMinOp(RewriterBase &rewriter, AffineMinOp op) {
142  return simplifyAffineMinMaxOp(rewriter, op);
143 }
144 
145 bool mlir::affine::simplifyAffineMaxOp(RewriterBase &rewriter, AffineMaxOp op) {
146  return simplifyAffineMinMaxOp(rewriter, op);
147 }
148 
151  bool *modified) {
152  bool changed = false;
153  for (Operation *op : ops) {
154  if (auto minOp = dyn_cast<AffineMinOp>(op))
155  changed = simplifyAffineMinOp(rewriter, minOp) || changed;
156  else if (auto maxOp = cast<AffineMaxOp>(op))
157  changed = simplifyAffineMaxOp(rewriter, maxOp) || changed;
158  }
160  AffineMaxOp::getCanonicalizationPatterns(patterns, rewriter.getContext());
161  AffineMinOp::getCanonicalizationPatterns(patterns, rewriter.getContext());
162  FrozenRewritePatternSet frozenPatterns(std::move(patterns));
163  if (modified)
164  *modified = changed;
165  // Canonicalize to a fixpoint.
166  if (failed(applyOpPatternsGreedily(
167  ops, frozenPatterns,
169  .setListener(
170  static_cast<RewriterBase::Listener *>(rewriter.getListener()))
172  &changed))) {
173  return failure();
174  }
175  if (modified)
176  *modified = changed;
177  return success();
178 }
179 
180 namespace {
181 
182 struct SimplifyAffineMaxOp : public OpRewritePattern<AffineMaxOp> {
184 
185  LogicalResult matchAndRewrite(AffineMaxOp affineOp,
186  PatternRewriter &rewriter) const override {
187  return success(simplifyAffineMaxOp(rewriter, affineOp));
188  }
189 };
190 
191 struct SimplifyAffineMinOp : public OpRewritePattern<AffineMinOp> {
193 
194  LogicalResult matchAndRewrite(AffineMinOp affineOp,
195  PatternRewriter &rewriter) const override {
196  return success(simplifyAffineMinOp(rewriter, affineOp));
197  }
198 };
199 
200 struct SimplifyAffineApplyOp : public OpRewritePattern<AffineApplyOp> {
202 
203  LogicalResult matchAndRewrite(AffineApplyOp affineOp,
204  PatternRewriter &rewriter) const override {
205  AffineMap map = affineOp.getAffineMap();
206  SmallVector<Value> operands{affineOp->getOperands().begin(),
207  affineOp->getOperands().end()};
208  fullyComposeAffineMapAndOperands(&map, &operands,
209  /*composeAffineMin=*/true);
210 
211  // No change => failure to apply.
212  if (map == affineOp.getAffineMap())
213  return failure();
214 
215  rewriter.modifyOpInPlace(affineOp, [&]() {
216  affineOp.setMap(map);
217  affineOp->setOperands(operands);
218  });
219  return success();
220  }
221 };
222 
223 } // namespace
224 
225 namespace mlir {
226 namespace affine {
227 #define GEN_PASS_DEF_SIMPLIFYAFFINEMINMAXPASS
228 #include "mlir/Dialect/Affine/Passes.h.inc"
229 } // namespace affine
230 } // namespace mlir
231 
232 /// Creates a simplification pass for affine min/max/apply.
234  : public affine::impl::SimplifyAffineMinMaxPassBase<
235  SimplifyAffineMinMaxPass> {
236  void runOnOperation() override;
237 };
238 
240  FunctionOpInterface func = getOperation();
241  RewritePatternSet patterns(func.getContext());
242  AffineMaxOp::getCanonicalizationPatterns(patterns, func.getContext());
243  AffineMinOp::getCanonicalizationPatterns(patterns, func.getContext());
244  patterns.add<SimplifyAffineMaxOp, SimplifyAffineMinOp, SimplifyAffineApplyOp>(
245  func.getContext());
246  FrozenRewritePatternSet frozenPatterns(std::move(patterns));
247  if (failed(applyPatternsGreedily(func, std::move(frozenPatterns))))
248  return signalPassFailure();
249 }
static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp)
Simplifies an affine min/max operation by proving there's a lower or upper bound.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
Definition: AffineMap.cpp:394
unsigned getNumDims() const
Definition: AffineMap.cpp:390
unsigned getNumResults() const
Definition: AffineMap.cpp:398
MLIRContext * getContext() const
Definition: Builders.h:55
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
GreedyRewriteConfig & setStrictness(GreedyRewriteStrictness mode)
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:318
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:628
A variable that can be added to the constraint set as a "column".
static llvm::FailureOr< bool > strongCompare(const Variable &lhs, ComparisonOperator cmp, const Variable &rhs)
This function is similar to ValueBoundsConstraintSet::compare, except that it returns false if !...
ComparisonOperator
Comparison operator for ValueBoundsConstraintSet::compare.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
bool simplifyAffineMaxOp(RewriterBase &rewriter, AffineMaxOp op)
This transform tries to simplify the affine max operation op, by finding a common upper bound for a s...
void fullyComposeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands, bool composeAffineMin=false)
Given an affine map map and its input operands, this method composes into map, maps of AffineApplyOps...
Definition: AffineOps.cpp:1258
bool simplifyAffineMinOp(RewriterBase &rewriter, AffineMinOp op)
This transform tries to simplify the affine min operation op, by finding a common lower bound for a s...
LogicalResult simplifyAffineMinMaxOps(RewriterBase &rewriter, ArrayRef< Operation * > ops, bool *modified=nullptr)
This transform applies simplifyAffineMinOp and simplifyAffineMaxOp to all the affine....
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
const FrozenRewritePatternSet & patterns
@ ExistingAndNewOps
Only pre-existing and newly created ops are processed.
Creates a simplification pass for affine min/max/apply.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.