MLIR  19.0.0git
BufferizableOpInterfaceImpl.cpp
Go to the documentation of this file.
1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
14 #include "mlir/IR/Dialect.h"
15 #include "mlir/IR/Operation.h"
16 #include "mlir/IR/PatternMatch.h"
17 
18 using namespace mlir;
19 using namespace mlir::bufferization;
20 using namespace mlir::shape;
21 
22 namespace mlir {
23 namespace shape {
24 namespace {
25 
26 /// Bufferization of shape.assuming.
27 struct AssumingOpInterface
28  : public BufferizableOpInterface::ExternalModel<AssumingOpInterface,
29  shape::AssumingOp> {
31  getAliasingOpOperands(Operation *op, Value value,
32  const AnalysisState &state) const {
33  // AssumingOps do not have tensor OpOperands. The yielded value can be any
34  // SSA value that is in scope. To allow for use-def chain traversal through
35  // AssumingOps in the analysis, the corresponding yield value is considered
36  // to be aliasing with the result.
37  auto assumingOp = cast<shape::AssumingOp>(op);
38  size_t resultNum = std::distance(op->getOpResults().begin(),
39  llvm::find(op->getOpResults(), value));
40  // TODO: Support multiple blocks.
41  assert(assumingOp.getDoRegion().getBlocks().size() == 1 &&
42  "expected exactly 1 block");
43  auto yieldOp = dyn_cast<shape::AssumingYieldOp>(
44  assumingOp.getDoRegion().front().getTerminator());
45  assert(yieldOp && "expected shape.assuming_yield terminator");
46  return {{&yieldOp->getOpOperand(resultNum), BufferRelation::Equivalent}};
47  }
48 
49  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
50  const BufferizationOptions &options) const {
51  auto assumingOp = cast<shape::AssumingOp>(op);
52  assert(assumingOp.getDoRegion().getBlocks().size() == 1 &&
53  "only 1 block supported");
54  auto yieldOp = cast<shape::AssumingYieldOp>(
55  assumingOp.getDoRegion().front().getTerminator());
56 
57  // Create new op and move over region.
58  TypeRange newResultTypes(yieldOp.getOperands());
59  auto newOp = rewriter.create<shape::AssumingOp>(
60  op->getLoc(), newResultTypes, assumingOp.getWitness());
61  newOp.getDoRegion().takeBody(assumingOp.getRegion());
62 
63  // Update all uses of the old op.
64  rewriter.setInsertionPointAfter(newOp);
65  SmallVector<Value> newResults;
66  for (const auto &it : llvm::enumerate(assumingOp->getResultTypes())) {
67  if (isa<TensorType>(it.value())) {
68  newResults.push_back(rewriter.create<bufferization::ToTensorOp>(
69  assumingOp.getLoc(), newOp->getResult(it.index())));
70  } else {
71  newResults.push_back(newOp->getResult(it.index()));
72  }
73  }
74 
75  // Replace old op.
76  rewriter.replaceOp(assumingOp, newResults);
77 
78  return success();
79  }
80 };
81 
82 /// Bufferization of shape.assuming_yield. Bufferized as part of their enclosing
83 /// ops, so this is for analysis only.
84 struct AssumingYieldOpInterface
85  : public BufferizableOpInterface::ExternalModel<AssumingYieldOpInterface,
86  shape::AssumingYieldOp> {
87  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
88  const AnalysisState &state) const {
89  return true;
90  }
91 
92  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
93  const AnalysisState &state) const {
94  return false;
95  }
96 
97  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
98  const AnalysisState &state) const {
99  assert(isa<shape::AssumingOp>(op->getParentOp()) &&
100  "expected that parent is an AssumingOp");
101  OpResult opResult =
102  op->getParentOp()->getResult(opOperand.getOperandNumber());
103  return {{opResult, BufferRelation::Equivalent}};
104  }
105 
106  bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
107  const AnalysisState &state) const {
108  // Yield operands always bufferize inplace. Otherwise, an alloc + copy
109  // may be generated inside the block. We should not return/yield allocations
110  // when possible.
111  return true;
112  }
113 
114  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
115  const BufferizationOptions &options) const {
116  auto yieldOp = cast<shape::AssumingYieldOp>(op);
117  SmallVector<Value> newResults;
118  for (Value value : yieldOp.getOperands()) {
119  if (isa<TensorType>(value.getType())) {
120  FailureOr<Value> buffer = getBuffer(rewriter, value, options);
121  if (failed(buffer))
122  return failure();
123  newResults.push_back(*buffer);
124  } else {
125  newResults.push_back(value);
126  }
127  }
128  replaceOpWithNewBufferizedOp<shape::AssumingYieldOp>(rewriter, op,
129  newResults);
130  return success();
131  }
132 };
133 
134 } // namespace
135 } // namespace shape
136 } // namespace mlir
137 
139  DialectRegistry &registry) {
140  registry.addExtension(+[](MLIRContext *ctx, shape::ShapeDialect *dialect) {
141  shape::AssumingOp::attachInterface<AssumingOpInterface>(*ctx);
142  shape::AssumingYieldOp::attachInterface<AssumingYieldOpInterface>(*ctx);
143  });
144 }
static llvm::ManagedStatic< PassManagerOptions > options
Base class for generic analysis states.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:414
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:457
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
result_range getOpResults()
Definition: Operation.h:415
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options)
Lookup the buffer for the given value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Options for BufferizableOpInterface-based bufferization.