MLIR  19.0.0git
BufferDeallocationSimplification.cpp
Go to the documentation of this file.
1 //===- BufferDeallocationSimplification.cpp -------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic for optimizing `bufferization.dealloc` operations
10 // that requires more analysis than what can be supported by regular
11 // canonicalization patterns.
12 //
13 //===----------------------------------------------------------------------===//
14 
20 #include "mlir/IR/Matchers.h"
22 
23 namespace mlir {
24 namespace bufferization {
25 #define GEN_PASS_DEF_BUFFERDEALLOCATIONSIMPLIFICATION
26 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
27 } // namespace bufferization
28 } // namespace mlir
29 
30 using namespace mlir;
31 using namespace mlir::bufferization;
32 
33 //===----------------------------------------------------------------------===//
34 // Helpers
35 //===----------------------------------------------------------------------===//
36 
37 /// Given a memref value, return the "base" value by skipping over all
38 /// ViewLikeOpInterface ops (if any) in the reverse use-def chain.
39 static Value getViewBase(Value value) {
40  while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>())
41  value = viewLikeOp.getViewSource();
42  return value;
43 }
44 
45 static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp,
46  ValueRange memrefs,
47  ValueRange conditions,
48  PatternRewriter &rewriter) {
49  if (deallocOp.getMemrefs() == memrefs &&
50  deallocOp.getConditions() == conditions)
51  return failure();
52 
53  rewriter.modifyOpInPlace(deallocOp, [&]() {
54  deallocOp.getMemrefsMutable().assign(memrefs);
55  deallocOp.getConditionsMutable().assign(conditions);
56  });
57  return success();
58 }
59 
60 /// Return "true" if the given values are guaranteed to be different (and
61 /// non-aliasing) allocations based on the fact that one value is the result
62 /// of an allocation and the other value is a block argument of a parent block.
63 /// Note: This is a best-effort analysis that will eventually be replaced by a
64 /// proper "is same allocation" analysis. This function may return "false" even
65 /// though the two values are distinct allocations.
67  Value v1Base = getViewBase(v1);
68  Value v2Base = getViewBase(v2);
69  auto areDistinct = [](Value v1, Value v2) {
70  if (Operation *op = v1.getDefiningOp())
71  if (hasEffect<MemoryEffects::Allocate>(op, v1))
72  if (auto bbArg = dyn_cast<BlockArgument>(v2))
73  if (bbArg.getOwner()->findAncestorOpInBlock(*op))
74  return true;
75  return false;
76  };
77  return areDistinct(v1Base, v2Base) || areDistinct(v2Base, v1Base);
78 }
79 
80 /// Checks if `memref` may potentially alias a MemRef in `otherList`. It is
81 /// often a requirement of optimization patterns that there cannot be any
82 /// aliasing memref in order to perform the desired simplification.
84  ValueRange otherList, Value memref) {
85  for (auto other : otherList) {
86  if (distinctAllocAndBlockArgument(other, memref))
87  continue;
88  std::optional<bool> analysisResult =
89  analysis.isSameAllocation(other, memref);
90  if (!analysisResult.has_value() || analysisResult == true)
91  return true;
92  }
93  return false;
94 }
95 
96 //===----------------------------------------------------------------------===//
97 // Patterns
98 //===----------------------------------------------------------------------===//
99 
100 namespace {
101 
102 /// Remove values from the `memref` operand list that are also present in the
103 /// `retained` list (or a guaranteed alias of it) because they will never
104 /// actually be deallocated. However, we also need to be certain about which
105 /// other memrefs in the `retained` list can alias, i.e., there must not by any
106 /// may-aliasing memref. This is necessary because the `dealloc` operation is
107 /// defined to return one `i1` value per memref in the `retained` list which
108 /// represents the disjunction of the condition values corresponding to all
109 /// aliasing values in the `memref` list. In particular, this means that if
110 /// there is some value R in the `retained` list which aliases with a value M in
111 /// the `memref` list (but can only be staticaly determined to may-alias) and M
112 /// is also present in the `retained` list, then it would be illegal to remove M
113 /// because the result corresponding to R would be computed incorrectly
114 /// afterwards. Because we require an alias analysis, this pattern cannot be
115 /// applied as a regular canonicalization pattern.
116 ///
117 /// Example:
118 /// ```mlir
119 /// %0:3 = bufferization.dealloc (%m0 : ...) if (%cond0)
120 /// retain (%m0, %r0, %r1 : ...)
121 /// ```
122 /// is canonicalized to
123 /// ```mlir
124 /// // bufferization.dealloc without memrefs and conditions returns %false for
125 /// // every retained value
126 /// %0:3 = bufferization.dealloc retain (%m0, %r0, %r1 : ...)
127 /// %1 = arith.ori %0#0, %cond0 : i1
128 /// // replace %0#0 with %1
129 /// ```
130 /// given that `%r0` and `%r1` may not alias with `%m0`.
131 struct RemoveDeallocMemrefsContainedInRetained
132  : public OpRewritePattern<DeallocOp> {
133  RemoveDeallocMemrefsContainedInRetained(MLIRContext *context,
134  BufferOriginAnalysis &analysis)
135  : OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
136 
137  /// The passed 'memref' must not have a may-alias relation to any retained
138  /// memref, and at least one must-alias relation. If there is no must-aliasing
139  /// memref in the retain list, we cannot simply remove the memref as there
140  /// could be situations in which it actually has to be deallocated. If it's
141  /// no-alias, then just proceed, if it's must-alias we need to update the
142  /// updated condition returned by the dealloc operation for that alias.
143  LogicalResult handleOneMemref(DeallocOp deallocOp, Value memref, Value cond,
144  PatternRewriter &rewriter) const {
145  rewriter.setInsertionPointAfter(deallocOp);
146 
147  // Check that there is no may-aliasing memref and that at least one memref
148  // in the retain list aliases (because otherwise it might have to be
149  // deallocated in some situations and can thus not be dropped).
150  bool atLeastOneMustAlias = false;
151  for (Value retained : deallocOp.getRetained()) {
152  std::optional<bool> analysisResult =
153  analysis.isSameAllocation(retained, memref);
154  if (!analysisResult.has_value())
155  return failure();
156  if (analysisResult == true)
157  atLeastOneMustAlias = true;
158  }
159  if (!atLeastOneMustAlias)
160  return failure();
161 
162  // Insert arith.ori operations to update the corresponding dealloc result
163  // values to incorporate the condition of the must-aliasing memref such that
164  // we can remove that operand later on.
165  for (auto [i, retained] : llvm::enumerate(deallocOp.getRetained())) {
166  Value updatedCondition = deallocOp.getUpdatedConditions()[i];
167  std::optional<bool> analysisResult =
168  analysis.isSameAllocation(retained, memref);
169  if (analysisResult == true) {
170  auto disjunction = rewriter.create<arith::OrIOp>(
171  deallocOp.getLoc(), updatedCondition, cond);
172  rewriter.replaceAllUsesExcept(updatedCondition, disjunction.getResult(),
173  disjunction);
174  }
175  }
176 
177  return success();
178  }
179 
180  LogicalResult matchAndRewrite(DeallocOp deallocOp,
181  PatternRewriter &rewriter) const override {
182  // There must not be any duplicates in the retain list anymore because we
183  // would miss updating one of the result values otherwise.
184  DenseSet<Value> retained(deallocOp.getRetained().begin(),
185  deallocOp.getRetained().end());
186  if (retained.size() != deallocOp.getRetained().size())
187  return failure();
188 
189  SmallVector<Value> newMemrefs, newConditions;
190  for (auto [memref, cond] :
191  llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
192 
193  if (succeeded(handleOneMemref(deallocOp, memref, cond, rewriter)))
194  continue;
195 
196  if (auto extractOp =
197  memref.getDefiningOp<memref::ExtractStridedMetadataOp>())
198  if (succeeded(handleOneMemref(deallocOp, extractOp.getOperand(), cond,
199  rewriter)))
200  continue;
201 
202  newMemrefs.push_back(memref);
203  newConditions.push_back(cond);
204  }
205 
206  // Return failure if we don't change anything such that we don't run into an
207  // infinite loop of pattern applications.
208  return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
209  rewriter);
210  }
211 
212 private:
213  BufferOriginAnalysis &analysis;
214 };
215 
216 /// Remove memrefs from the `retained` list which are guaranteed to not alias
217 /// any memref in the `memrefs` list. The corresponding result value can be
218 /// replaced with `false` in that case according to the operation description.
219 ///
220 /// Example:
221 /// ```mlir
222 /// %0:2 = bufferization.dealloc (%m : memref<2xi32>) if (%cond)
223 /// retain (%r0, %r1 : memref<2xi32>, memref<2xi32>)
224 /// return %0#0, %0#1
225 /// ```
226 /// can be canonicalized to the following given that `%r0` and `%r1` do not
227 /// alias `%m`:
228 /// ```mlir
229 /// bufferization.dealloc (%m : memref<2xi32>) if (%cond)
230 /// return %false, %false
231 /// ```
232 struct RemoveRetainedMemrefsGuaranteedToNotAlias
233  : public OpRewritePattern<DeallocOp> {
234  RemoveRetainedMemrefsGuaranteedToNotAlias(MLIRContext *context,
235  BufferOriginAnalysis &analysis)
236  : OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
237 
238  LogicalResult matchAndRewrite(DeallocOp deallocOp,
239  PatternRewriter &rewriter) const override {
240  SmallVector<Value> newRetainedMemrefs, replacements;
241 
242  for (auto retainedMemref : deallocOp.getRetained()) {
243  if (potentiallyAliasesMemref(analysis, deallocOp.getMemrefs(),
244  retainedMemref)) {
245  newRetainedMemrefs.push_back(retainedMemref);
246  replacements.push_back({});
247  continue;
248  }
249 
250  replacements.push_back(rewriter.create<arith::ConstantOp>(
251  deallocOp.getLoc(), rewriter.getBoolAttr(false)));
252  }
253 
254  if (newRetainedMemrefs.size() == deallocOp.getRetained().size())
255  return failure();
256 
257  auto newDeallocOp = rewriter.create<DeallocOp>(
258  deallocOp.getLoc(), deallocOp.getMemrefs(), deallocOp.getConditions(),
259  newRetainedMemrefs);
260  int i = 0;
261  for (auto &repl : replacements) {
262  if (!repl)
263  repl = newDeallocOp.getUpdatedConditions()[i++];
264  }
265 
266  rewriter.replaceOp(deallocOp, replacements);
267  return success();
268  }
269 
270 private:
271  BufferOriginAnalysis &analysis;
272 };
273 
274 /// Split off memrefs to separate dealloc operations to reduce the number of
275 /// runtime checks required and enable further canonicalization of the new and
276 /// simpler dealloc operations. A memref can be split off if it is guaranteed to
277 /// not alias with any other memref in the `memref` operand list. The results
278 /// of the old and the new dealloc operation have to be combined by computing
279 /// the element-wise disjunction of them.
280 ///
281 /// Example:
282 /// ```mlir
283 /// %0:2 = bufferization.dealloc (%m0, %m1 : memref<2xi32>, memref<2xi32>)
284 /// if (%cond0, %cond1)
285 /// retain (%r0, %r1 : memref<2xi32>, memref<2xi32>)
286 /// return %0#0, %0#1
287 /// ```
288 /// Given that `%m0` is guaranteed to never alias with `%m1`, the above IR is
289 /// canonicalized to the following, thus reducing the number of runtime alias
290 /// checks by 1 and potentially enabling further canonicalization of the new
291 /// split-up dealloc operations.
292 /// ```mlir
293 /// %0:2 = bufferization.dealloc (%m0 : memref<2xi32>) if (%cond0)
294 /// retain (%r0, %r1 : memref<2xi32>, memref<2xi32>)
295 /// %1:2 = bufferization.dealloc (%m1 : memref<2xi32>) if (%cond1)
296 /// retain (%r0, %r1 : memref<2xi32>, memref<2xi32>)
297 /// %2 = arith.ori %0#0, %1#0
298 /// %3 = arith.ori %0#1, %1#1
299 /// return %2, %3
300 /// ```
301 struct SplitDeallocWhenNotAliasingAnyOther
302  : public OpRewritePattern<DeallocOp> {
303  SplitDeallocWhenNotAliasingAnyOther(MLIRContext *context,
304  BufferOriginAnalysis &analysis)
305  : OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
306 
307  LogicalResult matchAndRewrite(DeallocOp deallocOp,
308  PatternRewriter &rewriter) const override {
309  Location loc = deallocOp.getLoc();
310  if (deallocOp.getMemrefs().size() <= 1)
311  return failure();
312 
313  SmallVector<Value> remainingMemrefs, remainingConditions;
314  SmallVector<SmallVector<Value>> updatedConditions;
315  for (int64_t i = 0, e = deallocOp.getMemrefs().size(); i < e; ++i) {
316  Value memref = deallocOp.getMemrefs()[i];
317  Value cond = deallocOp.getConditions()[i];
318  SmallVector<Value> otherMemrefs(deallocOp.getMemrefs());
319  otherMemrefs.erase(otherMemrefs.begin() + i);
320  // Check if `memref` can split off into a separate bufferization.dealloc.
321  if (potentiallyAliasesMemref(analysis, otherMemrefs, memref)) {
322  // `memref` alias with other memrefs, do not split off.
323  remainingMemrefs.push_back(memref);
324  remainingConditions.push_back(cond);
325  continue;
326  }
327 
328  // Create new bufferization.dealloc op for `memref`.
329  auto newDeallocOp = rewriter.create<DeallocOp>(loc, memref, cond,
330  deallocOp.getRetained());
331  updatedConditions.push_back(
332  llvm::to_vector(ValueRange(newDeallocOp.getUpdatedConditions())));
333  }
334 
335  // Fail if no memref was split off.
336  if (remainingMemrefs.size() == deallocOp.getMemrefs().size())
337  return failure();
338 
339  // Create bufferization.dealloc op for all remaining memrefs.
340  auto newDeallocOp = rewriter.create<DeallocOp>(
341  loc, remainingMemrefs, remainingConditions, deallocOp.getRetained());
342 
343  // Bit-or all conditions.
344  SmallVector<Value> replacements =
345  llvm::to_vector(ValueRange(newDeallocOp.getUpdatedConditions()));
346  for (auto additionalConditions : updatedConditions) {
347  assert(replacements.size() == additionalConditions.size() &&
348  "expected same number of updated conditions");
349  for (int64_t i = 0, e = replacements.size(); i < e; ++i) {
350  replacements[i] = rewriter.create<arith::OrIOp>(
351  loc, replacements[i], additionalConditions[i]);
352  }
353  }
354  rewriter.replaceOp(deallocOp, replacements);
355  return success();
356  }
357 
358 private:
359  BufferOriginAnalysis &analysis;
360 };
361 
362 /// Check for every retained memref if a must-aliasing memref exists in the
363 /// 'memref' operand list with constant 'true' condition. If so, we can replace
364 /// the operation result corresponding to that retained memref with 'true'. If
365 /// this condition holds for all retained memrefs we can also remove the
366 /// aliasing memrefs and their conditions since they will never be deallocated
367 /// due to the must-alias and we don't need them to compute the result value
368 /// anymore since it got replaced with 'true'.
369 ///
370 /// Example:
371 /// ```mlir
372 /// %0:2 = bufferization.dealloc (%arg0, %arg1, %arg2 : ...)
373 /// if (%true, %true, %true)
374 /// retain (%arg0, %arg1 : memref<2xi32>, memref<2xi32>)
375 /// ```
376 /// becomes
377 /// ```mlir
378 /// %0:2 = bufferization.dealloc (%arg2 : memref<2xi32>) if (%true)
379 /// retain (%arg0, %arg1 : memref<2xi32>, memref<2xi32>)
380 /// // replace %0#0 with %true
381 /// // replace %0#1 with %true
382 /// ```
383 /// Note that the dealloc operation will still have the result values, but they
384 /// don't have uses anymore.
385 struct RetainedMemrefAliasingAlwaysDeallocatedMemref
386  : public OpRewritePattern<DeallocOp> {
387  RetainedMemrefAliasingAlwaysDeallocatedMemref(MLIRContext *context,
388  BufferOriginAnalysis &analysis)
389  : OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
390 
391  LogicalResult matchAndRewrite(DeallocOp deallocOp,
392  PatternRewriter &rewriter) const override {
393  BitVector aliasesWithConstTrueMemref(deallocOp.getRetained().size());
394  SmallVector<Value> newMemrefs, newConditions;
395  for (auto [memref, cond] :
396  llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
397  bool canDropMemref = false;
398  for (auto [i, retained, res] : llvm::enumerate(
399  deallocOp.getRetained(), deallocOp.getUpdatedConditions())) {
400  if (!matchPattern(cond, m_One()))
401  continue;
402 
403  std::optional<bool> analysisResult =
404  analysis.isSameAllocation(retained, memref);
405  if (analysisResult == true) {
406  rewriter.replaceAllUsesWith(res, cond);
407  aliasesWithConstTrueMemref[i] = true;
408  canDropMemref = true;
409  continue;
410  }
411 
412  // TODO: once our alias analysis is powerful enough we can remove the
413  // rest of this loop body
414  auto extractOp =
415  memref.getDefiningOp<memref::ExtractStridedMetadataOp>();
416  if (!extractOp)
417  continue;
418 
419  std::optional<bool> extractAnalysisResult =
420  analysis.isSameAllocation(retained, extractOp.getOperand());
421  if (extractAnalysisResult == true) {
422  rewriter.replaceAllUsesWith(res, cond);
423  aliasesWithConstTrueMemref[i] = true;
424  canDropMemref = true;
425  }
426  }
427 
428  if (!canDropMemref) {
429  newMemrefs.push_back(memref);
430  newConditions.push_back(cond);
431  }
432  }
433  if (!aliasesWithConstTrueMemref.all())
434  return failure();
435 
436  return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
437  rewriter);
438  }
439 
440 private:
441  BufferOriginAnalysis &analysis;
442 };
443 
444 } // namespace
445 
446 //===----------------------------------------------------------------------===//
447 // BufferDeallocationSimplificationPass
448 //===----------------------------------------------------------------------===//
449 
450 namespace {
451 
452 /// The actual buffer deallocation pass that inserts and moves dealloc nodes
453 /// into the right positions. Furthermore, it inserts additional clones if
454 /// necessary. It uses the algorithm described at the top of the file.
455 struct BufferDeallocationSimplificationPass
456  : public bufferization::impl::BufferDeallocationSimplificationBase<
457  BufferDeallocationSimplificationPass> {
458  void runOnOperation() override {
459  BufferOriginAnalysis analysis(getOperation());
460  RewritePatternSet patterns(&getContext());
461  patterns.add<RemoveDeallocMemrefsContainedInRetained,
462  RemoveRetainedMemrefsGuaranteedToNotAlias,
463  SplitDeallocWhenNotAliasingAnyOther,
464  RetainedMemrefAliasingAlwaysDeallocatedMemref>(&getContext(),
465  analysis);
467 
468  if (failed(
469  applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
470  signalPassFailure();
471  }
472 };
473 
474 } // namespace
475 
476 std::unique_ptr<Pass>
478  return std::make_unique<BufferDeallocationSimplificationPass>();
479 }
static bool potentiallyAliasesMemref(BufferOriginAnalysis &analysis, ValueRange otherList, Value memref)
Checks if memref may potentially alias a MemRef in otherList.
static Value getViewBase(Value value)
Given a memref value, return the "base" value by skipping over all ViewLikeOpInterface ops (if any) i...
static bool distinctAllocAndBlockArgument(Value v1, Value v2)
Return "true" if the given values are guaranteed to be different (and non-aliasing) allocations based...
static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp, ValueRange memrefs, ValueRange conditions, PatternRewriter &rewriter)
static MLIRContext * getContext(OpFoldResult val)
An is-same-buffer analysis that checks if two SSA values belong to the same buffer allocation or not.
std::optional< bool > isSameAllocation(Value v1, Value v2)
Return "true" if v1 and v2 originate from the same buffer allocation.
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:116
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:414
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:638
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:702
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
void populateDeallocOpCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context)
Add the canonicalization patterns for bufferization.dealloc to the given pattern set to make them ava...
std::unique_ptr< Pass > createBufferDeallocationSimplificationPass()
Creates a pass that optimizes bufferization.dealloc operations.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:389
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358