25 namespace bufferization {
26 #define GEN_PASS_DEF_LOWERDEALLOCATIONS
27 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
42 class DeallocOpConversion
61 rewriteOneMemrefNoRetainCase(bufferization::DeallocOp op, OpAdaptor adaptor,
63 assert(adaptor.getMemrefs().size() == 1 &&
"expected only one memref");
64 assert(adaptor.getRetained().empty() &&
"expected no retained memrefs");
68 builder.
create<memref::DeallocOp>(loc, adaptor.getMemrefs()[0]);
69 builder.
create<scf::YieldOp>(loc);
104 LogicalResult rewriteOneMemrefMultipleRetainCase(
105 bufferization::DeallocOp op, OpAdaptor adaptor,
107 assert(adaptor.getMemrefs().size() == 1 &&
"expected only one memref");
112 Value memrefAsIdx = rewriter.
create<memref::ExtractAlignedPointerAsIndexOp>(
113 op->getLoc(), adaptor.getMemrefs()[0]);
114 for (
Value retained : adaptor.getRetained()) {
115 Value retainedAsIdx =
116 rewriter.
create<memref::ExtractAlignedPointerAsIndexOp>(op->getLoc(),
118 Value doesNotAlias = rewriter.
create<arith::CmpIOp>(
119 op->getLoc(), arith::CmpIPredicate::ne, memrefAsIdx, retainedAsIdx);
120 doesNotAliasList.push_back(doesNotAlias);
124 Value prev = doesNotAliasList.front();
125 for (
Value doesNotAlias :
ArrayRef(doesNotAliasList).drop_front())
126 prev = rewriter.
create<arith::AndIOp>(op->getLoc(), prev, doesNotAlias);
130 Value shouldDealloc = rewriter.
create<arith::AndIOp>(
131 op->getLoc(), prev, adaptor.getConditions()[0]);
133 rewriter.
create<scf::IfOp>(
135 builder.create<memref::DeallocOp>(loc, adaptor.getMemrefs()[0]);
136 builder.create<scf::YieldOp>(loc);
144 Value trueVal = rewriter.
create<arith::ConstantOp>(
146 for (
Value doesNotAlias : doesNotAliasList) {
148 rewriter.
create<arith::XOrIOp>(op->getLoc(), doesNotAlias, trueVal);
149 Value result = rewriter.
create<arith::AndIOp>(op->getLoc(), aliases,
150 adaptor.getConditions()[0]);
151 replacements.push_back(result);
225 LogicalResult rewriteGeneralCase(bufferization::DeallocOp op,
234 Value toDeallocMemref = rewriter.
create<memref::AllocOp>(
237 Value conditionMemref = rewriter.
create<memref::AllocOp>(
238 op.getLoc(),
MemRefType::get({(int64_t)adaptor.getConditions().size()},
240 Value toRetainMemref = rewriter.
create<memref::AllocOp>(
244 auto getConstValue = [&](uint64_t value) ->
Value {
245 return rewriter.
create<arith::ConstantOp>(op.getLoc(),
253 rewriter.
create<memref::ExtractAlignedPointerAsIndexOp>(op.getLoc(),
255 rewriter.
create<memref::StoreOp>(op.getLoc(), memrefAsIdx,
256 toDeallocMemref, getConstValue(i));
260 rewriter.
create<memref::StoreOp>(op.getLoc(), cond, conditionMemref,
265 rewriter.
create<memref::ExtractAlignedPointerAsIndexOp>(op.getLoc(),
267 rewriter.
create<memref::StoreOp>(op.getLoc(), memrefAsIdx, toRetainMemref,
274 Value castedDeallocMemref = rewriter.
create<memref::CastOp>(
278 Value castedCondsMemref = rewriter.
create<memref::CastOp>(
282 Value castedRetainMemref = rewriter.
create<memref::CastOp>(
287 Value deallocCondsMemref = rewriter.
create<memref::AllocOp>(
290 Value retainCondsMemref = rewriter.
create<memref::AllocOp>(
294 Value castedDeallocCondsMemref = rewriter.
create<memref::CastOp>(
298 Value castedRetainCondsMemref = rewriter.
create<memref::CastOp>(
304 rewriter.
create<func::CallOp>(
305 op.getLoc(), deallocHelperFuncMap.lookup(symtableOp),
307 castedCondsMemref, castedDeallocCondsMemref,
308 castedRetainCondsMemref});
310 for (
unsigned i = 0, e = adaptor.getMemrefs().size(); i < e; ++i) {
311 Value idxValue = getConstValue(i);
312 Value shouldDealloc = rewriter.
create<memref::LoadOp>(
313 op.getLoc(), deallocCondsMemref, idxValue);
314 rewriter.
create<scf::IfOp>(
316 builder.create<memref::DeallocOp>(loc, adaptor.getMemrefs()[i]);
317 builder.create<scf::YieldOp>(loc);
322 for (
unsigned i = 0, e = adaptor.getRetained().size(); i < e; ++i) {
323 Value idxValue = getConstValue(i);
325 op.getLoc(), retainCondsMemref, idxValue);
326 replacements.push_back(ownership);
331 rewriter.
create<memref::DeallocOp>(op.getLoc(), toDeallocMemref);
332 rewriter.
create<memref::DeallocOp>(op.getLoc(), toRetainMemref);
333 rewriter.
create<memref::DeallocOp>(op.getLoc(), conditionMemref);
334 rewriter.
create<memref::DeallocOp>(op.getLoc(), deallocCondsMemref);
335 rewriter.
create<memref::DeallocOp>(op.getLoc(), retainCondsMemref);
346 deallocHelperFuncMap(deallocHelperFuncMap) {}
349 matchAndRewrite(bufferization::DeallocOp op, OpAdaptor adaptor,
352 if (adaptor.getMemrefs().empty()) {
353 Value falseVal = rewriter.
create<arith::ConstantOp>(
360 if (adaptor.getMemrefs().size() == 1 && adaptor.getRetained().empty())
361 return rewriteOneMemrefNoRetainCase(op, adaptor, rewriter);
363 if (adaptor.getMemrefs().size() == 1)
364 return rewriteOneMemrefMultipleRetainCase(op, adaptor, rewriter);
367 if (!deallocHelperFuncMap.contains(symtableOp))
368 return op->emitError(
369 "library function required for generic lowering, but cannot be "
370 "automatically inserted when operating on functions");
372 return rewriteGeneralCase(op, adaptor, rewriter);
381 struct LowerDeallocationsPass
382 :
public bufferization::impl::LowerDeallocationsBase<
383 LowerDeallocationsPass> {
384 void runOnOperation()
override {
385 if (!isa<ModuleOp, FunctionOpInterface>(getOperation())) {
387 "root operation must be a builtin.module or a function");
393 if (
auto module = dyn_cast<ModuleOp>(getOperation())) {
397 getOperation()->walk([&](bufferization::DeallocOp deallocOp) {
400 if (deallocOp.getMemrefs().size() > 1 &&
401 !deallocHelperFuncMap.contains(symtableOp)) {
403 func::FuncOp helperFuncOp =
405 builder, getOperation()->getLoc(), symbolTable);
406 deallocHelperFuncMap[symtableOp] = helperFuncOp;
413 patterns, deallocHelperFuncMap);
416 target.addLegalDialect<memref::MemRefDialect, arith::ArithDialect,
417 scf::SCFDialect, func::FuncDialect>();
418 target.addIllegalOp<bufferization::DeallocOp>();
421 std::move(patterns))))
429 Type indexMemrefType =
431 Type boolMemrefType =
434 boolMemrefType, boolMemrefType};
438 auto helperFuncOp = func::FuncOp::create(
441 symbolTable.
insert(helperFuncOp);
442 auto &block = helperFuncOp.getFunctionBody().emplaceBlock();
446 Value toDeallocMemref = helperFuncOp.getArguments()[0];
447 Value toRetainMemref = helperFuncOp.getArguments()[1];
448 Value conditionMemref = helperFuncOp.getArguments()[2];
449 Value deallocCondsMemref = helperFuncOp.getArguments()[3];
450 Value retainCondsMemref = helperFuncOp.getArguments()[4];
459 Value toDeallocSize = builder.
create<memref::DimOp>(loc, toDeallocMemref, c0);
460 Value toRetainSize = builder.
create<memref::DimOp>(loc, toRetainMemref, c0);
462 builder.
create<scf::ForOp>(
463 loc, c0, toRetainSize, c1, std::nullopt,
465 builder.
create<memref::StoreOp>(loc, falseValue, retainCondsMemref, i);
466 builder.
create<scf::YieldOp>(loc);
469 builder.
create<scf::ForOp>(
470 loc, c0, toDeallocSize, c1, std::nullopt,
474 builder.
create<memref::LoadOp>(loc, toDeallocMemref, outerIter);
476 builder.
create<memref::LoadOp>(loc, conditionMemref, outerIter);
480 Value noRetainAlias =
483 loc, c0, toRetainSize, c1, trueValue,
486 Value retainValue = builder.create<memref::LoadOp>(
487 loc, toRetainMemref, i);
488 Value doesAlias = builder.create<arith::CmpIOp>(
489 loc, arith::CmpIPredicate::eq, retainValue,
491 builder.create<scf::IfOp>(
494 Value retainCondValue =
495 builder.create<memref::LoadOp>(
496 loc, retainCondsMemref, i);
497 Value aggregatedRetainCond =
498 builder.create<arith::OrIOp>(
499 loc, retainCondValue, cond);
500 builder.create<memref::StoreOp>(
501 loc, aggregatedRetainCond, retainCondsMemref,
503 builder.create<scf::YieldOp>(loc);
505 Value doesntAlias = builder.create<arith::CmpIOp>(
506 loc, arith::CmpIPredicate::ne, retainValue,
508 Value yieldValue = builder.create<arith::AndIOp>(
509 loc, iterArgs[0], doesntAlias);
510 builder.create<scf::YieldOp>(loc, yieldValue);
519 loc, c0, outerIter, c1, noRetainAlias,
522 Value prevDeallocValue = builder.create<memref::LoadOp>(
523 loc, toDeallocMemref, i);
524 Value doesntAlias = builder.create<arith::CmpIOp>(
525 loc, arith::CmpIPredicate::ne, prevDeallocValue,
527 Value yieldValue = builder.create<arith::AndIOp>(
528 loc, iterArgs[0], doesntAlias);
529 builder.create<scf::YieldOp>(loc, yieldValue);
533 Value shouldDealoc = builder.create<arith::AndIOp>(loc, noAlias, cond);
534 builder.create<memref::StoreOp>(loc, shouldDealoc, deallocCondsMemref,
536 builder.create<scf::YieldOp>(loc);
539 builder.
create<func::ReturnOp>(loc);
547 deallocHelperFuncMap);
551 return std::make_unique<LowerDeallocationsPass>();
static MLIRContext * getContext(OpFoldResult val)
IntegerAttr getIndexAttr(int64_t value)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
BoolAttr getBoolAttr(bool value)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
static OpBuilder atBlockBegin(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
void clearInsertionPoint()
Reset the insertion point to no location.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
A trait used to provide symbol table functionalities to a region operation.
Operation is the basic unit of execution within MLIR.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
@ Private
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
std::unique_ptr< Pass > createLowerDeallocationsPass()
Creates an instance of the LowerDeallocations pass to lower bufferization.dealloc operations to the m...
void populateBufferizationDeallocLoweringPattern(RewritePatternSet &patterns, const DeallocHelperMap &deallocHelperFuncMap)
Adds the conversion pattern of the bufferization.dealloc operation to the given pattern set for use i...
func::FuncOp buildDeallocationLibraryFunction(OpBuilder &builder, Location loc, SymbolTable &symbolTable)
Construct the library function needed for the fully generic bufferization.dealloc lowering implemente...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.