24 namespace bufferization {
25 #define GEN_PASS_DEF_LOWERDEALLOCATIONSPASS
26 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
41 class DeallocOpConversion
60 rewriteOneMemrefNoRetainCase(bufferization::DeallocOp op, OpAdaptor adaptor,
62 assert(adaptor.getMemrefs().size() == 1 &&
"expected only one memref");
63 assert(adaptor.getRetained().empty() &&
"expected no retained memrefs");
67 memref::DeallocOp::create(builder, loc, adaptor.getMemrefs()[0]);
68 scf::YieldOp::create(builder, loc);
103 LogicalResult rewriteOneMemrefMultipleRetainCase(
104 bufferization::DeallocOp op, OpAdaptor adaptor,
106 assert(adaptor.getMemrefs().size() == 1 &&
"expected only one memref");
111 Value memrefAsIdx = memref::ExtractAlignedPointerAsIndexOp::create(
112 rewriter, op->getLoc(), adaptor.getMemrefs()[0]);
113 for (
Value retained : adaptor.getRetained()) {
114 Value retainedAsIdx = memref::ExtractAlignedPointerAsIndexOp::create(
115 rewriter, op->getLoc(), retained);
116 Value doesNotAlias = arith::CmpIOp::create(rewriter, op->getLoc(),
117 arith::CmpIPredicate::ne,
118 memrefAsIdx, retainedAsIdx);
119 doesNotAliasList.push_back(doesNotAlias);
123 Value prev = doesNotAliasList.front();
124 for (
Value doesNotAlias :
ArrayRef(doesNotAliasList).drop_front())
125 prev = arith::AndIOp::create(rewriter, op->getLoc(), prev, doesNotAlias);
129 Value shouldDealloc = arith::AndIOp::create(rewriter, op->getLoc(), prev,
130 adaptor.getConditions()[0]);
132 scf::IfOp::create(rewriter, op.getLoc(), shouldDealloc,
134 memref::DeallocOp::create(builder, loc,
135 adaptor.getMemrefs()[0]);
136 scf::YieldOp::create(builder, loc);
144 Value trueVal = arith::ConstantOp::create(rewriter, op->getLoc(),
146 for (
Value doesNotAlias : doesNotAliasList) {
148 arith::XOrIOp::create(rewriter, op->getLoc(), doesNotAlias, trueVal);
149 Value result = arith::AndIOp::create(rewriter, op->getLoc(), aliases,
150 adaptor.getConditions()[0]);
151 replacements.push_back(result);
225 LogicalResult rewriteGeneralCase(bufferization::DeallocOp op,
234 Value toDeallocMemref = memref::AllocOp::create(
235 rewriter, op.getLoc(),
238 Value conditionMemref = memref::AllocOp::create(
239 rewriter, op.getLoc(),
242 Value toRetainMemref = memref::AllocOp::create(
243 rewriter, op.getLoc(),
247 auto getConstValue = [&](uint64_t value) ->
Value {
248 return arith::ConstantOp::create(rewriter, op.getLoc(),
255 Value memrefAsIdx = memref::ExtractAlignedPointerAsIndexOp::create(
256 rewriter, op.getLoc(), toDealloc);
257 memref::StoreOp::create(rewriter, op.getLoc(), memrefAsIdx,
258 toDeallocMemref, getConstValue(i));
262 memref::StoreOp::create(rewriter, op.getLoc(), cond, conditionMemref,
266 Value memrefAsIdx = memref::ExtractAlignedPointerAsIndexOp::create(
267 rewriter, op.getLoc(), toRetain);
268 memref::StoreOp::create(rewriter, op.getLoc(), memrefAsIdx,
269 toRetainMemref, getConstValue(i));
275 Value castedDeallocMemref = memref::CastOp::create(
276 rewriter, op->getLoc(),
279 Value castedCondsMemref = memref::CastOp::create(
280 rewriter, op->getLoc(),
283 Value castedRetainMemref = memref::CastOp::create(
284 rewriter, op->getLoc(),
288 Value deallocCondsMemref = memref::AllocOp::create(
289 rewriter, op.getLoc(),
292 Value retainCondsMemref = memref::AllocOp::create(
293 rewriter, op.getLoc(),
297 Value castedDeallocCondsMemref = memref::CastOp::create(
298 rewriter, op->getLoc(),
301 Value castedRetainCondsMemref = memref::CastOp::create(
302 rewriter, op->getLoc(),
307 func::CallOp::create(
308 rewriter, op.getLoc(), deallocHelperFuncMap.lookup(symtableOp),
310 castedCondsMemref, castedDeallocCondsMemref,
311 castedRetainCondsMemref});
313 for (
unsigned i = 0, e = adaptor.getMemrefs().size(); i < e; ++i) {
314 Value idxValue = getConstValue(i);
315 Value shouldDealloc = memref::LoadOp::create(
316 rewriter, op.getLoc(), deallocCondsMemref, idxValue);
317 scf::IfOp::create(rewriter, op.getLoc(), shouldDealloc,
319 memref::DeallocOp::create(builder, loc,
320 adaptor.getMemrefs()[i]);
321 scf::YieldOp::create(builder, loc);
326 for (
unsigned i = 0, e = adaptor.getRetained().size(); i < e; ++i) {
327 Value idxValue = getConstValue(i);
328 Value ownership = memref::LoadOp::create(rewriter, op.getLoc(),
329 retainCondsMemref, idxValue);
330 replacements.push_back(ownership);
335 memref::DeallocOp::create(rewriter, op.getLoc(), toDeallocMemref);
336 memref::DeallocOp::create(rewriter, op.getLoc(), toRetainMemref);
337 memref::DeallocOp::create(rewriter, op.getLoc(), conditionMemref);
338 memref::DeallocOp::create(rewriter, op.getLoc(), deallocCondsMemref);
339 memref::DeallocOp::create(rewriter, op.getLoc(), retainCondsMemref);
350 deallocHelperFuncMap(deallocHelperFuncMap) {}
353 matchAndRewrite(bufferization::DeallocOp op, OpAdaptor adaptor,
356 if (adaptor.getMemrefs().empty()) {
357 Value falseVal = arith::ConstantOp::create(rewriter, op.getLoc(),
364 if (adaptor.getMemrefs().size() == 1 && adaptor.getRetained().empty())
365 return rewriteOneMemrefNoRetainCase(op, adaptor, rewriter);
367 if (adaptor.getMemrefs().size() == 1)
368 return rewriteOneMemrefMultipleRetainCase(op, adaptor, rewriter);
371 if (!deallocHelperFuncMap.contains(symtableOp))
372 return op->emitError(
373 "library function required for generic lowering, but cannot be "
374 "automatically inserted when operating on functions");
376 return rewriteGeneralCase(op, adaptor, rewriter);
385 struct LowerDeallocationsPass
386 :
public bufferization::impl::LowerDeallocationsPassBase<
387 LowerDeallocationsPass> {
388 void runOnOperation()
override {
389 if (!isa<ModuleOp, FunctionOpInterface>(getOperation())) {
391 "root operation must be a builtin.module or a function");
397 if (
auto module = dyn_cast<ModuleOp>(getOperation())) {
401 getOperation()->walk([&](bufferization::DeallocOp deallocOp) {
404 if (deallocOp.getMemrefs().size() > 1 &&
405 !deallocHelperFuncMap.contains(symtableOp)) {
407 func::FuncOp helperFuncOp =
409 builder, getOperation()->getLoc(), symbolTable);
410 deallocHelperFuncMap[symtableOp] = helperFuncOp;
420 target.addLegalDialect<memref::MemRefDialect, arith::ArithDialect,
421 scf::SCFDialect, func::FuncDialect>();
422 target.addIllegalOp<bufferization::DeallocOp>();
433 Type indexMemrefType =
435 Type boolMemrefType =
438 boolMemrefType, boolMemrefType};
442 auto helperFuncOp = func::FuncOp::create(
445 symbolTable.
insert(helperFuncOp);
446 auto &block = helperFuncOp.getFunctionBody().emplaceBlock();
450 Value toDeallocMemref = helperFuncOp.getArguments()[0];
451 Value toRetainMemref = helperFuncOp.getArguments()[1];
452 Value conditionMemref = helperFuncOp.getArguments()[2];
453 Value deallocCondsMemref = helperFuncOp.getArguments()[3];
454 Value retainCondsMemref = helperFuncOp.getArguments()[4];
460 arith::ConstantOp::create(builder, loc, builder.
getBoolAttr(
true));
462 arith::ConstantOp::create(builder, loc, builder.
getBoolAttr(
false));
463 Value toDeallocSize =
464 memref::DimOp::create(builder, loc, toDeallocMemref, c0);
465 Value toRetainSize = memref::DimOp::create(builder, loc, toRetainMemref, c0);
468 builder, loc, c0, toRetainSize, c1,
ValueRange(),
470 memref::StoreOp::create(builder, loc, falseValue, retainCondsMemref, i);
471 scf::YieldOp::create(builder, loc);
475 builder, loc, c0, toDeallocSize, c1,
ValueRange(),
479 memref::LoadOp::create(builder, loc, toDeallocMemref, outerIter);
481 memref::LoadOp::create(builder, loc, conditionMemref, outerIter);
488 builder, loc, c0, toRetainSize, c1, trueValue,
491 Value retainValue = memref::LoadOp::create(
492 builder, loc, toRetainMemref, i);
493 Value doesAlias = arith::CmpIOp::create(
494 builder, loc, arith::CmpIPredicate::eq, retainValue,
497 builder, loc, doesAlias,
499 Value retainCondValue = memref::LoadOp::create(
500 builder, loc, retainCondsMemref, i);
501 Value aggregatedRetainCond = arith::OrIOp::create(
502 builder, loc, retainCondValue, cond);
503 memref::StoreOp::create(builder, loc,
504 aggregatedRetainCond,
505 retainCondsMemref, i);
506 scf::YieldOp::create(builder, loc);
508 Value doesntAlias = arith::CmpIOp::create(
509 builder, loc, arith::CmpIPredicate::ne, retainValue,
511 Value yieldValue = arith::AndIOp::create(
512 builder, loc, iterArgs[0], doesntAlias);
513 scf::YieldOp::create(builder, loc, yieldValue);
522 builder, loc, c0, outerIter, c1, noRetainAlias,
525 Value prevDeallocValue = memref::LoadOp::create(
526 builder, loc, toDeallocMemref, i);
527 Value doesntAlias = arith::CmpIOp::create(
528 builder, loc, arith::CmpIPredicate::ne,
529 prevDeallocValue, toDealloc);
530 Value yieldValue = arith::AndIOp::create(
531 builder, loc, iterArgs[0], doesntAlias);
532 scf::YieldOp::create(builder, loc, yieldValue);
536 Value shouldDealoc = arith::AndIOp::create(builder, loc, noAlias, cond);
537 memref::StoreOp::create(builder, loc, shouldDealoc, deallocCondsMemref,
539 scf::YieldOp::create(builder, loc);
542 func::ReturnOp::create(builder, loc);
550 deallocHelperFuncMap);
static MLIRContext * getContext(OpFoldResult val)
IntegerAttr getIndexAttr(int64_t value)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
BoolAttr getBoolAttr(bool value)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
static OpBuilder atBlockBegin(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
void clearInsertionPoint()
Reset the insertion point to no location.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
A trait used to provide symbol table functionalities to a region operation.
Operation is the basic unit of execution within MLIR.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
@ Private
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void populateBufferizationDeallocLoweringPattern(RewritePatternSet &patterns, const DeallocHelperMap &deallocHelperFuncMap)
Adds the conversion pattern of the bufferization.dealloc operation to the given pattern set for use i...
func::FuncOp buildDeallocationLibraryFunction(OpBuilder &builder, Location loc, SymbolTable &symbolTable)
Construct the library function needed for the fully generic bufferization.dealloc lowering implemente...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.