28 struct TransferReadOpInterface
29 :
public BufferizableOpInterface::ExternalModel<TransferReadOpInterface,
30 vector::TransferReadOp> {
33 assert(isa<RankedTensorType>(opOperand.
get().
getType()) &&
34 "only tensor types expected");
40 assert(isa<RankedTensorType>(opOperand.
get().
getType()) &&
41 "only tensor types expected");
52 auto readOp = cast<vector::TransferReadOp>(op);
53 assert(isa<TensorType>(readOp.getShapedType()) &&
54 "only tensor types expected");
58 replaceOpWithNewBufferizedOp<vector::TransferReadOp>(
59 rewriter, readOp, readOp.getVectorType(), *buffer, readOp.getIndices(),
60 readOp.getPermutationMap(), readOp.getPadding(), readOp.getMask(),
61 readOp.getInBoundsAttr());
71 struct TransferWriteOpInterface
73 vector::TransferWriteOp> {
76 auto writeOp = cast<vector::TransferWriteOp>(op);
82 if (!writeOp.getShapedType().hasStaticShape())
86 for (
Value offset : writeOp.getIndices()) {
92 if (writeOp.isMasked())
96 for (
auto [d0, d1] : llvm::zip(writeOp.getShapedType().getShape(),
97 writeOp.getVectorType().getShape())) {
107 auto writeOp = cast<vector::TransferWriteOp>(op);
108 assert(isa<TensorType>(writeOp.getShapedType()) &&
109 "only tensor types expected");
116 rewriter.
create<vector::TransferWriteOp>(
117 writeOp.getLoc(), writeOp.getVector(), *resultBuffer,
118 writeOp.getIndices(), writeOp.getPermutationMapAttr(),
119 writeOp.getMask(), writeOp.getInBoundsAttr());
128 struct GatherOpInterface
129 :
public BufferizableOpInterface::ExternalModel<GatherOpInterface,
133 assert(isa<RankedTensorType>(opOperand.
get().
getType()) &&
134 "only tensor types expected");
140 assert(isa<RankedTensorType>(opOperand.
get().
getType()) &&
141 "only tensor types expected");
152 auto gatherOp = cast<vector::GatherOp>(op);
153 assert(isa<TensorType>(gatherOp.getBaseType()) &&
154 "only tensor types expected");
158 replaceOpWithNewBufferizedOp<vector::GatherOp>(
159 rewriter, gatherOp, gatherOp.getVectorType(), *buffer,
160 gatherOp.getIndices(), gatherOp.getIndexVec(), gatherOp.getMask(),
161 gatherOp.getPassThru());
168 struct MaskOpInterface
169 :
public BufferizableOpInterface::ExternalModel<MaskOpInterface,
176 auto maskOp = cast<vector::MaskOp>(op);
177 size_t resultNum = std::distance(op->
getOpResults().begin(),
180 cast<vector::YieldOp>(maskOp.getMaskRegion().front().getTerminator());
186 auto bufferizableOp = cast<BufferizableOpInterface>(op);
187 if (
failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
194 auto maskOp = cast<vector::MaskOp>(op);
195 if (!maskOp.getMaskRegion()
197 .getOps<bufferization::AllocTensorOp>()
199 return op->
emitOpError(
"body must bufferize in-place");
206 auto maskOp = cast<vector::MaskOp>(op);
209 Operation *maskedOp = maskOp.getMaskableOp();
210 if (!
options.dynCastBufferizableOp(maskedOp))
216 cast<vector::YieldOp>(maskOp.getMaskRegion().front().getTerminator());
222 newYieldedValues.push_back(it.value());
226 newReturnValues[it.index()] = it.value();
230 yieldOp.getOperandsMutable().assign(newYieldedValues);
234 ValueRange newYieldedValuesRange(newYieldedValues);
235 TypeRange newResultTypes(newYieldedValuesRange);
236 auto newOp = rewriter.
create<vector::MaskOp>(
237 op->
getLoc(), newResultTypes, maskOp.getMask(), maskOp.getPassthru(),
244 for (
int i = 0; i < static_cast<int>(maskOp->getNumResults()); ++i) {
245 if (!newReturnValues[i])
246 newReturnValues[i] = newOp->getResult(idx++);
255 struct YieldOpInterface
256 :
public BufferizableOpInterface::ExternalModel<YieldOpInterface,
284 auto yieldOp = cast<vector::YieldOp>(op);
287 auto maskOp = dyn_cast<vector::MaskOp>(yieldOp->getParentOp());
289 return yieldOp->emitError(
"unsupported vector::YieldOp parent");
292 Operation *maskedOp = &maskOp.getMaskRegion().front().front();
293 if (!
options.dynCastBufferizableOp(maskedOp))
299 for (
Value value : yieldOp.getOperands()) {
300 if (isa<TensorType>(value.
getType())) {
304 newResults.push_back(*maybeBuffer);
306 newResults.push_back(value);
310 replaceOpWithNewBufferizedOp<vector::YieldOp>(rewriter, op, newResults);
322 TransferReadOp::attachInterface<TransferReadOpInterface>(*ctx);
323 TransferWriteOp::attachInterface<TransferWriteOpInterface>(*ctx);
324 GatherOp::attachInterface<GatherOpInterface>(*ctx);
325 MaskOp::attachInterface<MaskOpInterface>(*ctx);
326 YieldOp::attachInterface<YieldOpInterface>(*ctx);
static llvm::ManagedStatic< PassManagerOptions > options
Base class for generic analysis states.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class provides support for representing a failure result, or a valid value of type T.
IRValueT get() const
Return the current value being used by this operand.
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
result_range getOpResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
This class provides an abstraction over the various different ranges of value types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values)
Replace an op with replacement values.
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options)
Lookup the buffer for the given value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
Options for BufferizableOpInterface-based bufferization.
Bufferizable ops that implement the DestinationStyleOpInterface can use this external model base clas...