MLIR  18.0.0git
BufferizableOpInterfaceImpl.cpp
Go to the documentation of this file.
1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
15 #include "mlir/IR/Dialect.h"
16 #include "mlir/IR/Operation.h"
17 
18 using namespace mlir;
19 using namespace mlir::bufferization;
20 using namespace mlir::vector;
21 
22 namespace mlir {
23 namespace vector {
24 namespace {
25 
26 /// Bufferization of vector.transfer_read. Replaced with a new
27 /// vector.transfer_read that operates on a memref.
28 struct TransferReadOpInterface
29  : public BufferizableOpInterface::ExternalModel<TransferReadOpInterface,
30  vector::TransferReadOp> {
31  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
32  const AnalysisState &state) const {
33  assert(isa<RankedTensorType>(opOperand.get().getType()) &&
34  "only tensor types expected");
35  return true;
36  }
37 
38  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
39  const AnalysisState &state) const {
40  assert(isa<RankedTensorType>(opOperand.get().getType()) &&
41  "only tensor types expected");
42  return false;
43  }
44 
45  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
46  const AnalysisState &state) const {
47  return {};
48  }
49 
50  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
51  const BufferizationOptions &options) const {
52  auto readOp = cast<vector::TransferReadOp>(op);
53  assert(isa<TensorType>(readOp.getShapedType()) &&
54  "only tensor types expected");
55  FailureOr<Value> buffer = getBuffer(rewriter, readOp.getSource(), options);
56  if (failed(buffer))
57  return failure();
58  replaceOpWithNewBufferizedOp<vector::TransferReadOp>(
59  rewriter, readOp, readOp.getVectorType(), *buffer, readOp.getIndices(),
60  readOp.getPermutationMap(), readOp.getPadding(), readOp.getMask(),
61  readOp.getInBoundsAttr());
62  return success();
63  }
64 };
65 
66 /// Bufferization of vector.transfer_write. Replace with a new
67 /// vector.transfer_write that operates on a memref.
68 ///
69 /// Note: DstBufferizableOpInterfaceExternalModel provides many default method
70 /// implementations for DestinationStyle ops.
71 struct TransferWriteOpInterface
72  : public DstBufferizableOpInterfaceExternalModel<TransferWriteOpInterface,
73  vector::TransferWriteOp> {
74  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
75  const AnalysisState &state) const {
76  auto writeOp = cast<vector::TransferWriteOp>(op);
77 
78  // Does not bufferize to a memory read if the vector completely overwrites
79  // the buffer.
80 
81  // Destination must have static shape.
82  if (!writeOp.getShapedType().hasStaticShape())
83  return true;
84 
85  // All offsets must be 0.
86  for (Value offset : writeOp.getIndices()) {
87  if (getConstantIntValue(offset) != 0)
88  return true;
89  }
90 
91  // There is no mask.
92  if (writeOp.isMasked())
93  return true;
94 
95  // Must write at least the full dimension size.
96  for (auto [d0, d1] : llvm::zip(writeOp.getShapedType().getShape(),
97  writeOp.getVectorType().getShape())) {
98  if (d0 > d1)
99  return true;
100  }
101 
102  return false;
103  }
104 
105  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
106  const BufferizationOptions &options) const {
107  auto writeOp = cast<vector::TransferWriteOp>(op);
108  assert(isa<TensorType>(writeOp.getShapedType()) &&
109  "only tensor types expected");
110 
111  // Create a new transfer_write on buffer that doesn't have a return value.
112  FailureOr<Value> resultBuffer =
113  getBuffer(rewriter, writeOp.getSource(), options);
114  if (failed(resultBuffer))
115  return failure();
116  rewriter.create<vector::TransferWriteOp>(
117  writeOp.getLoc(), writeOp.getVector(), *resultBuffer,
118  writeOp.getIndices(), writeOp.getPermutationMapAttr(),
119  writeOp.getMask(), writeOp.getInBoundsAttr());
120  replaceOpWithBufferizedValues(rewriter, op, *resultBuffer);
121 
122  return success();
123  }
124 };
125 
126 /// Bufferization of vector.gather. Replaced with a new vector.gather that
127 /// operates on a memref.
128 struct GatherOpInterface
129  : public BufferizableOpInterface::ExternalModel<GatherOpInterface,
130  vector::GatherOp> {
131  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
132  const AnalysisState &state) const {
133  assert(isa<RankedTensorType>(opOperand.get().getType()) &&
134  "only tensor types expected");
135  return true;
136  }
137 
138  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
139  const AnalysisState &state) const {
140  assert(isa<RankedTensorType>(opOperand.get().getType()) &&
141  "only tensor types expected");
142  return false;
143  }
144 
145  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
146  const AnalysisState &state) const {
147  return {};
148  }
149 
150  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
151  const BufferizationOptions &options) const {
152  auto gatherOp = cast<vector::GatherOp>(op);
153  assert(isa<TensorType>(gatherOp.getBaseType()) &&
154  "only tensor types expected");
155  FailureOr<Value> buffer = getBuffer(rewriter, gatherOp.getBase(), options);
156  if (failed(buffer))
157  return failure();
158  replaceOpWithNewBufferizedOp<vector::GatherOp>(
159  rewriter, gatherOp, gatherOp.getVectorType(), *buffer,
160  gatherOp.getIndices(), gatherOp.getIndexVec(), gatherOp.getMask(),
161  gatherOp.getPassThru());
162  return success();
163  }
164 };
165 
166 /// Bufferization of vector.mask. Replaced with a new vector.mask that
167 /// operates on a memref.
168 struct MaskOpInterface
169  : public BufferizableOpInterface::ExternalModel<MaskOpInterface,
170  vector::MaskOp> {
172  getAliasingOpOperands(Operation *op, Value value,
173  const AnalysisState &state) const {
174  // MaskOps do not have tensor OpOperands. The yielded values are the result
175  // of the wrapped op.
176  auto maskOp = cast<vector::MaskOp>(op);
177  size_t resultNum = std::distance(op->getOpResults().begin(),
178  llvm::find(op->getOpResults(), value));
179  auto yieldOp =
180  cast<vector::YieldOp>(maskOp.getMaskRegion().front().getTerminator());
181  return {{&yieldOp->getOpOperand(resultNum), BufferRelation::Equivalent}};
182  }
183 
184  LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
185  const AnalysisState &state) const {
186  auto bufferizableOp = cast<BufferizableOpInterface>(op);
187  if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
188  return failure();
189 
190  // TODO: Remove this function when vector.mask bodies can bufferize
191  // out-of-place. This is currently not supported because yielding allocs
192  // from a block leads to a memory leak and because vector.mask supports only
193  // a single op in its body.
194  auto maskOp = cast<vector::MaskOp>(op);
195  if (!maskOp.getMaskRegion()
196  .front()
197  .getOps<bufferization::AllocTensorOp>()
198  .empty())
199  return op->emitOpError("body must bufferize in-place");
200 
201  return success();
202  }
203 
204  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
205  const BufferizationOptions &options) const {
206  auto maskOp = cast<vector::MaskOp>(op);
207 
208  // Do not bufferize if the masked op is not bufferizable.
209  Operation *maskedOp = maskOp.getMaskableOp();
210  if (!options.dynCastBufferizableOp(maskedOp))
211  return success();
212 
213  // Update the terminator: Drop all operands that are not results of the
214  // masked op.
215  auto yieldOp =
216  cast<vector::YieldOp>(maskOp.getMaskRegion().front().getTerminator());
217  SmallVector<Value> newReturnValues(maskOp->getNumResults(), Value());
218  SmallVector<Value> newYieldedValues;
219  for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
220  if (llvm::find(maskedOp->getOpResults(), it.value()) !=
221  maskedOp->getOpResults().end()) {
222  newYieldedValues.push_back(it.value());
223  } else {
224  // This used to be a tensor result of the masked op, but is now a memref
225  // that is defined outside of the vector.mask op.
226  newReturnValues[it.index()] = it.value();
227  }
228  }
229  rewriter.updateRootInPlace(yieldOp, [&]() {
230  yieldOp.getOperandsMutable().assign(newYieldedValues);
231  });
232 
233  // Create a new vector.mask op.
234  ValueRange newYieldedValuesRange(newYieldedValues);
235  TypeRange newResultTypes(newYieldedValuesRange);
236  auto newOp = rewriter.create<vector::MaskOp>(
237  op->getLoc(), newResultTypes, maskOp.getMask(), maskOp.getPassthru(),
238  /*maskableOp=*/nullptr,
239  /*maskRegionBuilder=*/[](OpBuilder &b, Operation *) {});
240  newOp.getRegion().takeBody(maskOp.getMaskRegion());
241 
242  // Replace all uses of the old vector.mask op.
243  int idx = 0;
244  for (int i = 0; i < static_cast<int>(maskOp->getNumResults()); ++i) {
245  if (!newReturnValues[i])
246  newReturnValues[i] = newOp->getResult(idx++);
247  }
248  replaceOpWithBufferizedValues(rewriter, maskOp, newReturnValues);
249  return success();
250  }
251 };
252 
253 /// Bufferization of vector.yield. Replaced with a new vector.yield that
254 /// operates on a memref.
255 struct YieldOpInterface
256  : public BufferizableOpInterface::ExternalModel<YieldOpInterface,
257  vector::YieldOp> {
258  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
259  const AnalysisState &state) const {
260  return true;
261  }
262 
263  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
264  const AnalysisState &state) const {
265  return false;
266  }
267 
268  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
269  const AnalysisState &state) const {
270  return {{op->getParentOp()->getResult(opOperand.getOperandNumber()),
272  }
273 
274  bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
275  const AnalysisState &state) const {
276  // Yield operands always bufferize inplace. Otherwise, an alloc + copy
277  // may be generated inside the block. We should not return/yield allocations
278  // when possible.
279  return true;
280  }
281 
282  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
283  const BufferizationOptions &options) const {
284  auto yieldOp = cast<vector::YieldOp>(op);
285 
286  // Only supported as a vector.mask terminator.
287  auto maskOp = dyn_cast<vector::MaskOp>(yieldOp->getParentOp());
288  if (!maskOp)
289  return yieldOp->emitError("unsupported vector::YieldOp parent");
290 
291  // Do not bufferize if the masked op is not bufferizable.
292  Operation *maskedOp = &maskOp.getMaskRegion().front().front();
293  if (!options.dynCastBufferizableOp(maskedOp))
294  return success();
295 
296  // Create a new terminator with the same number of operands. Some of these
297  // may get dropped during the bufferization of vector.mask.
298  SmallVector<Value> newResults;
299  for (Value value : yieldOp.getOperands()) {
300  if (isa<TensorType>(value.getType())) {
301  FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options);
302  if (failed(maybeBuffer))
303  return failure();
304  newResults.push_back(*maybeBuffer);
305  } else {
306  newResults.push_back(value);
307  }
308  }
309 
310  replaceOpWithNewBufferizedOp<vector::YieldOp>(rewriter, op, newResults);
311  return success();
312  }
313 };
314 
315 } // namespace
316 } // namespace vector
317 } // namespace mlir
318 
320  DialectRegistry &registry) {
321  registry.addExtension(+[](MLIRContext *ctx, vector::VectorDialect *dialect) {
322  TransferReadOp::attachInterface<TransferReadOpInterface>(*ctx);
323  TransferWriteOp::attachInterface<TransferWriteOpInterface>(*ctx);
324  GatherOp::attachInterface<GatherOpInterface>(*ctx);
325  MaskOp::attachInterface<MaskOpInterface>(*ctx);
326  YieldOp::attachInterface<YieldOpInterface>(*ctx);
327  });
328 }
static llvm::ManagedStatic< PassManagerOptions > options
Base class for generic analysis states.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:206
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
This class represents an operand of an operation.
Definition: Value.h:263
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:665
result_range getOpResults()
Definition: Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:640
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
Definition: Region.h:241
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:606
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values)
Replace an op with replacement values.
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options)
Lookup the buffer for the given value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Options for BufferizableOpInterface-based bufferization.
Bufferizable ops that implement the DestinationStyleOpInterface can use this external model base clas...