MLIR  22.0.0git
BufferizableOpInterfaceImpl.cpp
Go to the documentation of this file.
1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
14 #include "mlir/IR/Dialect.h"
15 #include "mlir/IR/Operation.h"
17 
18 using namespace mlir;
19 using namespace linalg;
20 using namespace mlir::bufferization;
21 
22 namespace {
23 
24 /// Generic conversion for any DestinationStyleOpInterface on tensors.
25 static LogicalResult bufferizeDestinationStyleOpInterface(
26  RewriterBase &rewriter, DestinationStyleOpInterface op,
27  const BufferizationOptions &options, const BufferizationState &state) {
28  // Take a guard before anything else.
29  OpBuilder::InsertionGuard g(rewriter);
30  rewriter.setInsertionPoint(op);
31 
32  // Nothing to do. This op is already bufferized.
33  if (op.hasPureBufferSemantics())
34  return success();
35 
36  // Ensure op has only tensors. Allow mixed tensor-buffer mode on a per-need
37  // basis.
38  if (!op.hasPureTensorSemantics())
39  return op->emitError() << "op does not have pure tensor semantics";
40 
41  // New input operands for the cloned op.
42  SmallVector<Value> newInputBuffers;
43  newInputBuffers.reserve(op.getNumDpsInputs());
44  for (OpOperand *opOperand : op.getDpsInputOperands()) {
45  if (op.isScalar(opOperand)) {
46  newInputBuffers.push_back(opOperand->get());
47  continue;
48  }
49  FailureOr<Value> buffer =
50  getBuffer(rewriter, opOperand->get(), options, state);
51  if (failed(buffer))
52  return failure();
53  newInputBuffers.push_back(*buffer);
54  }
55 
56  // New output operands for the cloned op.
57  SmallVector<Value> newOutputBuffers;
58  for (OpResult opResult : op->getOpResults()) {
59  OpOperand *opOperand = op.getDpsInitOperand(opResult.getResultNumber());
60  FailureOr<Value> resultBuffer =
61  getBuffer(rewriter, opOperand->get(), options, state);
62  if (failed(resultBuffer))
63  return failure();
64  newOutputBuffers.push_back(*resultBuffer);
65  }
66 
67  // Merge input/output operands.
68  SmallVector<Value> newOperands = newInputBuffers;
69  newOperands.append(newOutputBuffers.begin(), newOutputBuffers.end());
70 
71  // Set insertion point now that potential alloc/dealloc are introduced.
72  rewriter.setInsertionPoint(op);
73  // Clone the op, but use the new operands. Move the existing block into the
74  // new op. Since the new op does not have any tensor results, it does not
75  // return anything.
76  assert(op->getNumRegions() == 1 && "expected that op has 1 region");
77  OperationState opState(op->getLoc(), op->getName(), newOperands, TypeRange{},
78  op->getAttrs());
79  opState.addRegion();
80  Operation *newOp = Operation::create(opState);
81  newOp->getRegion(0).getBlocks().splice(newOp->getRegion(0).begin(),
82  op->getRegion(0).getBlocks());
83 
84  // We don't want the rewriter tracks an incomplete operation, so insert new
85  // operation after op was fully constructed.
86  rewriter.insert(newOp);
87 
88  // Replace the results of the old op with the new output buffers.
89  replaceOpWithBufferizedValues(rewriter, op, newOutputBuffers);
90 
91  return success();
92 }
93 
94 /// Bufferization of linalg.generic. Replace with a new linalg.generic that
95 /// operates entirely on memrefs.
96 template <typename OpTy>
97 struct LinalgOpInterface
98  : public DstBufferizableOpInterfaceExternalModel<LinalgOpInterface<OpTy>,
99  OpTy> {
100  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
101  const AnalysisState &state) const {
102  // Operand is read if it is used in the computation.
103  auto linalgOp = cast<linalg::LinalgOp>(op);
104  return linalgOp.payloadUsesValueFromOperand(&opOperand);
105  }
106 
107  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
108  const AnalysisState &state) const {
109  // Operand is written to if it is not an input/init.
110  auto dpsOp = cast<DestinationStyleOpInterface>(op);
111  return dpsOp.isDpsInit(&opOperand);
112  }
113 
114  bool bufferizesToElementwiseAccess(Operation *op, const AnalysisState &state,
115  ArrayRef<OpOperand *> opOperands) const {
116  auto linalgOp = cast<linalg::LinalgOp>(op);
117 
118  // Accesses into sparse data structures are not necessarily elementwise.
120  return false;
121 
122  // All loops must be parallel.
123  if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops())
124  return false;
125 
126  // All index maps of tensors must be identity maps.
127  SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
128  assert(linalgOp->getNumOperands() == indexingMaps.size() &&
129  "unexpected number of indexing maps");
130  for (auto [operand, map] :
131  llvm::zip(linalgOp->getOpOperands(), indexingMaps)) {
132  // Non-tensors do not participate in bufferization, so they can be
133  // ignored.
134  if (!isa<RankedTensorType, MemRefType>(operand.get().getType()))
135  continue;
136  // Only consider operands in `opOperands`.
137  if (!llvm::is_contained(opOperands, &operand))
138  continue;
139  // TODO: This could be generalized to other indexing maps. (All indexing
140  // must be the same.)
141  if (!map.isIdentity())
142  return false;
143  }
144 
145  return true;
146  }
147 
148  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
150  BufferizationState &state) const {
151  return bufferizeDestinationStyleOpInterface(
152  rewriter, cast<DestinationStyleOpInterface>(op), options, state);
153  }
154 };
155 
156 /// Helper structure that iterates over all LinalgOps in `OpTys` and registers
157 /// the `BufferizableOpInterface` with each of them.
158 template <typename... Ops>
159 struct LinalgOpInterfaceHelper {
160  static void registerOpInterface(MLIRContext *ctx) {
161  (Ops::template attachInterface<LinalgOpInterface<Ops>>(*ctx), ...);
162  }
163 };
164 
165 struct SoftmaxOpInterface
166  : public DstBufferizableOpInterfaceExternalModel<SoftmaxOpInterface,
167  linalg::SoftmaxOp> {
168  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
169  const AnalysisState &state) const {
170  // Output operand is not read.
171  auto softmaxOp = cast<linalg::SoftmaxOp>(op);
172  return &opOperand == &softmaxOp.getInputMutable();
173  }
174 
175  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
177  BufferizationState &state) const {
178  auto softmaxOp = cast<linalg::SoftmaxOp>(op);
179  FailureOr<Value> inputBuffer =
180  getBuffer(rewriter, softmaxOp.getInput(), options, state);
181  if (failed(inputBuffer))
182  return failure();
183  FailureOr<Value> outputBuffer =
184  getBuffer(rewriter, softmaxOp.getOutput(), options, state);
185  if (failed(outputBuffer))
186  return failure();
187  linalg::SoftmaxOp::create(rewriter, softmaxOp.getLoc(),
188  /*result=*/TypeRange(), *inputBuffer,
189  *outputBuffer, softmaxOp.getDimension());
190  replaceOpWithBufferizedValues(rewriter, op, *outputBuffer);
191  return success();
192  }
193 };
194 } // namespace
195 
197  DialectRegistry &registry) {
198  registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) {
199  // Register all Linalg structured ops. `LinalgOp` is an interface and it is
200  // not possible to attach an external interface to an existing interface.
201  // Therefore, attach the `BufferizableOpInterface` to all ops one-by-one.
202  LinalgOpInterfaceHelper<
203 #define GET_OP_LIST
204 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
205 
206  >::registerOpInterface(ctx);
207 
208  SoftmaxOp::attachInterface<SoftmaxOpInterface>(*ctx);
209  });
210 }
static llvm::ManagedStatic< PassManagerOptions > options
Base class for generic analysis states.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
Operation * insert(Operation *op)
Insert the given operation at the current insertion point and return it.
Definition: Builders.cpp:416
This class represents an operand of an operation.
Definition: Value.h:257
This is a value defined by a result of an operation.
Definition: Value.h:447
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition: Operation.cpp:66
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:686
iterator begin()
Definition: Region.h:55
BlockListType & getBlocks()
Definition: Region.h:45
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
BufferizationState provides information about the state of the IR during the bufferization process.
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values)
Replace an op with replacement values.
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options, const BufferizationState &state)
Lookup the buffer for the given value.
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry)
bool hasAnySparseOperand(Operation *op)
Returns true iff MLIR operand has any sparse operand.
Definition: SparseTensor.h:186
Include the generated interface declarations.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Options for BufferizableOpInterface-based bufferization.
Bufferizable ops that implement the DestinationStyleOpInterface can use this external model base clas...