MLIR  22.0.0git
BufferDeallocationSimplification.cpp
Go to the documentation of this file.
1 //===- BufferDeallocationSimplification.cpp -------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic for optimizing `bufferization.dealloc` operations
10 // that requires more analysis than what can be supported by regular
11 // canonicalization patterns.
12 //
13 //===----------------------------------------------------------------------===//
14 
20 #include "mlir/IR/Matchers.h"
22 
23 namespace mlir {
24 namespace bufferization {
25 #define GEN_PASS_DEF_BUFFERDEALLOCATIONSIMPLIFICATIONPASS
26 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
27 } // namespace bufferization
28 } // namespace mlir
29 
30 using namespace mlir;
31 using namespace mlir::bufferization;
32 
33 //===----------------------------------------------------------------------===//
34 // Helpers
35 //===----------------------------------------------------------------------===//
36 
37 /// Given a memref value, return the "base" value by skipping over all
38 /// ViewLikeOpInterface ops (if any) in the reverse use-def chain.
39 static Value getViewBase(Value value) {
40  while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>()) {
41  if (value != viewLikeOp.getViewDest()) {
42  break;
43  }
44  value = viewLikeOp.getViewSource();
45  }
46  return value;
47 }
48 
49 static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp,
50  ValueRange memrefs,
51  ValueRange conditions,
52  PatternRewriter &rewriter) {
53  if (deallocOp.getMemrefs() == memrefs &&
54  deallocOp.getConditions() == conditions)
55  return failure();
56 
57  rewriter.modifyOpInPlace(deallocOp, [&]() {
58  deallocOp.getMemrefsMutable().assign(memrefs);
59  deallocOp.getConditionsMutable().assign(conditions);
60  });
61  return success();
62 }
63 
64 /// Return "true" if the given values are guaranteed to be different (and
65 /// non-aliasing) allocations based on the fact that one value is the result
66 /// of an allocation and the other value is a block argument of a parent block.
67 /// Note: This is a best-effort analysis that will eventually be replaced by a
68 /// proper "is same allocation" analysis. This function may return "false" even
69 /// though the two values are distinct allocations.
71  Value v1Base = getViewBase(v1);
72  Value v2Base = getViewBase(v2);
73  auto areDistinct = [](Value v1, Value v2) {
74  if (Operation *op = v1.getDefiningOp())
75  if (hasEffect<MemoryEffects::Allocate>(op, v1))
76  if (auto bbArg = dyn_cast<BlockArgument>(v2))
77  if (bbArg.getOwner()->findAncestorOpInBlock(*op))
78  return true;
79  return false;
80  };
81  return areDistinct(v1Base, v2Base) || areDistinct(v2Base, v1Base);
82 }
83 
84 /// Checks if `memref` may potentially alias a MemRef in `otherList`. It is
85 /// often a requirement of optimization patterns that there cannot be any
86 /// aliasing memref in order to perform the desired simplification.
88  ValueRange otherList, Value memref) {
89  for (auto other : otherList) {
90  if (distinctAllocAndBlockArgument(other, memref))
91  continue;
92  std::optional<bool> analysisResult =
93  analysis.isSameAllocation(other, memref);
94  if (!analysisResult.has_value() || analysisResult == true)
95  return true;
96  }
97  return false;
98 }
99 
100 //===----------------------------------------------------------------------===//
101 // Patterns
102 //===----------------------------------------------------------------------===//
103 
104 namespace {
105 
106 /// Remove values from the `memref` operand list that are also present in the
107 /// `retained` list (or a guaranteed alias of it) because they will never
108 /// actually be deallocated. However, we also need to be certain about which
109 /// other memrefs in the `retained` list can alias, i.e., there must not by any
110 /// may-aliasing memref. This is necessary because the `dealloc` operation is
111 /// defined to return one `i1` value per memref in the `retained` list which
112 /// represents the disjunction of the condition values corresponding to all
113 /// aliasing values in the `memref` list. In particular, this means that if
114 /// there is some value R in the `retained` list which aliases with a value M in
115 /// the `memref` list (but can only be staticaly determined to may-alias) and M
116 /// is also present in the `retained` list, then it would be illegal to remove M
117 /// because the result corresponding to R would be computed incorrectly
118 /// afterwards. Because we require an alias analysis, this pattern cannot be
119 /// applied as a regular canonicalization pattern.
120 ///
121 /// Example:
122 /// ```mlir
123 /// %0:3 = bufferization.dealloc (%m0 : ...) if (%cond0)
124 /// retain (%m0, %r0, %r1 : ...)
125 /// ```
126 /// is canonicalized to
127 /// ```mlir
128 /// // bufferization.dealloc without memrefs and conditions returns %false for
129 /// // every retained value
130 /// %0:3 = bufferization.dealloc retain (%m0, %r0, %r1 : ...)
131 /// %1 = arith.ori %0#0, %cond0 : i1
132 /// // replace %0#0 with %1
133 /// ```
134 /// given that `%r0` and `%r1` may not alias with `%m0`.
135 struct RemoveDeallocMemrefsContainedInRetained
136  : public OpRewritePattern<DeallocOp> {
137  RemoveDeallocMemrefsContainedInRetained(MLIRContext *context,
139  : OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
140 
141  /// The passed 'memref' must not have a may-alias relation to any retained
142  /// memref, and at least one must-alias relation. If there is no must-aliasing
143  /// memref in the retain list, we cannot simply remove the memref as there
144  /// could be situations in which it actually has to be deallocated. If it's
145  /// no-alias, then just proceed, if it's must-alias we need to update the
146  /// updated condition returned by the dealloc operation for that alias.
147  LogicalResult handleOneMemref(DeallocOp deallocOp, Value memref, Value cond,
148  PatternRewriter &rewriter) const {
149  rewriter.setInsertionPointAfter(deallocOp);
150 
151  // Check that there is no may-aliasing memref and that at least one memref
152  // in the retain list aliases (because otherwise it might have to be
153  // deallocated in some situations and can thus not be dropped).
154  bool atLeastOneMustAlias = false;
155  for (Value retained : deallocOp.getRetained()) {
156  std::optional<bool> analysisResult =
157  analysis.isSameAllocation(retained, memref);
158  if (!analysisResult.has_value())
159  return failure();
160  if (analysisResult == true)
161  atLeastOneMustAlias = true;
162  }
163  if (!atLeastOneMustAlias)
164  return failure();
165 
166  // Insert arith.ori operations to update the corresponding dealloc result
167  // values to incorporate the condition of the must-aliasing memref such that
168  // we can remove that operand later on.
169  for (auto [i, retained] : llvm::enumerate(deallocOp.getRetained())) {
170  Value updatedCondition = deallocOp.getUpdatedConditions()[i];
171  std::optional<bool> analysisResult =
172  analysis.isSameAllocation(retained, memref);
173  if (analysisResult == true) {
174  auto disjunction = arith::OrIOp::create(rewriter, deallocOp.getLoc(),
175  updatedCondition, cond);
176  rewriter.replaceAllUsesExcept(updatedCondition, disjunction.getResult(),
177  disjunction);
178  }
179  }
180 
181  return success();
182  }
183 
184  LogicalResult matchAndRewrite(DeallocOp deallocOp,
185  PatternRewriter &rewriter) const override {
186  // There must not be any duplicates in the retain list anymore because we
187  // would miss updating one of the result values otherwise.
188  DenseSet<Value> retained(deallocOp.getRetained().begin(),
189  deallocOp.getRetained().end());
190  if (retained.size() != deallocOp.getRetained().size())
191  return failure();
192 
193  SmallVector<Value> newMemrefs, newConditions;
194  for (auto [memref, cond] :
195  llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
196 
197  if (succeeded(handleOneMemref(deallocOp, memref, cond, rewriter)))
198  continue;
199 
200  if (auto extractOp =
201  memref.getDefiningOp<memref::ExtractStridedMetadataOp>())
202  if (succeeded(handleOneMemref(deallocOp, extractOp.getOperand(), cond,
203  rewriter)))
204  continue;
205 
206  newMemrefs.push_back(memref);
207  newConditions.push_back(cond);
208  }
209 
210  // Return failure if we don't change anything such that we don't run into an
211  // infinite loop of pattern applications.
212  return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
213  rewriter);
214  }
215 
216 private:
218 };
219 
220 /// Remove memrefs from the `retained` list which are guaranteed to not alias
221 /// any memref in the `memrefs` list. The corresponding result value can be
222 /// replaced with `false` in that case according to the operation description.
223 ///
224 /// Example:
225 /// ```mlir
226 /// %0:2 = bufferization.dealloc (%m : memref<2xi32>) if (%cond)
227 /// retain (%r0, %r1 : memref<2xi32>, memref<2xi32>)
228 /// return %0#0, %0#1
229 /// ```
230 /// can be canonicalized to the following given that `%r0` and `%r1` do not
231 /// alias `%m`:
232 /// ```mlir
233 /// bufferization.dealloc (%m : memref<2xi32>) if (%cond)
234 /// return %false, %false
235 /// ```
236 struct RemoveRetainedMemrefsGuaranteedToNotAlias
237  : public OpRewritePattern<DeallocOp> {
238  RemoveRetainedMemrefsGuaranteedToNotAlias(MLIRContext *context,
240  : OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
241 
242  LogicalResult matchAndRewrite(DeallocOp deallocOp,
243  PatternRewriter &rewriter) const override {
244  SmallVector<Value> newRetainedMemrefs, replacements;
245 
246  for (auto retainedMemref : deallocOp.getRetained()) {
247  if (potentiallyAliasesMemref(analysis, deallocOp.getMemrefs(),
248  retainedMemref)) {
249  newRetainedMemrefs.push_back(retainedMemref);
250  replacements.push_back({});
251  continue;
252  }
253 
254  replacements.push_back(arith::ConstantOp::create(
255  rewriter, deallocOp.getLoc(), rewriter.getBoolAttr(false)));
256  }
257 
258  if (newRetainedMemrefs.size() == deallocOp.getRetained().size())
259  return failure();
260 
261  auto newDeallocOp =
262  DeallocOp::create(rewriter, deallocOp.getLoc(), deallocOp.getMemrefs(),
263  deallocOp.getConditions(), newRetainedMemrefs);
264  int i = 0;
265  for (auto &repl : replacements) {
266  if (!repl)
267  repl = newDeallocOp.getUpdatedConditions()[i++];
268  }
269 
270  rewriter.replaceOp(deallocOp, replacements);
271  return success();
272  }
273 
274 private:
276 };
277 
278 /// Split off memrefs to separate dealloc operations to reduce the number of
279 /// runtime checks required and enable further canonicalization of the new and
280 /// simpler dealloc operations. A memref can be split off if it is guaranteed to
281 /// not alias with any other memref in the `memref` operand list. The results
282 /// of the old and the new dealloc operation have to be combined by computing
283 /// the element-wise disjunction of them.
284 ///
285 /// Example:
286 /// ```mlir
287 /// %0:2 = bufferization.dealloc (%m0, %m1 : memref<2xi32>, memref<2xi32>)
288 /// if (%cond0, %cond1)
289 /// retain (%r0, %r1 : memref<2xi32>, memref<2xi32>)
290 /// return %0#0, %0#1
291 /// ```
292 /// Given that `%m0` is guaranteed to never alias with `%m1`, the above IR is
293 /// canonicalized to the following, thus reducing the number of runtime alias
294 /// checks by 1 and potentially enabling further canonicalization of the new
295 /// split-up dealloc operations.
296 /// ```mlir
297 /// %0:2 = bufferization.dealloc (%m0 : memref<2xi32>) if (%cond0)
298 /// retain (%r0, %r1 : memref<2xi32>, memref<2xi32>)
299 /// %1:2 = bufferization.dealloc (%m1 : memref<2xi32>) if (%cond1)
300 /// retain (%r0, %r1 : memref<2xi32>, memref<2xi32>)
301 /// %2 = arith.ori %0#0, %1#0
302 /// %3 = arith.ori %0#1, %1#1
303 /// return %2, %3
304 /// ```
305 struct SplitDeallocWhenNotAliasingAnyOther
306  : public OpRewritePattern<DeallocOp> {
307  SplitDeallocWhenNotAliasingAnyOther(MLIRContext *context,
309  : OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
310 
311  LogicalResult matchAndRewrite(DeallocOp deallocOp,
312  PatternRewriter &rewriter) const override {
313  Location loc = deallocOp.getLoc();
314  if (deallocOp.getMemrefs().size() <= 1)
315  return failure();
316 
317  SmallVector<Value> remainingMemrefs, remainingConditions;
318  SmallVector<SmallVector<Value>> updatedConditions;
319  for (int64_t i = 0, e = deallocOp.getMemrefs().size(); i < e; ++i) {
320  Value memref = deallocOp.getMemrefs()[i];
321  Value cond = deallocOp.getConditions()[i];
322  SmallVector<Value> otherMemrefs(deallocOp.getMemrefs());
323  otherMemrefs.erase(otherMemrefs.begin() + i);
324  // Check if `memref` can split off into a separate bufferization.dealloc.
325  if (potentiallyAliasesMemref(analysis, otherMemrefs, memref)) {
326  // `memref` alias with other memrefs, do not split off.
327  remainingMemrefs.push_back(memref);
328  remainingConditions.push_back(cond);
329  continue;
330  }
331 
332  // Create new bufferization.dealloc op for `memref`.
333  auto newDeallocOp = DeallocOp::create(rewriter, loc, memref, cond,
334  deallocOp.getRetained());
335  updatedConditions.push_back(
336  llvm::to_vector(ValueRange(newDeallocOp.getUpdatedConditions())));
337  }
338 
339  // Fail if no memref was split off.
340  if (remainingMemrefs.size() == deallocOp.getMemrefs().size())
341  return failure();
342 
343  // Create bufferization.dealloc op for all remaining memrefs.
344  auto newDeallocOp =
345  DeallocOp::create(rewriter, loc, remainingMemrefs, remainingConditions,
346  deallocOp.getRetained());
347 
348  // Bit-or all conditions.
349  SmallVector<Value> replacements =
350  llvm::to_vector(ValueRange(newDeallocOp.getUpdatedConditions()));
351  for (auto additionalConditions : updatedConditions) {
352  assert(replacements.size() == additionalConditions.size() &&
353  "expected same number of updated conditions");
354  for (int64_t i = 0, e = replacements.size(); i < e; ++i) {
355  replacements[i] = arith::OrIOp::create(rewriter, loc, replacements[i],
356  additionalConditions[i]);
357  }
358  }
359  rewriter.replaceOp(deallocOp, replacements);
360  return success();
361  }
362 
363 private:
365 };
366 
367 /// Check for every retained memref if a must-aliasing memref exists in the
368 /// 'memref' operand list with constant 'true' condition. If so, we can replace
369 /// the operation result corresponding to that retained memref with 'true'. If
370 /// this condition holds for all retained memrefs we can also remove the
371 /// aliasing memrefs and their conditions since they will never be deallocated
372 /// due to the must-alias and we don't need them to compute the result value
373 /// anymore since it got replaced with 'true'.
374 ///
375 /// Example:
376 /// ```mlir
377 /// %0:2 = bufferization.dealloc (%arg0, %arg1, %arg2 : ...)
378 /// if (%true, %true, %true)
379 /// retain (%arg0, %arg1 : memref<2xi32>, memref<2xi32>)
380 /// ```
381 /// becomes
382 /// ```mlir
383 /// %0:2 = bufferization.dealloc (%arg2 : memref<2xi32>) if (%true)
384 /// retain (%arg0, %arg1 : memref<2xi32>, memref<2xi32>)
385 /// // replace %0#0 with %true
386 /// // replace %0#1 with %true
387 /// ```
388 /// Note that the dealloc operation will still have the result values, but they
389 /// don't have uses anymore.
390 struct RetainedMemrefAliasingAlwaysDeallocatedMemref
391  : public OpRewritePattern<DeallocOp> {
392  RetainedMemrefAliasingAlwaysDeallocatedMemref(MLIRContext *context,
394  : OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
395 
396  LogicalResult matchAndRewrite(DeallocOp deallocOp,
397  PatternRewriter &rewriter) const override {
398  BitVector aliasesWithConstTrueMemref(deallocOp.getRetained().size());
399  SmallVector<Value> newMemrefs, newConditions;
400  for (auto [memref, cond] :
401  llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
402  bool canDropMemref = false;
403  for (auto [i, retained, res] : llvm::enumerate(
404  deallocOp.getRetained(), deallocOp.getUpdatedConditions())) {
405  if (!matchPattern(cond, m_One()))
406  continue;
407 
408  std::optional<bool> analysisResult =
409  analysis.isSameAllocation(retained, memref);
410  if (analysisResult == true) {
411  rewriter.replaceAllUsesWith(res, cond);
412  aliasesWithConstTrueMemref[i] = true;
413  canDropMemref = true;
414  continue;
415  }
416 
417  // TODO: once our alias analysis is powerful enough we can remove the
418  // rest of this loop body
419  auto extractOp =
420  memref.getDefiningOp<memref::ExtractStridedMetadataOp>();
421  if (!extractOp)
422  continue;
423 
424  std::optional<bool> extractAnalysisResult =
425  analysis.isSameAllocation(retained, extractOp.getOperand());
426  if (extractAnalysisResult == true) {
427  rewriter.replaceAllUsesWith(res, cond);
428  aliasesWithConstTrueMemref[i] = true;
429  canDropMemref = true;
430  }
431  }
432 
433  if (!canDropMemref) {
434  newMemrefs.push_back(memref);
435  newConditions.push_back(cond);
436  }
437  }
438  if (!aliasesWithConstTrueMemref.all())
439  return failure();
440 
441  return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
442  rewriter);
443  }
444 
445 private:
447 };
448 
449 } // namespace
450 
451 //===----------------------------------------------------------------------===//
452 // BufferDeallocationSimplificationPass
453 //===----------------------------------------------------------------------===//
454 
455 namespace {
456 
457 /// The actual buffer deallocation pass that inserts and moves dealloc nodes
458 /// into the right positions. Furthermore, it inserts additional clones if
459 /// necessary. It uses the algorithm described at the top of the file.
460 struct BufferDeallocationSimplificationPass
461  : public bufferization::impl::BufferDeallocationSimplificationPassBase<
462  BufferDeallocationSimplificationPass> {
463  void runOnOperation() override {
464  BufferOriginAnalysis analysis(getOperation());
466  patterns.add<RemoveDeallocMemrefsContainedInRetained,
467  RemoveRetainedMemrefsGuaranteedToNotAlias,
468  SplitDeallocWhenNotAliasingAnyOther,
469  RetainedMemrefAliasingAlwaysDeallocatedMemref>(&getContext(),
470  analysis);
471 
473  // We don't want that the block structure changes invalidating the
474  // `BufferOriginAnalysis` so we apply the rewrites with `Normal` level of
475  // region simplification
477  getOperation(), std::move(patterns),
478  GreedyRewriteConfig().setRegionSimplificationLevel(
480  signalPassFailure();
481  }
482 };
483 
484 } // namespace
static bool potentiallyAliasesMemref(BufferOriginAnalysis &analysis, ValueRange otherList, Value memref)
Checks if memref may potentially alias a MemRef in otherList.
static Value getViewBase(Value value)
Given a memref value, return the "base" value by skipping over all ViewLikeOpInterface ops (if any) i...
static bool distinctAllocAndBlockArgument(Value v1, Value v2)
Return "true" if the given values are guaranteed to be different (and non-aliasing) allocations based...
static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp, ValueRange memrefs, ValueRange conditions, PatternRewriter &rewriter)
static MLIRContext * getContext(OpFoldResult val)
An is-same-buffer analysis that checks if two SSA values belong to the same buffer allocation or not.
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:95
This class allows control over how the GreedyPatternRewriteDriver works.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:636
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:700
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:628
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
void populateDeallocOpCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context)
Add the canonicalization patterns for bufferization.dealloc to the given pattern set to make them ava...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
detail::InFlightRemark analysis(Location loc, RemarkOpts opts)
Report an optimization analysis remark.
Definition: Remarks.h:497
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
@ Normal
Run the normal simplification (e.g. dead args elimination).
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:478
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314