MLIR  21.0.0git
BufferizableOpInterfaceImpl.cpp
Go to the documentation of this file.
1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
15 #include "mlir/IR/Dialect.h"
16 #include "mlir/IR/Operation.h"
17 
18 using namespace mlir;
19 using namespace mlir::bufferization;
20 using namespace mlir::vector;
21 
22 namespace mlir {
23 namespace vector {
24 namespace {
25 
26 /// Bufferization of vector.transfer_read. Replaced with a new
27 /// vector.transfer_read that operates on a memref.
28 struct TransferReadOpInterface
29  : public BufferizableOpInterface::ExternalModel<TransferReadOpInterface,
30  vector::TransferReadOp> {
31  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
32  const AnalysisState &state) const {
33  assert(isa<RankedTensorType>(opOperand.get().getType()) &&
34  "only tensor types expected");
35  return true;
36  }
37 
38  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
39  const AnalysisState &state) const {
40  assert(isa<RankedTensorType>(opOperand.get().getType()) &&
41  "only tensor types expected");
42  return false;
43  }
44 
45  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
46  const AnalysisState &state) const {
47  return {};
48  }
49 
50  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
52  BufferizationState &state) const {
53  auto readOp = cast<vector::TransferReadOp>(op);
54  assert(isa<TensorType>(readOp.getShapedType()) &&
55  "only tensor types expected");
56  FailureOr<Value> buffer =
57  getBuffer(rewriter, readOp.getBase(), options, state);
58  if (failed(buffer))
59  return failure();
60  replaceOpWithNewBufferizedOp<vector::TransferReadOp>(
61  rewriter, readOp, readOp.getVectorType(), *buffer, readOp.getIndices(),
62  readOp.getPermutationMap(), readOp.getPadding(), readOp.getMask(),
63  readOp.getInBoundsAttr());
64  return success();
65  }
66 };
67 
68 /// Bufferization of vector.transfer_write. Replace with a new
69 /// vector.transfer_write that operates on a memref.
70 ///
71 /// Note: DstBufferizableOpInterfaceExternalModel provides many default method
72 /// implementations for DestinationStyle ops.
73 struct TransferWriteOpInterface
74  : public DstBufferizableOpInterfaceExternalModel<TransferWriteOpInterface,
75  vector::TransferWriteOp> {
76  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
77  const AnalysisState &state) const {
78  auto writeOp = cast<vector::TransferWriteOp>(op);
79 
80  // Does not bufferize to a memory read if the vector completely overwrites
81  // the buffer.
82 
83  // Destination must have static shape.
84  if (!writeOp.getShapedType().hasStaticShape())
85  return true;
86 
87  // All offsets must be 0.
88  for (Value offset : writeOp.getIndices()) {
89  if (getConstantIntValue(offset) != 0)
90  return true;
91  }
92 
93  // There is no mask.
94  if (writeOp.isMasked())
95  return true;
96 
97  // Must write at least the full dimension size.
98  for (auto [d0, d1] : llvm::zip(writeOp.getShapedType().getShape(),
99  writeOp.getVectorType().getShape())) {
100  if (d0 > d1)
101  return true;
102  }
103 
104  return false;
105  }
106 
107  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
109  BufferizationState &state) const {
110  auto writeOp = cast<vector::TransferWriteOp>(op);
111  assert(isa<TensorType>(writeOp.getShapedType()) &&
112  "only tensor types expected");
113 
114  // Create a new transfer_write on buffer that doesn't have a return value.
115  FailureOr<Value> resultBuffer =
116  getBuffer(rewriter, writeOp.getBase(), options, state);
117  if (failed(resultBuffer))
118  return failure();
119  rewriter.create<vector::TransferWriteOp>(
120  writeOp.getLoc(), writeOp.getVector(), *resultBuffer,
121  writeOp.getIndices(), writeOp.getPermutationMapAttr(),
122  writeOp.getMask(), writeOp.getInBoundsAttr());
123  replaceOpWithBufferizedValues(rewriter, op, *resultBuffer);
124 
125  return success();
126  }
127 };
128 
129 /// Bufferization of vector.gather. Replaced with a new vector.gather that
130 /// operates on a memref.
131 struct GatherOpInterface
132  : public BufferizableOpInterface::ExternalModel<GatherOpInterface,
133  vector::GatherOp> {
134  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
135  const AnalysisState &state) const {
136  assert(isa<RankedTensorType>(opOperand.get().getType()) &&
137  "only tensor types expected");
138  return true;
139  }
140 
141  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
142  const AnalysisState &state) const {
143  assert(isa<RankedTensorType>(opOperand.get().getType()) &&
144  "only tensor types expected");
145  return false;
146  }
147 
148  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
149  const AnalysisState &state) const {
150  return {};
151  }
152 
153  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
155  BufferizationState &state) const {
156  auto gatherOp = cast<vector::GatherOp>(op);
157  assert(isa<TensorType>(gatherOp.getBaseType()) &&
158  "only tensor types expected");
159  FailureOr<Value> buffer =
160  getBuffer(rewriter, gatherOp.getBase(), options, state);
161  if (failed(buffer))
162  return failure();
163  replaceOpWithNewBufferizedOp<vector::GatherOp>(
164  rewriter, gatherOp, gatherOp.getVectorType(), *buffer,
165  gatherOp.getIndices(), gatherOp.getIndexVec(), gatherOp.getMask(),
166  gatherOp.getPassThru());
167  return success();
168  }
169 };
170 
171 /// Bufferization of vector.mask. Replaced with a new vector.mask that
172 /// operates on a memref.
173 struct MaskOpInterface
174  : public BufferizableOpInterface::ExternalModel<MaskOpInterface,
175  vector::MaskOp> {
177  getAliasingOpOperands(Operation *op, Value value,
178  const AnalysisState &state) const {
179  // MaskOps do not have tensor OpOperands. The yielded values are the result
180  // of the wrapped op.
181  auto maskOp = cast<vector::MaskOp>(op);
182  size_t resultNum = std::distance(op->getOpResults().begin(),
183  llvm::find(op->getOpResults(), value));
184  auto yieldOp =
185  cast<vector::YieldOp>(maskOp.getMaskRegion().front().getTerminator());
186  return {{&yieldOp->getOpOperand(resultNum), BufferRelation::Equivalent}};
187  }
188 
189  LogicalResult
190  resolveConflicts(Operation *op, RewriterBase &rewriter,
191  const AnalysisState &analysisState,
192  const BufferizationState &bufferizationState) const {
193  auto bufferizableOp = cast<BufferizableOpInterface>(op);
194  if (failed(bufferizableOp.resolveTensorOpOperandConflicts(
195  rewriter, analysisState, bufferizationState)))
196  return failure();
197 
198  // TODO: Remove this function when vector.mask bodies can bufferize
199  // out-of-place. This is currently not supported because yielding allocs
200  // from a block leads to a memory leak and because vector.mask supports only
201  // a single op in its body.
202  auto maskOp = cast<vector::MaskOp>(op);
203  if (!maskOp.getMaskRegion()
204  .front()
205  .getOps<bufferization::AllocTensorOp>()
206  .empty())
207  return op->emitOpError("body must bufferize in-place");
208 
209  return success();
210  }
211 
212  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
214  BufferizationState &state) const {
215  auto maskOp = cast<vector::MaskOp>(op);
216 
217  // Do not bufferize if the masked op is not bufferizable.
218  Operation *maskedOp = maskOp.getMaskableOp();
219  if (!options.dynCastBufferizableOp(maskedOp))
220  return success();
221 
222  // Update the terminator: Drop all operands that are not results of the
223  // masked op.
224  auto yieldOp =
225  cast<vector::YieldOp>(maskOp.getMaskRegion().front().getTerminator());
226  SmallVector<Value> newReturnValues(maskOp->getNumResults(), Value());
227  SmallVector<Value> newYieldedValues;
228  for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
229  if (llvm::is_contained(maskedOp->getOpResults(), it.value())) {
230  newYieldedValues.push_back(it.value());
231  } else {
232  // This used to be a tensor result of the masked op, but is now a memref
233  // that is defined outside of the vector.mask op.
234  newReturnValues[it.index()] = it.value();
235  }
236  }
237  rewriter.modifyOpInPlace(yieldOp, [&]() {
238  yieldOp.getOperandsMutable().assign(newYieldedValues);
239  });
240 
241  // Create a new vector.mask op.
242  ValueRange newYieldedValuesRange(newYieldedValues);
243  TypeRange newResultTypes(newYieldedValuesRange);
244  auto newOp = rewriter.create<vector::MaskOp>(
245  op->getLoc(), newResultTypes, maskOp.getMask(), maskOp.getPassthru(),
246  /*maskableOp=*/nullptr,
247  /*maskRegionBuilder=*/[](OpBuilder &b, Operation *) {});
248  newOp.getRegion().takeBody(maskOp.getMaskRegion());
249 
250  // Replace all uses of the old vector.mask op.
251  int idx = 0;
252  for (int i = 0; i < static_cast<int>(maskOp->getNumResults()); ++i) {
253  if (!newReturnValues[i])
254  newReturnValues[i] = newOp->getResult(idx++);
255  }
256  replaceOpWithBufferizedValues(rewriter, maskOp, newReturnValues);
257  return success();
258  }
259 };
260 
261 /// Bufferization of vector.yield. Replaced with a new vector.yield that
262 /// operates on a memref.
263 struct YieldOpInterface
264  : public BufferizableOpInterface::ExternalModel<YieldOpInterface,
265  vector::YieldOp> {
266  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
267  const AnalysisState &state) const {
268  return true;
269  }
270 
271  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
272  const AnalysisState &state) const {
273  return false;
274  }
275 
276  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
277  const AnalysisState &state) const {
278  return {{op->getParentOp()->getResult(opOperand.getOperandNumber()),
280  }
281 
282  bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
283  const AnalysisState &state) const {
284  // Yield operands always bufferize inplace. Otherwise, an alloc + copy
285  // may be generated inside the block. We should not return/yield allocations
286  // when possible.
287  return true;
288  }
289 
290  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
292  BufferizationState &state) const {
293  auto yieldOp = cast<vector::YieldOp>(op);
294 
295  // Only supported as a vector.mask terminator.
296  auto maskOp = dyn_cast<vector::MaskOp>(yieldOp->getParentOp());
297  if (!maskOp)
298  return yieldOp->emitError("unsupported vector::YieldOp parent");
299 
300  // Do not bufferize if the masked op is not bufferizable.
301  Operation *maskedOp = &maskOp.getMaskRegion().front().front();
302  if (!options.dynCastBufferizableOp(maskedOp))
303  return success();
304 
305  // Create a new terminator with the same number of operands. Some of these
306  // may get dropped during the bufferization of vector.mask.
307  SmallVector<Value> newResults;
308  for (Value value : yieldOp.getOperands()) {
309  if (isa<TensorType>(value.getType())) {
310  FailureOr<Value> maybeBuffer =
311  getBuffer(rewriter, value, options, state);
312  if (failed(maybeBuffer))
313  return failure();
314  newResults.push_back(*maybeBuffer);
315  } else {
316  newResults.push_back(value);
317  }
318  }
319 
320  replaceOpWithNewBufferizedOp<vector::YieldOp>(rewriter, op, newResults);
321  return success();
322  }
323 };
324 
325 } // namespace
326 } // namespace vector
327 } // namespace mlir
328 
330  DialectRegistry &registry) {
331  registry.addExtension(+[](MLIRContext *ctx, vector::VectorDialect *dialect) {
332  TransferReadOp::attachInterface<TransferReadOpInterface>(*ctx);
333  TransferWriteOp::attachInterface<TransferWriteOpInterface>(*ctx);
334  GatherOp::attachInterface<GatherOpInterface>(*ctx);
335  MaskOp::attachInterface<MaskOpInterface>(*ctx);
336  YieldOp::attachInterface<YieldOpInterface>(*ctx);
337  });
338 }
static llvm::ManagedStatic< PassManagerOptions > options
Base class for generic analysis states.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:204
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
This class represents an operand of an operation.
Definition: Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:228
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:686
result_range getOpResults()
Definition: Operation.h:420
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:673
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
Definition: Region.h:241
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:594
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
BufferizationState provides information about the state of the IR during the bufferization process.
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values)
Replace an op with replacement values.
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options, const BufferizationState &state)
Lookup the buffer for the given value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry)
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Options for BufferizableOpInterface-based bufferization.
Bufferizable ops that implement the DestinationStyleOpInterface can use this external model base clas...