MLIR  20.0.0git
UnstructuredControlFlow.h
Go to the documentation of this file.
1 //===- UnstructuredControlFlow.h - Op Interface Helpers ---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_BUFFERIZATION_IR_UNSTRUCTUREDCONTROLFLOW_H_
10 #define MLIR_DIALECT_BUFFERIZATION_IR_UNSTRUCTUREDCONTROLFLOW_H_
11 
14 
15 //===----------------------------------------------------------------------===//
16 // Helpers for Unstructured Control Flow
17 //===----------------------------------------------------------------------===//
18 
19 namespace mlir {
20 namespace bufferization {
21 
22 namespace detail {
23 /// Return a list of operands that are forwarded to the given block argument.
24 /// I.e., find all predecessors of the block argument's owner and gather the
25 /// operands that are equivalent to the block argument.
26 SmallVector<OpOperand *> getCallerOpOperands(BlockArgument bbArg);
27 } // namespace detail
28 
29 /// A template that provides a default implementation of `getAliasingOpOperands`
30 /// for ops that support unstructured control flow within their regions.
31 template <typename ConcreteModel, typename ConcreteOp>
33  : public BufferizableOpInterface::ExternalModel<ConcreteModel, ConcreteOp> {
34 
35  FailureOr<BaseMemRefType>
37  SmallVector<Value> &invocationStack) const {
38  // Note: The user may want to override this function for OpResults in
39  // case the bufferized result type is different from the bufferized type of
40  // the aliasing OpOperand (if any).
41  if (isa<OpResult>(value))
43  invocationStack);
44 
45  // Compute the buffer type of the block argument by computing the bufferized
46  // operand types of all forwarded values. If these are all the same type,
47  // take that type. Otherwise, take only the memory space and fall back to a
48  // buffer type with a fully dynamic layout map.
49  BaseMemRefType bufferType;
50  auto tensorType = cast<TensorType>(value.getType());
51  for (OpOperand *opOperand :
52  detail::getCallerOpOperands(cast<BlockArgument>(value))) {
53 
54  // If the forwarded operand is already on the invocation stack, we ran
55  // into a loop and this operand cannot be used to compute the bufferized
56  // type.
57  if (llvm::find(invocationStack, opOperand->get()) !=
58  invocationStack.end())
59  continue;
60 
61  // Compute the bufferized type of the forwarded operand.
62  BaseMemRefType callerType;
63  if (auto memrefType =
64  dyn_cast<BaseMemRefType>(opOperand->get().getType())) {
65  // The operand was already bufferized. Take its type directly.
66  callerType = memrefType;
67  } else {
68  FailureOr<BaseMemRefType> maybeCallerType =
69  bufferization::getBufferType(opOperand->get(), options,
70  invocationStack);
71  if (failed(maybeCallerType))
72  return failure();
73  callerType = *maybeCallerType;
74  }
75 
76  if (!bufferType) {
77  // This is the first buffer type that we computed.
78  bufferType = callerType;
79  continue;
80  }
81 
82  if (bufferType == callerType)
83  continue;
84 
85  // If the computed buffer type does not match the computed buffer type
86  // of the earlier forwarded operands, fall back to a buffer type with a
87  // fully dynamic layout map.
88 #ifndef NDEBUG
89  if (auto rankedTensorType = dyn_cast<RankedTensorType>(tensorType)) {
90  assert(bufferType.hasRank() && callerType.hasRank() &&
91  "expected ranked memrefs");
92  assert(llvm::all_equal({bufferType.getShape(), callerType.getShape(),
93  rankedTensorType.getShape()}) &&
94  "expected same shape");
95  } else {
96  assert(!bufferType.hasRank() && !callerType.hasRank() &&
97  "expected unranked memrefs");
98  }
99 #endif // NDEBUG
100 
101  if (bufferType.getMemorySpace() != callerType.getMemorySpace())
102  return op->emitOpError("incoming operands of block argument have "
103  "inconsistent memory spaces");
104 
106  tensorType, bufferType.getMemorySpace());
107  }
108 
109  if (!bufferType)
110  return op->emitOpError("could not infer buffer type of block argument");
111 
112  return bufferType;
113  }
114 
115 protected:
116  /// Assuming that `bbArg` is a block argument of a block that belongs to the
117  /// given `op`, return all OpOperands of users of this block that are
118  /// aliasing with the given block argument.
121  const AnalysisState &state) const {
122  assert(bbArg.getOwner()->getParentOp() == op && "invalid bbArg");
123 
124  // Gather aliasing OpOperands of all operations (callers) that link to
125  // this block.
126  AliasingOpOperandList result;
127  for (OpOperand *opOperand : detail::getCallerOpOperands(bbArg))
128  result.addAlias(
129  {opOperand, BufferRelation::Equivalent, /*isDefinite=*/false});
130 
131  return result;
132  }
133 };
134 
135 /// A template that provides a default implementation of `getAliasingValues`
136 /// for ops that implement the `BranchOpInterface`.
137 template <typename ConcreteModel, typename ConcreteOp>
139  : public BufferizableOpInterface::ExternalModel<ConcreteModel, ConcreteOp> {
141  const AnalysisState &state) const {
142  AliasingValueList result;
143  auto branchOp = cast<BranchOpInterface>(op);
144  auto operandNumber = opOperand.getOperandNumber();
145 
146  // Gather aliasing block arguments of blocks to which this op may branch to.
147  for (const auto &it : llvm::enumerate(op->getSuccessors())) {
148  Block *block = it.value();
149  SuccessorOperands operands = branchOp.getSuccessorOperands(it.index());
150  assert(operands.getProducedOperandCount() == 0 &&
151  "produced operands not supported");
152  if (operands.getForwardedOperands().empty())
153  continue;
154  // The first and last operands that are forwarded to this successor.
155  int64_t firstOperandIndex =
157  int64_t lastOperandIndex =
158  firstOperandIndex + operands.getForwardedOperands().size();
159  bool matchingDestination = operandNumber >= firstOperandIndex &&
160  operandNumber < lastOperandIndex;
161  // A branch op may have multiple successors. Find the ones that correspond
162  // to this OpOperand. (There is usually only one.)
163  if (!matchingDestination)
164  continue;
165  // Compute the matching block argument of the destination block.
166  BlockArgument bbArg =
167  block->getArgument(operandNumber - firstOperandIndex);
168  result.addAlias(
169  {bbArg, BufferRelation::Equivalent, /*isDefinite=*/false});
170  }
171 
172  return result;
173  }
174 };
175 
176 } // namespace bufferization
177 } // namespace mlir
178 
179 #endif // MLIR_DIALECT_BUFFERIZATION_IR_UNSTRUCTUREDCONTROLFLOW_H_
static llvm::ManagedStatic< PassManagerOptions > options
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:149
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
This class represents an argument of a Block.
Definition: Value.h:319
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:328
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:33
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
unsigned getBeginOperandIndex() const
Return the operand index of the first element of this range.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
SuccessorRange getSuccessors()
Definition: Operation.h:699
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
This class models how operands are forwarded to block arguments in control flow.
unsigned getProducedOperandCount() const
Returns the amount of operands that are produced internally by the operation.
OperandRange getForwardedOperands() const
Get the range of operands that are simply forwarded to the successor.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
AnalysisState provides a variety of helper functions for dealing with tensor values.
SmallVector< OpOperand * > getCallerOpOperands(BlockArgument bbArg)
Return a list of operands that are forwarded to the given block argument.
FailureOr< BaseMemRefType > defaultGetBufferType(Value value, const BufferizationOptions &options, SmallVector< Value > &invocationStack)
This is the default implementation of BufferizableOpInterface::getBufferType.
FailureOr< BaseMemRefType > getBufferType(Value value, const BufferizationOptions &options)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
BaseMemRefType getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with fully dynamic layout.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Include the generated interface declarations.
A template that provides a default implementation of getAliasingValues for ops that implement the Bra...
AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, const AnalysisState &state) const
Options for BufferizableOpInterface-based bufferization.
A template that provides a default implementation of getAliasingOpOperands for ops that support unstr...
FailureOr< BaseMemRefType > getBufferType(Operation *op, Value value, const BufferizationOptions &options, SmallVector< Value > &invocationStack) const
AliasingOpOperandList getAliasingBranchOpOperands(Operation *op, BlockArgument bbArg, const AnalysisState &state) const
Assuming that bbArg is a block argument of a block that belongs to the given op, return all OpOperand...