MLIR 22.0.0git
UnstructuredControlFlow.h
Go to the documentation of this file.
1//===- UnstructuredControlFlow.h - Op Interface Helpers ---------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_DIALECT_BUFFERIZATION_IR_UNSTRUCTUREDCONTROLFLOW_H_
10#define MLIR_DIALECT_BUFFERIZATION_IR_UNSTRUCTUREDCONTROLFLOW_H_
11
14
15//===----------------------------------------------------------------------===//
16// Helpers for Unstructured Control Flow
17//===----------------------------------------------------------------------===//
18
19namespace mlir {
20namespace bufferization {
21
22namespace detail {
23/// Return a list of operands that are forwarded to the given block argument.
24/// I.e., find all predecessors of the block argument's owner and gather the
25/// operands that are equivalent to the block argument.
27} // namespace detail
28
29/// A template that provides a default implementation of `getAliasingOpOperands`
30/// for ops that support unstructured control flow within their regions.
31template <typename ConcreteModel, typename ConcreteOp>
33 : public BufferizableOpInterface::ExternalModel<ConcreteModel, ConcreteOp> {
34
35 FailureOr<BufferLikeType>
37 const BufferizationState &state,
38 SmallVector<Value> &invocationStack) const {
39 // Note: The user may want to override this function for OpResults in
40 // case the bufferized result type is different from the bufferized type of
41 // the aliasing OpOperand (if any).
42 if (isa<OpResult>(value))
43 return bufferization::detail::defaultGetBufferType(value, options, state,
44 invocationStack);
45
46 // Compute the buffer type of the block argument by computing the bufferized
47 // operand types of all forwarded values. If these are all the same type,
48 // take that type. Otherwise, take only the memory space and fall back to a
49 // buffer type with a fully dynamic layout map.
50 BaseMemRefType bufferType;
51 auto tensorType = cast<TensorType>(value.getType());
52 for (OpOperand *opOperand :
53 detail::getCallerOpOperands(cast<BlockArgument>(value))) {
54
55 // If the forwarded operand is already on the invocation stack, we ran
56 // into a loop and this operand cannot be used to compute the bufferized
57 // type.
58 if (llvm::is_contained(invocationStack, opOperand->get()))
59 continue;
60
61 // Compute the bufferized type of the forwarded operand.
62 BaseMemRefType callerType;
63 if (auto memrefType =
64 dyn_cast<BaseMemRefType>(opOperand->get().getType())) {
65 // The operand was already bufferized. Take its type directly.
66 callerType = memrefType;
67 } else {
68 FailureOr<BufferLikeType> maybeCallerType =
69 bufferization::getBufferType(opOperand->get(), options, state,
70 invocationStack);
71 if (failed(maybeCallerType))
72 return failure();
73 assert(isa<BaseMemRefType>(*maybeCallerType) && "expected memref type");
74 callerType = cast<BaseMemRefType>(*maybeCallerType);
75 }
76
77 if (!bufferType) {
78 // This is the first buffer type that we computed.
79 bufferType = callerType;
80 continue;
81 }
82
83 if (bufferType == callerType)
84 continue;
85
86 // If the computed buffer type does not match the computed buffer type
87 // of the earlier forwarded operands, fall back to a buffer type with a
88 // fully dynamic layout map.
89#ifndef NDEBUG
90 if (auto rankedTensorType = dyn_cast<RankedTensorType>(tensorType)) {
91 assert(bufferType.hasRank() && callerType.hasRank() &&
92 "expected ranked memrefs");
93 assert(llvm::all_equal({bufferType.getShape(), callerType.getShape(),
94 rankedTensorType.getShape()}) &&
95 "expected same shape");
96 } else {
97 assert(!bufferType.hasRank() && !callerType.hasRank() &&
98 "expected unranked memrefs");
99 }
100#endif // NDEBUG
101
102 if (bufferType.getMemorySpace() != callerType.getMemorySpace())
103 return op->emitOpError("incoming operands of block argument have "
104 "inconsistent memory spaces");
105
106 bufferType = getMemRefTypeWithFullyDynamicLayout(
107 tensorType, bufferType.getMemorySpace());
108 }
109
110 if (!bufferType)
111 return op->emitOpError("could not infer buffer type of block argument");
112
113 return cast<BufferLikeType>(bufferType);
114 }
115
116protected:
117 /// Assuming that `bbArg` is a block argument of a block that belongs to the
118 /// given `op`, return all OpOperands of users of this block that are
119 /// aliasing with the given block argument.
120 AliasingOpOperandList
122 const AnalysisState &state) const {
123 assert(bbArg.getOwner()->getParentOp() == op && "invalid bbArg");
124
125 // Gather aliasing OpOperands of all operations (callers) that link to
126 // this block.
127 AliasingOpOperandList result;
128 for (OpOperand *opOperand : detail::getCallerOpOperands(bbArg))
129 result.addAlias(
130 {opOperand, BufferRelation::Equivalent, /*isDefinite=*/false});
131
132 return result;
133 }
134};
135
136/// A template that provides a default implementation of `getAliasingValues`
137/// for ops that implement the `BranchOpInterface`.
138template <typename ConcreteModel, typename ConcreteOp>
140 : public BufferizableOpInterface::ExternalModel<ConcreteModel, ConcreteOp> {
141 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
142 const AnalysisState &state) const {
143 AliasingValueList result;
144 auto branchOp = cast<BranchOpInterface>(op);
145 auto operandNumber = opOperand.getOperandNumber();
146
147 // Gather aliasing block arguments of blocks to which this op may branch to.
148 for (const auto &it : llvm::enumerate(op->getSuccessors())) {
149 Block *block = it.value();
150 SuccessorOperands operands = branchOp.getSuccessorOperands(it.index());
151 assert(operands.getProducedOperandCount() == 0 &&
152 "produced operands not supported");
153 if (operands.getForwardedOperands().empty())
154 continue;
155 // The first and last operands that are forwarded to this successor.
156 int64_t firstOperandIndex =
158 int64_t lastOperandIndex =
159 firstOperandIndex + operands.getForwardedOperands().size();
160 bool matchingDestination = operandNumber >= firstOperandIndex &&
161 operandNumber < lastOperandIndex;
162 // A branch op may have multiple successors. Find the ones that correspond
163 // to this OpOperand. (There is usually only one.)
164 if (!matchingDestination)
165 continue;
166 // Compute the matching block argument of the destination block.
167 BlockArgument bbArg =
168 block->getArgument(operandNumber - firstOperandIndex);
169 result.addAlias(
170 {bbArg, BufferRelation::Equivalent, /*isDefinite=*/false});
171 }
172
173 return result;
174 }
175};
176
177} // namespace bufferization
178} // namespace mlir
179
180#endif // MLIR_DIALECT_BUFFERIZATION_IR_UNSTRUCTUREDCONTROLFLOW_H_
static llvm::ManagedStatic< PassManagerOptions > options
Base class for generic analysis states.
This class provides a shared interface for ranked and unranked memref types.
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
This class represents an argument of a Block.
Definition Value.h:309
Block * getOwner() const
Returns the block that owns this argument.
Definition Value.h:318
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:129
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
This class represents an operand of an operation.
Definition Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
unsigned getBeginOperandIndex() const
Return the operand index of the first element of this range.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
SuccessorRange getSuccessors()
Definition Operation.h:703
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class models how operands are forwarded to block arguments in control flow.
unsigned getProducedOperandCount() const
Returns the amount of operands that are produced internally by the operation.
OperandRange getForwardedOperands() const
Get the range of operands that are simply forwarded to the successor.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
SmallVector< OpOperand * > getCallerOpOperands(BlockArgument bbArg)
Return a list of operands that are forwarded to the given block argument.
Include the generated interface declarations.
A template that provides a default implementation of getAliasingValues for ops that implement the Bra...
AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, const AnalysisState &state) const
A template that provides a default implementation of getAliasingOpOperands for ops that support unstr...
AliasingOpOperandList getAliasingBranchOpOperands(Operation *op, BlockArgument bbArg, const AnalysisState &state) const
Assuming that bbArg is a block argument of a block that belongs to the given op, return all OpOperand...
FailureOr< BufferLikeType > getBufferType(Operation *op, Value value, const BufferizationOptions &options, const BufferizationState &state, SmallVector< Value > &invocationStack) const