MLIR  16.0.0git
BufferizableOpInterfaceImpl.cpp
Go to the documentation of this file.
1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
14 #include "mlir/IR/Dialect.h"
15 #include "mlir/IR/Operation.h"
16 
17 using namespace mlir;
18 using namespace linalg;
19 using namespace mlir::bufferization;
20 
21 namespace {
22 
23 /// Generic conversion for any LinalgOp on tensors.
24 static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, LinalgOp op,
26  // Take a guard before anything else.
27  OpBuilder::InsertionGuard g(rewriter);
28  rewriter.setInsertionPoint(op);
29 
30  // Nothing to do. This op is already bufferized.
31  if (op.hasBufferSemantics())
32  return success();
33 
34  // Ensure op has only tensors. Allow mixed tensor-buffer mode on a per-need
35  // basis.
36  if (!op.hasTensorSemantics())
37  return op->emitError() << "op does not have tensor semantics";
38 
39  // New input operands for the cloned op.
40  SmallVector<Value> newInputBuffers;
41  newInputBuffers.reserve(op.getNumInputs());
42  for (OpOperand *opOperand : op.getInputOperands()) {
43  if (op.isScalar(opOperand)) {
44  newInputBuffers.push_back(opOperand->get());
45  continue;
46  }
47  FailureOr<Value> buffer = getBuffer(rewriter, opOperand->get(), options);
48  if (failed(buffer))
49  return failure();
50  newInputBuffers.push_back(*buffer);
51  }
52 
53  // New output operands for the cloned op.
54  SmallVector<Value> newOutputBuffers;
55  for (OpResult opResult : op->getOpResults()) {
56  OpOperand *opOperand = op.getOutputOperand(opResult.getResultNumber());
57  FailureOr<Value> resultBuffer =
58  getBuffer(rewriter, opOperand->get(), options);
59  if (failed(resultBuffer))
60  return failure();
61  newOutputBuffers.push_back(*resultBuffer);
62  }
63 
64  // Merge input/output operands.
65  SmallVector<Value> newOperands = newInputBuffers;
66  newOperands.append(newOutputBuffers.begin(), newOutputBuffers.end());
67 
68  // Set insertion point now that potential alloc/dealloc are introduced.
69  rewriter.setInsertionPoint(op);
70  // Clone the op, but use the new operands. Move the existing block into the
71  // new op. Since the new op does not have any tensor results, it does not
72  // return anything.
73  assert(op->getNumRegions() == 1 && "expected that op has 1 region");
74  auto newOp = cast<LinalgOp>(op.cloneWithoutRegions(
75  rewriter, op.getLoc(), /*resultTypes=*/TypeRange{}, newOperands));
76  rewriter.inlineRegionBefore(op->getRegion(0), newOp->getRegion(0),
77  newOp->getRegion(0).begin());
78 
79  // Replace the results of the old op with the new output buffers.
80  replaceOpWithBufferizedValues(rewriter, op, newOutputBuffers);
81 
82  return success();
83 }
84 
85 /// Bufferization of linalg.generic. Replace with a new linalg.generic that
86 /// operates entirely on memrefs.
87 template <typename OpTy>
88 struct LinalgOpInterface
89  : public BufferizableOpInterface::ExternalModel<LinalgOpInterface<OpTy>,
90  OpTy> {
91  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
92  const AnalysisState &state) const {
93  // Operand is read if it is used in the computation.
94  auto genericOp = cast<linalg::LinalgOp>(op);
95  return genericOp.payloadUsesValueFromOperand(&opOperand);
96  }
97 
98  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
99  const AnalysisState &state) const {
100  // Operand is written to if it has an aliasing OpResult.
101  auto bufferizableOp = cast<BufferizableOpInterface>(op);
102  return !bufferizableOp.getAliasingOpResult(opOperand, state).empty();
103  }
104 
106  getAliasingOpOperand(Operation *op, OpResult opResult,
107  const AnalysisState &state) const {
108  auto genericOp = cast<linalg::LinalgOp>(op);
109 
110  // The i-th OpResult may alias with the i-th "out" tensor.
111  return {genericOp.getOutputOperand(opResult.getResultNumber())};
112  }
113 
114  SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
115  const AnalysisState &state) const {
116  auto genericOp = cast<linalg::LinalgOp>(op);
117 
118  // The i-th "out" tensor may alias with the i-th OpResult.
119  if (genericOp.isOutputTensor(&opOperand))
120  return {genericOp.getTiedOpResult(&opOperand)};
121  return {};
122  }
123 
124  BufferRelation bufferRelation(Operation *op, OpResult opResult,
125  const AnalysisState &state) const {
126  return BufferRelation::Equivalent;
127  }
128 
129  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
130  const BufferizationOptions &options) const {
131  return bufferizeLinalgOp(rewriter, cast<LinalgOp>(op), options);
132  }
133 };
134 
135 /// Helper structure that iterates over all LinalgOps in `OpTys` and registers
136 /// the `BufferizableOpInterface` with each of them.
137 template <typename... Ops>
138 struct LinalgOpInterfaceHelper {
139  static void registerOpInterface(MLIRContext *ctx) {
140  (Ops::template attachInterface<LinalgOpInterface<Ops>>(*ctx), ...);
141  }
142 };
143 } // namespace
144 
146  DialectRegistry &registry) {
147  registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) {
148  // Register all Linalg structured ops. `LinalgOp` is an interface and it is
149  // not possible to attach an external interface to an existing interface.
150  // Therefore, attach the `BufferizableOpInterface` to all ops one-by-one.
151  LinalgOpInterfaceHelper<
152 #define GET_OP_LIST
153 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
154  >::registerOpInterface(ctx);
155  });
156 }
Include the generated interface declarations.
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options)
Lookup the buffer for the given value.
This is a value defined by a result of an operation.
Definition: Value.h:425
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:345
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values)
Replace an op with replacement values.
virtual void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent"...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
This class provides support for representing a failure result, or a valid value of type T...
Definition: LogicalResult.h:78
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:437
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:32
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry)
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:137
Options for BufferizableOpInterface-based bufferization.
static llvm::ManagedStatic< PassManagerOptions > options
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:295
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
This class represents an operand of an operation.
Definition: Value.h:251
BufferRelation
Specify fine-grain relationship between buffers to enable more analysis.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:398
Base class for generic analysis states.