24 namespace bufferization {
25 #define GEN_PASS_DEF_BUFFERDEALLOCATIONSIMPLIFICATIONPASS
26 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
40 while (
auto viewLikeOp = value.
getDefiningOp<ViewLikeOpInterface>()) {
41 if (value != viewLikeOp.getViewDest()) {
44 value = viewLikeOp.getViewSource();
53 if (deallocOp.getMemrefs() == memrefs &&
54 deallocOp.getConditions() == conditions)
58 deallocOp.getMemrefsMutable().assign(memrefs);
59 deallocOp.getConditionsMutable().assign(conditions);
75 if (hasEffect<MemoryEffects::Allocate>(op, v1))
76 if (
auto bbArg = dyn_cast<BlockArgument>(v2))
77 if (bbArg.getOwner()->findAncestorOpInBlock(*op))
81 return areDistinct(v1Base, v2Base) || areDistinct(v2Base, v1Base);
89 for (
auto other : otherList) {
92 std::optional<bool> analysisResult =
93 analysis.isSameAllocation(other, memref);
94 if (!analysisResult.has_value() || analysisResult ==
true)
135 struct RemoveDeallocMemrefsContainedInRetained
137 RemoveDeallocMemrefsContainedInRetained(
MLIRContext *context,
147 LogicalResult handleOneMemref(DeallocOp deallocOp,
Value memref,
Value cond,
154 bool atLeastOneMustAlias =
false;
155 for (
Value retained : deallocOp.getRetained()) {
156 std::optional<bool> analysisResult =
157 analysis.isSameAllocation(retained, memref);
158 if (!analysisResult.has_value())
160 if (analysisResult ==
true)
161 atLeastOneMustAlias =
true;
163 if (!atLeastOneMustAlias)
170 Value updatedCondition = deallocOp.getUpdatedConditions()[i];
171 std::optional<bool> analysisResult =
172 analysis.isSameAllocation(retained, memref);
173 if (analysisResult ==
true) {
174 auto disjunction = arith::OrIOp::create(rewriter, deallocOp.getLoc(),
175 updatedCondition, cond);
184 LogicalResult matchAndRewrite(DeallocOp deallocOp,
189 deallocOp.getRetained().end());
190 if (retained.size() != deallocOp.getRetained().size())
194 for (
auto [memref, cond] :
195 llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
197 if (succeeded(handleOneMemref(deallocOp, memref, cond, rewriter)))
202 if (succeeded(handleOneMemref(deallocOp, extractOp.getOperand(), cond,
206 newMemrefs.push_back(memref);
207 newConditions.push_back(cond);
236 struct RemoveRetainedMemrefsGuaranteedToNotAlias
238 RemoveRetainedMemrefsGuaranteedToNotAlias(
MLIRContext *context,
242 LogicalResult matchAndRewrite(DeallocOp deallocOp,
246 for (
auto retainedMemref : deallocOp.getRetained()) {
249 newRetainedMemrefs.push_back(retainedMemref);
250 replacements.push_back({});
254 replacements.push_back(arith::ConstantOp::create(
255 rewriter, deallocOp.getLoc(), rewriter.
getBoolAttr(
false)));
258 if (newRetainedMemrefs.size() == deallocOp.getRetained().size())
262 DeallocOp::create(rewriter, deallocOp.getLoc(), deallocOp.getMemrefs(),
263 deallocOp.getConditions(), newRetainedMemrefs);
265 for (
auto &repl : replacements) {
267 repl = newDeallocOp.getUpdatedConditions()[i++];
270 rewriter.
replaceOp(deallocOp, replacements);
305 struct SplitDeallocWhenNotAliasingAnyOther
307 SplitDeallocWhenNotAliasingAnyOther(
MLIRContext *context,
311 LogicalResult matchAndRewrite(DeallocOp deallocOp,
314 if (deallocOp.getMemrefs().size() <= 1)
319 for (int64_t i = 0, e = deallocOp.getMemrefs().size(); i < e; ++i) {
320 Value memref = deallocOp.getMemrefs()[i];
321 Value cond = deallocOp.getConditions()[i];
323 otherMemrefs.erase(otherMemrefs.begin() + i);
327 remainingMemrefs.push_back(memref);
328 remainingConditions.push_back(cond);
333 auto newDeallocOp = DeallocOp::create(rewriter, loc, memref, cond,
334 deallocOp.getRetained());
335 updatedConditions.push_back(
336 llvm::to_vector(
ValueRange(newDeallocOp.getUpdatedConditions())));
340 if (remainingMemrefs.size() == deallocOp.getMemrefs().size())
345 DeallocOp::create(rewriter, loc, remainingMemrefs, remainingConditions,
346 deallocOp.getRetained());
350 llvm::to_vector(
ValueRange(newDeallocOp.getUpdatedConditions()));
351 for (
auto additionalConditions : updatedConditions) {
352 assert(replacements.size() == additionalConditions.size() &&
353 "expected same number of updated conditions");
354 for (int64_t i = 0, e = replacements.size(); i < e; ++i) {
355 replacements[i] = arith::OrIOp::create(rewriter, loc, replacements[i],
356 additionalConditions[i]);
359 rewriter.
replaceOp(deallocOp, replacements);
390 struct RetainedMemrefAliasingAlwaysDeallocatedMemref
392 RetainedMemrefAliasingAlwaysDeallocatedMemref(
MLIRContext *context,
396 LogicalResult matchAndRewrite(DeallocOp deallocOp,
398 BitVector aliasesWithConstTrueMemref(deallocOp.getRetained().size());
400 for (
auto [memref, cond] :
401 llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
402 bool canDropMemref =
false;
404 deallocOp.getRetained(), deallocOp.getUpdatedConditions())) {
408 std::optional<bool> analysisResult =
409 analysis.isSameAllocation(retained, memref);
410 if (analysisResult ==
true) {
412 aliasesWithConstTrueMemref[i] =
true;
413 canDropMemref =
true;
424 std::optional<bool> extractAnalysisResult =
425 analysis.isSameAllocation(retained, extractOp.getOperand());
426 if (extractAnalysisResult ==
true) {
428 aliasesWithConstTrueMemref[i] =
true;
429 canDropMemref =
true;
433 if (!canDropMemref) {
434 newMemrefs.push_back(memref);
435 newConditions.push_back(cond);
438 if (!aliasesWithConstTrueMemref.all())
460 struct BufferDeallocationSimplificationPass
461 :
public bufferization::impl::BufferDeallocationSimplificationPassBase<
462 BufferDeallocationSimplificationPass> {
463 void runOnOperation()
override {
466 patterns.add<RemoveDeallocMemrefsContainedInRetained,
467 RemoveRetainedMemrefsGuaranteedToNotAlias,
468 SplitDeallocWhenNotAliasingAnyOther,
469 RetainedMemrefAliasingAlwaysDeallocatedMemref>(&
getContext(),
477 getOperation(), std::move(
patterns),
static bool potentiallyAliasesMemref(BufferOriginAnalysis &analysis, ValueRange otherList, Value memref)
Checks if memref may potentially alias a MemRef in otherList.
static Value getViewBase(Value value)
Given a memref value, return the "base" value by skipping over all ViewLikeOpInterface ops (if any) i...
static bool distinctAllocAndBlockArgument(Value v1, Value v2)
Return "true" if the given values are guaranteed to be different (and non-aliasing) allocations based...
static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp, ValueRange memrefs, ValueRange conditions, PatternRewriter &rewriter)
static MLIRContext * getContext(OpFoldResult val)
An is-same-buffer analysis that checks if two SSA values belong to the same buffer allocation or not.
BoolAttr getBoolAttr(bool value)
This class allows control over how the GreedyPatternRewriteDriver works.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
void populateDeallocOpCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context)
Add the canonicalization patterns for bufferization.dealloc to the given pattern set to make them ava...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
@ Normal
Run the normal simplification (e.g. dead args elimination).
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...