MLIR  18.0.0git
BufferizableOpInterfaceImpl.cpp
Go to the documentation of this file.
1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
14 #include "mlir/IR/Dialect.h"
15 #include "mlir/IR/Operation.h"
16 #include "mlir/IR/PatternMatch.h"
17 
18 using namespace mlir;
19 using namespace mlir::bufferization;
20 using namespace mlir::shape;
21 
22 namespace mlir {
23 namespace shape {
24 namespace {
25 
26 /// Bufferization of shape.assuming.
27 struct AssumingOpInterface
28  : public BufferizableOpInterface::ExternalModel<AssumingOpInterface,
29  shape::AssumingOp> {
31  getAliasingOpOperands(Operation *op, Value value,
32  const AnalysisState &state) const {
33  // AssumingOps do not have tensor OpOperands. The yielded value can be any
34  // SSA value that is in scope. To allow for use-def chain traversal through
35  // AssumingOps in the analysis, the corresponding yield value is considered
36  // to be aliasing with the result.
37  auto assumingOp = cast<shape::AssumingOp>(op);
38  size_t resultNum = std::distance(op->getOpResults().begin(),
39  llvm::find(op->getOpResults(), value));
40  // TODO: Support multiple blocks.
41  assert(assumingOp.getDoRegion().getBlocks().size() == 1 &&
42  "expected exactly 1 block");
43  auto yieldOp = dyn_cast<shape::AssumingYieldOp>(
44  assumingOp.getDoRegion().front().getTerminator());
45  assert(yieldOp && "expected shape.assuming_yield terminator");
46  return {{&yieldOp->getOpOperand(resultNum), BufferRelation::Equivalent}};
47  }
48 
49  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
50  const BufferizationOptions &options) const {
51  auto assumingOp = cast<shape::AssumingOp>(op);
52  assert(assumingOp.getDoRegion().getBlocks().size() == 1 &&
53  "only 1 block supported");
54  auto yieldOp = cast<shape::AssumingYieldOp>(
55  assumingOp.getDoRegion().front().getTerminator());
56 
57  // Create new op and move over region.
58  TypeRange newResultTypes(yieldOp.getOperands());
59  auto newOp = rewriter.create<shape::AssumingOp>(
60  op->getLoc(), newResultTypes, assumingOp.getWitness());
61  newOp.getDoRegion().takeBody(assumingOp.getRegion());
62 
63  // Update all uses of the old op.
64  rewriter.setInsertionPointAfter(newOp);
65  SmallVector<Value> newResults;
66  for (const auto &it : llvm::enumerate(assumingOp->getResultTypes())) {
67  if (isa<TensorType>(it.value())) {
68  newResults.push_back(rewriter.create<bufferization::ToTensorOp>(
69  assumingOp.getLoc(), newOp->getResult(it.index())));
70  } else {
71  newResults.push_back(newOp->getResult(it.index()));
72  }
73  }
74 
75  // Replace old op.
76  rewriter.replaceOp(assumingOp, newResults);
77 
78  return success();
79  }
80 };
81 
82 /// Bufferization of shape.assuming_yield. Bufferized as part of their enclosing
83 /// ops, so this is for analysis only.
84 struct AssumingYieldOpInterface
85  : public BufferizableOpInterface::ExternalModel<AssumingYieldOpInterface,
86  shape::AssumingYieldOp> {
87  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
88  const AnalysisState &state) const {
89  return true;
90  }
91 
92  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
93  const AnalysisState &state) const {
94  return false;
95  }
96 
97  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
98  const AnalysisState &state) const {
99  assert(isa<shape::AssumingOp>(op->getParentOp()) &&
100  "expected that parent is an AssumingOp");
101  OpResult opResult =
102  op->getParentOp()->getResult(opOperand.getOperandNumber());
103  return {{opResult, BufferRelation::Equivalent}};
104  }
105 
106  bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
107  const AnalysisState &state) const {
108  // Yield operands always bufferize inplace. Otherwise, an alloc + copy
109  // may be generated inside the block. We should not return/yield allocations
110  // when possible.
111  return true;
112  }
113 
114  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
115  const BufferizationOptions &options) const {
116  auto yieldOp = cast<shape::AssumingYieldOp>(op);
117  SmallVector<Value> newResults;
118  for (Value value : yieldOp.getOperands()) {
119  if (isa<TensorType>(value.getType())) {
120  FailureOr<Value> buffer = getBuffer(rewriter, value, options);
121  if (failed(buffer))
122  return failure();
123  newResults.push_back(*buffer);
124  } else {
125  newResults.push_back(value);
126  }
127  }
128  replaceOpWithNewBufferizedOp<shape::AssumingYieldOp>(rewriter, op,
129  newResults);
130  return success();
131  }
132 };
133 
134 } // namespace
135 } // namespace shape
136 } // namespace mlir
137 
139  DialectRegistry &registry) {
140  registry.addExtension(+[](MLIRContext *ctx, shape::ShapeDialect *dialect) {
141  shape::AssumingOp::attachInterface<AssumingOpInterface>(*ctx);
142  shape::AssumingYieldOp::attachInterface<AssumingYieldOpInterface>(*ctx);
143  });
144 }
static llvm::ManagedStatic< PassManagerOptions > options
Base class for generic analysis states.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:397
This class represents an operand of an operation.
Definition: Value.h:263
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:453
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
result_range getOpResults()
Definition: Operation.h:415
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options)
Lookup the buffer for the given value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Options for BufferizableOpInterface-based bufferization.