28 struct TransferReadOpInterface
29 :
public BufferizableOpInterface::ExternalModel<TransferReadOpInterface,
30 vector::TransferReadOp> {
33 assert(isa<RankedTensorType>(opOperand.
get().
getType()) &&
34 "only tensor types expected");
40 assert(isa<RankedTensorType>(opOperand.
get().
getType()) &&
41 "only tensor types expected");
53 auto readOp = cast<vector::TransferReadOp>(op);
54 assert(isa<TensorType>(readOp.getShapedType()) &&
55 "only tensor types expected");
56 FailureOr<Value> buffer =
60 replaceOpWithNewBufferizedOp<vector::TransferReadOp>(
61 rewriter, readOp, readOp.getVectorType(), *buffer, readOp.getIndices(),
62 readOp.getPermutationMap(), readOp.getPadding(), readOp.getMask(),
63 readOp.getInBoundsAttr());
73 struct TransferWriteOpInterface
75 vector::TransferWriteOp> {
78 auto writeOp = cast<vector::TransferWriteOp>(op);
84 if (!writeOp.getShapedType().hasStaticShape())
88 for (
Value offset : writeOp.getIndices()) {
94 if (writeOp.isMasked())
98 for (
auto [d0, d1] : llvm::zip(writeOp.getShapedType().getShape(),
99 writeOp.getVectorType().getShape())) {
110 auto writeOp = cast<vector::TransferWriteOp>(op);
111 assert(isa<TensorType>(writeOp.getShapedType()) &&
112 "only tensor types expected");
115 FailureOr<Value> resultBuffer =
117 if (failed(resultBuffer))
119 rewriter.
create<vector::TransferWriteOp>(
120 writeOp.getLoc(), writeOp.getVector(), *resultBuffer,
121 writeOp.getIndices(), writeOp.getPermutationMapAttr(),
122 writeOp.getMask(), writeOp.getInBoundsAttr());
131 struct GatherOpInterface
132 :
public BufferizableOpInterface::ExternalModel<GatherOpInterface,
136 assert(isa<RankedTensorType>(opOperand.
get().
getType()) &&
137 "only tensor types expected");
143 assert(isa<RankedTensorType>(opOperand.
get().
getType()) &&
144 "only tensor types expected");
156 auto gatherOp = cast<vector::GatherOp>(op);
157 assert(isa<TensorType>(gatherOp.getBaseType()) &&
158 "only tensor types expected");
159 FailureOr<Value> buffer =
163 replaceOpWithNewBufferizedOp<vector::GatherOp>(
164 rewriter, gatherOp, gatherOp.getVectorType(), *buffer,
165 gatherOp.getIndices(), gatherOp.getIndexVec(), gatherOp.getMask(),
166 gatherOp.getPassThru());
173 struct MaskOpInterface
174 :
public BufferizableOpInterface::ExternalModel<MaskOpInterface,
181 auto maskOp = cast<vector::MaskOp>(op);
182 size_t resultNum = std::distance(op->
getOpResults().begin(),
185 cast<vector::YieldOp>(maskOp.getMaskRegion().front().getTerminator());
193 auto bufferizableOp = cast<BufferizableOpInterface>(op);
194 if (failed(bufferizableOp.resolveTensorOpOperandConflicts(
195 rewriter, analysisState, bufferizationState)))
202 auto maskOp = cast<vector::MaskOp>(op);
203 if (!maskOp.getMaskRegion()
205 .getOps<bufferization::AllocTensorOp>()
207 return op->
emitOpError(
"body must bufferize in-place");
215 auto maskOp = cast<vector::MaskOp>(op);
218 Operation *maskedOp = maskOp.getMaskableOp();
219 if (!
options.dynCastBufferizableOp(maskedOp))
225 cast<vector::YieldOp>(maskOp.getMaskRegion().front().getTerminator());
229 if (llvm::is_contained(maskedOp->
getOpResults(), it.value())) {
230 newYieldedValues.push_back(it.value());
234 newReturnValues[it.index()] = it.value();
238 yieldOp.getOperandsMutable().assign(newYieldedValues);
242 ValueRange newYieldedValuesRange(newYieldedValues);
243 TypeRange newResultTypes(newYieldedValuesRange);
244 auto newOp = rewriter.
create<vector::MaskOp>(
245 op->
getLoc(), newResultTypes, maskOp.getMask(), maskOp.getPassthru(),
252 for (
int i = 0; i < static_cast<int>(maskOp->getNumResults()); ++i) {
253 if (!newReturnValues[i])
254 newReturnValues[i] = newOp->getResult(idx++);
263 struct YieldOpInterface
264 :
public BufferizableOpInterface::ExternalModel<YieldOpInterface,
293 auto yieldOp = cast<vector::YieldOp>(op);
296 auto maskOp = dyn_cast<vector::MaskOp>(yieldOp->getParentOp());
298 return yieldOp->emitError(
"unsupported vector::YieldOp parent");
301 Operation *maskedOp = &maskOp.getMaskRegion().front().front();
302 if (!
options.dynCastBufferizableOp(maskedOp))
308 for (
Value value : yieldOp.getOperands()) {
309 if (isa<TensorType>(value.
getType())) {
310 FailureOr<Value> maybeBuffer =
312 if (failed(maybeBuffer))
314 newResults.push_back(*maybeBuffer);
316 newResults.push_back(value);
320 replaceOpWithNewBufferizedOp<vector::YieldOp>(rewriter, op, newResults);
332 TransferReadOp::attachInterface<TransferReadOpInterface>(*ctx);
333 TransferWriteOp::attachInterface<TransferWriteOpInterface>(*ctx);
334 GatherOp::attachInterface<GatherOpInterface>(*ctx);
335 MaskOp::attachInterface<MaskOpInterface>(*ctx);
336 YieldOp::attachInterface<YieldOpInterface>(*ctx);
static llvm::ManagedStatic< PassManagerOptions > options
Base class for generic analysis states.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
IRValueT get() const
Return the current value being used by this operand.
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
result_range getOpResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
This class provides an abstraction over the various different ranges of value types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
BufferizationState provides information about the state of the IR during the bufferization process.
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values)
Replace an op with replacement values.
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options, const BufferizationState &state)
Lookup the buffer for the given value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry)
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Options for BufferizableOpInterface-based bufferization.
Bufferizable ops that implement the DestinationStyleOpInterface can use this external model base clas...