MLIR  22.0.0git
BufferDeallocationOpInterfaceImpl.cpp
Go to the documentation of this file.
1 //===- BufferDeallocationOpInterfaceImpl.cpp ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
12 
13 using namespace mlir;
14 using namespace mlir::bufferization;
15 
16 namespace {
17 /// The `scf.forall.in_parallel` terminator is special in a few ways:
18 /// * It does not implement the BranchOpInterface or
19 /// RegionBranchTerminatorOpInterface, but the ParallelCombiningOpInterface
20 /// which is not supported by BufferDeallocation.
21 /// * It has a graph-like region which only allows one specific tensor op
22 /// * After bufferization the nested region is always empty
23 /// For these reasons we provide custom deallocation logic via this external
24 /// model.
25 ///
26 /// Example:
27 /// ```mlir
28 /// scf.forall (%arg1) in (%arg0) {
29 /// %alloc = memref.alloc() : memref<2xf32>
30 /// ...
31 /// <implicit in_parallel terminator here>
32 /// }
33 /// ```
34 /// gets transformed to
35 /// ```mlir
36 /// scf.forall (%arg1) in (%arg0) {
37 /// %alloc = memref.alloc() : memref<2xf32>
38 /// ...
39 /// bufferization.dealloc (%alloc : memref<2xf32>) if (%true)
40 /// <implicit in_parallel terminator here>
41 /// }
42 /// ```
43 struct InParallelOpInterface
44  : public BufferDeallocationOpInterface::ExternalModel<InParallelOpInterface,
45  scf::InParallelOp> {
46  FailureOr<Operation *> process(Operation *op, DeallocationState &state,
47  const DeallocationOptions &options) const {
48  auto inParallelOp = cast<scf::InParallelOp>(op);
49  if (!inParallelOp.getBody()->empty())
50  return op->emitError("only supported when nested region is empty");
51 
52  SmallVector<Value> updatedOperandOwnership;
54  state, op, {}, updatedOperandOwnership);
55  }
56 };
57 
58 struct ReduceReturnOpInterface
59  : public BufferDeallocationOpInterface::ExternalModel<
60  ReduceReturnOpInterface, scf::ReduceReturnOp> {
61  FailureOr<Operation *> process(Operation *op, DeallocationState &state,
62  const DeallocationOptions &options) const {
63  auto reduceReturnOp = cast<scf::ReduceReturnOp>(op);
64  if (isa<BaseMemRefType>(reduceReturnOp.getOperand().getType()))
65  return op->emitError("only supported when operand is not a MemRef");
66 
67  SmallVector<Value> updatedOperandOwnership;
69  state, op, {}, updatedOperandOwnership);
70  }
71 };
72 
73 } // namespace
74 
76  DialectRegistry &registry) {
77  registry.addExtension(+[](MLIRContext *ctx, SCFDialect *dialect) {
78  InParallelOp::attachInterface<InParallelOpInterface>(*ctx);
79  ReduceReturnOp::attachInterface<ReduceReturnOpInterface>(*ctx);
80  });
81 }
static llvm::ManagedStatic< PassManagerOptions > options
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
This class collects all the state that we need to perform the buffer deallocation pass with associate...
FailureOr< Operation * > insertDeallocOpForReturnLike(DeallocationState &state, Operation *op, ValueRange operands, SmallVectorImpl< Value > &updatedOperandOwnerships)
Insert a bufferization.dealloc operation right before op which has to be a terminator without any suc...
void registerBufferDeallocationOpInterfaceExternalModels(DialectRegistry &registry)
Include the generated interface declarations.
Options for BufferDeallocationOpInterface-based buffer deallocation.