29struct TransferReadOpInterface
30 :
public BufferizableOpInterface::ExternalModel<TransferReadOpInterface,
31 vector::TransferReadOp> {
32 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
33 const AnalysisState &state)
const {
34 assert(isa<RankedTensorType>(opOperand.
get().
getType()) &&
35 "only tensor types expected");
39 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
40 const AnalysisState &state)
const {
41 assert(isa<RankedTensorType>(opOperand.
get().
getType()) &&
42 "only tensor types expected");
46 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
47 const AnalysisState &state)
const {
51 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
52 const BufferizationOptions &
options,
53 BufferizationState &state)
const {
54 auto readOp = cast<vector::TransferReadOp>(op);
55 assert(isa<TensorType>(readOp.getShapedType()) &&
56 "only tensor types expected");
57 FailureOr<Value> buffer =
58 getBuffer(rewriter, readOp.getBase(),
options, state);
61 replaceOpWithNewBufferizedOp<vector::TransferReadOp>(
62 rewriter, readOp, readOp.getVectorType(), *buffer, readOp.getIndices(),
63 readOp.getPermutationMap(), readOp.getPadding(), readOp.getMask(),
64 readOp.getInBoundsAttr());
74struct TransferWriteOpInterface
75 :
public DstBufferizableOpInterfaceExternalModel<TransferWriteOpInterface,
76 vector::TransferWriteOp> {
77 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
78 const AnalysisState &state)
const {
79 auto writeOp = cast<vector::TransferWriteOp>(op);
85 if (!writeOp.getShapedType().hasStaticShape())
89 for (Value offset : writeOp.getIndices()) {
95 if (writeOp.isMasked())
99 for (
auto [d0, d1] : llvm::zip(writeOp.getShapedType().getShape(),
100 writeOp.getVectorType().getShape())) {
108 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
109 const BufferizationOptions &
options,
110 BufferizationState &state)
const {
111 auto writeOp = cast<vector::TransferWriteOp>(op);
112 assert(isa<TensorType>(writeOp.getShapedType()) &&
113 "only tensor types expected");
116 FailureOr<Value> resultBuffer =
117 getBuffer(rewriter, writeOp.getBase(),
options, state);
120 vector::TransferWriteOp::create(
121 rewriter, writeOp.getLoc(), writeOp.getVector(), *resultBuffer,
122 writeOp.getIndices(), writeOp.getPermutationMapAttr(),
123 writeOp.getMask(), writeOp.getInBoundsAttr());
124 replaceOpWithBufferizedValues(rewriter, op, *resultBuffer);
132struct ScatterOpInterface
133 :
public BufferizableOpInterface::ExternalModel<ScatterOpInterface,
135 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
136 const AnalysisState &state)
const {
137 assert(isa<RankedTensorType>(opOperand.
get().
getType()) &&
138 "only tensor types expected");
142 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
143 const AnalysisState &state)
const {
144 assert(isa<RankedTensorType>(opOperand.
get().
getType()) &&
145 "only tensor types expected");
149 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
150 const AnalysisState &state)
const {
151 assert(isa<RankedTensorType>(opOperand.
get().
getType()) &&
152 "only tensor types expected");
153 auto scatterOp = cast<vector::ScatterOp>(op);
154 if (&opOperand != &scatterOp.getBaseMutable())
156 return {{scatterOp.getResult(), BufferRelation::Equivalent}};
159 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
160 const BufferizationOptions &
options,
161 BufferizationState &state)
const {
162 auto scatterOp = cast<vector::ScatterOp>(op);
163 assert(isa<TensorType>(scatterOp.getBaseType()) &&
164 "only tensor types expected");
165 FailureOr<Value> buffer =
166 getBuffer(rewriter, scatterOp.getBase(),
options, state);
169 vector::ScatterOp::create(rewriter, scatterOp.getLoc(),
171 scatterOp.getOffsets(), scatterOp.getIndices(),
172 scatterOp.getMask(), scatterOp.getValueToStore());
173 replaceOpWithBufferizedValues(rewriter, op, *buffer);
180struct GatherOpInterface
181 :
public BufferizableOpInterface::ExternalModel<GatherOpInterface,
183 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
184 const AnalysisState &state)
const {
185 assert(isa<RankedTensorType>(opOperand.
get().
getType()) &&
186 "only tensor types expected");
190 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
191 const AnalysisState &state)
const {
192 assert(isa<RankedTensorType>(opOperand.
get().
getType()) &&
193 "only tensor types expected");
197 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
198 const AnalysisState &state)
const {
202 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
203 const BufferizationOptions &
options,
204 BufferizationState &state)
const {
205 auto gatherOp = cast<vector::GatherOp>(op);
206 assert(isa<TensorType>(gatherOp.getBaseType()) &&
207 "only tensor types expected");
208 FailureOr<Value> buffer =
209 getBuffer(rewriter, gatherOp.getBase(),
options, state);
212 replaceOpWithNewBufferizedOp<vector::GatherOp>(
213 rewriter, gatherOp, gatherOp.getVectorType(), *buffer,
214 gatherOp.getOffsets(), gatherOp.getIndices(), gatherOp.getMask(),
215 gatherOp.getPassThru());
222struct MaskOpInterface
223 :
public BufferizableOpInterface::ExternalModel<MaskOpInterface,
225 AliasingOpOperandList
226 getAliasingOpOperands(Operation *op, Value value,
227 const AnalysisState &state)
const {
230 auto maskOp = cast<vector::MaskOp>(op);
231 size_t resultNum = std::distance(op->
getOpResults().begin(),
234 cast<vector::YieldOp>(maskOp.getMaskRegion().front().getTerminator());
235 return {{&yieldOp->getOpOperand(resultNum), BufferRelation::Equivalent}};
239 resolveConflicts(Operation *op, RewriterBase &rewriter,
240 const AnalysisState &analysisState,
241 const BufferizationState &bufferizationState)
const {
242 auto bufferizableOp = cast<BufferizableOpInterface>(op);
243 if (
failed(bufferizableOp.resolveTensorOpOperandConflicts(
244 rewriter, analysisState, bufferizationState)))
251 auto maskOp = cast<vector::MaskOp>(op);
252 if (!maskOp.getMaskRegion()
254 .getOps<bufferization::AllocTensorOp>()
256 return op->
emitOpError(
"body must bufferize in-place");
261 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
262 const BufferizationOptions &
options,
263 BufferizationState &state)
const {
264 auto maskOp = cast<vector::MaskOp>(op);
267 Operation *maskedOp = maskOp.getMaskableOp();
268 if (!
options.dynCastBufferizableOp(maskedOp))
274 cast<vector::YieldOp>(maskOp.getMaskRegion().front().getTerminator());
275 SmallVector<Value> newReturnValues(maskOp->getNumResults(), Value());
276 SmallVector<Value> newYieldedValues;
277 for (
const auto &it : llvm::enumerate(yieldOp.getOperands())) {
278 if (llvm::is_contained(maskedOp->
getOpResults(), it.value())) {
279 newYieldedValues.push_back(it.value());
283 newReturnValues[it.index()] = it.value();
287 yieldOp.getOperandsMutable().assign(newYieldedValues);
291 ValueRange newYieldedValuesRange(newYieldedValues);
292 TypeRange newResultTypes(newYieldedValuesRange);
293 auto newOp = vector::MaskOp::create(
294 rewriter, op->
getLoc(), newResultTypes, maskOp.getMask(),
295 maskOp.getPassthru(),
297 [](OpBuilder &
b, Operation *) {});
298 newOp.getRegion().takeBody(maskOp.getMaskRegion());
302 for (
int i = 0; i < static_cast<int>(maskOp->getNumResults()); ++i) {
303 if (!newReturnValues[i])
304 newReturnValues[i] = newOp->getResult(idx++);
306 replaceOpWithBufferizedValues(rewriter, maskOp, newReturnValues);
313struct YieldOpInterface
314 :
public BufferizableOpInterface::ExternalModel<YieldOpInterface,
316 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
317 const AnalysisState &state)
const {
321 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
322 const AnalysisState &state)
const {
326 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
327 const AnalysisState &state)
const {
329 BufferRelation::Equivalent}};
332 bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
333 const AnalysisState &state)
const {
340 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
341 const BufferizationOptions &
options,
342 BufferizationState &state)
const {
343 auto yieldOp = cast<vector::YieldOp>(op);
346 auto maskOp = dyn_cast<vector::MaskOp>(yieldOp->getParentOp());
348 return yieldOp->emitError(
"unsupported vector::YieldOp parent");
351 Operation *maskedOp = &maskOp.getMaskRegion().front().front();
352 if (!
options.dynCastBufferizableOp(maskedOp))
357 SmallVector<Value> newResults;
358 for (Value value : yieldOp.getOperands()) {
359 if (isa<TensorType>(value.
getType())) {
360 FailureOr<Value> maybeBuffer =
361 getBuffer(rewriter, value,
options, state);
364 newResults.push_back(*maybeBuffer);
366 newResults.push_back(value);
370 replaceOpWithNewBufferizedOp<vector::YieldOp>(rewriter, op, newResults);
382 TransferReadOp::attachInterface<TransferReadOpInterface>(*ctx);
383 TransferWriteOp::attachInterface<TransferWriteOpInterface>(*ctx);
384 GatherOp::attachInterface<GatherOpInterface>(*ctx);
385 MaskOp::attachInterface<MaskOpInterface>(*ctx);
386 YieldOp::attachInterface<YieldOpInterface>(*ctx);
387 ScatterOp::attachInterface<ScatterOpInterface>(*ctx);
static llvm::ManagedStatic< PassManagerOptions > options
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
IRValueT get() const
Return the current value being used by this operand.
MLIRContext is the top-level object for a collection of MLIR operations.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
result_range getOpResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Type getType() const
Return the type of this value.
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry)
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.