MLIR  19.0.0git
BufferizableOpInterfaceImpl.cpp
Go to the documentation of this file.
1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/Operation.h"
19 
20 using namespace mlir;
21 using namespace linalg;
22 using namespace mlir::bufferization;
23 
24 namespace {
25 
26 /// Generic conversion for any DestinationStyleOpInterface on tensors.
27 static LogicalResult
28 bufferizeDestinationStyleOpInterface(RewriterBase &rewriter,
29  DestinationStyleOpInterface op,
31  // Take a guard before anything else.
32  OpBuilder::InsertionGuard g(rewriter);
33  rewriter.setInsertionPoint(op);
34 
35  // Nothing to do. This op is already bufferized.
36  if (op.hasPureBufferSemantics())
37  return success();
38 
39  // Ensure op has only tensors. Allow mixed tensor-buffer mode on a per-need
40  // basis.
41  if (!op.hasPureTensorSemantics())
42  return op->emitError() << "op does not have pure tensor semantics";
43 
44  // New input operands for the cloned op.
45  SmallVector<Value> newInputBuffers;
46  newInputBuffers.reserve(op.getNumDpsInputs());
47  for (OpOperand *opOperand : op.getDpsInputOperands()) {
48  if (op.isScalar(opOperand)) {
49  newInputBuffers.push_back(opOperand->get());
50  continue;
51  }
52  FailureOr<Value> buffer = getBuffer(rewriter, opOperand->get(), options);
53  if (failed(buffer))
54  return failure();
55  newInputBuffers.push_back(*buffer);
56  }
57 
58  // New output operands for the cloned op.
59  SmallVector<Value> newOutputBuffers;
60  for (OpResult opResult : op->getOpResults()) {
61  OpOperand *opOperand = op.getDpsInitOperand(opResult.getResultNumber());
62  FailureOr<Value> resultBuffer =
63  getBuffer(rewriter, opOperand->get(), options);
64  if (failed(resultBuffer))
65  return failure();
66  newOutputBuffers.push_back(*resultBuffer);
67  }
68 
69  // Merge input/output operands.
70  SmallVector<Value> newOperands = newInputBuffers;
71  newOperands.append(newOutputBuffers.begin(), newOutputBuffers.end());
72 
73  // Set insertion point now that potential alloc/dealloc are introduced.
74  rewriter.setInsertionPoint(op);
75  // Clone the op, but use the new operands. Move the existing block into the
76  // new op. Since the new op does not have any tensor results, it does not
77  // return anything.
78  assert(op->getNumRegions() == 1 && "expected that op has 1 region");
79  auto newOp = cast<DestinationStyleOpInterface>(cloneWithoutRegions(
80  rewriter, op, /*newResultTypes=*/TypeRange{}, newOperands));
81  rewriter.inlineRegionBefore(op->getRegion(0), newOp->getRegion(0),
82  newOp->getRegion(0).begin());
83 
84  // Replace the results of the old op with the new output buffers.
85  replaceOpWithBufferizedValues(rewriter, op, newOutputBuffers);
86 
87  return success();
88 }
89 
90 /// Bufferization of linalg.generic. Replace with a new linalg.generic that
91 /// operates entirely on memrefs.
92 template <typename OpTy>
93 struct LinalgOpInterface
94  : public DstBufferizableOpInterfaceExternalModel<LinalgOpInterface<OpTy>,
95  OpTy> {
96  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
97  const AnalysisState &state) const {
98  // Operand is read if it is used in the computation.
99  auto linalgOp = cast<linalg::LinalgOp>(op);
100  return linalgOp.payloadUsesValueFromOperand(&opOperand);
101  }
102 
103  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
104  const AnalysisState &state) const {
105  // Operand is written to if it is not an input/init.
106  auto dpsOp = cast<DestinationStyleOpInterface>(op);
107  return dpsOp.isDpsInit(&opOperand);
108  }
109 
110  bool bufferizesToElementwiseAccess(Operation *op, const AnalysisState &state,
111  ArrayRef<OpOperand *> opOperands) const {
112  auto linalgOp = cast<linalg::LinalgOp>(op);
113 
114  // Accesses into sparse data structures are not necessarily elementwise.
116  return false;
117 
118  // All loops must be parallel.
119  if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops())
120  return false;
121 
122  // All index maps of tensors must be identity maps.
123  SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
124  assert(linalgOp->getNumOperands() == indexingMaps.size() &&
125  "unexpected number of indexing maps");
126  for (auto [operand, map] :
127  llvm::zip(linalgOp->getOpOperands(), indexingMaps)) {
128  // Non-tensors do not participate in bufferization, so they can be
129  // ignored.
130  if (!isa<RankedTensorType, MemRefType>(operand.get().getType()))
131  continue;
132  // Only consider operands in `opOperands`.
133  if (!llvm::is_contained(opOperands, &operand))
134  continue;
135  // TODO: This could be generalized to other indexing maps. (All indexing
136  // must be the same.)
137  if (!map.isIdentity())
138  return false;
139  }
140 
141  return true;
142  }
143 
144  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
145  const BufferizationOptions &options) const {
146  return bufferizeDestinationStyleOpInterface(
147  rewriter, cast<DestinationStyleOpInterface>(op), options);
148  }
149 };
150 
151 /// Helper structure that iterates over all LinalgOps in `OpTys` and registers
152 /// the `BufferizableOpInterface` with each of them.
153 template <typename... Ops>
154 struct LinalgOpInterfaceHelper {
155  static void registerOpInterface(MLIRContext *ctx) {
156  (Ops::template attachInterface<LinalgOpInterface<Ops>>(*ctx), ...);
157  }
158 };
159 } // namespace
160 
162  DialectRegistry &registry) {
163  registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) {
164  // Register all Linalg structured ops. `LinalgOp` is an interface and it is
165  // not possible to attach an external interface to an existing interface.
166  // Therefore, attach the `BufferizableOpInterface` to all ops one-by-one.
167  LinalgOpInterfaceHelper<
168 #define GET_OP_LIST
169 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
170  >::registerOpInterface(ctx);
171  });
172 }
static llvm::ManagedStatic< PassManagerOptions > options
Base class for generic analysis states.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
This class represents an operand of an operation.
Definition: Value.h:263
This is a value defined by a result of an operation.
Definition: Value.h:453
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:669
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
result_range getOpResults()
Definition: Operation.h:415
iterator begin()
Definition: Region.h:55
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values)
Replace an op with replacement values.
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options)
Lookup the buffer for the given value.
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry)
bool hasAnySparseOperand(Operation *op)
Returns true iff MLIR operand has any sparse operand.
Definition: SparseTensor.h:93
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
Operation * cloneWithoutRegions(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Options for BufferizableOpInterface-based bufferization.
Bufferizable ops that implement the DestinationStyleOpInterface can use this external model base clas...