28 struct TransferReadOpInterface
29 :
public BufferizableOpInterface::ExternalModel<TransferReadOpInterface,
30 vector::TransferReadOp> {
34 "only tensor types expected");
41 "only tensor types expected");
52 auto readOp = cast<vector::TransferReadOp>(op);
53 assert(readOp.getShapedType().isa<
TensorType>() &&
54 "only tensor types expected");
58 replaceOpWithNewBufferizedOp<vector::TransferReadOp>(
59 rewriter, readOp, readOp.getVectorType(), *buffer, readOp.getIndices(),
60 readOp.getPermutationMap(), readOp.getPadding(), readOp.getMask(),
61 readOp.getInBoundsAttr());
71 struct TransferWriteOpInterface
73 vector::TransferWriteOp> {
76 auto writeOp = cast<vector::TransferWriteOp>(op);
77 assert(writeOp.getShapedType().isa<
TensorType>() &&
78 "only tensor types expected");
85 rewriter.
create<vector::TransferWriteOp>(
86 writeOp.getLoc(), writeOp.getVector(), *resultBuffer,
87 writeOp.getIndices(), writeOp.getPermutationMapAttr(),
88 writeOp.getMask(), writeOp.getInBoundsAttr());
97 struct GatherOpInterface
98 :
public BufferizableOpInterface::ExternalModel<GatherOpInterface,
103 "only tensor types expected");
110 "only tensor types expected");
121 auto gatherOp = cast<vector::GatherOp>(op);
122 assert(gatherOp.getBaseType().isa<
TensorType>() &&
123 "only tensor types expected");
127 replaceOpWithNewBufferizedOp<vector::GatherOp>(
128 rewriter, gatherOp, gatherOp.getVectorType(), *buffer,
129 gatherOp.getIndices(), gatherOp.getIndexVec(), gatherOp.getMask(),
130 gatherOp.getPassThru());
137 struct MaskOpInterface
138 :
public BufferizableOpInterface::ExternalModel<MaskOpInterface,
145 auto maskOp = cast<vector::MaskOp>(op);
146 size_t resultNum = std::distance(op->
getOpResults().begin(),
149 cast<vector::YieldOp>(maskOp.getMaskRegion().front().getTerminator());
155 auto bufferizableOp = cast<BufferizableOpInterface>(op);
156 if (
failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
163 auto maskOp = cast<vector::MaskOp>(op);
164 if (!maskOp.getMaskRegion()
166 .getOps<bufferization::AllocTensorOp>()
168 return op->
emitOpError(
"body must bufferize in-place");
175 auto maskOp = cast<vector::MaskOp>(op);
178 Operation *maskedOp = maskOp.getMaskableOp();
179 if (!
options.dynCastBufferizableOp(maskedOp))
185 cast<vector::YieldOp>(maskOp.getMaskRegion().front().getTerminator());
191 newYieldedValues.push_back(it.value());
195 newReturnValues[it.index()] = it.value();
199 yieldOp.getOperandsMutable().assign(newYieldedValues);
203 ValueRange newYieldedValuesRange(newYieldedValues);
204 TypeRange newResultTypes(newYieldedValuesRange);
205 auto newOp = rewriter.
create<vector::MaskOp>(
206 op->
getLoc(), newResultTypes, maskOp.getMask(), maskOp.getPassthru(),
213 for (
int i = 0; i < static_cast<int>(maskOp->getNumResults()); ++i) {
214 if (!newReturnValues[i])
215 newReturnValues[i] = newOp->getResult(idx++);
224 struct YieldOpInterface
225 :
public BufferizableOpInterface::ExternalModel<YieldOpInterface,
253 auto yieldOp = cast<vector::YieldOp>(op);
256 auto maskOp = dyn_cast<vector::MaskOp>(yieldOp->getParentOp());
258 return yieldOp->emitError(
"unsupported vector::YieldOp parent");
261 Operation *maskedOp = &maskOp.getMaskRegion().front().front();
262 if (!
options.dynCastBufferizableOp(maskedOp))
268 for (
Value value : yieldOp.getOperands()) {
273 newResults.push_back(*maybeBuffer);
275 newResults.push_back(value);
279 replaceOpWithNewBufferizedOp<vector::YieldOp>(rewriter, op, newResults);
291 TransferReadOp::attachInterface<TransferReadOpInterface>(*ctx);
292 TransferWriteOp::attachInterface<TransferWriteOpInterface>(*ctx);
293 GatherOp::attachInterface<GatherOpInterface>(*ctx);
294 MaskOp::attachInterface<MaskOpInterface>(*ctx);
295 YieldOp::attachInterface<YieldOpInterface>(*ctx);
static llvm::ManagedStatic< PassManagerOptions > options
Base class for generic analysis states.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class provides support for representing a failure result, or a valid value of type T.
IRValueT get() const
Return the current value being used by this operand.
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
result_range getOpResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
This class provides an abstraction over the various different ranges of value types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values)
Replace an op with replacement values.
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options)
Lookup the buffer for the given value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry)
This header declares functions that assit transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
Options for BufferizableOpInterface-based bufferization.
Bufferizable ops that implement the DestinationStyleOpInterface can use this external model base clas...