26 namespace bufferization {
27 #define GEN_PASS_DEF_LOWERDEALLOCATIONS
28 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
43 class DeallocOpConversion
62 rewriteOneMemrefNoRetainCase(bufferization::DeallocOp op, OpAdaptor adaptor,
64 assert(adaptor.getMemrefs().size() == 1 &&
"expected only one memref");
65 assert(adaptor.getRetained().empty() &&
"expected no retained memrefs");
69 builder.
create<memref::DeallocOp>(loc, adaptor.getMemrefs()[0]);
70 builder.
create<scf::YieldOp>(loc);
106 bufferization::DeallocOp op, OpAdaptor adaptor,
108 assert(adaptor.getMemrefs().size() == 1 &&
"expected only one memref");
113 Value memrefAsIdx = rewriter.
create<memref::ExtractAlignedPointerAsIndexOp>(
114 op->
getLoc(), adaptor.getMemrefs()[0]);
115 for (
Value retained : adaptor.getRetained()) {
116 Value retainedAsIdx =
117 rewriter.
create<memref::ExtractAlignedPointerAsIndexOp>(op->
getLoc(),
119 Value doesNotAlias = rewriter.
create<arith::CmpIOp>(
120 op->
getLoc(), arith::CmpIPredicate::ne, memrefAsIdx, retainedAsIdx);
121 doesNotAliasList.push_back(doesNotAlias);
125 Value prev = doesNotAliasList.front();
126 for (
Value doesNotAlias :
ArrayRef(doesNotAliasList).drop_front())
127 prev = rewriter.
create<arith::AndIOp>(op->
getLoc(), prev, doesNotAlias);
131 Value shouldDealloc = rewriter.
create<arith::AndIOp>(
132 op->
getLoc(), prev, adaptor.getConditions()[0]);
134 rewriter.
create<scf::IfOp>(
136 builder.create<memref::DeallocOp>(loc, adaptor.getMemrefs()[0]);
137 builder.create<scf::YieldOp>(loc);
145 Value trueVal = rewriter.
create<arith::ConstantOp>(
147 for (
Value doesNotAlias : doesNotAliasList) {
149 rewriter.
create<arith::XOrIOp>(op->
getLoc(), doesNotAlias, trueVal);
151 adaptor.getConditions()[0]);
152 replacements.push_back(result);
226 LogicalResult rewriteGeneralCase(bufferization::DeallocOp op,
235 Value toDeallocMemref = rewriter.
create<memref::AllocOp>(
238 Value conditionMemref = rewriter.
create<memref::AllocOp>(
241 Value toRetainMemref = rewriter.
create<memref::AllocOp>(
245 auto getConstValue = [&](uint64_t value) ->
Value {
254 rewriter.
create<memref::ExtractAlignedPointerAsIndexOp>(op.
getLoc(),
256 rewriter.
create<memref::StoreOp>(op.
getLoc(), memrefAsIdx,
257 toDeallocMemref, getConstValue(i));
261 rewriter.
create<memref::StoreOp>(op.
getLoc(), cond, conditionMemref,
266 rewriter.
create<memref::ExtractAlignedPointerAsIndexOp>(op.
getLoc(),
268 rewriter.
create<memref::StoreOp>(op.
getLoc(), memrefAsIdx, toRetainMemref,
275 Value castedDeallocMemref = rewriter.
create<memref::CastOp>(
279 Value castedCondsMemref = rewriter.
create<memref::CastOp>(
283 Value castedRetainMemref = rewriter.
create<memref::CastOp>(
288 Value deallocCondsMemref = rewriter.
create<memref::AllocOp>(
291 Value retainCondsMemref = rewriter.
create<memref::AllocOp>(
295 Value castedDeallocCondsMemref = rewriter.
create<memref::CastOp>(
299 Value castedRetainCondsMemref = rewriter.
create<memref::CastOp>(
304 rewriter.
create<func::CallOp>(
305 op.
getLoc(), deallocHelperFunc,
307 castedCondsMemref, castedDeallocCondsMemref,
308 castedRetainCondsMemref});
310 for (
unsigned i = 0, e = adaptor.getMemrefs().size(); i < e; ++i) {
311 Value idxValue = getConstValue(i);
312 Value shouldDealloc = rewriter.
create<memref::LoadOp>(
313 op.
getLoc(), deallocCondsMemref, idxValue);
314 rewriter.
create<scf::IfOp>(
316 builder.create<memref::DeallocOp>(loc, adaptor.getMemrefs()[i]);
317 builder.create<scf::YieldOp>(loc);
322 for (
unsigned i = 0, e = adaptor.getRetained().size(); i < e; ++i) {
323 Value idxValue = getConstValue(i);
325 op.
getLoc(), retainCondsMemref, idxValue);
326 replacements.push_back(ownership);
331 rewriter.
create<memref::DeallocOp>(op.
getLoc(), toDeallocMemref);
332 rewriter.
create<memref::DeallocOp>(op.
getLoc(), toRetainMemref);
333 rewriter.
create<memref::DeallocOp>(op.
getLoc(), conditionMemref);
334 rewriter.
create<memref::DeallocOp>(op.
getLoc(), deallocCondsMemref);
335 rewriter.
create<memref::DeallocOp>(op.
getLoc(), retainCondsMemref);
342 DeallocOpConversion(
MLIRContext *context, func::FuncOp deallocHelperFunc)
344 deallocHelperFunc(deallocHelperFunc) {}
347 matchAndRewrite(bufferization::DeallocOp op, OpAdaptor adaptor,
350 if (adaptor.getMemrefs().empty()) {
351 Value falseVal = rewriter.
create<arith::ConstantOp>(
358 if (adaptor.getMemrefs().size() == 1 && adaptor.getRetained().empty())
359 return rewriteOneMemrefNoRetainCase(op, adaptor, rewriter);
361 if (adaptor.getMemrefs().size() == 1)
362 return rewriteOneMemrefMultipleRetainCase(op, adaptor, rewriter);
364 if (!deallocHelperFunc)
366 "library function required for generic lowering, but cannot be "
367 "automatically inserted when operating on functions");
369 return rewriteGeneralCase(op, adaptor, rewriter);
373 func::FuncOp deallocHelperFunc;
378 struct LowerDeallocationsPass
379 :
public bufferization::impl::LowerDeallocationsBase<
380 LowerDeallocationsPass> {
381 void runOnOperation()
override {
382 if (!isa<ModuleOp, FunctionOpInterface>(getOperation())) {
384 "root operation must be a builtin.module or a function");
389 func::FuncOp helperFuncOp;
390 if (
auto module = dyn_cast<ModuleOp>(getOperation())) {
396 getOperation()->walk([&](bufferization::DeallocOp deallocOp) {
397 if (deallocOp.getMemrefs().size() > 1) {
398 helperFuncOp = bufferization::buildDeallocationLibraryFunction(
399 builder, getOperation()->getLoc(), symbolTable);
400 return WalkResult::interrupt();
411 target.addLegalDialect<memref::MemRefDialect, arith::ArithDialect,
412 scf::SCFDialect, func::FuncDialect>();
413 target.addIllegalOp<bufferization::DeallocOp>();
416 std::move(patterns))))
424 Type indexMemrefType =
426 Type boolMemrefType =
429 boolMemrefType, boolMemrefType};
433 auto helperFuncOp = func::FuncOp::create(
436 symbolTable.
insert(helperFuncOp);
437 auto &block = helperFuncOp.getFunctionBody().emplaceBlock();
441 Value toDeallocMemref = helperFuncOp.getArguments()[0];
442 Value toRetainMemref = helperFuncOp.getArguments()[1];
443 Value conditionMemref = helperFuncOp.getArguments()[2];
444 Value deallocCondsMemref = helperFuncOp.getArguments()[3];
445 Value retainCondsMemref = helperFuncOp.getArguments()[4];
454 Value toDeallocSize = builder.
create<memref::DimOp>(loc, toDeallocMemref, c0);
455 Value toRetainSize = builder.
create<memref::DimOp>(loc, toRetainMemref, c0);
457 builder.
create<scf::ForOp>(
458 loc, c0, toRetainSize, c1, std::nullopt,
460 builder.
create<memref::StoreOp>(loc, falseValue, retainCondsMemref, i);
461 builder.
create<scf::YieldOp>(loc);
464 builder.
create<scf::ForOp>(
465 loc, c0, toDeallocSize, c1, std::nullopt,
469 builder.
create<memref::LoadOp>(loc, toDeallocMemref, outerIter);
471 builder.
create<memref::LoadOp>(loc, conditionMemref, outerIter);
475 Value noRetainAlias =
478 loc, c0, toRetainSize, c1, trueValue,
481 Value retainValue = builder.create<memref::LoadOp>(
482 loc, toRetainMemref, i);
483 Value doesAlias = builder.create<arith::CmpIOp>(
484 loc, arith::CmpIPredicate::eq, retainValue,
486 builder.create<scf::IfOp>(
489 Value retainCondValue =
490 builder.create<memref::LoadOp>(
491 loc, retainCondsMemref, i);
492 Value aggregatedRetainCond =
493 builder.create<arith::OrIOp>(
494 loc, retainCondValue, cond);
495 builder.create<memref::StoreOp>(
496 loc, aggregatedRetainCond, retainCondsMemref,
498 builder.create<scf::YieldOp>(loc);
500 Value doesntAlias = builder.create<arith::CmpIOp>(
501 loc, arith::CmpIPredicate::ne, retainValue,
503 Value yieldValue = builder.create<arith::AndIOp>(
504 loc, iterArgs[0], doesntAlias);
505 builder.create<scf::YieldOp>(loc, yieldValue);
514 loc, c0, outerIter, c1, noRetainAlias,
517 Value prevDeallocValue = builder.create<memref::LoadOp>(
518 loc, toDeallocMemref, i);
519 Value doesntAlias = builder.create<arith::CmpIOp>(
520 loc, arith::CmpIPredicate::ne, prevDeallocValue,
522 Value yieldValue = builder.create<arith::AndIOp>(
523 loc, iterArgs[0], doesntAlias);
524 builder.create<scf::YieldOp>(loc, yieldValue);
528 Value shouldDealoc = builder.create<arith::AndIOp>(loc, noAlias, cond);
529 builder.create<memref::StoreOp>(loc, shouldDealoc, deallocCondsMemref,
531 builder.create<scf::YieldOp>(loc);
534 builder.
create<func::ReturnOp>(loc);
540 patterns.
add<DeallocOpConversion>(patterns.
getContext(), deallocLibraryFunc);
544 return std::make_unique<LowerDeallocationsPass>();
static MLIRContext * getContext(OpFoldResult val)
IntegerAttr getIndexAttr(int64_t value)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
BoolAttr getBoolAttr(bool value)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
This class describes a specific conversion target.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
static OpBuilder atBlockBegin(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
void clearInsertionPoint()
Reset the insertion point to no location.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
@ Private
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
static WalkResult advance()
std::unique_ptr< Pass > createLowerDeallocationsPass()
Creates an instance of the LowerDeallocations pass to lower bufferization.dealloc operations to the m...
func::FuncOp buildDeallocationLibraryFunction(OpBuilder &builder, Location loc, SymbolTable &symbolTable)
Construct the library function needed for the fully generic bufferization.dealloc lowering implemente...
void populateBufferizationDeallocLoweringPattern(RewritePatternSet &patterns, func::FuncOp deallocLibraryFunc)
Adds the conversion pattern of the bufferization.dealloc operation to the given pattern set for use i...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation * > *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.