MLIR  19.0.0git
BufferDeallocationOpInterfaceImpl.cpp
Go to the documentation of this file.
1 //===- BufferDeallocationOpInterfaceImpl.cpp ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 
14 using namespace mlir;
15 using namespace mlir::bufferization;
16 
17 namespace {
18 /// The `scf.forall.in_parallel` terminator is special in a few ways:
19 /// * It does not implement the BranchOpInterface or
20 /// RegionBranchTerminatorOpInterface, but the ParallelCombiningOpInterface
21 /// which is not supported by BufferDeallocation.
22 /// * It has a graph-like region which only allows one specific tensor op
23 /// * After bufferization the nested region is always empty
24 /// For these reasons we provide custom deallocation logic via this external
25 /// model.
26 ///
27 /// Example:
28 /// ```mlir
29 /// scf.forall (%arg1) in (%arg0) {
30 /// %alloc = memref.alloc() : memref<2xf32>
31 /// ...
32 /// <implicit in_parallel terminator here>
33 /// }
34 /// ```
35 /// gets transformed to
36 /// ```mlir
37 /// scf.forall (%arg1) in (%arg0) {
38 /// %alloc = memref.alloc() : memref<2xf32>
39 /// ...
40 /// bufferization.dealloc (%alloc : memref<2xf32>) if (%true)
41 /// <implicit in_parallel terminator here>
42 /// }
43 /// ```
44 struct InParallelOpInterface
45  : public BufferDeallocationOpInterface::ExternalModel<InParallelOpInterface,
46  scf::InParallelOp> {
47  FailureOr<Operation *> process(Operation *op, DeallocationState &state,
48  const DeallocationOptions &options) const {
49  auto inParallelOp = cast<scf::InParallelOp>(op);
50  if (!inParallelOp.getBody()->empty())
51  return op->emitError("only supported when nested region is empty");
52 
53  SmallVector<Value> updatedOperandOwnership;
55  state, op, {}, updatedOperandOwnership);
56  }
57 };
58 
59 struct ReduceReturnOpInterface
60  : public BufferDeallocationOpInterface::ExternalModel<
61  ReduceReturnOpInterface, scf::ReduceReturnOp> {
62  FailureOr<Operation *> process(Operation *op, DeallocationState &state,
63  const DeallocationOptions &options) const {
64  auto reduceReturnOp = cast<scf::ReduceReturnOp>(op);
65  if (isa<BaseMemRefType>(reduceReturnOp.getOperand().getType()))
66  return op->emitError("only supported when operand is not a MemRef");
67 
68  SmallVector<Value> updatedOperandOwnership;
70  state, op, {}, updatedOperandOwnership);
71  }
72 };
73 
74 } // namespace
75 
77  DialectRegistry &registry) {
78  registry.addExtension(+[](MLIRContext *ctx, SCFDialect *dialect) {
79  InParallelOp::attachInterface<InParallelOpInterface>(*ctx);
80  ReduceReturnOp::attachInterface<ReduceReturnOpInterface>(*ctx);
81  });
82 }
static llvm::ManagedStatic< PassManagerOptions > options
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
This class collects all the state that we need to perform the buffer deallocation pass with associate...
FailureOr< Operation * > insertDeallocOpForReturnLike(DeallocationState &state, Operation *op, ValueRange operands, SmallVectorImpl< Value > &updatedOperandOwnerships)
Insert a bufferization.dealloc operation right before op which has to be a terminator without any suc...
void registerBufferDeallocationOpInterfaceExternalModels(DialectRegistry &registry)
Include the generated interface declarations.
Options for BufferDeallocationOpInterface-based buffer deallocation.