24 namespace bufferization {
25 #define GEN_PASS_DEF_BUFFERDEALLOCATIONSIMPLIFICATION
26 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
40 while (
auto viewLikeOp = value.
getDefiningOp<ViewLikeOpInterface>())
41 value = viewLikeOp.getViewSource();
49 if (deallocOp.getMemrefs() == memrefs &&
50 deallocOp.getConditions() == conditions)
54 deallocOp.getMemrefsMutable().assign(memrefs);
55 deallocOp.getConditionsMutable().assign(conditions);
71 if (hasEffect<MemoryEffects::Allocate>(op, v1))
72 if (
auto bbArg = dyn_cast<BlockArgument>(v2))
73 if (bbArg.getOwner()->findAncestorOpInBlock(*op))
77 return areDistinct(v1Base, v2Base) || areDistinct(v2Base, v1Base);
85 for (
auto other : otherList) {
88 std::optional<bool> analysisResult =
90 if (!analysisResult.has_value() || analysisResult ==
true)
131 struct RemoveDeallocMemrefsContainedInRetained
133 RemoveDeallocMemrefsContainedInRetained(
MLIRContext *context,
143 LogicalResult handleOneMemref(DeallocOp deallocOp,
Value memref,
Value cond,
150 bool atLeastOneMustAlias =
false;
151 for (
Value retained : deallocOp.getRetained()) {
152 std::optional<bool> analysisResult =
153 analysis.isSameAllocation(retained, memref);
154 if (!analysisResult.has_value())
156 if (analysisResult ==
true)
157 atLeastOneMustAlias =
true;
159 if (!atLeastOneMustAlias)
166 Value updatedCondition = deallocOp.getUpdatedConditions()[i];
167 std::optional<bool> analysisResult =
168 analysis.isSameAllocation(retained, memref);
169 if (analysisResult ==
true) {
170 auto disjunction = rewriter.
create<arith::OrIOp>(
171 deallocOp.getLoc(), updatedCondition, cond);
180 LogicalResult matchAndRewrite(DeallocOp deallocOp,
185 deallocOp.getRetained().end());
186 if (retained.size() != deallocOp.getRetained().size())
190 for (
auto [memref, cond] :
191 llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
193 if (succeeded(handleOneMemref(deallocOp, memref, cond, rewriter)))
198 if (succeeded(handleOneMemref(deallocOp, extractOp.getOperand(), cond,
202 newMemrefs.push_back(memref);
203 newConditions.push_back(cond);
232 struct RemoveRetainedMemrefsGuaranteedToNotAlias
234 RemoveRetainedMemrefsGuaranteedToNotAlias(
MLIRContext *context,
238 LogicalResult matchAndRewrite(DeallocOp deallocOp,
242 for (
auto retainedMemref : deallocOp.getRetained()) {
245 newRetainedMemrefs.push_back(retainedMemref);
246 replacements.push_back({});
250 replacements.push_back(rewriter.
create<arith::ConstantOp>(
254 if (newRetainedMemrefs.size() == deallocOp.getRetained().size())
257 auto newDeallocOp = rewriter.
create<DeallocOp>(
258 deallocOp.getLoc(), deallocOp.getMemrefs(), deallocOp.getConditions(),
261 for (
auto &repl : replacements) {
263 repl = newDeallocOp.getUpdatedConditions()[i++];
266 rewriter.
replaceOp(deallocOp, replacements);
301 struct SplitDeallocWhenNotAliasingAnyOther
303 SplitDeallocWhenNotAliasingAnyOther(
MLIRContext *context,
307 LogicalResult matchAndRewrite(DeallocOp deallocOp,
310 if (deallocOp.getMemrefs().size() <= 1)
315 for (int64_t i = 0, e = deallocOp.getMemrefs().size(); i < e; ++i) {
316 Value memref = deallocOp.getMemrefs()[i];
317 Value cond = deallocOp.getConditions()[i];
319 otherMemrefs.erase(otherMemrefs.begin() + i);
323 remainingMemrefs.push_back(memref);
324 remainingConditions.push_back(cond);
329 auto newDeallocOp = rewriter.
create<DeallocOp>(loc, memref, cond,
330 deallocOp.getRetained());
331 updatedConditions.push_back(
332 llvm::to_vector(
ValueRange(newDeallocOp.getUpdatedConditions())));
336 if (remainingMemrefs.size() == deallocOp.getMemrefs().size())
340 auto newDeallocOp = rewriter.
create<DeallocOp>(
341 loc, remainingMemrefs, remainingConditions, deallocOp.getRetained());
345 llvm::to_vector(
ValueRange(newDeallocOp.getUpdatedConditions()));
346 for (
auto additionalConditions : updatedConditions) {
347 assert(replacements.size() == additionalConditions.size() &&
348 "expected same number of updated conditions");
349 for (int64_t i = 0, e = replacements.size(); i < e; ++i) {
350 replacements[i] = rewriter.
create<arith::OrIOp>(
351 loc, replacements[i], additionalConditions[i]);
354 rewriter.
replaceOp(deallocOp, replacements);
385 struct RetainedMemrefAliasingAlwaysDeallocatedMemref
387 RetainedMemrefAliasingAlwaysDeallocatedMemref(
MLIRContext *context,
391 LogicalResult matchAndRewrite(DeallocOp deallocOp,
393 BitVector aliasesWithConstTrueMemref(deallocOp.getRetained().size());
395 for (
auto [memref, cond] :
396 llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
397 bool canDropMemref =
false;
399 deallocOp.getRetained(), deallocOp.getUpdatedConditions())) {
403 std::optional<bool> analysisResult =
404 analysis.isSameAllocation(retained, memref);
405 if (analysisResult ==
true) {
407 aliasesWithConstTrueMemref[i] =
true;
408 canDropMemref =
true;
415 memref.getDefiningOp<memref::ExtractStridedMetadataOp>();
419 std::optional<bool> extractAnalysisResult =
420 analysis.isSameAllocation(retained, extractOp.getOperand());
421 if (extractAnalysisResult ==
true) {
423 aliasesWithConstTrueMemref[i] =
true;
424 canDropMemref =
true;
428 if (!canDropMemref) {
429 newMemrefs.push_back(memref);
430 newConditions.push_back(cond);
433 if (!aliasesWithConstTrueMemref.all())
455 struct BufferDeallocationSimplificationPass
456 :
public bufferization::impl::BufferDeallocationSimplificationBase<
457 BufferDeallocationSimplificationPass> {
458 void runOnOperation()
override {
461 patterns.add<RemoveDeallocMemrefsContainedInRetained,
462 RemoveRetainedMemrefsGuaranteedToNotAlias,
463 SplitDeallocWhenNotAliasingAnyOther,
464 RetainedMemrefAliasingAlwaysDeallocatedMemref>(&
getContext(),
481 std::unique_ptr<Pass>
483 return std::make_unique<BufferDeallocationSimplificationPass>();
static bool potentiallyAliasesMemref(BufferOriginAnalysis &analysis, ValueRange otherList, Value memref)
Checks if memref may potentially alias a MemRef in otherList.
static Value getViewBase(Value value)
Given a memref value, return the "base" value by skipping over all ViewLikeOpInterface ops (if any) i...
static bool distinctAllocAndBlockArgument(Value v1, Value v2)
Return "true" if the given values are guaranteed to be different (and non-aliasing) allocations based...
static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp, ValueRange memrefs, ValueRange conditions, PatternRewriter &rewriter)
static MLIRContext * getContext(OpFoldResult val)
An is-same-buffer analysis that checks if two SSA values belong to the same buffer allocation or not.
std::optional< bool > isSameAllocation(Value v1, Value v2)
Return "true" if v1 and v2 originate from the same buffer allocation.
BoolAttr getBoolAttr(bool value)
This class allows control over how the GreedyPatternRewriteDriver works.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
void populateDeallocOpCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context)
Add the canonicalization patterns for bufferization.dealloc to the given pattern set to make them ava...
std::unique_ptr< Pass > createBufferDeallocationSimplificationPass()
Creates a pass that optimizes bufferization.dealloc operations.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
@ Normal
Run the normal simplification (e.g. dead args elimination).
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...