24 namespace bufferization {
25 #define GEN_PASS_DEF_BUFFERDEALLOCATIONSIMPLIFICATION
26 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
41 if (deallocOp.getMemrefs() == memrefs &&
42 deallocOp.getConditions() == conditions)
46 deallocOp.getMemrefsMutable().assign(memrefs);
47 deallocOp.getConditionsMutable().assign(conditions);
55 while (
auto viewLikeOp = value.
getDefiningOp<ViewLikeOpInterface>())
56 value = viewLikeOp.getViewSource();
71 if (hasEffect<MemoryEffects::Allocate>(op, v1))
72 if (
auto bbArg = dyn_cast<BlockArgument>(v2))
73 if (bbArg.getOwner()->findAncestorOpInBlock(*op))
77 return areDistinct(v1Base, v2Base) || areDistinct(v2Base, v1Base);
88 bool allowSelfAlias) {
89 for (
auto other : otherList) {
90 if (allowSelfAlias && other == memref)
94 if (!analysis.
alias(other, memref).
isNo())
135 struct RemoveDeallocMemrefsContainedInRetained
137 RemoveDeallocMemrefsContainedInRetained(
MLIRContext *context,
154 bool atLeastOneMustAlias =
false;
155 for (
Value retained : deallocOp.getRetained()) {
156 AliasResult analysisResult = aliasAnalysis.alias(retained, memref);
157 if (analysisResult.
isMay())
160 atLeastOneMustAlias =
true;
162 if (!atLeastOneMustAlias)
169 Value updatedCondition = deallocOp.getUpdatedConditions()[i];
170 AliasResult analysisResult = aliasAnalysis.alias(retained, memref);
172 auto disjunction = rewriter.
create<arith::OrIOp>(
173 deallocOp.getLoc(), updatedCondition, cond);
187 deallocOp.getRetained().end());
188 if (retained.size() != deallocOp.getRetained().size())
192 for (
auto [memref, cond] :
193 llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
195 if (
succeeded(handleOneMemref(deallocOp, memref, cond, rewriter)))
200 if (
succeeded(handleOneMemref(deallocOp, extractOp.getOperand(), cond,
204 newMemrefs.push_back(memref);
205 newConditions.push_back(cond);
234 struct RemoveRetainedMemrefsGuaranteedToNotAlias
236 RemoveRetainedMemrefsGuaranteedToNotAlias(
MLIRContext *context,
244 auto getOrCreateFalse = [&]() ->
Value {
246 falseValue = rewriter.
create<arith::ConstantOp>(
251 for (
auto retainedMemref : deallocOp.getRetained()) {
253 retainedMemref,
false)) {
254 newRetainedMemrefs.push_back(retainedMemref);
255 replacements.push_back({});
259 replacements.push_back(getOrCreateFalse());
262 if (newRetainedMemrefs.size() == deallocOp.getRetained().size())
265 auto newDeallocOp = rewriter.
create<DeallocOp>(
266 deallocOp.getLoc(), deallocOp.getMemrefs(), deallocOp.getConditions(),
269 for (
auto &repl : replacements) {
271 repl = newDeallocOp.getUpdatedConditions()[i++];
274 rewriter.
replaceOp(deallocOp, replacements);
309 struct SplitDeallocWhenNotAliasingAnyOther
311 SplitDeallocWhenNotAliasingAnyOther(
MLIRContext *context,
317 if (deallocOp.getMemrefs().size() <= 1)
322 replacements = deallocOp.getUpdatedConditions();
323 for (
auto [memref, cond] :
324 llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
327 newMemrefs.push_back(memref);
328 newConditions.push_back(cond);
332 auto newDeallocOp = rewriter.
create<DeallocOp>(
333 deallocOp.getLoc(), memref, cond, deallocOp.getRetained());
335 llvm::zip(replacements, newDeallocOp.getUpdatedConditions()),
336 [&](
auto replAndNew) ->
Value {
337 auto orOp = rewriter.create<arith::OrIOp>(deallocOp.getLoc(),
338 std::get<0>(replAndNew),
339 std::get<1>(replAndNew));
340 exceptedUsers.insert(orOp);
341 return orOp.getResult();
345 if (newMemrefs.size() == deallocOp.getMemrefs().size())
350 return !exceptedUsers.contains(
384 struct RetainedMemrefAliasingAlwaysDeallocatedMemref
386 RetainedMemrefAliasingAlwaysDeallocatedMemref(
MLIRContext *context,
392 BitVector aliasesWithConstTrueMemref(deallocOp.getRetained().size());
394 for (
auto [memref, cond] :
395 llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
396 bool canDropMemref =
false;
398 deallocOp.getRetained(), deallocOp.getUpdatedConditions())) {
402 AliasResult analysisResult = aliasAnalysis.alias(retained, memref);
405 aliasesWithConstTrueMemref[i] =
true;
406 canDropMemref =
true;
418 aliasAnalysis.alias(retained, extractOp.getOperand());
419 if (extractAnalysisResult.
isMust() ||
422 aliasesWithConstTrueMemref[i] =
true;
423 canDropMemref =
true;
427 if (!canDropMemref) {
428 newMemrefs.push_back(memref);
429 newConditions.push_back(cond);
432 if (!aliasesWithConstTrueMemref.all())
454 struct BufferDeallocationSimplificationPass
455 :
public bufferization::impl::BufferDeallocationSimplificationBase<
456 BufferDeallocationSimplificationPass> {
457 void runOnOperation()
override {
460 patterns.add<RemoveDeallocMemrefsContainedInRetained,
461 RemoveRetainedMemrefsGuaranteedToNotAlias,
462 SplitDeallocWhenNotAliasingAnyOther,
463 RetainedMemrefAliasingAlwaysDeallocatedMemref>(&
getContext(),
475 std::unique_ptr<Pass>
477 return std::make_unique<BufferDeallocationSimplificationPass>();
static bool potentiallyAliasesMemref(AliasAnalysis &analysis, ValueRange otherList, Value memref, bool allowSelfAlias)
Checks if memref may or must alias a MemRef in otherList.
static Value getViewBase(Value value)
Given a memref value, return the "base" value by skipping over all ViewLikeOpInterface ops (if any) i...
static bool distinctAllocAndBlockArgument(Value v1, Value v2)
Return "true" if the given values are guaranteed to be different (and non-aliasing) allocations based...
static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp, ValueRange memrefs, ValueRange conditions, PatternRewriter &rewriter)
static MLIRContext * getContext(OpFoldResult val)
This class represents the main alias analysis interface in MLIR.
AliasResult alias(Value lhs, Value rhs)
Given two values, return their aliasing behavior.
The possible results of an alias query.
bool isPartial() const
Returns if this result is a partial alias.
bool isMay() const
Returns if this result is a may alias.
bool isMust() const
Returns if this result is a must alias.
bool isNo() const
Returns if this result indicates no possibility of aliasing.
BoolAttr getBoolAttr(bool value)
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor)
Find uses of from and replace them with to if the functor returns true.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
void populateDeallocOpCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context)
Add the canonicalization patterns for bufferization.dealloc to the given pattern set to make them ava...
std::unique_ptr< Pass > createBufferDeallocationSimplificationPass()
Creates a pass that optimizes bufferization.dealloc operations.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...