25#define GEN_PASS_DEF_BUFFERDEALLOCATIONSIMPLIFICATIONPASS
26#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
40 while (
auto viewLikeOp = value.
getDefiningOp<ViewLikeOpInterface>()) {
41 if (value != viewLikeOp.getViewDest()) {
44 value = viewLikeOp.getViewSource();
53 if (deallocOp.getMemrefs() == memrefs &&
54 deallocOp.getConditions() == conditions)
58 deallocOp.getMemrefsMutable().assign(memrefs);
59 deallocOp.getConditionsMutable().assign(conditions);
76 if (
auto bbArg = dyn_cast<BlockArgument>(v2))
77 if (bbArg.getOwner()->findAncestorOpInBlock(*op))
81 return areDistinct(v1Base, v2Base) || areDistinct(v2Base, v1Base);
89 for (
auto other : otherList) {
92 std::optional<bool> analysisResult =
93 analysis.isSameAllocation(other,
memref);
94 if (!analysisResult.has_value() || analysisResult ==
true)
135struct RemoveDeallocMemrefsContainedInRetained
137 RemoveDeallocMemrefsContainedInRetained(MLIRContext *context,
138 BufferOriginAnalysis &analysis)
139 : OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
147 LogicalResult handleOneMemref(DeallocOp deallocOp, Value memref, Value cond,
148 PatternRewriter &rewriter)
const {
154 bool atLeastOneMustAlias =
false;
155 for (Value retained : deallocOp.getRetained()) {
156 std::optional<bool> analysisResult =
157 analysis.isSameAllocation(retained, memref);
158 if (!analysisResult.has_value())
160 if (analysisResult ==
true)
161 atLeastOneMustAlias =
true;
163 if (!atLeastOneMustAlias)
169 for (
auto [i, retained] : llvm::enumerate(deallocOp.getRetained())) {
170 Value updatedCondition = deallocOp.getUpdatedConditions()[i];
171 std::optional<bool> analysisResult =
172 analysis.isSameAllocation(retained, memref);
173 if (analysisResult ==
true) {
174 auto disjunction = arith::OrIOp::create(rewriter, deallocOp.getLoc(),
175 updatedCondition, cond);
184 LogicalResult matchAndRewrite(DeallocOp deallocOp,
185 PatternRewriter &rewriter)
const override {
189 deallocOp.getRetained().end());
190 if (retained.size() != deallocOp.getRetained().size())
193 SmallVector<Value> newMemrefs, newConditions;
194 for (
auto [memref, cond] :
195 llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
197 if (succeeded(handleOneMemref(deallocOp, memref, cond, rewriter)))
202 if (succeeded(handleOneMemref(deallocOp, extractOp.getOperand(), cond,
206 newMemrefs.push_back(memref);
207 newConditions.push_back(cond);
217 BufferOriginAnalysis &analysis;
236struct RemoveRetainedMemrefsGuaranteedToNotAlias
238 RemoveRetainedMemrefsGuaranteedToNotAlias(MLIRContext *context,
239 BufferOriginAnalysis &analysis)
240 : OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
242 LogicalResult matchAndRewrite(DeallocOp deallocOp,
243 PatternRewriter &rewriter)
const override {
244 SmallVector<Value> newRetainedMemrefs, replacements;
246 for (
auto retainedMemref : deallocOp.getRetained()) {
249 newRetainedMemrefs.push_back(retainedMemref);
250 replacements.push_back({});
254 replacements.push_back(arith::ConstantOp::create(
255 rewriter, deallocOp.getLoc(), rewriter.
getBoolAttr(
false)));
258 if (newRetainedMemrefs.size() == deallocOp.getRetained().size())
262 DeallocOp::create(rewriter, deallocOp.getLoc(), deallocOp.getMemrefs(),
263 deallocOp.getConditions(), newRetainedMemrefs);
265 for (
auto &repl : replacements) {
267 repl = newDeallocOp.getUpdatedConditions()[i++];
270 rewriter.
replaceOp(deallocOp, replacements);
275 BufferOriginAnalysis &analysis;
305struct SplitDeallocWhenNotAliasingAnyOther
307 SplitDeallocWhenNotAliasingAnyOther(MLIRContext *context,
308 BufferOriginAnalysis &analysis)
309 : OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
311 LogicalResult matchAndRewrite(DeallocOp deallocOp,
312 PatternRewriter &rewriter)
const override {
313 Location loc = deallocOp.getLoc();
314 if (deallocOp.getMemrefs().size() <= 1)
317 SmallVector<Value> remainingMemrefs, remainingConditions;
318 SmallVector<SmallVector<Value>> updatedConditions;
319 for (int64_t i = 0, e = deallocOp.getMemrefs().size(); i < e; ++i) {
320 Value memref = deallocOp.getMemrefs()[i];
321 Value cond = deallocOp.getConditions()[i];
322 SmallVector<Value> otherMemrefs(deallocOp.getMemrefs());
323 otherMemrefs.erase(otherMemrefs.begin() + i);
327 remainingMemrefs.push_back(memref);
328 remainingConditions.push_back(cond);
333 auto newDeallocOp = DeallocOp::create(rewriter, loc, memref, cond,
334 deallocOp.getRetained());
335 updatedConditions.push_back(
336 llvm::to_vector(
ValueRange(newDeallocOp.getUpdatedConditions())));
340 if (remainingMemrefs.size() == deallocOp.getMemrefs().size())
345 DeallocOp::create(rewriter, loc, remainingMemrefs, remainingConditions,
346 deallocOp.getRetained());
349 SmallVector<Value> replacements =
350 llvm::to_vector(
ValueRange(newDeallocOp.getUpdatedConditions()));
351 for (
auto additionalConditions : updatedConditions) {
352 assert(replacements.size() == additionalConditions.size() &&
353 "expected same number of updated conditions");
354 for (int64_t i = 0, e = replacements.size(); i < e; ++i) {
355 replacements[i] = arith::OrIOp::create(rewriter, loc, replacements[i],
356 additionalConditions[i]);
359 rewriter.
replaceOp(deallocOp, replacements);
364 BufferOriginAnalysis &analysis;
390struct RetainedMemrefAliasingAlwaysDeallocatedMemref
392 RetainedMemrefAliasingAlwaysDeallocatedMemref(MLIRContext *context,
393 BufferOriginAnalysis &analysis)
394 : OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
396 LogicalResult matchAndRewrite(DeallocOp deallocOp,
397 PatternRewriter &rewriter)
const override {
398 BitVector aliasesWithConstTrueMemref(deallocOp.getRetained().size());
399 SmallVector<Value> newMemrefs, newConditions;
400 for (
auto [memref, cond] :
401 llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
402 bool canDropMemref =
false;
403 for (
auto [i, retained, res] : llvm::enumerate(
404 deallocOp.getRetained(), deallocOp.getUpdatedConditions())) {
408 std::optional<bool> analysisResult =
409 analysis.isSameAllocation(retained, memref);
410 if (analysisResult ==
true) {
412 aliasesWithConstTrueMemref[i] =
true;
413 canDropMemref =
true;
424 std::optional<bool> extractAnalysisResult =
425 analysis.isSameAllocation(retained, extractOp.getOperand());
426 if (extractAnalysisResult ==
true) {
428 aliasesWithConstTrueMemref[i] =
true;
429 canDropMemref =
true;
433 if (!canDropMemref) {
434 newMemrefs.push_back(memref);
435 newConditions.push_back(cond);
438 if (!aliasesWithConstTrueMemref.all())
446 BufferOriginAnalysis &analysis;
460struct BufferDeallocationSimplificationPass
461 :
public bufferization::impl::BufferDeallocationSimplificationPassBase<
462 BufferDeallocationSimplificationPass> {
463 void runOnOperation()
override {
464 BufferOriginAnalysis
analysis(getOperation());
466 patterns.add<RemoveDeallocMemrefsContainedInRetained,
467 RemoveRetainedMemrefsGuaranteedToNotAlias,
468 SplitDeallocWhenNotAliasingAnyOther,
469 RetainedMemrefAliasingAlwaysDeallocatedMemref>(&
getContext(),
477 getOperation(), std::move(
patterns),
478 GreedyRewriteConfig().setRegionSimplificationLevel(
479 GreedySimplifyRegionLevel::Normal))))
static bool potentiallyAliasesMemref(BufferOriginAnalysis &analysis, ValueRange otherList, Value memref)
Checks if memref may potentially alias a MemRef in otherList.
static Value getViewBase(Value value)
Given a memref value, return the "base" value by skipping over all ViewLikeOpInterface ops (if any) i...
static bool distinctAllocAndBlockArgument(Value v1, Value v2)
Return "true" if the given values are guaranteed to be different (and non-aliasing) allocations based...
static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp, ValueRange memrefs, ValueRange conditions, PatternRewriter &rewriter)
An is-same-buffer analysis that checks if two SSA values belong to the same buffer allocation or not.
BoolAttr getBoolAttr(bool value)
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
void populateDeallocOpCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context)
Add the canonicalization patterns for bufferization.dealloc to the given pattern set to make them ava...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
const FrozenRewritePatternSet & patterns
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
bool hasEffect(Operation *op)
Returns "true" if op has an effect of type EffectTy.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...