MLIR  20.0.0git
ControlFlowInterfaces.cpp
Go to the documentation of this file.
1 //===- ControlFlowInterfaces.cpp - ControlFlow Interfaces -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "mlir/IR/BuiltinTypes.h"
13 #include "llvm/ADT/SmallPtrSet.h"
14 
15 using namespace mlir;
16 
17 //===----------------------------------------------------------------------===//
18 // ControlFlowInterfaces
19 //===----------------------------------------------------------------------===//
20 
21 #include "mlir/Interfaces/ControlFlowInterfaces.cpp.inc"
22 
24  : producedOperandCount(0), forwardedOperands(std::move(forwardedOperands)) {
25 }
26 
27 SuccessorOperands::SuccessorOperands(unsigned int producedOperandCount,
28  MutableOperandRange forwardedOperands)
29  : producedOperandCount(producedOperandCount),
30  forwardedOperands(std::move(forwardedOperands)) {}
31 
32 //===----------------------------------------------------------------------===//
33 // BranchOpInterface
34 //===----------------------------------------------------------------------===//
35 
36 /// Returns the `BlockArgument` corresponding to operand `operandIndex` in some
37 /// successor if 'operandIndex' is within the range of 'operands', or
38 /// std::nullopt if `operandIndex` isn't a successor operand index.
39 std::optional<BlockArgument>
41  unsigned operandIndex, Block *successor) {
42  OperandRange forwardedOperands = operands.getForwardedOperands();
43  // Check that the operands are valid.
44  if (forwardedOperands.empty())
45  return std::nullopt;
46 
47  // Check to ensure that this operand is within the range.
48  unsigned operandsStart = forwardedOperands.getBeginOperandIndex();
49  if (operandIndex < operandsStart ||
50  operandIndex >= (operandsStart + forwardedOperands.size()))
51  return std::nullopt;
52 
53  // Index the successor.
54  unsigned argIndex =
55  operands.getProducedOperandCount() + operandIndex - operandsStart;
56  return successor->getArgument(argIndex);
57 }
58 
59 /// Verify that the given operands match those of the given successor block.
60 LogicalResult
62  const SuccessorOperands &operands) {
63  // Check the count.
64  unsigned operandCount = operands.size();
65  Block *destBB = op->getSuccessor(succNo);
66  if (operandCount != destBB->getNumArguments())
67  return op->emitError() << "branch has " << operandCount
68  << " operands for successor #" << succNo
69  << ", but target block has "
70  << destBB->getNumArguments();
71 
72  // Check the types.
73  for (unsigned i = operands.getProducedOperandCount(); i != operandCount;
74  ++i) {
75  if (!cast<BranchOpInterface>(op).areTypesCompatible(
76  operands[i].getType(), destBB->getArgument(i).getType()))
77  return op->emitError() << "type mismatch for bb argument #" << i
78  << " of successor #" << succNo;
79  }
80  return success();
81 }
82 
83 //===----------------------------------------------------------------------===//
84 // RegionBranchOpInterface
85 //===----------------------------------------------------------------------===//
86 
88  RegionBranchPoint sourceNo,
89  RegionBranchPoint succRegionNo) {
90  diag << "from ";
91  if (Region *region = sourceNo.getRegionOrNull())
92  diag << "Region #" << region->getRegionNumber();
93  else
94  diag << "parent operands";
95 
96  diag << " to ";
97  if (Region *region = succRegionNo.getRegionOrNull())
98  diag << "Region #" << region->getRegionNumber();
99  else
100  diag << "parent results";
101  return diag;
102 }
103 
104 /// Verify that types match along all region control flow edges originating from
105 /// `sourcePoint`. `getInputsTypesForRegion` is a function that returns the
106 /// types of the inputs that flow to a successor region.
107 static LogicalResult
109  function_ref<FailureOr<TypeRange>(RegionBranchPoint)>
110  getInputsTypesForRegion) {
111  auto regionInterface = cast<RegionBranchOpInterface>(op);
112 
114  regionInterface.getSuccessorRegions(sourcePoint, successors);
115 
116  for (RegionSuccessor &succ : successors) {
117  FailureOr<TypeRange> sourceTypes = getInputsTypesForRegion(succ);
118  if (failed(sourceTypes))
119  return failure();
120 
121  TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes();
122  if (sourceTypes->size() != succInputsTypes.size()) {
123  InFlightDiagnostic diag = op->emitOpError(" region control flow edge ");
124  return printRegionEdgeName(diag, sourcePoint, succ)
125  << ": source has " << sourceTypes->size()
126  << " operands, but target successor needs "
127  << succInputsTypes.size();
128  }
129 
130  for (const auto &typesIdx :
131  llvm::enumerate(llvm::zip(*sourceTypes, succInputsTypes))) {
132  Type sourceType = std::get<0>(typesIdx.value());
133  Type inputType = std::get<1>(typesIdx.value());
134  if (!regionInterface.areTypesCompatible(sourceType, inputType)) {
135  InFlightDiagnostic diag = op->emitOpError(" along control flow edge ");
136  return printRegionEdgeName(diag, sourcePoint, succ)
137  << ": source type #" << typesIdx.index() << " " << sourceType
138  << " should match input type #" << typesIdx.index() << " "
139  << inputType;
140  }
141  }
142  }
143  return success();
144 }
145 
146 /// Verify that types match along control flow edges described the given op.
148  auto regionInterface = cast<RegionBranchOpInterface>(op);
149 
150  auto inputTypesFromParent = [&](RegionBranchPoint point) -> TypeRange {
151  return regionInterface.getEntrySuccessorOperands(point).getTypes();
152  };
153 
154  // Verify types along control flow edges originating from the parent.
156  inputTypesFromParent)))
157  return failure();
158 
159  auto areTypesCompatible = [&](TypeRange lhs, TypeRange rhs) {
160  if (lhs.size() != rhs.size())
161  return false;
162  for (auto types : llvm::zip(lhs, rhs)) {
163  if (!regionInterface.areTypesCompatible(std::get<0>(types),
164  std::get<1>(types))) {
165  return false;
166  }
167  }
168  return true;
169  };
170 
171  // Verify types along control flow edges originating from each region.
172  for (Region &region : op->getRegions()) {
173 
174  // Since there can be multiple terminators implementing the
175  // `RegionBranchTerminatorOpInterface`, all should have the same operand
176  // types when passing them to the same region.
177 
179  for (Block &block : region)
180  if (!block.empty())
181  if (auto terminator =
182  dyn_cast<RegionBranchTerminatorOpInterface>(block.back()))
183  regionReturnOps.push_back(terminator);
184 
185  // If there is no return-like terminator, the op itself should verify
186  // type consistency.
187  if (regionReturnOps.empty())
188  continue;
189 
190  auto inputTypesForRegion =
191  [&](RegionBranchPoint point) -> FailureOr<TypeRange> {
192  std::optional<OperandRange> regionReturnOperands;
193  for (RegionBranchTerminatorOpInterface regionReturnOp : regionReturnOps) {
194  auto terminatorOperands = regionReturnOp.getSuccessorOperands(point);
195 
196  if (!regionReturnOperands) {
197  regionReturnOperands = terminatorOperands;
198  continue;
199  }
200 
201  // Found more than one ReturnLike terminator. Make sure the operand
202  // types match with the first one.
203  if (!areTypesCompatible(regionReturnOperands->getTypes(),
204  terminatorOperands.getTypes())) {
205  InFlightDiagnostic diag = op->emitOpError(" along control flow edge");
206  return printRegionEdgeName(diag, region, point)
207  << " operands mismatch between return-like terminators";
208  }
209  }
210 
211  // All successors get the same set of operand types.
212  return TypeRange(regionReturnOperands->getTypes());
213  };
214 
215  if (failed(verifyTypesAlongAllEdges(op, region, inputTypesForRegion)))
216  return failure();
217  }
218 
219  return success();
220 }
221 
222 /// Stop condition for `traverseRegionGraph`. The traversal is interrupted if
223 /// this function returns "true" for a successor region. The first parameter is
224 /// the successor region. The second parameter indicates all already visited
225 /// regions.
227 
228 /// Traverse the region graph starting at `begin`. The traversal is interrupted
229 /// if `stopCondition` evaluates to "true" for a successor region. In that case,
230 /// this function returns "true". Otherwise, if the traversal was not
231 /// interrupted, this function returns "false".
232 static bool traverseRegionGraph(Region *begin,
233  StopConditionFn stopConditionFn) {
234  auto op = cast<RegionBranchOpInterface>(begin->getParentOp());
235  SmallVector<bool> visited(op->getNumRegions(), false);
236  visited[begin->getRegionNumber()] = true;
237 
238  // Retrieve all successors of the region and enqueue them in the worklist.
239  SmallVector<Region *> worklist;
240  auto enqueueAllSuccessors = [&](Region *region) {
241  SmallVector<RegionSuccessor> successors;
242  op.getSuccessorRegions(region, successors);
243  for (RegionSuccessor successor : successors)
244  if (!successor.isParent())
245  worklist.push_back(successor.getSuccessor());
246  };
247  enqueueAllSuccessors(begin);
248 
249  // Process all regions in the worklist via DFS.
250  while (!worklist.empty()) {
251  Region *nextRegion = worklist.pop_back_val();
252  if (stopConditionFn(nextRegion, visited))
253  return true;
254  if (visited[nextRegion->getRegionNumber()])
255  continue;
256  visited[nextRegion->getRegionNumber()] = true;
257  enqueueAllSuccessors(nextRegion);
258  }
259 
260  return false;
261 }
262 
263 /// Return `true` if region `r` is reachable from region `begin` according to
264 /// the RegionBranchOpInterface (by taking a branch).
265 static bool isRegionReachable(Region *begin, Region *r) {
266  assert(begin->getParentOp() == r->getParentOp() &&
267  "expected that both regions belong to the same op");
268  return traverseRegionGraph(begin,
269  [&](Region *nextRegion, ArrayRef<bool> visited) {
270  // Interrupt traversal if `r` was reached.
271  return nextRegion == r;
272  });
273 }
274 
275 /// Return `true` if `a` and `b` are in mutually exclusive regions.
276 ///
277 /// 1. Find the first common of `a` and `b` (ancestor) that implements
278 /// RegionBranchOpInterface.
279 /// 2. Determine the regions `regionA` and `regionB` in which `a` and `b` are
280 /// contained.
281 /// 3. Check if `regionA` and `regionB` are mutually exclusive. They are
282 /// mutually exclusive if they are not reachable from each other as per
283 /// RegionBranchOpInterface::getSuccessorRegions.
285  assert(a && "expected non-empty operation");
286  assert(b && "expected non-empty operation");
287 
288  auto branchOp = a->getParentOfType<RegionBranchOpInterface>();
289  while (branchOp) {
290  // Check if b is inside branchOp. (We already know that a is.)
291  if (!branchOp->isProperAncestor(b)) {
292  // Check next enclosing RegionBranchOpInterface.
293  branchOp = branchOp->getParentOfType<RegionBranchOpInterface>();
294  continue;
295  }
296 
297  // b is contained in branchOp. Retrieve the regions in which `a` and `b`
298  // are contained.
299  Region *regionA = nullptr, *regionB = nullptr;
300  for (Region &r : branchOp->getRegions()) {
301  if (r.findAncestorOpInRegion(*a)) {
302  assert(!regionA && "already found a region for a");
303  regionA = &r;
304  }
305  if (r.findAncestorOpInRegion(*b)) {
306  assert(!regionB && "already found a region for b");
307  regionB = &r;
308  }
309  }
310  assert(regionA && regionB && "could not find region of op");
311 
312  // `a` and `b` are in mutually exclusive regions if both regions are
313  // distinct and neither region is reachable from the other region.
314  return regionA != regionB && !isRegionReachable(regionA, regionB) &&
315  !isRegionReachable(regionB, regionA);
316  }
317 
318  // Could not find a common RegionBranchOpInterface among a's and b's
319  // ancestors.
320  return false;
321 }
322 
323 bool RegionBranchOpInterface::isRepetitiveRegion(unsigned index) {
324  Region *region = &getOperation()->getRegion(index);
325  return isRegionReachable(region, region);
326 }
327 
328 bool RegionBranchOpInterface::hasLoop() {
329  SmallVector<RegionSuccessor> entryRegions;
330  getSuccessorRegions(RegionBranchPoint::parent(), entryRegions);
331  for (RegionSuccessor successor : entryRegions)
332  if (!successor.isParent() &&
333  traverseRegionGraph(successor.getSuccessor(),
334  [](Region *nextRegion, ArrayRef<bool> visited) {
335  // Interrupt traversal if the region was already
336  // visited.
337  return visited[nextRegion->getRegionNumber()];
338  }))
339  return true;
340  return false;
341 }
342 
344  while (Region *region = op->getParentRegion()) {
345  op = region->getParentOp();
346  if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op))
347  if (branchOp.isRepetitiveRegion(region->getRegionNumber()))
348  return region;
349  }
350  return nullptr;
351 }
352 
354  Region *region = value.getParentRegion();
355  while (region) {
356  Operation *op = region->getParentOp();
357  if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op))
358  if (branchOp.isRepetitiveRegion(region->getRegionNumber()))
359  return region;
360  region = op->getParentRegion();
361  }
362  return nullptr;
363 }
static bool isRepetitiveRegion(Region *region, const BufferizationOptions &options)
static bool traverseRegionGraph(Region *begin, StopConditionFn stopConditionFn)
Traverse the region graph starting at begin.
static InFlightDiagnostic & printRegionEdgeName(InFlightDiagnostic &diag, RegionBranchPoint sourceNo, RegionBranchPoint succRegionNo)
static LogicalResult verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint, function_ref< FailureOr< TypeRange >(RegionBranchPoint)> getInputsTypesForRegion)
Verify that types match along all region control flow edges originating from sourcePoint.
static bool isRegionReachable(Region *begin, Region *r)
Return true if region r is reachable from region begin according to the RegionBranchOpInterface (by t...
static std::string diag(const llvm::Value &value)
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:115
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
unsigned getBeginOperandIndex() const
Return the operand index of the first element of this range.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Block * getSuccessor(unsigned index)
Definition: Operation.h:704
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:669
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:672
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
static constexpr RegionBranchPoint parent()
Returns an instance of RegionBranchPoint representing the parent operation.
Region * getRegionOrNull() const
Returns the region if branching from a region.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
unsigned getRegionNumber()
Return the number of this region in the parent operation.
Definition: Region.cpp:62
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
This class models how operands are forwarded to block arguments in control flow.
SuccessorOperands(MutableOperandRange forwardedOperands)
Constructs a SuccessorOperands with no produced operands that simply forwards operands to the success...
unsigned getProducedOperandCount() const
Returns the amount of operands that are produced internally by the operation.
unsigned size() const
Returns the amount of operands passed to the successor.
OperandRange getForwardedOperands() const
Get the range of operands that are simply forwarded to the successor.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition: Value.cpp:41
std::optional< BlockArgument > getBranchSuccessorArgument(const SuccessorOperands &operands, unsigned operandIndex, Block *successor)
Return the BlockArgument corresponding to operand operandIndex in some successor if operandIndex is w...
LogicalResult verifyBranchSuccessorOperands(Operation *op, unsigned succNo, const SuccessorOperands &operands)
Verify that the given operands match those of the given successor block.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
LogicalResult verifyTypesAlongControlFlowEdges(Operation *op)
Verify that types match along control flow edges described the given op.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
bool insideMutuallyExclusiveRegions(Operation *a, Operation *b)
Return true if a and b are in mutually exclusive regions as per RegionBranchOpInterface.
Region * getEnclosingRepetitiveRegion(Operation *op)
Return the first enclosing region of the given op that may be executed repetitively as per RegionBran...