89 for (
auto other : otherList) {
92 std::optional<bool> analysisResult =
93 analysis.isSameAllocation(other,
memref);
94 if (!analysisResult.has_value() || analysisResult ==
true)
135struct RemoveDeallocMemrefsContainedInRetained
137 RemoveDeallocMemrefsContainedInRetained(
MLIRContext *context,
147 LogicalResult handleOneMemref(DeallocOp deallocOp, Value memref, Value cond,
148 PatternRewriter &rewriter)
const {
154 bool atLeastOneMustAlias =
false;
155 for (Value retained : deallocOp.getRetained()) {
156 std::optional<bool> analysisResult =
157 analysis.isSameAllocation(retained, memref);
158 if (!analysisResult.has_value())
160 if (analysisResult ==
true)
161 atLeastOneMustAlias =
true;
163 if (!atLeastOneMustAlias)
169 for (
auto [i, retained] : llvm::enumerate(deallocOp.getRetained())) {
170 Value updatedCondition = deallocOp.getUpdatedConditions()[i];
171 std::optional<bool> analysisResult =
172 analysis.isSameAllocation(retained, memref);
173 if (analysisResult ==
true) {
174 auto disjunction = arith::OrIOp::create(rewriter, deallocOp.getLoc(),
175 updatedCondition, cond);
184 LogicalResult matchAndRewrite(DeallocOp deallocOp,
185 PatternRewriter &rewriter)
const override {
189 deallocOp.getRetained().end());
190 if (retained.size() != deallocOp.getRetained().size())
193 SmallVector<Value> newMemrefs, newConditions;
194 for (
auto [memref, cond] :
195 llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
197 if (succeeded(handleOneMemref(deallocOp, memref, cond, rewriter)))
202 if (succeeded(handleOneMemref(deallocOp, extractOp.getOperand(), cond,
206 newMemrefs.push_back(memref);
207 newConditions.push_back(cond);
217 BufferOriginAnalysis &analysis;
236struct RemoveRetainedMemrefsGuaranteedToNotAlias
238 RemoveRetainedMemrefsGuaranteedToNotAlias(MLIRContext *context,
239 BufferOriginAnalysis &analysis)
240 : OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
242 LogicalResult matchAndRewrite(DeallocOp deallocOp,
243 PatternRewriter &rewriter)
const override {
244 SmallVector<Value> newRetainedMemrefs, replacements;
246 for (
auto retainedMemref : deallocOp.getRetained()) {
249 newRetainedMemrefs.push_back(retainedMemref);
250 replacements.push_back({});
254 replacements.push_back(arith::ConstantOp::create(
255 rewriter, deallocOp.getLoc(), rewriter.
getBoolAttr(
false)));
258 if (newRetainedMemrefs.size() == deallocOp.getRetained().size())
262 DeallocOp::create(rewriter, deallocOp.getLoc(), deallocOp.getMemrefs(),
263 deallocOp.getConditions(), newRetainedMemrefs);
265 for (
auto &repl : replacements) {
267 repl = newDeallocOp.getUpdatedConditions()[i++];
270 rewriter.
replaceOp(deallocOp, replacements);
275 BufferOriginAnalysis &analysis;
305struct SplitDeallocWhenNotAliasingAnyOther
307 SplitDeallocWhenNotAliasingAnyOther(MLIRContext *context,
308 BufferOriginAnalysis &analysis)
309 : OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
311 LogicalResult matchAndRewrite(DeallocOp deallocOp,
312 PatternRewriter &rewriter)
const override {
313 Location loc = deallocOp.getLoc();
314 if (deallocOp.getMemrefs().size() <= 1)
317 SmallVector<Value> remainingMemrefs, remainingConditions;
318 SmallVector<SmallVector<Value>> updatedConditions;
319 for (int64_t i = 0, e = deallocOp.getMemrefs().size(); i < e; ++i) {
320 Value memref = deallocOp.getMemrefs()[i];
321 Value cond = deallocOp.getConditions()[i];
322 SmallVector<Value> otherMemrefs(deallocOp.getMemrefs());
323 otherMemrefs.erase(otherMemrefs.begin() + i);
327 remainingMemrefs.push_back(memref);
328 remainingConditions.push_back(cond);
333 auto newDeallocOp = DeallocOp::create(rewriter, loc, memref, cond,
334 deallocOp.getRetained());
335 updatedConditions.push_back(
336 llvm::to_vector(
ValueRange(newDeallocOp.getUpdatedConditions())));
340 if (remainingMemrefs.size() == deallocOp.getMemrefs().size())
345 DeallocOp::create(rewriter, loc, remainingMemrefs, remainingConditions,
346 deallocOp.getRetained());
349 SmallVector<Value> replacements =
350 llvm::to_vector(
ValueRange(newDeallocOp.getUpdatedConditions()));
351 for (
auto additionalConditions : updatedConditions) {
352 assert(replacements.size() == additionalConditions.size() &&
353 "expected same number of updated conditions");
354 for (int64_t i = 0, e = replacements.size(); i < e; ++i) {
355 replacements[i] = arith::OrIOp::create(rewriter, loc, replacements[i],
356 additionalConditions[i]);
359 rewriter.
replaceOp(deallocOp, replacements);
364 BufferOriginAnalysis &analysis;
390struct RetainedMemrefAliasingAlwaysDeallocatedMemref
392 RetainedMemrefAliasingAlwaysDeallocatedMemref(MLIRContext *context,
393 BufferOriginAnalysis &analysis)
394 : OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
396 LogicalResult matchAndRewrite(DeallocOp deallocOp,
397 PatternRewriter &rewriter)
const override {
398 BitVector aliasesWithConstTrueMemref(deallocOp.getRetained().size());
399 SmallVector<Value> newMemrefs, newConditions;
400 for (
auto [memref, cond] :
401 llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
402 bool canDropMemref =
false;
403 for (
auto [i, retained, res] : llvm::enumerate(
404 deallocOp.getRetained(), deallocOp.getUpdatedConditions())) {
408 std::optional<bool> analysisResult =
409 analysis.isSameAllocation(retained, memref);
410 if (analysisResult ==
true) {
412 aliasesWithConstTrueMemref[i] =
true;
413 canDropMemref =
true;
424 std::optional<bool> extractAnalysisResult =
425 analysis.isSameAllocation(retained, extractOp.getOperand());
426 if (extractAnalysisResult ==
true) {
428 aliasesWithConstTrueMemref[i] =
true;
429 canDropMemref =
true;
433 if (!canDropMemref) {
434 newMemrefs.push_back(memref);
435 newConditions.push_back(cond);
438 if (!aliasesWithConstTrueMemref.all())
446 BufferOriginAnalysis &analysis;
460struct BufferDeallocationSimplificationPass
462 BufferDeallocationSimplificationPass> {
463 void runOnOperation()
override {
464 BufferOriginAnalysis
analysis(getOperation());
466 patterns.add<RemoveDeallocMemrefsContainedInRetained,
467 RemoveRetainedMemrefsGuaranteedToNotAlias,
468 SplitDeallocWhenNotAliasingAnyOther,
469 RetainedMemrefAliasingAlwaysDeallocatedMemref>(&
getContext(),
477 getOperation(), std::move(
patterns),
478 GreedyRewriteConfig().setRegionSimplificationLevel(
479 GreedySimplifyRegionLevel::Normal))))