MLIR 22.0.0git
BufferizableOpInterfaceImpl.cpp
Go to the documentation of this file.
1//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
14#include "mlir/IR/Operation.h"
16
17using namespace mlir;
18using namespace mlir::bufferization;
19using namespace mlir::shape;
20
21namespace mlir {
22namespace shape {
23namespace {
24
25/// Bufferization of shape.assuming.
26struct AssumingOpInterface
27 : public BufferizableOpInterface::ExternalModel<AssumingOpInterface,
28 shape::AssumingOp> {
29 AliasingOpOperandList
30 getAliasingOpOperands(Operation *op, Value value,
31 const AnalysisState &state) const {
32 // AssumingOps do not have tensor OpOperands. The yielded value can be any
33 // SSA value that is in scope. To allow for use-def chain traversal through
34 // AssumingOps in the analysis, the corresponding yield value is considered
35 // to be aliasing with the result.
36 auto assumingOp = cast<shape::AssumingOp>(op);
37 size_t resultNum = std::distance(op->getOpResults().begin(),
38 llvm::find(op->getOpResults(), value));
39 // TODO: Support multiple blocks.
40 assert(assumingOp.getDoRegion().hasOneBlock() &&
41 "expected exactly 1 block");
42 auto yieldOp = dyn_cast<shape::AssumingYieldOp>(
43 assumingOp.getDoRegion().front().getTerminator());
44 assert(yieldOp && "expected shape.assuming_yield terminator");
45 return {{&yieldOp->getOpOperand(resultNum), BufferRelation::Equivalent}};
46 }
47
48 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
49 const BufferizationOptions &options,
50 BufferizationState &state) const {
51 auto assumingOp = cast<shape::AssumingOp>(op);
52 assert(assumingOp.getDoRegion().hasOneBlock() && "only 1 block supported");
53 auto yieldOp = cast<shape::AssumingYieldOp>(
54 assumingOp.getDoRegion().front().getTerminator());
55
56 // Create new op and move over region.
57 TypeRange newResultTypes(yieldOp.getOperands());
58 auto newOp = shape::AssumingOp::create(
59 rewriter, op->getLoc(), newResultTypes, assumingOp.getWitness());
60 newOp.getDoRegion().takeBody(assumingOp.getRegion());
61
62 // Update all uses of the old op.
63 rewriter.setInsertionPointAfter(newOp);
64 SmallVector<Value> newResults;
65 for (const auto &it : llvm::enumerate(assumingOp->getResultTypes())) {
66 if (isa<TensorType>(it.value())) {
67 newResults.push_back(bufferization::ToTensorOp::create(
68 rewriter, assumingOp.getLoc(), it.value(),
69 newOp->getResult(it.index())));
70 } else {
71 newResults.push_back(newOp->getResult(it.index()));
72 }
73 }
74
75 // Replace old op.
76 rewriter.replaceOp(assumingOp, newResults);
77
78 return success();
79 }
80};
81
82/// Bufferization of shape.assuming_yield. Bufferized as part of their enclosing
83/// ops, so this is for analysis only.
84struct AssumingYieldOpInterface
85 : public BufferizableOpInterface::ExternalModel<AssumingYieldOpInterface,
86 shape::AssumingYieldOp> {
87 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
88 const AnalysisState &state) const {
89 return true;
90 }
91
92 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
93 const AnalysisState &state) const {
94 return false;
95 }
96
97 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
98 const AnalysisState &state) const {
99 assert(isa<shape::AssumingOp>(op->getParentOp()) &&
100 "expected that parent is an AssumingOp");
101 OpResult opResult =
102 op->getParentOp()->getResult(opOperand.getOperandNumber());
103 return {{opResult, BufferRelation::Equivalent}};
104 }
105
106 bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
107 const AnalysisState &state) const {
108 // Yield operands always bufferize inplace. Otherwise, an alloc + copy
109 // may be generated inside the block. We should not return/yield allocations
110 // when possible.
111 return true;
112 }
113
114 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
115 const BufferizationOptions &options,
116 BufferizationState &state) const {
117 auto yieldOp = cast<shape::AssumingYieldOp>(op);
118 SmallVector<Value> newResults;
119 for (Value value : yieldOp.getOperands()) {
120 if (isa<TensorType>(value.getType())) {
121 FailureOr<Value> buffer = getBuffer(rewriter, value, options, state);
122 if (failed(buffer))
123 return failure();
124 newResults.push_back(*buffer);
125 } else {
126 newResults.push_back(value);
127 }
128 }
129 replaceOpWithNewBufferizedOp<shape::AssumingYieldOp>(rewriter, op,
130 newResults);
131 return success();
132 }
133};
134
135} // namespace
136} // namespace shape
137} // namespace mlir
138
140 DialectRegistry &registry) {
141 registry.addExtension(+[](MLIRContext *ctx, shape::ShapeDialect *dialect) {
142 shape::AssumingOp::attachInterface<AssumingOpInterface>(*ctx);
143 shape::AssumingYieldOp::attachInterface<AssumingYieldOpInterface>(*ctx);
144 });
145}
return success()
static llvm::ManagedStatic< PassManagerOptions > options
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
result_range getOpResults()
Definition Operation.h:420
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Type getType() const
Return the type of this value.
Definition Value.h:105
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry)
Include the generated interface declarations.