27 struct AssumingOpInterface
28 :
public BufferizableOpInterface::ExternalModel<AssumingOpInterface,
37 auto assumingOp = cast<shape::AssumingOp>(op);
38 size_t resultNum = std::distance(op->
getOpResults().begin(),
41 assert(assumingOp.getDoRegion().getBlocks().size() == 1 &&
42 "expected exactly 1 block");
43 auto yieldOp = dyn_cast<shape::AssumingYieldOp>(
44 assumingOp.getDoRegion().front().getTerminator());
45 assert(yieldOp &&
"expected shape.assuming_yield terminator");
51 auto assumingOp = cast<shape::AssumingOp>(op);
52 assert(assumingOp.getDoRegion().getBlocks().size() == 1 &&
53 "only 1 block supported");
54 auto yieldOp = cast<shape::AssumingYieldOp>(
55 assumingOp.getDoRegion().front().getTerminator());
58 TypeRange newResultTypes(yieldOp.getOperands());
59 auto newOp = rewriter.
create<shape::AssumingOp>(
60 op->
getLoc(), newResultTypes, assumingOp.getWitness());
61 newOp.getDoRegion().takeBody(assumingOp.getRegion());
67 if (isa<TensorType>(it.value())) {
68 newResults.push_back(rewriter.
create<bufferization::ToTensorOp>(
69 assumingOp.getLoc(), newOp->getResult(it.index())));
71 newResults.push_back(newOp->getResult(it.index()));
76 rewriter.
replaceOp(assumingOp, newResults);
84 struct AssumingYieldOpInterface
85 :
public BufferizableOpInterface::ExternalModel<AssumingYieldOpInterface,
86 shape::AssumingYieldOp> {
100 "expected that parent is an AssumingOp");
116 auto yieldOp = cast<shape::AssumingYieldOp>(op);
118 for (
Value value : yieldOp.getOperands()) {
119 if (isa<TensorType>(value.
getType())) {
123 newResults.push_back(*buffer);
125 newResults.push_back(value);
128 replaceOpWithNewBufferizedOp<shape::AssumingYieldOp>(rewriter, op,
141 shape::AssumingOp::attachInterface<AssumingOpInterface>(*ctx);
142 shape::AssumingYieldOp::attachInterface<AssumingYieldOpInterface>(*ctx);
static llvm::ManagedStatic< PassManagerOptions > options
Base class for generic analysis states.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
result_range getOpResults()
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class provides an abstraction over the various different ranges of value types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options)
Lookup the buffer for the given value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry)
Include the generated interface declarations.
Options for BufferizableOpInterface-based bufferization.