28 struct TransferReadOpInterface
29 :
public BufferizableOpInterface::ExternalModel<TransferReadOpInterface,
30 vector::TransferReadOp> {
33 assert(isa<RankedTensorType>(opOperand.
get().
getType()) &&
34 "only tensor types expected");
40 assert(isa<RankedTensorType>(opOperand.
get().
getType()) &&
41 "only tensor types expected");
52 auto readOp = cast<vector::TransferReadOp>(op);
53 assert(isa<TensorType>(readOp.getShapedType()) &&
54 "only tensor types expected");
58 replaceOpWithNewBufferizedOp<vector::TransferReadOp>(
59 rewriter, readOp, readOp.getVectorType(), *buffer, readOp.getIndices(),
60 readOp.getPermutationMap(), readOp.getPadding(), readOp.getMask(),
61 readOp.getInBoundsAttr());
71 struct TransferWriteOpInterface
73 vector::TransferWriteOp> {
76 auto writeOp = cast<vector::TransferWriteOp>(op);
82 if (!writeOp.getShapedType().hasStaticShape())
86 for (
Value offset : writeOp.getIndices()) {
92 if (writeOp.isMasked())
96 for (
auto [d0, d1] : llvm::zip(writeOp.getShapedType().getShape(),
97 writeOp.getVectorType().getShape())) {
107 auto writeOp = cast<vector::TransferWriteOp>(op);
108 assert(isa<TensorType>(writeOp.getShapedType()) &&
109 "only tensor types expected");
112 FailureOr<Value> resultBuffer =
114 if (failed(resultBuffer))
116 rewriter.
create<vector::TransferWriteOp>(
117 writeOp.getLoc(), writeOp.getVector(), *resultBuffer,
118 writeOp.getIndices(), writeOp.getPermutationMapAttr(),
119 writeOp.getMask(), writeOp.getInBoundsAttr());
128 struct GatherOpInterface
129 :
public BufferizableOpInterface::ExternalModel<GatherOpInterface,
133 assert(isa<RankedTensorType>(opOperand.
get().
getType()) &&
134 "only tensor types expected");
140 assert(isa<RankedTensorType>(opOperand.
get().
getType()) &&
141 "only tensor types expected");
152 auto gatherOp = cast<vector::GatherOp>(op);
153 assert(isa<TensorType>(gatherOp.getBaseType()) &&
154 "only tensor types expected");
158 replaceOpWithNewBufferizedOp<vector::GatherOp>(
159 rewriter, gatherOp, gatherOp.getVectorType(), *buffer,
160 gatherOp.getIndices(), gatherOp.getIndexVec(), gatherOp.getMask(),
161 gatherOp.getPassThru());
168 struct MaskOpInterface
169 :
public BufferizableOpInterface::ExternalModel<MaskOpInterface,
176 auto maskOp = cast<vector::MaskOp>(op);
177 size_t resultNum = std::distance(op->
getOpResults().begin(),
180 cast<vector::YieldOp>(maskOp.getMaskRegion().front().getTerminator());
186 auto bufferizableOp = cast<BufferizableOpInterface>(op);
187 if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
194 auto maskOp = cast<vector::MaskOp>(op);
195 if (!maskOp.getMaskRegion()
197 .getOps<bufferization::AllocTensorOp>()
199 return op->
emitOpError(
"body must bufferize in-place");
206 auto maskOp = cast<vector::MaskOp>(op);
209 Operation *maskedOp = maskOp.getMaskableOp();
210 if (!
options.dynCastBufferizableOp(maskedOp))
216 cast<vector::YieldOp>(maskOp.getMaskRegion().front().getTerminator());
220 if (llvm::is_contained(maskedOp->
getOpResults(), it.value())) {
221 newYieldedValues.push_back(it.value());
225 newReturnValues[it.index()] = it.value();
229 yieldOp.getOperandsMutable().assign(newYieldedValues);
233 ValueRange newYieldedValuesRange(newYieldedValues);
234 TypeRange newResultTypes(newYieldedValuesRange);
235 auto newOp = rewriter.
create<vector::MaskOp>(
236 op->
getLoc(), newResultTypes, maskOp.getMask(), maskOp.getPassthru(),
243 for (
int i = 0; i < static_cast<int>(maskOp->getNumResults()); ++i) {
244 if (!newReturnValues[i])
245 newReturnValues[i] = newOp->getResult(idx++);
254 struct YieldOpInterface
255 :
public BufferizableOpInterface::ExternalModel<YieldOpInterface,
283 auto yieldOp = cast<vector::YieldOp>(op);
286 auto maskOp = dyn_cast<vector::MaskOp>(yieldOp->getParentOp());
288 return yieldOp->emitError(
"unsupported vector::YieldOp parent");
291 Operation *maskedOp = &maskOp.getMaskRegion().front().front();
292 if (!
options.dynCastBufferizableOp(maskedOp))
298 for (
Value value : yieldOp.getOperands()) {
299 if (isa<TensorType>(value.
getType())) {
301 if (failed(maybeBuffer))
303 newResults.push_back(*maybeBuffer);
305 newResults.push_back(value);
309 replaceOpWithNewBufferizedOp<vector::YieldOp>(rewriter, op, newResults);
321 TransferReadOp::attachInterface<TransferReadOpInterface>(*ctx);
322 TransferWriteOp::attachInterface<TransferWriteOpInterface>(*ctx);
323 GatherOp::attachInterface<GatherOpInterface>(*ctx);
324 MaskOp::attachInterface<MaskOpInterface>(*ctx);
325 YieldOp::attachInterface<YieldOpInterface>(*ctx);
static llvm::ManagedStatic< PassManagerOptions > options
Base class for generic analysis states.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
IRValueT get() const
Return the current value being used by this operand.
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
result_range getOpResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
This class provides an abstraction over the various different ranges of value types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values)
Replace an op with replacement values.
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options)
Lookup the buffer for the given value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry)
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Options for BufferizableOpInterface-based bufferization.
Bufferizable ops that implement the DestinationStyleOpInterface can use this external model base clas...