MLIR  22.0.0git
BufferizableOpInterfaceImpl.cpp
Go to the documentation of this file.
1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
14 #include "mlir/IR/Operation.h"
15 #include "mlir/IR/PatternMatch.h"
16 
17 using namespace mlir;
18 using namespace mlir::bufferization;
19 using namespace mlir::shape;
20 
21 namespace mlir {
22 namespace shape {
23 namespace {
24 
25 /// Bufferization of shape.assuming.
26 struct AssumingOpInterface
27  : public BufferizableOpInterface::ExternalModel<AssumingOpInterface,
28  shape::AssumingOp> {
30  getAliasingOpOperands(Operation *op, Value value,
31  const AnalysisState &state) const {
32  // AssumingOps do not have tensor OpOperands. The yielded value can be any
33  // SSA value that is in scope. To allow for use-def chain traversal through
34  // AssumingOps in the analysis, the corresponding yield value is considered
35  // to be aliasing with the result.
36  auto assumingOp = cast<shape::AssumingOp>(op);
37  size_t resultNum = std::distance(op->getOpResults().begin(),
38  llvm::find(op->getOpResults(), value));
39  // TODO: Support multiple blocks.
40  assert(assumingOp.getDoRegion().hasOneBlock() &&
41  "expected exactly 1 block");
42  auto yieldOp = dyn_cast<shape::AssumingYieldOp>(
43  assumingOp.getDoRegion().front().getTerminator());
44  assert(yieldOp && "expected shape.assuming_yield terminator");
45  return {{&yieldOp->getOpOperand(resultNum), BufferRelation::Equivalent}};
46  }
47 
48  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
50  BufferizationState &state) const {
51  auto assumingOp = cast<shape::AssumingOp>(op);
52  assert(assumingOp.getDoRegion().hasOneBlock() && "only 1 block supported");
53  auto yieldOp = cast<shape::AssumingYieldOp>(
54  assumingOp.getDoRegion().front().getTerminator());
55 
56  // Create new op and move over region.
57  TypeRange newResultTypes(yieldOp.getOperands());
58  auto newOp = shape::AssumingOp::create(
59  rewriter, op->getLoc(), newResultTypes, assumingOp.getWitness());
60  newOp.getDoRegion().takeBody(assumingOp.getRegion());
61 
62  // Update all uses of the old op.
63  rewriter.setInsertionPointAfter(newOp);
64  SmallVector<Value> newResults;
65  for (const auto &it : llvm::enumerate(assumingOp->getResultTypes())) {
66  if (isa<TensorType>(it.value())) {
67  newResults.push_back(bufferization::ToTensorOp::create(
68  rewriter, assumingOp.getLoc(), it.value(),
69  newOp->getResult(it.index())));
70  } else {
71  newResults.push_back(newOp->getResult(it.index()));
72  }
73  }
74 
75  // Replace old op.
76  rewriter.replaceOp(assumingOp, newResults);
77 
78  return success();
79  }
80 };
81 
82 /// Bufferization of shape.assuming_yield. Bufferized as part of their enclosing
83 /// ops, so this is for analysis only.
84 struct AssumingYieldOpInterface
85  : public BufferizableOpInterface::ExternalModel<AssumingYieldOpInterface,
86  shape::AssumingYieldOp> {
87  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
88  const AnalysisState &state) const {
89  return true;
90  }
91 
92  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
93  const AnalysisState &state) const {
94  return false;
95  }
96 
97  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
98  const AnalysisState &state) const {
99  assert(isa<shape::AssumingOp>(op->getParentOp()) &&
100  "expected that parent is an AssumingOp");
101  OpResult opResult =
102  op->getParentOp()->getResult(opOperand.getOperandNumber());
103  return {{opResult, BufferRelation::Equivalent}};
104  }
105 
106  bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
107  const AnalysisState &state) const {
108  // Yield operands always bufferize inplace. Otherwise, an alloc + copy
109  // may be generated inside the block. We should not return/yield allocations
110  // when possible.
111  return true;
112  }
113 
114  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
116  BufferizationState &state) const {
117  auto yieldOp = cast<shape::AssumingYieldOp>(op);
118  SmallVector<Value> newResults;
119  for (Value value : yieldOp.getOperands()) {
120  if (isa<TensorType>(value.getType())) {
121  FailureOr<Value> buffer = getBuffer(rewriter, value, options, state);
122  if (failed(buffer))
123  return failure();
124  newResults.push_back(*buffer);
125  } else {
126  newResults.push_back(value);
127  }
128  }
129  replaceOpWithNewBufferizedOp<shape::AssumingYieldOp>(rewriter, op,
130  newResults);
131  return success();
132  }
133 };
134 
135 } // namespace
136 } // namespace shape
137 } // namespace mlir
138 
140  DialectRegistry &registry) {
141  registry.addExtension(+[](MLIRContext *ctx, shape::ShapeDialect *dialect) {
142  shape::AssumingOp::attachInterface<AssumingOpInterface>(*ctx);
143  shape::AssumingYieldOp::attachInterface<AssumingYieldOpInterface>(*ctx);
144  });
145 }
static llvm::ManagedStatic< PassManagerOptions > options
Base class for generic analysis states.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
This class represents an operand of an operation.
Definition: Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:226
This is a value defined by a result of an operation.
Definition: Value.h:447
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
result_range getOpResults()
Definition: Operation.h:420
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
BufferizationState provides information about the state of the IR during the bufferization process.
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options, const BufferizationState &state)
Lookup the buffer for the given value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry)
Include the generated interface declarations.
Options for BufferizableOpInterface-based bufferization.