MLIR 22.0.0git
BufferDeallocationSimplification.cpp
Go to the documentation of this file.
1//===- BufferDeallocationSimplification.cpp -------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements logic for optimizing `bufferization.dealloc` operations
10// that requires more analysis than what can be supported by regular
11// canonicalization patterns.
12//
13//===----------------------------------------------------------------------===//
14
20#include "mlir/IR/Matchers.h"
22
23namespace mlir {
24namespace bufferization {
25#define GEN_PASS_DEF_BUFFERDEALLOCATIONSIMPLIFICATIONPASS
26#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
27} // namespace bufferization
28} // namespace mlir
29
30using namespace mlir;
31using namespace mlir::bufferization;
32
33//===----------------------------------------------------------------------===//
34// Helpers
35//===----------------------------------------------------------------------===//
37/// Given a memref value, return the "base" value by skipping over all
38/// ViewLikeOpInterface ops (if any) in the reverse use-def chain.
39static Value getViewBase(Value value) {
40 while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>()) {
41 if (value != viewLikeOp.getViewDest()) {
42 break;
43 }
44 value = viewLikeOp.getViewSource();
45 }
46 return value;
47}
49static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp,
50 ValueRange memrefs,
51 ValueRange conditions,
52 PatternRewriter &rewriter) {
53 if (deallocOp.getMemrefs() == memrefs &&
54 deallocOp.getConditions() == conditions)
55 return failure();
57 rewriter.modifyOpInPlace(deallocOp, [&]() {
58 deallocOp.getMemrefsMutable().assign(memrefs);
59 deallocOp.getConditionsMutable().assign(conditions);
60 });
61 return success();
63
64/// Return "true" if the given values are guaranteed to be different (and
65/// non-aliasing) allocations based on the fact that one value is the result
66/// of an allocation and the other value is a block argument of a parent block.
67/// Note: This is a best-effort analysis that will eventually be replaced by a
68/// proper "is same allocation" analysis. This function may return "false" even
69/// though the two values are distinct allocations.
71 Value v1Base = getViewBase(v1);
72 Value v2Base = getViewBase(v2);
73 auto areDistinct = [](Value v1, Value v2) {
74 if (Operation *op = v1.getDefiningOp())
76 if (auto bbArg = dyn_cast<BlockArgument>(v2))
77 if (bbArg.getOwner()->findAncestorOpInBlock(*op))
78 return true;
79 return false;
80 };
81 return areDistinct(v1Base, v2Base) || areDistinct(v2Base, v1Base);
82}
83
84/// Checks if `memref` may potentially alias a MemRef in `otherList`. It is
85/// often a requirement of optimization patterns that there cannot be any
86/// aliasing memref in order to perform the desired simplification.
88 ValueRange otherList, Value memref) {
89 for (auto other : otherList) {
91 continue;
92 std::optional<bool> analysisResult =
93 analysis.isSameAllocation(other, memref);
94 if (!analysisResult.has_value() || analysisResult == true)
95 return true;
96 }
97 return false;
98}
99
100//===----------------------------------------------------------------------===//
101// Patterns
102//===----------------------------------------------------------------------===//
103
104namespace {
105
106/// Remove values from the `memref` operand list that are also present in the
107/// `retained` list (or a guaranteed alias of it) because they will never
108/// actually be deallocated. However, we also need to be certain about which
109/// other memrefs in the `retained` list can alias, i.e., there must not by any
110/// may-aliasing memref. This is necessary because the `dealloc` operation is
111/// defined to return one `i1` value per memref in the `retained` list which
112/// represents the disjunction of the condition values corresponding to all
113/// aliasing values in the `memref` list. In particular, this means that if
114/// there is some value R in the `retained` list which aliases with a value M in
115/// the `memref` list (but can only be staticaly determined to may-alias) and M
116/// is also present in the `retained` list, then it would be illegal to remove M
117/// because the result corresponding to R would be computed incorrectly
118/// afterwards. Because we require an alias analysis, this pattern cannot be
119/// applied as a regular canonicalization pattern.
120///
121/// Example:
122/// ```mlir
123/// %0:3 = bufferization.dealloc (%m0 : ...) if (%cond0)
124/// retain (%m0, %r0, %r1 : ...)
125/// ```
126/// is canonicalized to
127/// ```mlir
128/// // bufferization.dealloc without memrefs and conditions returns %false for
129/// // every retained value
130/// %0:3 = bufferization.dealloc retain (%m0, %r0, %r1 : ...)
131/// %1 = arith.ori %0#0, %cond0 : i1
132/// // replace %0#0 with %1
133/// ```
134/// given that `%r0` and `%r1` may not alias with `%m0`.
135struct RemoveDeallocMemrefsContainedInRetained
136 : public OpRewritePattern<DeallocOp> {
137 RemoveDeallocMemrefsContainedInRetained(MLIRContext *context,
138 BufferOriginAnalysis &analysis)
139 : OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
140
141 /// The passed 'memref' must not have a may-alias relation to any retained
142 /// memref, and at least one must-alias relation. If there is no must-aliasing
143 /// memref in the retain list, we cannot simply remove the memref as there
144 /// could be situations in which it actually has to be deallocated. If it's
145 /// no-alias, then just proceed, if it's must-alias we need to update the
146 /// updated condition returned by the dealloc operation for that alias.
147 LogicalResult handleOneMemref(DeallocOp deallocOp, Value memref, Value cond,
148 PatternRewriter &rewriter) const {
149 rewriter.setInsertionPointAfter(deallocOp);
150
151 // Check that there is no may-aliasing memref and that at least one memref
152 // in the retain list aliases (because otherwise it might have to be
153 // deallocated in some situations and can thus not be dropped).
154 bool atLeastOneMustAlias = false;
155 for (Value retained : deallocOp.getRetained()) {
156 std::optional<bool> analysisResult =
157 analysis.isSameAllocation(retained, memref);
158 if (!analysisResult.has_value())
159 return failure();
160 if (analysisResult == true)
161 atLeastOneMustAlias = true;
162 }
163 if (!atLeastOneMustAlias)
164 return failure();
165
166 // Insert arith.ori operations to update the corresponding dealloc result
167 // values to incorporate the condition of the must-aliasing memref such that
168 // we can remove that operand later on.
169 for (auto [i, retained] : llvm::enumerate(deallocOp.getRetained())) {
170 Value updatedCondition = deallocOp.getUpdatedConditions()[i];
171 std::optional<bool> analysisResult =
172 analysis.isSameAllocation(retained, memref);
173 if (analysisResult == true) {
174 auto disjunction = arith::OrIOp::create(rewriter, deallocOp.getLoc(),
175 updatedCondition, cond);
176 rewriter.replaceAllUsesExcept(updatedCondition, disjunction.getResult(),
177 disjunction);
178 }
179 }
180
181 return success();
182 }
183
184 LogicalResult matchAndRewrite(DeallocOp deallocOp,
185 PatternRewriter &rewriter) const override {
186 // There must not be any duplicates in the retain list anymore because we
187 // would miss updating one of the result values otherwise.
188 DenseSet<Value> retained(deallocOp.getRetained().begin(),
189 deallocOp.getRetained().end());
190 if (retained.size() != deallocOp.getRetained().size())
191 return failure();
192
193 SmallVector<Value> newMemrefs, newConditions;
194 for (auto [memref, cond] :
195 llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
196
197 if (succeeded(handleOneMemref(deallocOp, memref, cond, rewriter)))
198 continue;
199
200 if (auto extractOp =
201 memref.getDefiningOp<memref::ExtractStridedMetadataOp>())
202 if (succeeded(handleOneMemref(deallocOp, extractOp.getOperand(), cond,
203 rewriter)))
204 continue;
205
206 newMemrefs.push_back(memref);
207 newConditions.push_back(cond);
208 }
209
210 // Return failure if we don't change anything such that we don't run into an
211 // infinite loop of pattern applications.
212 return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
213 rewriter);
214 }
215
216private:
217 BufferOriginAnalysis &analysis;
218};
219
220/// Remove memrefs from the `retained` list which are guaranteed to not alias
221/// any memref in the `memrefs` list. The corresponding result value can be
222/// replaced with `false` in that case according to the operation description.
223///
224/// Example:
225/// ```mlir
226/// %0:2 = bufferization.dealloc (%m : memref<2xi32>) if (%cond)
227/// retain (%r0, %r1 : memref<2xi32>, memref<2xi32>)
228/// return %0#0, %0#1
229/// ```
230/// can be canonicalized to the following given that `%r0` and `%r1` do not
231/// alias `%m`:
232/// ```mlir
233/// bufferization.dealloc (%m : memref<2xi32>) if (%cond)
234/// return %false, %false
235/// ```
236struct RemoveRetainedMemrefsGuaranteedToNotAlias
237 : public OpRewritePattern<DeallocOp> {
238 RemoveRetainedMemrefsGuaranteedToNotAlias(MLIRContext *context,
239 BufferOriginAnalysis &analysis)
240 : OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
241
242 LogicalResult matchAndRewrite(DeallocOp deallocOp,
243 PatternRewriter &rewriter) const override {
244 SmallVector<Value> newRetainedMemrefs, replacements;
245
246 for (auto retainedMemref : deallocOp.getRetained()) {
247 if (potentiallyAliasesMemref(analysis, deallocOp.getMemrefs(),
248 retainedMemref)) {
249 newRetainedMemrefs.push_back(retainedMemref);
250 replacements.push_back({});
251 continue;
252 }
253
254 replacements.push_back(arith::ConstantOp::create(
255 rewriter, deallocOp.getLoc(), rewriter.getBoolAttr(false)));
256 }
257
258 if (newRetainedMemrefs.size() == deallocOp.getRetained().size())
259 return failure();
260
261 auto newDeallocOp =
262 DeallocOp::create(rewriter, deallocOp.getLoc(), deallocOp.getMemrefs(),
263 deallocOp.getConditions(), newRetainedMemrefs);
264 int i = 0;
265 for (auto &repl : replacements) {
266 if (!repl)
267 repl = newDeallocOp.getUpdatedConditions()[i++];
268 }
269
270 rewriter.replaceOp(deallocOp, replacements);
271 return success();
272 }
273
274private:
275 BufferOriginAnalysis &analysis;
276};
277
278/// Split off memrefs to separate dealloc operations to reduce the number of
279/// runtime checks required and enable further canonicalization of the new and
280/// simpler dealloc operations. A memref can be split off if it is guaranteed to
281/// not alias with any other memref in the `memref` operand list. The results
282/// of the old and the new dealloc operation have to be combined by computing
283/// the element-wise disjunction of them.
284///
285/// Example:
286/// ```mlir
287/// %0:2 = bufferization.dealloc (%m0, %m1 : memref<2xi32>, memref<2xi32>)
288/// if (%cond0, %cond1)
289/// retain (%r0, %r1 : memref<2xi32>, memref<2xi32>)
290/// return %0#0, %0#1
291/// ```
292/// Given that `%m0` is guaranteed to never alias with `%m1`, the above IR is
293/// canonicalized to the following, thus reducing the number of runtime alias
294/// checks by 1 and potentially enabling further canonicalization of the new
295/// split-up dealloc operations.
296/// ```mlir
297/// %0:2 = bufferization.dealloc (%m0 : memref<2xi32>) if (%cond0)
298/// retain (%r0, %r1 : memref<2xi32>, memref<2xi32>)
299/// %1:2 = bufferization.dealloc (%m1 : memref<2xi32>) if (%cond1)
300/// retain (%r0, %r1 : memref<2xi32>, memref<2xi32>)
301/// %2 = arith.ori %0#0, %1#0
302/// %3 = arith.ori %0#1, %1#1
303/// return %2, %3
304/// ```
305struct SplitDeallocWhenNotAliasingAnyOther
306 : public OpRewritePattern<DeallocOp> {
307 SplitDeallocWhenNotAliasingAnyOther(MLIRContext *context,
308 BufferOriginAnalysis &analysis)
309 : OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
310
311 LogicalResult matchAndRewrite(DeallocOp deallocOp,
312 PatternRewriter &rewriter) const override {
313 Location loc = deallocOp.getLoc();
314 if (deallocOp.getMemrefs().size() <= 1)
315 return failure();
316
317 SmallVector<Value> remainingMemrefs, remainingConditions;
318 SmallVector<SmallVector<Value>> updatedConditions;
319 for (int64_t i = 0, e = deallocOp.getMemrefs().size(); i < e; ++i) {
320 Value memref = deallocOp.getMemrefs()[i];
321 Value cond = deallocOp.getConditions()[i];
322 SmallVector<Value> otherMemrefs(deallocOp.getMemrefs());
323 otherMemrefs.erase(otherMemrefs.begin() + i);
324 // Check if `memref` can split off into a separate bufferization.dealloc.
325 if (potentiallyAliasesMemref(analysis, otherMemrefs, memref)) {
326 // `memref` alias with other memrefs, do not split off.
327 remainingMemrefs.push_back(memref);
328 remainingConditions.push_back(cond);
329 continue;
330 }
331
332 // Create new bufferization.dealloc op for `memref`.
333 auto newDeallocOp = DeallocOp::create(rewriter, loc, memref, cond,
334 deallocOp.getRetained());
335 updatedConditions.push_back(
336 llvm::to_vector(ValueRange(newDeallocOp.getUpdatedConditions())));
337 }
338
339 // Fail if no memref was split off.
340 if (remainingMemrefs.size() == deallocOp.getMemrefs().size())
341 return failure();
342
343 // Create bufferization.dealloc op for all remaining memrefs.
344 auto newDeallocOp =
345 DeallocOp::create(rewriter, loc, remainingMemrefs, remainingConditions,
346 deallocOp.getRetained());
347
348 // Bit-or all conditions.
349 SmallVector<Value> replacements =
350 llvm::to_vector(ValueRange(newDeallocOp.getUpdatedConditions()));
351 for (auto additionalConditions : updatedConditions) {
352 assert(replacements.size() == additionalConditions.size() &&
353 "expected same number of updated conditions");
354 for (int64_t i = 0, e = replacements.size(); i < e; ++i) {
355 replacements[i] = arith::OrIOp::create(rewriter, loc, replacements[i],
356 additionalConditions[i]);
357 }
358 }
359 rewriter.replaceOp(deallocOp, replacements);
360 return success();
361 }
362
363private:
364 BufferOriginAnalysis &analysis;
365};
366
367/// Check for every retained memref if a must-aliasing memref exists in the
368/// 'memref' operand list with constant 'true' condition. If so, we can replace
369/// the operation result corresponding to that retained memref with 'true'. If
370/// this condition holds for all retained memrefs we can also remove the
371/// aliasing memrefs and their conditions since they will never be deallocated
372/// due to the must-alias and we don't need them to compute the result value
373/// anymore since it got replaced with 'true'.
374///
375/// Example:
376/// ```mlir
377/// %0:2 = bufferization.dealloc (%arg0, %arg1, %arg2 : ...)
378/// if (%true, %true, %true)
379/// retain (%arg0, %arg1 : memref<2xi32>, memref<2xi32>)
380/// ```
381/// becomes
382/// ```mlir
383/// %0:2 = bufferization.dealloc (%arg2 : memref<2xi32>) if (%true)
384/// retain (%arg0, %arg1 : memref<2xi32>, memref<2xi32>)
385/// // replace %0#0 with %true
386/// // replace %0#1 with %true
387/// ```
388/// Note that the dealloc operation will still have the result values, but they
389/// don't have uses anymore.
390struct RetainedMemrefAliasingAlwaysDeallocatedMemref
391 : public OpRewritePattern<DeallocOp> {
392 RetainedMemrefAliasingAlwaysDeallocatedMemref(MLIRContext *context,
393 BufferOriginAnalysis &analysis)
394 : OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
395
396 LogicalResult matchAndRewrite(DeallocOp deallocOp,
397 PatternRewriter &rewriter) const override {
398 BitVector aliasesWithConstTrueMemref(deallocOp.getRetained().size());
399 SmallVector<Value> newMemrefs, newConditions;
400 for (auto [memref, cond] :
401 llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
402 bool canDropMemref = false;
403 for (auto [i, retained, res] : llvm::enumerate(
404 deallocOp.getRetained(), deallocOp.getUpdatedConditions())) {
405 if (!matchPattern(cond, m_One()))
406 continue;
407
408 std::optional<bool> analysisResult =
409 analysis.isSameAllocation(retained, memref);
410 if (analysisResult == true) {
411 rewriter.replaceAllUsesWith(res, cond);
412 aliasesWithConstTrueMemref[i] = true;
413 canDropMemref = true;
414 continue;
415 }
416
417 // TODO: once our alias analysis is powerful enough we can remove the
418 // rest of this loop body
419 auto extractOp =
420 memref.getDefiningOp<memref::ExtractStridedMetadataOp>();
421 if (!extractOp)
422 continue;
423
424 std::optional<bool> extractAnalysisResult =
425 analysis.isSameAllocation(retained, extractOp.getOperand());
426 if (extractAnalysisResult == true) {
427 rewriter.replaceAllUsesWith(res, cond);
428 aliasesWithConstTrueMemref[i] = true;
429 canDropMemref = true;
430 }
431 }
432
433 if (!canDropMemref) {
434 newMemrefs.push_back(memref);
435 newConditions.push_back(cond);
436 }
437 }
438 if (!aliasesWithConstTrueMemref.all())
439 return failure();
440
441 return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
442 rewriter);
443 }
444
445private:
446 BufferOriginAnalysis &analysis;
447};
448
449} // namespace
450
451//===----------------------------------------------------------------------===//
452// BufferDeallocationSimplificationPass
453//===----------------------------------------------------------------------===//
454
455namespace {
456
457/// The actual buffer deallocation pass that inserts and moves dealloc nodes
458/// into the right positions. Furthermore, it inserts additional clones if
459/// necessary. It uses the algorithm described at the top of the file.
460struct BufferDeallocationSimplificationPass
462 BufferDeallocationSimplificationPass> {
463 void runOnOperation() override {
464 BufferOriginAnalysis analysis(getOperation());
465 RewritePatternSet patterns(&getContext());
466 patterns.add<RemoveDeallocMemrefsContainedInRetained,
467 RemoveRetainedMemrefsGuaranteedToNotAlias,
468 SplitDeallocWhenNotAliasingAnyOther,
469 RetainedMemrefAliasingAlwaysDeallocatedMemref>(&getContext(),
470 analysis);
471
473 // We don't want that the block structure changes invalidating the
474 // `BufferOriginAnalysis` so we apply the rewrites with `Normal` level of
475 // region simplification
477 getOperation(), std::move(patterns),
478 GreedyRewriteConfig().setRegionSimplificationLevel(
479 GreedySimplifyRegionLevel::Normal))))
480 signalPassFailure();
481 }
482};
483
484} // namespace
return success()
static bool potentiallyAliasesMemref(BufferOriginAnalysis &analysis, ValueRange otherList, Value memref)
Checks if memref may potentially alias a MemRef in otherList.
static Value getViewBase(Value value)
Given a memref value, return the "base" value by skipping over all ViewLikeOpInterface ops (if any) i...
static bool distinctAllocAndBlockArgument(Value v1, Value v2)
Return "true" if the given values are guaranteed to be different (and non-aliasing) allocations based...
static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp, ValueRange memrefs, ValueRange conditions, PatternRewriter &rewriter)
b getContext())
An is-same-buffer analysis that checks if two SSA values belong to the same buffer allocation or not.
BoolAttr getBoolAttr(bool value)
Definition Builders.cpp:100
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
void populateDeallocOpCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context)
Add the canonicalization patterns for bufferization.dealloc to the given pattern set to make them ava...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
detail::InFlightRemark analysis(Location loc, RemarkOpts opts)
Report an optimization analysis remark.
Definition Remarks.h:567
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:128
const FrozenRewritePatternSet & patterns
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition Matchers.h:478
bool hasEffect(Operation *op)
Returns "true" if op has an effect of type EffectTy.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...