24 namespace bufferization {
25 #define GEN_PASS_DEF_LOWERDEALLOCATIONSPASS
26 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
41 class DeallocOpConversion
60 rewriteOneMemrefNoRetainCase(bufferization::DeallocOp op, OpAdaptor adaptor,
62 assert(adaptor.getMemrefs().size() == 1 &&
"expected only one memref");
63 assert(adaptor.getRetained().empty() &&
"expected no retained memrefs");
67 builder.
create<memref::DeallocOp>(loc, adaptor.getMemrefs()[0]);
68 builder.
create<scf::YieldOp>(loc);
103 LogicalResult rewriteOneMemrefMultipleRetainCase(
104 bufferization::DeallocOp op, OpAdaptor adaptor,
106 assert(adaptor.getMemrefs().size() == 1 &&
"expected only one memref");
111 Value memrefAsIdx = rewriter.
create<memref::ExtractAlignedPointerAsIndexOp>(
112 op->getLoc(), adaptor.getMemrefs()[0]);
113 for (
Value retained : adaptor.getRetained()) {
114 Value retainedAsIdx =
115 rewriter.
create<memref::ExtractAlignedPointerAsIndexOp>(op->getLoc(),
117 Value doesNotAlias = rewriter.
create<arith::CmpIOp>(
118 op->getLoc(), arith::CmpIPredicate::ne, memrefAsIdx, retainedAsIdx);
119 doesNotAliasList.push_back(doesNotAlias);
123 Value prev = doesNotAliasList.front();
124 for (
Value doesNotAlias :
ArrayRef(doesNotAliasList).drop_front())
125 prev = rewriter.
create<arith::AndIOp>(op->getLoc(), prev, doesNotAlias);
129 Value shouldDealloc = rewriter.
create<arith::AndIOp>(
130 op->getLoc(), prev, adaptor.getConditions()[0]);
132 rewriter.
create<scf::IfOp>(
134 builder.create<memref::DeallocOp>(loc, adaptor.getMemrefs()[0]);
135 builder.create<scf::YieldOp>(loc);
143 Value trueVal = rewriter.
create<arith::ConstantOp>(
145 for (
Value doesNotAlias : doesNotAliasList) {
147 rewriter.
create<arith::XOrIOp>(op->getLoc(), doesNotAlias, trueVal);
148 Value result = rewriter.
create<arith::AndIOp>(op->getLoc(), aliases,
149 adaptor.getConditions()[0]);
150 replacements.push_back(result);
224 LogicalResult rewriteGeneralCase(bufferization::DeallocOp op,
233 Value toDeallocMemref = rewriter.
create<memref::AllocOp>(
236 Value conditionMemref = rewriter.
create<memref::AllocOp>(
237 op.getLoc(),
MemRefType::get({(int64_t)adaptor.getConditions().size()},
239 Value toRetainMemref = rewriter.
create<memref::AllocOp>(
243 auto getConstValue = [&](uint64_t value) ->
Value {
244 return rewriter.
create<arith::ConstantOp>(op.getLoc(),
252 rewriter.
create<memref::ExtractAlignedPointerAsIndexOp>(op.getLoc(),
254 rewriter.
create<memref::StoreOp>(op.getLoc(), memrefAsIdx,
255 toDeallocMemref, getConstValue(i));
259 rewriter.
create<memref::StoreOp>(op.getLoc(), cond, conditionMemref,
264 rewriter.
create<memref::ExtractAlignedPointerAsIndexOp>(op.getLoc(),
266 rewriter.
create<memref::StoreOp>(op.getLoc(), memrefAsIdx, toRetainMemref,
273 Value castedDeallocMemref = rewriter.
create<memref::CastOp>(
277 Value castedCondsMemref = rewriter.
create<memref::CastOp>(
281 Value castedRetainMemref = rewriter.
create<memref::CastOp>(
286 Value deallocCondsMemref = rewriter.
create<memref::AllocOp>(
289 Value retainCondsMemref = rewriter.
create<memref::AllocOp>(
293 Value castedDeallocCondsMemref = rewriter.
create<memref::CastOp>(
297 Value castedRetainCondsMemref = rewriter.
create<memref::CastOp>(
303 rewriter.
create<func::CallOp>(
304 op.getLoc(), deallocHelperFuncMap.lookup(symtableOp),
306 castedCondsMemref, castedDeallocCondsMemref,
307 castedRetainCondsMemref});
309 for (
unsigned i = 0, e = adaptor.getMemrefs().size(); i < e; ++i) {
310 Value idxValue = getConstValue(i);
311 Value shouldDealloc = rewriter.
create<memref::LoadOp>(
312 op.getLoc(), deallocCondsMemref, idxValue);
313 rewriter.
create<scf::IfOp>(
315 builder.create<memref::DeallocOp>(loc, adaptor.getMemrefs()[i]);
316 builder.create<scf::YieldOp>(loc);
321 for (
unsigned i = 0, e = adaptor.getRetained().size(); i < e; ++i) {
322 Value idxValue = getConstValue(i);
324 op.getLoc(), retainCondsMemref, idxValue);
325 replacements.push_back(ownership);
330 rewriter.
create<memref::DeallocOp>(op.getLoc(), toDeallocMemref);
331 rewriter.
create<memref::DeallocOp>(op.getLoc(), toRetainMemref);
332 rewriter.
create<memref::DeallocOp>(op.getLoc(), conditionMemref);
333 rewriter.
create<memref::DeallocOp>(op.getLoc(), deallocCondsMemref);
334 rewriter.
create<memref::DeallocOp>(op.getLoc(), retainCondsMemref);
345 deallocHelperFuncMap(deallocHelperFuncMap) {}
348 matchAndRewrite(bufferization::DeallocOp op, OpAdaptor adaptor,
351 if (adaptor.getMemrefs().empty()) {
352 Value falseVal = rewriter.
create<arith::ConstantOp>(
359 if (adaptor.getMemrefs().size() == 1 && adaptor.getRetained().empty())
360 return rewriteOneMemrefNoRetainCase(op, adaptor, rewriter);
362 if (adaptor.getMemrefs().size() == 1)
363 return rewriteOneMemrefMultipleRetainCase(op, adaptor, rewriter);
366 if (!deallocHelperFuncMap.contains(symtableOp))
367 return op->emitError(
368 "library function required for generic lowering, but cannot be "
369 "automatically inserted when operating on functions");
371 return rewriteGeneralCase(op, adaptor, rewriter);
380 struct LowerDeallocationsPass
381 :
public bufferization::impl::LowerDeallocationsPassBase<
382 LowerDeallocationsPass> {
383 void runOnOperation()
override {
384 if (!isa<ModuleOp, FunctionOpInterface>(getOperation())) {
386 "root operation must be a builtin.module or a function");
392 if (
auto module = dyn_cast<ModuleOp>(getOperation())) {
396 getOperation()->walk([&](bufferization::DeallocOp deallocOp) {
399 if (deallocOp.getMemrefs().size() > 1 &&
400 !deallocHelperFuncMap.contains(symtableOp)) {
402 func::FuncOp helperFuncOp =
404 builder, getOperation()->getLoc(), symbolTable);
405 deallocHelperFuncMap[symtableOp] = helperFuncOp;
415 target.addLegalDialect<memref::MemRefDialect, arith::ArithDialect,
416 scf::SCFDialect, func::FuncDialect>();
417 target.addIllegalOp<bufferization::DeallocOp>();
428 Type indexMemrefType =
430 Type boolMemrefType =
433 boolMemrefType, boolMemrefType};
437 auto helperFuncOp = func::FuncOp::create(
440 symbolTable.
insert(helperFuncOp);
441 auto &block = helperFuncOp.getFunctionBody().emplaceBlock();
445 Value toDeallocMemref = helperFuncOp.getArguments()[0];
446 Value toRetainMemref = helperFuncOp.getArguments()[1];
447 Value conditionMemref = helperFuncOp.getArguments()[2];
448 Value deallocCondsMemref = helperFuncOp.getArguments()[3];
449 Value retainCondsMemref = helperFuncOp.getArguments()[4];
458 Value toDeallocSize = builder.
create<memref::DimOp>(loc, toDeallocMemref, c0);
459 Value toRetainSize = builder.
create<memref::DimOp>(loc, toRetainMemref, c0);
461 builder.
create<scf::ForOp>(
464 builder.
create<memref::StoreOp>(loc, falseValue, retainCondsMemref, i);
465 builder.
create<scf::YieldOp>(loc);
468 builder.
create<scf::ForOp>(
473 builder.
create<memref::LoadOp>(loc, toDeallocMemref, outerIter);
475 builder.
create<memref::LoadOp>(loc, conditionMemref, outerIter);
479 Value noRetainAlias =
482 loc, c0, toRetainSize, c1, trueValue,
485 Value retainValue = builder.create<memref::LoadOp>(
486 loc, toRetainMemref, i);
487 Value doesAlias = builder.create<arith::CmpIOp>(
488 loc, arith::CmpIPredicate::eq, retainValue,
490 builder.create<scf::IfOp>(
493 Value retainCondValue =
494 builder.create<memref::LoadOp>(
495 loc, retainCondsMemref, i);
496 Value aggregatedRetainCond =
497 builder.create<arith::OrIOp>(
498 loc, retainCondValue, cond);
499 builder.create<memref::StoreOp>(
500 loc, aggregatedRetainCond, retainCondsMemref,
502 builder.create<scf::YieldOp>(loc);
504 Value doesntAlias = builder.create<arith::CmpIOp>(
505 loc, arith::CmpIPredicate::ne, retainValue,
507 Value yieldValue = builder.create<arith::AndIOp>(
508 loc, iterArgs[0], doesntAlias);
509 builder.create<scf::YieldOp>(loc, yieldValue);
518 loc, c0, outerIter, c1, noRetainAlias,
521 Value prevDeallocValue = builder.create<memref::LoadOp>(
522 loc, toDeallocMemref, i);
523 Value doesntAlias = builder.create<arith::CmpIOp>(
524 loc, arith::CmpIPredicate::ne, prevDeallocValue,
526 Value yieldValue = builder.create<arith::AndIOp>(
527 loc, iterArgs[0], doesntAlias);
528 builder.create<scf::YieldOp>(loc, yieldValue);
532 Value shouldDealoc = builder.create<arith::AndIOp>(loc, noAlias, cond);
533 builder.create<memref::StoreOp>(loc, shouldDealoc, deallocCondsMemref,
535 builder.create<scf::YieldOp>(loc);
538 builder.
create<func::ReturnOp>(loc);
546 deallocHelperFuncMap);
static MLIRContext * getContext(OpFoldResult val)
IntegerAttr getIndexAttr(int64_t value)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
BoolAttr getBoolAttr(bool value)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
static OpBuilder atBlockBegin(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
void clearInsertionPoint()
Reset the insertion point to no location.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
A trait used to provide symbol table functionalities to a region operation.
Operation is the basic unit of execution within MLIR.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
@ Private
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void populateBufferizationDeallocLoweringPattern(RewritePatternSet &patterns, const DeallocHelperMap &deallocHelperFuncMap)
Adds the conversion pattern of the bufferization.dealloc operation to the given pattern set for use i...
func::FuncOp buildDeallocationLibraryFunction(OpBuilder &builder, Location loc, SymbolTable &symbolTable)
Construct the library function needed for the fully generic bufferization.dealloc lowering implemente...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.