25 namespace bufferization {
26 #define GEN_PASS_DEF_LOWERDEALLOCATIONS
27 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
42 class DeallocOpConversion
61 rewriteOneMemrefNoRetainCase(bufferization::DeallocOp op, OpAdaptor adaptor,
63 assert(adaptor.getMemrefs().size() == 1 &&
"expected only one memref");
64 assert(adaptor.getRetained().empty() &&
"expected no retained memrefs");
68 builder.
create<memref::DeallocOp>(loc, adaptor.getMemrefs()[0]);
69 builder.
create<scf::YieldOp>(loc);
104 LogicalResult rewriteOneMemrefMultipleRetainCase(
105 bufferization::DeallocOp op, OpAdaptor adaptor,
107 assert(adaptor.getMemrefs().size() == 1 &&
"expected only one memref");
112 Value memrefAsIdx = rewriter.
create<memref::ExtractAlignedPointerAsIndexOp>(
113 op->
getLoc(), adaptor.getMemrefs()[0]);
114 for (
Value retained : adaptor.getRetained()) {
115 Value retainedAsIdx =
116 rewriter.
create<memref::ExtractAlignedPointerAsIndexOp>(op->
getLoc(),
118 Value doesNotAlias = rewriter.
create<arith::CmpIOp>(
119 op->
getLoc(), arith::CmpIPredicate::ne, memrefAsIdx, retainedAsIdx);
120 doesNotAliasList.push_back(doesNotAlias);
124 Value prev = doesNotAliasList.front();
125 for (
Value doesNotAlias :
ArrayRef(doesNotAliasList).drop_front())
126 prev = rewriter.
create<arith::AndIOp>(op->
getLoc(), prev, doesNotAlias);
130 Value shouldDealloc = rewriter.
create<arith::AndIOp>(
131 op->
getLoc(), prev, adaptor.getConditions()[0]);
133 rewriter.
create<scf::IfOp>(
135 builder.create<memref::DeallocOp>(loc, adaptor.getMemrefs()[0]);
136 builder.create<scf::YieldOp>(loc);
144 Value trueVal = rewriter.
create<arith::ConstantOp>(
146 for (
Value doesNotAlias : doesNotAliasList) {
148 rewriter.
create<arith::XOrIOp>(op->
getLoc(), doesNotAlias, trueVal);
150 adaptor.getConditions()[0]);
151 replacements.push_back(result);
225 LogicalResult rewriteGeneralCase(bufferization::DeallocOp op,
234 Value toDeallocMemref = rewriter.
create<memref::AllocOp>(
237 Value conditionMemref = rewriter.
create<memref::AllocOp>(
240 Value toRetainMemref = rewriter.
create<memref::AllocOp>(
244 auto getConstValue = [&](uint64_t value) ->
Value {
253 rewriter.
create<memref::ExtractAlignedPointerAsIndexOp>(op.
getLoc(),
255 rewriter.
create<memref::StoreOp>(op.
getLoc(), memrefAsIdx,
256 toDeallocMemref, getConstValue(i));
260 rewriter.
create<memref::StoreOp>(op.
getLoc(), cond, conditionMemref,
265 rewriter.
create<memref::ExtractAlignedPointerAsIndexOp>(op.
getLoc(),
267 rewriter.
create<memref::StoreOp>(op.
getLoc(), memrefAsIdx, toRetainMemref,
274 Value castedDeallocMemref = rewriter.
create<memref::CastOp>(
278 Value castedCondsMemref = rewriter.
create<memref::CastOp>(
282 Value castedRetainMemref = rewriter.
create<memref::CastOp>(
287 Value deallocCondsMemref = rewriter.
create<memref::AllocOp>(
290 Value retainCondsMemref = rewriter.
create<memref::AllocOp>(
294 Value castedDeallocCondsMemref = rewriter.
create<memref::CastOp>(
298 Value castedRetainCondsMemref = rewriter.
create<memref::CastOp>(
304 rewriter.
create<func::CallOp>(
305 op.
getLoc(), deallocHelperFuncMap.lookup(symtableOp),
307 castedCondsMemref, castedDeallocCondsMemref,
308 castedRetainCondsMemref});
310 for (
unsigned i = 0, e = adaptor.getMemrefs().size(); i < e; ++i) {
311 Value idxValue = getConstValue(i);
312 Value shouldDealloc = rewriter.
create<memref::LoadOp>(
313 op.
getLoc(), deallocCondsMemref, idxValue);
314 rewriter.
create<scf::IfOp>(
316 builder.create<memref::DeallocOp>(loc, adaptor.getMemrefs()[i]);
317 builder.create<scf::YieldOp>(loc);
322 for (
unsigned i = 0, e = adaptor.getRetained().size(); i < e; ++i) {
323 Value idxValue = getConstValue(i);
325 op.
getLoc(), retainCondsMemref, idxValue);
326 replacements.push_back(ownership);
331 rewriter.
create<memref::DeallocOp>(op.
getLoc(), toDeallocMemref);
332 rewriter.
create<memref::DeallocOp>(op.
getLoc(), toRetainMemref);
333 rewriter.
create<memref::DeallocOp>(op.
getLoc(), conditionMemref);
334 rewriter.
create<memref::DeallocOp>(op.
getLoc(), deallocCondsMemref);
335 rewriter.
create<memref::DeallocOp>(op.
getLoc(), retainCondsMemref);
346 deallocHelperFuncMap(deallocHelperFuncMap) {}
349 matchAndRewrite(bufferization::DeallocOp op, OpAdaptor adaptor,
352 if (adaptor.getMemrefs().empty()) {
353 Value falseVal = rewriter.
create<arith::ConstantOp>(
360 if (adaptor.getMemrefs().size() == 1 && adaptor.getRetained().empty())
361 return rewriteOneMemrefNoRetainCase(op, adaptor, rewriter);
363 if (adaptor.getMemrefs().size() == 1)
364 return rewriteOneMemrefMultipleRetainCase(op, adaptor, rewriter);
367 if (!deallocHelperFuncMap.contains(symtableOp))
369 "library function required for generic lowering, but cannot be "
370 "automatically inserted when operating on functions");
372 return rewriteGeneralCase(op, adaptor, rewriter);
381 struct LowerDeallocationsPass
382 :
public bufferization::impl::LowerDeallocationsBase<
383 LowerDeallocationsPass> {
384 void runOnOperation()
override {
385 if (!isa<ModuleOp, FunctionOpInterface>(getOperation())) {
387 "root operation must be a builtin.module or a function");
393 if (
auto module = dyn_cast<ModuleOp>(getOperation())) {
398 getOperation()->walk([&](bufferization::DeallocOp deallocOp) {
401 if (deallocOp.getMemrefs().size() > 1 &&
402 !deallocHelperFuncMap.contains(symtableOp)) {
404 func::FuncOp helperFuncOp =
406 builder, getOperation()->getLoc(), symbolTable);
407 deallocHelperFuncMap[symtableOp] = helperFuncOp;
414 patterns, deallocHelperFuncMap);
417 target.addLegalDialect<memref::MemRefDialect, arith::ArithDialect,
418 scf::SCFDialect, func::FuncDialect>();
419 target.addIllegalOp<bufferization::DeallocOp>();
422 std::move(patterns))))
430 Type indexMemrefType =
432 Type boolMemrefType =
435 boolMemrefType, boolMemrefType};
439 auto helperFuncOp = func::FuncOp::create(
442 symbolTable.
insert(helperFuncOp);
443 auto &block = helperFuncOp.getFunctionBody().emplaceBlock();
447 Value toDeallocMemref = helperFuncOp.getArguments()[0];
448 Value toRetainMemref = helperFuncOp.getArguments()[1];
449 Value conditionMemref = helperFuncOp.getArguments()[2];
450 Value deallocCondsMemref = helperFuncOp.getArguments()[3];
451 Value retainCondsMemref = helperFuncOp.getArguments()[4];
460 Value toDeallocSize = builder.
create<memref::DimOp>(loc, toDeallocMemref, c0);
461 Value toRetainSize = builder.
create<memref::DimOp>(loc, toRetainMemref, c0);
463 builder.
create<scf::ForOp>(
464 loc, c0, toRetainSize, c1, std::nullopt,
466 builder.
create<memref::StoreOp>(loc, falseValue, retainCondsMemref, i);
467 builder.
create<scf::YieldOp>(loc);
470 builder.
create<scf::ForOp>(
471 loc, c0, toDeallocSize, c1, std::nullopt,
475 builder.
create<memref::LoadOp>(loc, toDeallocMemref, outerIter);
477 builder.
create<memref::LoadOp>(loc, conditionMemref, outerIter);
481 Value noRetainAlias =
484 loc, c0, toRetainSize, c1, trueValue,
487 Value retainValue = builder.create<memref::LoadOp>(
488 loc, toRetainMemref, i);
489 Value doesAlias = builder.create<arith::CmpIOp>(
490 loc, arith::CmpIPredicate::eq, retainValue,
492 builder.create<scf::IfOp>(
495 Value retainCondValue =
496 builder.create<memref::LoadOp>(
497 loc, retainCondsMemref, i);
498 Value aggregatedRetainCond =
499 builder.create<arith::OrIOp>(
500 loc, retainCondValue, cond);
501 builder.create<memref::StoreOp>(
502 loc, aggregatedRetainCond, retainCondsMemref,
504 builder.create<scf::YieldOp>(loc);
506 Value doesntAlias = builder.create<arith::CmpIOp>(
507 loc, arith::CmpIPredicate::ne, retainValue,
509 Value yieldValue = builder.create<arith::AndIOp>(
510 loc, iterArgs[0], doesntAlias);
511 builder.create<scf::YieldOp>(loc, yieldValue);
520 loc, c0, outerIter, c1, noRetainAlias,
523 Value prevDeallocValue = builder.create<memref::LoadOp>(
524 loc, toDeallocMemref, i);
525 Value doesntAlias = builder.create<arith::CmpIOp>(
526 loc, arith::CmpIPredicate::ne, prevDeallocValue,
528 Value yieldValue = builder.create<arith::AndIOp>(
529 loc, iterArgs[0], doesntAlias);
530 builder.create<scf::YieldOp>(loc, yieldValue);
534 Value shouldDealoc = builder.create<arith::AndIOp>(loc, noAlias, cond);
535 builder.create<memref::StoreOp>(loc, shouldDealoc, deallocCondsMemref,
537 builder.create<scf::YieldOp>(loc);
540 builder.
create<func::ReturnOp>(loc);
548 deallocHelperFuncMap);
552 return std::make_unique<LowerDeallocationsPass>();
static MLIRContext * getContext(OpFoldResult val)
IntegerAttr getIndexAttr(int64_t value)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
BoolAttr getBoolAttr(bool value)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
This class describes a specific conversion target.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
static OpBuilder atBlockBegin(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
void clearInsertionPoint()
Reset the insertion point to no location.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
A trait used to provide symbol table functionalities to a region operation.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
@ Private
The symbol is private and may only be referenced by SymbolRefAttrs local to the operations within the...
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
std::unique_ptr< Pass > createLowerDeallocationsPass()
Creates an instance of the LowerDeallocations pass to lower bufferization.dealloc operations to the m...
void populateBufferizationDeallocLoweringPattern(RewritePatternSet &patterns, const DeallocHelperMap &deallocHelperFuncMap)
Adds the conversion pattern of the bufferization.dealloc operation to the given pattern set for use i...
func::FuncOp buildDeallocationLibraryFunction(OpBuilder &builder, Location loc, SymbolTable &symbolTable)
Construct the library function needed for the fully generic bufferization.dealloc lowering implemente...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.