26struct AssumingOpInterface
27 :
public BufferizableOpInterface::ExternalModel<AssumingOpInterface,
30 getAliasingOpOperands(Operation *op, Value value,
31 const AnalysisState &state)
const {
36 auto assumingOp = cast<shape::AssumingOp>(op);
37 size_t resultNum = std::distance(op->
getOpResults().begin(),
40 assert(assumingOp.getDoRegion().hasOneBlock() &&
41 "expected exactly 1 block");
42 auto yieldOp = dyn_cast<shape::AssumingYieldOp>(
43 assumingOp.getDoRegion().front().getTerminator());
44 assert(yieldOp &&
"expected shape.assuming_yield terminator");
45 return {{&yieldOp->getOpOperand(resultNum), BufferRelation::Equivalent}};
48 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
49 const BufferizationOptions &
options,
50 BufferizationState &state)
const {
51 auto assumingOp = cast<shape::AssumingOp>(op);
52 assert(assumingOp.getDoRegion().hasOneBlock() &&
"only 1 block supported");
53 auto yieldOp = cast<shape::AssumingYieldOp>(
54 assumingOp.getDoRegion().front().getTerminator());
57 TypeRange newResultTypes(yieldOp.getOperands());
58 auto newOp = shape::AssumingOp::create(
59 rewriter, op->
getLoc(), newResultTypes, assumingOp.getWitness());
60 newOp.getDoRegion().takeBody(assumingOp.getRegion());
64 SmallVector<Value> newResults;
65 for (
const auto &it : llvm::enumerate(assumingOp->getResultTypes())) {
66 if (isa<TensorType>(it.value())) {
67 newResults.push_back(bufferization::ToTensorOp::create(
68 rewriter, assumingOp.getLoc(), it.value(),
69 newOp->getResult(it.index())));
71 newResults.push_back(newOp->getResult(it.index()));
76 rewriter.
replaceOp(assumingOp, newResults);
84struct AssumingYieldOpInterface
85 :
public BufferizableOpInterface::ExternalModel<AssumingYieldOpInterface,
86 shape::AssumingYieldOp> {
87 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
88 const AnalysisState &state)
const {
92 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
93 const AnalysisState &state)
const {
97 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
98 const AnalysisState &state)
const {
100 "expected that parent is an AssumingOp");
103 return {{opResult, BufferRelation::Equivalent}};
106 bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
107 const AnalysisState &state)
const {
114 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
115 const BufferizationOptions &
options,
116 BufferizationState &state)
const {
117 auto yieldOp = cast<shape::AssumingYieldOp>(op);
118 SmallVector<Value> newResults;
119 for (Value value : yieldOp.getOperands()) {
120 if (isa<TensorType>(value.
getType())) {
121 FailureOr<Value> buffer = getBuffer(rewriter, value,
options, state);
124 newResults.push_back(*buffer);
126 newResults.push_back(value);
129 replaceOpWithNewBufferizedOp<shape::AssumingYieldOp>(rewriter, op,
142 shape::AssumingOp::attachInterface<AssumingOpInterface>(*ctx);
143 shape::AssumingYieldOp::attachInterface<AssumingYieldOpInterface>(*ctx);
static llvm::ManagedStatic< PassManagerOptions > options
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
result_range getOpResults()
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Type getType() const
Return the type of this value.
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry)
Include the generated interface declarations.