MLIR  17.0.0git
BufferizableOpInterfaceImpl.cpp
Go to the documentation of this file.
1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
15 #include "mlir/IR/Dialect.h"
16 #include "mlir/IR/Operation.h"
17 
18 using namespace mlir;
19 using namespace mlir::bufferization;
20 using namespace mlir::vector;
21 
22 namespace mlir {
23 namespace vector {
24 namespace {
25 
26 /// Bufferization of vector.transfer_read. Replaced with a new
27 /// vector.transfer_read that operates on a memref.
28 struct TransferReadOpInterface
29  : public BufferizableOpInterface::ExternalModel<TransferReadOpInterface,
30  vector::TransferReadOp> {
31  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
32  const AnalysisState &state) const {
33  assert(opOperand.get().getType().isa<RankedTensorType>() &&
34  "only tensor types expected");
35  return true;
36  }
37 
38  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
39  const AnalysisState &state) const {
40  assert(opOperand.get().getType().isa<RankedTensorType>() &&
41  "only tensor types expected");
42  return false;
43  }
44 
45  AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand,
46  const AnalysisState &state) const {
47  return {};
48  }
49 
50  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
51  const BufferizationOptions &options) const {
52  auto readOp = cast<vector::TransferReadOp>(op);
53  assert(readOp.getShapedType().isa<TensorType>() &&
54  "only tensor types expected");
55  FailureOr<Value> buffer = getBuffer(rewriter, readOp.getSource(), options);
56  if (failed(buffer))
57  return failure();
58  replaceOpWithNewBufferizedOp<vector::TransferReadOp>(
59  rewriter, readOp, readOp.getVectorType(), *buffer, readOp.getIndices(),
60  readOp.getPermutationMap(), readOp.getPadding(), readOp.getMask(),
61  readOp.getInBoundsAttr());
62  return success();
63  }
64 };
65 
66 /// Bufferization of vector.transfer_write. Replace with a new
67 /// vector.transfer_write that operates on a memref.
68 ///
69 /// Note: DstBufferizableOpInterfaceExternalModel provides many default method
70 /// implementations for DestinationStyle ops.
71 struct TransferWriteOpInterface
72  : public DstBufferizableOpInterfaceExternalModel<TransferWriteOpInterface,
73  vector::TransferWriteOp> {
74  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
75  const BufferizationOptions &options) const {
76  auto writeOp = cast<vector::TransferWriteOp>(op);
77  assert(writeOp.getShapedType().isa<TensorType>() &&
78  "only tensor types expected");
79 
80  // Create a new transfer_write on buffer that doesn't have a return value.
81  FailureOr<Value> resultBuffer =
82  getBuffer(rewriter, writeOp.getSource(), options);
83  if (failed(resultBuffer))
84  return failure();
85  rewriter.create<vector::TransferWriteOp>(
86  writeOp.getLoc(), writeOp.getVector(), *resultBuffer,
87  writeOp.getIndices(), writeOp.getPermutationMapAttr(),
88  writeOp.getMask(), writeOp.getInBoundsAttr());
89  replaceOpWithBufferizedValues(rewriter, op, *resultBuffer);
90 
91  return success();
92  }
93 };
94 
95 /// Bufferization of vector.gather. Replaced with a new vector.gather that
96 /// operates on a memref.
97 struct GatherOpInterface
98  : public BufferizableOpInterface::ExternalModel<GatherOpInterface,
99  vector::GatherOp> {
100  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
101  const AnalysisState &state) const {
102  assert(opOperand.get().getType().isa<RankedTensorType>() &&
103  "only tensor types expected");
104  return true;
105  }
106 
107  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
108  const AnalysisState &state) const {
109  assert(opOperand.get().getType().isa<RankedTensorType>() &&
110  "only tensor types expected");
111  return false;
112  }
113 
114  AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand,
115  const AnalysisState &state) const {
116  return {};
117  }
118 
119  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
120  const BufferizationOptions &options) const {
121  auto gatherOp = cast<vector::GatherOp>(op);
122  assert(gatherOp.getBaseType().isa<TensorType>() &&
123  "only tensor types expected");
124  FailureOr<Value> buffer = getBuffer(rewriter, gatherOp.getBase(), options);
125  if (failed(buffer))
126  return failure();
127  replaceOpWithNewBufferizedOp<vector::GatherOp>(
128  rewriter, gatherOp, gatherOp.getVectorType(), *buffer,
129  gatherOp.getIndices(), gatherOp.getIndexVec(), gatherOp.getMask(),
130  gatherOp.getPassThru());
131  return success();
132  }
133 };
134 
135 /// Bufferization of vector.mask. Replaced with a new vector.mask that
136 /// operates on a memref.
137 struct MaskOpInterface
138  : public BufferizableOpInterface::ExternalModel<MaskOpInterface,
139  vector::MaskOp> {
141  getAliasingOpOperands(Operation *op, OpResult opResult,
142  const AnalysisState &state) const {
143  // MaskOps do not have tensor OpOperands. The yielded values are the result
144  // of the wrapped op.
145  auto maskOp = cast<vector::MaskOp>(op);
146  size_t resultNum = std::distance(op->getOpResults().begin(),
147  llvm::find(op->getOpResults(), opResult));
148  auto yieldOp =
149  cast<vector::YieldOp>(maskOp.getMaskRegion().front().getTerminator());
150  return {{&yieldOp->getOpOperand(resultNum), BufferRelation::Equivalent}};
151  }
152 
153  LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
154  const AnalysisState &state) const {
155  auto bufferizableOp = cast<BufferizableOpInterface>(op);
156  if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
157  return failure();
158 
159  // TODO: Remove this function when vector.mask bodies can bufferize
160  // out-of-place. This is currently not supported because yielding allocs
161  // from a block leads to a memory leak and because vector.mask supports only
162  // a single op in its body.
163  auto maskOp = cast<vector::MaskOp>(op);
164  if (!maskOp.getMaskRegion()
165  .front()
166  .getOps<bufferization::AllocTensorOp>()
167  .empty())
168  return op->emitOpError("body must bufferize in-place");
169 
170  return success();
171  }
172 
173  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
174  const BufferizationOptions &options) const {
175  auto maskOp = cast<vector::MaskOp>(op);
176 
177  // Do not bufferize if the masked op is not bufferizable.
178  Operation *maskedOp = maskOp.getMaskableOp();
179  if (!options.dynCastBufferizableOp(maskedOp))
180  return success();
181 
182  // Update the terminator: Drop all operands that are not results of the
183  // masked op.
184  auto yieldOp =
185  cast<vector::YieldOp>(maskOp.getMaskRegion().front().getTerminator());
186  SmallVector<Value> newReturnValues(maskOp->getNumResults(), Value());
187  SmallVector<Value> newYieldedValues;
188  for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
189  if (llvm::find(maskedOp->getOpResults(), it.value()) !=
190  maskedOp->getOpResults().end()) {
191  newYieldedValues.push_back(it.value());
192  } else {
193  // This used to be a tensor result of the masked op, but is now a memref
194  // that is defined outside of the vector.mask op.
195  newReturnValues[it.index()] = it.value();
196  }
197  }
198  rewriter.updateRootInPlace(yieldOp, [&]() {
199  yieldOp.getOperandsMutable().assign(newYieldedValues);
200  });
201 
202  // Create a new vector.mask op.
203  ValueRange newYieldedValuesRange(newYieldedValues);
204  TypeRange newResultTypes(newYieldedValuesRange);
205  auto newOp = rewriter.create<vector::MaskOp>(
206  op->getLoc(), newResultTypes, maskOp.getMask(), maskOp.getPassthru(),
207  /*maskableOp=*/nullptr,
208  /*maskRegionBuilder=*/[](OpBuilder &b, Operation *) {});
209  newOp.getRegion().takeBody(maskOp.getMaskRegion());
210 
211  // Replace all uses of the old vector.mask op.
212  int idx = 0;
213  for (int i = 0; i < static_cast<int>(maskOp->getNumResults()); ++i) {
214  if (!newReturnValues[i])
215  newReturnValues[i] = newOp->getResult(idx++);
216  }
217  replaceOpWithBufferizedValues(rewriter, maskOp, newReturnValues);
218  return success();
219  }
220 };
221 
222 /// Bufferization of vector.yield. Replaced with a new vector.yield that
223 /// operates on a memref.
224 struct YieldOpInterface
225  : public BufferizableOpInterface::ExternalModel<YieldOpInterface,
226  vector::YieldOp> {
227  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
228  const AnalysisState &state) const {
229  return true;
230  }
231 
232  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
233  const AnalysisState &state) const {
234  return false;
235  }
236 
237  AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand,
238  const AnalysisState &state) const {
239  return {{op->getParentOp()->getResult(opOperand.getOperandNumber()),
241  }
242 
243  bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
244  const AnalysisState &state) const {
245  // Yield operands always bufferize inplace. Otherwise, an alloc + copy
246  // may be generated inside the block. We should not return/yield allocations
247  // when possible.
248  return true;
249  }
250 
251  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
252  const BufferizationOptions &options) const {
253  auto yieldOp = cast<vector::YieldOp>(op);
254 
255  // Only supported as a vector.mask terminator.
256  auto maskOp = dyn_cast<vector::MaskOp>(yieldOp->getParentOp());
257  if (!maskOp)
258  return yieldOp->emitError("unsupported vector::YieldOp parent");
259 
260  // Do not bufferize if the masked op is not bufferizable.
261  Operation *maskedOp = &maskOp.getMaskRegion().front().front();
262  if (!options.dynCastBufferizableOp(maskedOp))
263  return success();
264 
265  // Create a new terminator with the same number of operands. Some of these
266  // may get dropped during the bufferization of vector.mask.
267  SmallVector<Value> newResults;
268  for (Value value : yieldOp.getOperands()) {
269  if (value.getType().isa<TensorType>()) {
270  FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options);
271  if (failed(maybeBuffer))
272  return failure();
273  newResults.push_back(*maybeBuffer);
274  } else {
275  newResults.push_back(value);
276  }
277  }
278 
279  replaceOpWithNewBufferizedOp<vector::YieldOp>(rewriter, op, newResults);
280  return success();
281  }
282 };
283 
284 } // namespace
285 } // namespace vector
286 } // namespace mlir
287 
289  DialectRegistry &registry) {
290  registry.addExtension(+[](MLIRContext *ctx, vector::VectorDialect *dialect) {
291  TransferReadOp::attachInterface<TransferReadOpInterface>(*ctx);
292  TransferWriteOp::attachInterface<TransferWriteOpInterface>(*ctx);
293  GatherOp::attachInterface<GatherOpInterface>(*ctx);
294  MaskOp::attachInterface<MaskOpInterface>(*ctx);
295  YieldOp::attachInterface<YieldOpInterface>(*ctx);
296  });
297 }
static llvm::ManagedStatic< PassManagerOptions > options
Base class for generic analysis states.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:137
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:202
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:432
This class represents an operand of an operation.
Definition: Value.h:255
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:212
This is a value defined by a result of an operation.
Definition: Value.h:442
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:75
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:386
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:207
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:218
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:550
result_range getOpResults()
Definition: Operation.h:399
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:520
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
Definition: Region.h:241
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:549
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:80
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
bool isa() const
Definition: Types.h:301
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:370
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
Type getType() const
Return the type of this value.
Definition: Value.h:122
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values)
Replace an op with replacement values.
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options)
Lookup the buffer for the given value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:223
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry)
This header declares functions that assit transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Options for BufferizableOpInterface-based bufferization.
Bufferizable ops that implement the DestinationStyleOpInterface can use this external model base clas...