MLIR  21.0.0git
UnstructuredControlFlow.h
Go to the documentation of this file.
1 //===- UnstructuredControlFlow.h - Op Interface Helpers ---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_BUFFERIZATION_IR_UNSTRUCTUREDCONTROLFLOW_H_
10 #define MLIR_DIALECT_BUFFERIZATION_IR_UNSTRUCTUREDCONTROLFLOW_H_
11 
14 
15 //===----------------------------------------------------------------------===//
16 // Helpers for Unstructured Control Flow
17 //===----------------------------------------------------------------------===//
18 
19 namespace mlir {
20 namespace bufferization {
21 
22 namespace detail {
23 /// Return a list of operands that are forwarded to the given block argument.
24 /// I.e., find all predecessors of the block argument's owner and gather the
25 /// operands that are equivalent to the block argument.
26 SmallVector<OpOperand *> getCallerOpOperands(BlockArgument bbArg);
27 } // namespace detail
28 
29 /// A template that provides a default implementation of `getAliasingOpOperands`
30 /// for ops that support unstructured control flow within their regions.
31 template <typename ConcreteModel, typename ConcreteOp>
33  : public BufferizableOpInterface::ExternalModel<ConcreteModel, ConcreteOp> {
34 
35  FailureOr<BufferLikeType>
37  const BufferizationState &state,
38  SmallVector<Value> &invocationStack) const {
39  // Note: The user may want to override this function for OpResults in
40  // case the bufferized result type is different from the bufferized type of
41  // the aliasing OpOperand (if any).
42  if (isa<OpResult>(value))
44  invocationStack);
45 
46  // Compute the buffer type of the block argument by computing the bufferized
47  // operand types of all forwarded values. If these are all the same type,
48  // take that type. Otherwise, take only the memory space and fall back to a
49  // buffer type with a fully dynamic layout map.
50  BaseMemRefType bufferType;
51  auto tensorType = cast<TensorType>(value.getType());
52  for (OpOperand *opOperand :
53  detail::getCallerOpOperands(cast<BlockArgument>(value))) {
54 
55  // If the forwarded operand is already on the invocation stack, we ran
56  // into a loop and this operand cannot be used to compute the bufferized
57  // type.
58  if (llvm::is_contained(invocationStack, opOperand->get()))
59  continue;
60 
61  // Compute the bufferized type of the forwarded operand.
62  BaseMemRefType callerType;
63  if (auto memrefType =
64  dyn_cast<BaseMemRefType>(opOperand->get().getType())) {
65  // The operand was already bufferized. Take its type directly.
66  callerType = memrefType;
67  } else {
68  FailureOr<BufferLikeType> maybeCallerType =
69  bufferization::getBufferType(opOperand->get(), options, state,
70  invocationStack);
71  if (failed(maybeCallerType))
72  return failure();
73  assert(isa<BaseMemRefType>(*maybeCallerType) && "expected memref type");
74  callerType = cast<BaseMemRefType>(*maybeCallerType);
75  }
76 
77  if (!bufferType) {
78  // This is the first buffer type that we computed.
79  bufferType = callerType;
80  continue;
81  }
82 
83  if (bufferType == callerType)
84  continue;
85 
86  // If the computed buffer type does not match the computed buffer type
87  // of the earlier forwarded operands, fall back to a buffer type with a
88  // fully dynamic layout map.
89 #ifndef NDEBUG
90  if (auto rankedTensorType = dyn_cast<RankedTensorType>(tensorType)) {
91  assert(bufferType.hasRank() && callerType.hasRank() &&
92  "expected ranked memrefs");
93  assert(llvm::all_equal({bufferType.getShape(), callerType.getShape(),
94  rankedTensorType.getShape()}) &&
95  "expected same shape");
96  } else {
97  assert(!bufferType.hasRank() && !callerType.hasRank() &&
98  "expected unranked memrefs");
99  }
100 #endif // NDEBUG
101 
102  if (bufferType.getMemorySpace() != callerType.getMemorySpace())
103  return op->emitOpError("incoming operands of block argument have "
104  "inconsistent memory spaces");
105 
107  tensorType, bufferType.getMemorySpace());
108  }
109 
110  if (!bufferType)
111  return op->emitOpError("could not infer buffer type of block argument");
112 
113  return cast<BufferLikeType>(bufferType);
114  }
115 
116 protected:
117  /// Assuming that `bbArg` is a block argument of a block that belongs to the
118  /// given `op`, return all OpOperands of users of this block that are
119  /// aliasing with the given block argument.
122  const AnalysisState &state) const {
123  assert(bbArg.getOwner()->getParentOp() == op && "invalid bbArg");
124 
125  // Gather aliasing OpOperands of all operations (callers) that link to
126  // this block.
127  AliasingOpOperandList result;
128  for (OpOperand *opOperand : detail::getCallerOpOperands(bbArg))
129  result.addAlias(
130  {opOperand, BufferRelation::Equivalent, /*isDefinite=*/false});
131 
132  return result;
133  }
134 };
135 
136 /// A template that provides a default implementation of `getAliasingValues`
137 /// for ops that implement the `BranchOpInterface`.
138 template <typename ConcreteModel, typename ConcreteOp>
140  : public BufferizableOpInterface::ExternalModel<ConcreteModel, ConcreteOp> {
142  const AnalysisState &state) const {
143  AliasingValueList result;
144  auto branchOp = cast<BranchOpInterface>(op);
145  auto operandNumber = opOperand.getOperandNumber();
146 
147  // Gather aliasing block arguments of blocks to which this op may branch to.
148  for (const auto &it : llvm::enumerate(op->getSuccessors())) {
149  Block *block = it.value();
150  SuccessorOperands operands = branchOp.getSuccessorOperands(it.index());
151  assert(operands.getProducedOperandCount() == 0 &&
152  "produced operands not supported");
153  if (operands.getForwardedOperands().empty())
154  continue;
155  // The first and last operands that are forwarded to this successor.
156  int64_t firstOperandIndex =
158  int64_t lastOperandIndex =
159  firstOperandIndex + operands.getForwardedOperands().size();
160  bool matchingDestination = operandNumber >= firstOperandIndex &&
161  operandNumber < lastOperandIndex;
162  // A branch op may have multiple successors. Find the ones that correspond
163  // to this OpOperand. (There is usually only one.)
164  if (!matchingDestination)
165  continue;
166  // Compute the matching block argument of the destination block.
167  BlockArgument bbArg =
168  block->getArgument(operandNumber - firstOperandIndex);
169  result.addAlias(
170  {bbArg, BufferRelation::Equivalent, /*isDefinite=*/false});
171  }
172 
173  return result;
174  }
175 };
176 
177 } // namespace bufferization
178 } // namespace mlir
179 
180 #endif // MLIR_DIALECT_BUFFERIZATION_IR_UNSTRUCTUREDCONTROLFLOW_H_
static llvm::ManagedStatic< PassManagerOptions > options
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:104
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
This class represents an argument of a Block.
Definition: Value.h:309
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:318
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:31
This class represents an operand of an operation.
Definition: Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:228
unsigned getBeginOperandIndex() const
Return the operand index of the first element of this range.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
SuccessorRange getSuccessors()
Definition: Operation.h:703
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:672
This class models how operands are forwarded to block arguments in control flow.
unsigned getProducedOperandCount() const
Returns the amount of operands that are produced internally by the operation.
OperandRange getForwardedOperands() const
Get the range of operands that are simply forwarded to the successor.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
AnalysisState provides a variety of helper functions for dealing with tensor values.
BufferizationState provides information about the state of the IR during the bufferization process.
FailureOr< BufferLikeType > defaultGetBufferType(Value value, const BufferizationOptions &options, const BufferizationState &state, SmallVector< Value > &invocationStack)
This is the default implementation of BufferizableOpInterface::getBufferType.
SmallVector< OpOperand * > getCallerOpOperands(BlockArgument bbArg)
Return a list of operands that are forwarded to the given block argument.
FailureOr< BufferLikeType > getBufferType(Value value, const BufferizationOptions &options, const BufferizationState &state)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
BaseMemRefType getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with fully dynamic layout.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Include the generated interface declarations.
A template that provides a default implementation of getAliasingValues for ops that implement the Bra...
AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, const AnalysisState &state) const
Options for BufferizableOpInterface-based bufferization.
A template that provides a default implementation of getAliasingOpOperands for ops that support unstr...
AliasingOpOperandList getAliasingBranchOpOperands(Operation *op, BlockArgument bbArg, const AnalysisState &state) const
Assuming that bbArg is a block argument of a block that belongs to the given op, return all OpOperand...
FailureOr< BufferLikeType > getBufferType(Operation *op, Value value, const BufferizationOptions &options, const BufferizationState &state, SmallVector< Value > &invocationStack) const