MLIR 22.0.0git
BufferizableOpInterfaceImpl.cpp
Go to the documentation of this file.
1//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
15#include "mlir/IR/Dialect.h"
16#include "mlir/IR/Operation.h"
17#include "mlir/IR/Value.h"
18
19using namespace mlir;
20using namespace mlir::bufferization;
21using namespace mlir::vector;
22
23namespace mlir {
24namespace vector {
25namespace {
26
27/// Bufferization of vector.transfer_read. Replaced with a new
28/// vector.transfer_read that operates on a memref.
29struct TransferReadOpInterface
30 : public BufferizableOpInterface::ExternalModel<TransferReadOpInterface,
31 vector::TransferReadOp> {
32 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
33 const AnalysisState &state) const {
34 assert(isa<RankedTensorType>(opOperand.get().getType()) &&
35 "only tensor types expected");
36 return true;
37 }
38
39 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
40 const AnalysisState &state) const {
41 assert(isa<RankedTensorType>(opOperand.get().getType()) &&
42 "only tensor types expected");
43 return false;
44 }
45
46 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
47 const AnalysisState &state) const {
48 return {};
49 }
50
51 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
52 const BufferizationOptions &options,
53 BufferizationState &state) const {
54 auto readOp = cast<vector::TransferReadOp>(op);
55 assert(isa<TensorType>(readOp.getShapedType()) &&
56 "only tensor types expected");
57 FailureOr<Value> buffer =
58 getBuffer(rewriter, readOp.getBase(), options, state);
59 if (failed(buffer))
60 return failure();
61 replaceOpWithNewBufferizedOp<vector::TransferReadOp>(
62 rewriter, readOp, readOp.getVectorType(), *buffer, readOp.getIndices(),
63 readOp.getPermutationMap(), readOp.getPadding(), readOp.getMask(),
64 readOp.getInBoundsAttr());
65 return success();
66 }
67};
68
69/// Bufferization of vector.transfer_write. Replace with a new
70/// vector.transfer_write that operates on a memref.
71///
72/// Note: DstBufferizableOpInterfaceExternalModel provides many default method
73/// implementations for DestinationStyle ops.
74struct TransferWriteOpInterface
75 : public DstBufferizableOpInterfaceExternalModel<TransferWriteOpInterface,
76 vector::TransferWriteOp> {
77 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
78 const AnalysisState &state) const {
79 auto writeOp = cast<vector::TransferWriteOp>(op);
80
81 // Does not bufferize to a memory read if the vector completely overwrites
82 // the buffer.
83
84 // Destination must have static shape.
85 if (!writeOp.getShapedType().hasStaticShape())
86 return true;
87
88 // All offsets must be 0.
89 for (Value offset : writeOp.getIndices()) {
90 if (getConstantIntValue(offset) != 0)
91 return true;
92 }
93
94 // There is no mask.
95 if (writeOp.isMasked())
96 return true;
97
98 // Must write at least the full dimension size.
99 for (auto [d0, d1] : llvm::zip(writeOp.getShapedType().getShape(),
100 writeOp.getVectorType().getShape())) {
101 if (d0 > d1)
102 return true;
103 }
104
105 return false;
106 }
107
108 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
109 const BufferizationOptions &options,
110 BufferizationState &state) const {
111 auto writeOp = cast<vector::TransferWriteOp>(op);
112 assert(isa<TensorType>(writeOp.getShapedType()) &&
113 "only tensor types expected");
114
115 // Create a new transfer_write on buffer that doesn't have a return value.
116 FailureOr<Value> resultBuffer =
117 getBuffer(rewriter, writeOp.getBase(), options, state);
118 if (failed(resultBuffer))
119 return failure();
120 vector::TransferWriteOp::create(
121 rewriter, writeOp.getLoc(), writeOp.getVector(), *resultBuffer,
122 writeOp.getIndices(), writeOp.getPermutationMapAttr(),
123 writeOp.getMask(), writeOp.getInBoundsAttr());
124 replaceOpWithBufferizedValues(rewriter, op, *resultBuffer);
125
126 return success();
127 }
128};
129
130/// Bufferization of vector.scatter. Replaced with a new vector.scatter that
131/// operates on a memref.
132struct ScatterOpInterface
133 : public BufferizableOpInterface::ExternalModel<ScatterOpInterface,
134 vector::ScatterOp> {
135 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
136 const AnalysisState &state) const {
137 assert(isa<RankedTensorType>(opOperand.get().getType()) &&
138 "only tensor types expected");
139 return true;
140 }
141
142 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
143 const AnalysisState &state) const {
144 assert(isa<RankedTensorType>(opOperand.get().getType()) &&
145 "only tensor types expected");
146 return true;
147 }
148
149 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
150 const AnalysisState &state) const {
151 assert(isa<RankedTensorType>(opOperand.get().getType()) &&
152 "only tensor types expected");
153 auto scatterOp = cast<vector::ScatterOp>(op);
154 if (&opOperand != &scatterOp.getBaseMutable())
155 return {};
156 return {{scatterOp.getResult(), BufferRelation::Equivalent}};
157 }
158
159 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
160 const BufferizationOptions &options,
161 BufferizationState &state) const {
162 auto scatterOp = cast<vector::ScatterOp>(op);
163 assert(isa<TensorType>(scatterOp.getBaseType()) &&
164 "only tensor types expected");
165 FailureOr<Value> buffer =
166 getBuffer(rewriter, scatterOp.getBase(), options, state);
167 if (failed(buffer))
168 return failure();
169 vector::ScatterOp::create(rewriter, scatterOp.getLoc(),
170 /*resultType=*/nullptr, *buffer,
171 scatterOp.getOffsets(), scatterOp.getIndices(),
172 scatterOp.getMask(), scatterOp.getValueToStore());
173 replaceOpWithBufferizedValues(rewriter, op, *buffer);
174 return success();
175 }
176};
177
178/// Bufferization of vector.gather. Replaced with a new vector.gather that
179/// operates on a memref.
180struct GatherOpInterface
181 : public BufferizableOpInterface::ExternalModel<GatherOpInterface,
182 vector::GatherOp> {
183 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
184 const AnalysisState &state) const {
185 assert(isa<RankedTensorType>(opOperand.get().getType()) &&
186 "only tensor types expected");
187 return true;
188 }
189
190 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
191 const AnalysisState &state) const {
192 assert(isa<RankedTensorType>(opOperand.get().getType()) &&
193 "only tensor types expected");
194 return false;
195 }
196
197 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
198 const AnalysisState &state) const {
199 return {};
200 }
201
202 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
203 const BufferizationOptions &options,
204 BufferizationState &state) const {
205 auto gatherOp = cast<vector::GatherOp>(op);
206 assert(isa<TensorType>(gatherOp.getBaseType()) &&
207 "only tensor types expected");
208 FailureOr<Value> buffer =
209 getBuffer(rewriter, gatherOp.getBase(), options, state);
210 if (failed(buffer))
211 return failure();
212 replaceOpWithNewBufferizedOp<vector::GatherOp>(
213 rewriter, gatherOp, gatherOp.getVectorType(), *buffer,
214 gatherOp.getOffsets(), gatherOp.getIndices(), gatherOp.getMask(),
215 gatherOp.getPassThru());
216 return success();
217 }
218};
219
220/// Bufferization of vector.mask. Replaced with a new vector.mask that
221/// operates on a memref.
222struct MaskOpInterface
223 : public BufferizableOpInterface::ExternalModel<MaskOpInterface,
224 vector::MaskOp> {
225 AliasingOpOperandList
226 getAliasingOpOperands(Operation *op, Value value,
227 const AnalysisState &state) const {
228 // MaskOps do not have tensor OpOperands. The yielded values are the result
229 // of the wrapped op.
230 auto maskOp = cast<vector::MaskOp>(op);
231 size_t resultNum = std::distance(op->getOpResults().begin(),
232 llvm::find(op->getOpResults(), value));
233 auto yieldOp =
234 cast<vector::YieldOp>(maskOp.getMaskRegion().front().getTerminator());
235 return {{&yieldOp->getOpOperand(resultNum), BufferRelation::Equivalent}};
236 }
237
238 LogicalResult
239 resolveConflicts(Operation *op, RewriterBase &rewriter,
240 const AnalysisState &analysisState,
241 const BufferizationState &bufferizationState) const {
242 auto bufferizableOp = cast<BufferizableOpInterface>(op);
243 if (failed(bufferizableOp.resolveTensorOpOperandConflicts(
244 rewriter, analysisState, bufferizationState)))
245 return failure();
246
247 // TODO: Remove this function when vector.mask bodies can bufferize
248 // out-of-place. This is currently not supported because yielding allocs
249 // from a block leads to a memory leak and because vector.mask supports only
250 // a single op in its body.
251 auto maskOp = cast<vector::MaskOp>(op);
252 if (!maskOp.getMaskRegion()
253 .front()
254 .getOps<bufferization::AllocTensorOp>()
255 .empty())
256 return op->emitOpError("body must bufferize in-place");
257
258 return success();
259 }
260
261 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
262 const BufferizationOptions &options,
263 BufferizationState &state) const {
264 auto maskOp = cast<vector::MaskOp>(op);
265
266 // Do not bufferize if the masked op is not bufferizable.
267 Operation *maskedOp = maskOp.getMaskableOp();
268 if (!options.dynCastBufferizableOp(maskedOp))
269 return success();
270
271 // Update the terminator: Drop all operands that are not results of the
272 // masked op.
273 auto yieldOp =
274 cast<vector::YieldOp>(maskOp.getMaskRegion().front().getTerminator());
275 SmallVector<Value> newReturnValues(maskOp->getNumResults(), Value());
276 SmallVector<Value> newYieldedValues;
277 for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
278 if (llvm::is_contained(maskedOp->getOpResults(), it.value())) {
279 newYieldedValues.push_back(it.value());
280 } else {
281 // This used to be a tensor result of the masked op, but is now a memref
282 // that is defined outside of the vector.mask op.
283 newReturnValues[it.index()] = it.value();
284 }
285 }
286 rewriter.modifyOpInPlace(yieldOp, [&]() {
287 yieldOp.getOperandsMutable().assign(newYieldedValues);
288 });
289
290 // Create a new vector.mask op.
291 ValueRange newYieldedValuesRange(newYieldedValues);
292 TypeRange newResultTypes(newYieldedValuesRange);
293 auto newOp = vector::MaskOp::create(
294 rewriter, op->getLoc(), newResultTypes, maskOp.getMask(),
295 maskOp.getPassthru(),
296 /*maskableOp=*/nullptr,
297 /*maskRegionBuilder=*/[](OpBuilder &b, Operation *) {});
298 newOp.getRegion().takeBody(maskOp.getMaskRegion());
299
300 // Replace all uses of the old vector.mask op.
301 int idx = 0;
302 for (int i = 0; i < static_cast<int>(maskOp->getNumResults()); ++i) {
303 if (!newReturnValues[i])
304 newReturnValues[i] = newOp->getResult(idx++);
305 }
306 replaceOpWithBufferizedValues(rewriter, maskOp, newReturnValues);
307 return success();
308 }
309};
310
311/// Bufferization of vector.yield. Replaced with a new vector.yield that
312/// operates on a memref.
313struct YieldOpInterface
314 : public BufferizableOpInterface::ExternalModel<YieldOpInterface,
315 vector::YieldOp> {
316 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
317 const AnalysisState &state) const {
318 return true;
319 }
320
321 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
322 const AnalysisState &state) const {
323 return false;
324 }
325
326 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
327 const AnalysisState &state) const {
328 return {{op->getParentOp()->getResult(opOperand.getOperandNumber()),
329 BufferRelation::Equivalent}};
330 }
331
332 bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
333 const AnalysisState &state) const {
334 // Yield operands always bufferize inplace. Otherwise, an alloc + copy
335 // may be generated inside the block. We should not return/yield allocations
336 // when possible.
337 return true;
338 }
339
340 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
341 const BufferizationOptions &options,
342 BufferizationState &state) const {
343 auto yieldOp = cast<vector::YieldOp>(op);
344
345 // Only supported as a vector.mask terminator.
346 auto maskOp = dyn_cast<vector::MaskOp>(yieldOp->getParentOp());
347 if (!maskOp)
348 return yieldOp->emitError("unsupported vector::YieldOp parent");
349
350 // Do not bufferize if the masked op is not bufferizable.
351 Operation *maskedOp = &maskOp.getMaskRegion().front().front();
352 if (!options.dynCastBufferizableOp(maskedOp))
353 return success();
354
355 // Create a new terminator with the same number of operands. Some of these
356 // may get dropped during the bufferization of vector.mask.
357 SmallVector<Value> newResults;
358 for (Value value : yieldOp.getOperands()) {
359 if (isa<TensorType>(value.getType())) {
360 FailureOr<Value> maybeBuffer =
361 getBuffer(rewriter, value, options, state);
362 if (failed(maybeBuffer))
363 return failure();
364 newResults.push_back(*maybeBuffer);
365 } else {
366 newResults.push_back(value);
367 }
368 }
369
370 replaceOpWithNewBufferizedOp<vector::YieldOp>(rewriter, op, newResults);
371 return success();
372 }
373};
374
375} // namespace
376} // namespace vector
377} // namespace mlir
378
380 DialectRegistry &registry) {
381 registry.addExtension(+[](MLIRContext *ctx, vector::VectorDialect *dialect) {
382 TransferReadOp::attachInterface<TransferReadOpInterface>(*ctx);
383 TransferWriteOp::attachInterface<TransferWriteOpInterface>(*ctx);
384 GatherOp::attachInterface<GatherOpInterface>(*ctx);
385 MaskOp::attachInterface<MaskOpInterface>(*ctx);
386 YieldOp::attachInterface<YieldOpInterface>(*ctx);
387 ScatterOp::attachInterface<ScatterOpInterface>(*ctx);
388 });
389}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static llvm::ManagedStatic< PassManagerOptions > options
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
IRValueT get() const
Return the current value being used by this operand.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
result_range getOpResults()
Definition Operation.h:420
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Type getType() const
Return the type of this value.
Definition Value.h:105
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry)
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.