MLIR  21.0.0git
ControlFlowInterfaces.cpp
Go to the documentation of this file.
1 //===- ControlFlowInterfaces.cpp - ControlFlow Interfaces -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "mlir/IR/BuiltinTypes.h"
14 #include "llvm/ADT/SmallPtrSet.h"
15 
16 using namespace mlir;
17 
18 //===----------------------------------------------------------------------===//
19 // ControlFlowInterfaces
20 //===----------------------------------------------------------------------===//
21 
22 #include "mlir/Interfaces/ControlFlowInterfaces.cpp.inc"
23 
25  : producedOperandCount(0), forwardedOperands(std::move(forwardedOperands)) {
26 }
27 
28 SuccessorOperands::SuccessorOperands(unsigned int producedOperandCount,
29  MutableOperandRange forwardedOperands)
30  : producedOperandCount(producedOperandCount),
31  forwardedOperands(std::move(forwardedOperands)) {}
32 
33 //===----------------------------------------------------------------------===//
34 // BranchOpInterface
35 //===----------------------------------------------------------------------===//
36 
37 /// Returns the `BlockArgument` corresponding to operand `operandIndex` in some
38 /// successor if 'operandIndex' is within the range of 'operands', or
39 /// std::nullopt if `operandIndex` isn't a successor operand index.
40 std::optional<BlockArgument>
42  unsigned operandIndex, Block *successor) {
43  OperandRange forwardedOperands = operands.getForwardedOperands();
44  // Check that the operands are valid.
45  if (forwardedOperands.empty())
46  return std::nullopt;
47 
48  // Check to ensure that this operand is within the range.
49  unsigned operandsStart = forwardedOperands.getBeginOperandIndex();
50  if (operandIndex < operandsStart ||
51  operandIndex >= (operandsStart + forwardedOperands.size()))
52  return std::nullopt;
53 
54  // Index the successor.
55  unsigned argIndex =
56  operands.getProducedOperandCount() + operandIndex - operandsStart;
57  return successor->getArgument(argIndex);
58 }
59 
60 /// Verify that the given operands match those of the given successor block.
61 LogicalResult
63  const SuccessorOperands &operands) {
64  // Check the count.
65  unsigned operandCount = operands.size();
66  Block *destBB = op->getSuccessor(succNo);
67  if (operandCount != destBB->getNumArguments())
68  return op->emitError() << "branch has " << operandCount
69  << " operands for successor #" << succNo
70  << ", but target block has "
71  << destBB->getNumArguments();
72 
73  // Check the types.
74  for (unsigned i = operands.getProducedOperandCount(); i != operandCount;
75  ++i) {
76  if (!cast<BranchOpInterface>(op).areTypesCompatible(
77  operands[i].getType(), destBB->getArgument(i).getType()))
78  return op->emitError() << "type mismatch for bb argument #" << i
79  << " of successor #" << succNo;
80  }
81  return success();
82 }
83 
84 //===----------------------------------------------------------------------===//
85 // WeightedBranchOpInterface
86 //===----------------------------------------------------------------------===//
87 
88 static LogicalResult verifyWeights(Operation *op,
90  std::size_t expectedWeightsNum,
91  llvm::StringRef weightAnchorName,
92  llvm::StringRef weightRefName) {
93  if (weights.empty())
94  return success();
95 
96  if (weights.size() != expectedWeightsNum)
97  return op->emitError() << "expects number of " << weightAnchorName
98  << " weights to match number of " << weightRefName
99  << ": " << weights.size() << " vs "
100  << expectedWeightsNum;
101 
102  for (auto [index, weight] : llvm::enumerate(weights))
103  if (weight < 0)
104  return op->emitError() << "weight #" << index << " must be non-negative";
105 
106  if (llvm::all_of(weights, [](int32_t value) { return value == 0; }))
107  return op->emitError() << "branch weights cannot all be zero";
108 
109  return success();
110 }
111 
113  llvm::ArrayRef<int32_t> weights =
114  cast<WeightedBranchOpInterface>(op).getWeights();
115  return verifyWeights(op, weights, op->getNumSuccessors(), "branch",
116  "successors");
117 }
118 
119 //===----------------------------------------------------------------------===//
120 // WeightedRegionBranchOpInterface
121 //===----------------------------------------------------------------------===//
122 
124  llvm::ArrayRef<int32_t> weights =
125  cast<WeightedRegionBranchOpInterface>(op).getWeights();
126  return verifyWeights(op, weights, op->getNumRegions(), "region", "regions");
127 }
128 
129 //===----------------------------------------------------------------------===//
130 // RegionBranchOpInterface
131 //===----------------------------------------------------------------------===//
132 
134  RegionBranchPoint sourceNo,
135  RegionBranchPoint succRegionNo) {
136  diag << "from ";
137  if (Region *region = sourceNo.getRegionOrNull())
138  diag << "Region #" << region->getRegionNumber();
139  else
140  diag << "parent operands";
141 
142  diag << " to ";
143  if (Region *region = succRegionNo.getRegionOrNull())
144  diag << "Region #" << region->getRegionNumber();
145  else
146  diag << "parent results";
147  return diag;
148 }
149 
150 /// Verify that types match along all region control flow edges originating from
151 /// `sourcePoint`. `getInputsTypesForRegion` is a function that returns the
152 /// types of the inputs that flow to a successor region.
153 static LogicalResult
155  function_ref<FailureOr<TypeRange>(RegionBranchPoint)>
156  getInputsTypesForRegion) {
157  auto regionInterface = cast<RegionBranchOpInterface>(op);
158 
160  regionInterface.getSuccessorRegions(sourcePoint, successors);
161 
162  for (RegionSuccessor &succ : successors) {
163  FailureOr<TypeRange> sourceTypes = getInputsTypesForRegion(succ);
164  if (failed(sourceTypes))
165  return failure();
166 
167  TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes();
168  if (sourceTypes->size() != succInputsTypes.size()) {
169  InFlightDiagnostic diag = op->emitOpError("region control flow edge ");
170  return printRegionEdgeName(diag, sourcePoint, succ)
171  << ": source has " << sourceTypes->size()
172  << " operands, but target successor needs "
173  << succInputsTypes.size();
174  }
175 
176  for (const auto &typesIdx :
177  llvm::enumerate(llvm::zip(*sourceTypes, succInputsTypes))) {
178  Type sourceType = std::get<0>(typesIdx.value());
179  Type inputType = std::get<1>(typesIdx.value());
180  if (!regionInterface.areTypesCompatible(sourceType, inputType)) {
181  InFlightDiagnostic diag = op->emitOpError("along control flow edge ");
182  return printRegionEdgeName(diag, sourcePoint, succ)
183  << ": source type #" << typesIdx.index() << " " << sourceType
184  << " should match input type #" << typesIdx.index() << " "
185  << inputType;
186  }
187  }
188  }
189  return success();
190 }
191 
192 /// Verify that types match along control flow edges described the given op.
194  auto regionInterface = cast<RegionBranchOpInterface>(op);
195 
196  auto inputTypesFromParent = [&](RegionBranchPoint point) -> TypeRange {
197  return regionInterface.getEntrySuccessorOperands(point).getTypes();
198  };
199 
200  // Verify types along control flow edges originating from the parent.
202  inputTypesFromParent)))
203  return failure();
204 
205  auto areTypesCompatible = [&](TypeRange lhs, TypeRange rhs) {
206  if (lhs.size() != rhs.size())
207  return false;
208  for (auto types : llvm::zip(lhs, rhs)) {
209  if (!regionInterface.areTypesCompatible(std::get<0>(types),
210  std::get<1>(types))) {
211  return false;
212  }
213  }
214  return true;
215  };
216 
217  // Verify types along control flow edges originating from each region.
218  for (Region &region : op->getRegions()) {
219 
220  // Since there can be multiple terminators implementing the
221  // `RegionBranchTerminatorOpInterface`, all should have the same operand
222  // types when passing them to the same region.
223 
225  for (Block &block : region)
226  if (!block.empty())
227  if (auto terminator =
228  dyn_cast<RegionBranchTerminatorOpInterface>(block.back()))
229  regionReturnOps.push_back(terminator);
230 
231  // If there is no return-like terminator, the op itself should verify
232  // type consistency.
233  if (regionReturnOps.empty())
234  continue;
235 
236  auto inputTypesForRegion =
237  [&](RegionBranchPoint point) -> FailureOr<TypeRange> {
238  std::optional<OperandRange> regionReturnOperands;
239  for (RegionBranchTerminatorOpInterface regionReturnOp : regionReturnOps) {
240  auto terminatorOperands = regionReturnOp.getSuccessorOperands(point);
241 
242  if (!regionReturnOperands) {
243  regionReturnOperands = terminatorOperands;
244  continue;
245  }
246 
247  // Found more than one ReturnLike terminator. Make sure the operand
248  // types match with the first one.
249  if (!areTypesCompatible(regionReturnOperands->getTypes(),
250  terminatorOperands.getTypes())) {
251  InFlightDiagnostic diag = op->emitOpError("along control flow edge");
252  return printRegionEdgeName(diag, region, point)
253  << " operands mismatch between return-like terminators";
254  }
255  }
256 
257  // All successors get the same set of operand types.
258  return TypeRange(regionReturnOperands->getTypes());
259  };
260 
261  if (failed(verifyTypesAlongAllEdges(op, region, inputTypesForRegion)))
262  return failure();
263  }
264 
265  return success();
266 }
267 
268 /// Stop condition for `traverseRegionGraph`. The traversal is interrupted if
269 /// this function returns "true" for a successor region. The first parameter is
270 /// the successor region. The second parameter indicates all already visited
271 /// regions.
273 
274 /// Traverse the region graph starting at `begin`. The traversal is interrupted
275 /// if `stopCondition` evaluates to "true" for a successor region. In that case,
276 /// this function returns "true". Otherwise, if the traversal was not
277 /// interrupted, this function returns "false".
278 static bool traverseRegionGraph(Region *begin,
279  StopConditionFn stopConditionFn) {
280  auto op = cast<RegionBranchOpInterface>(begin->getParentOp());
281  SmallVector<bool> visited(op->getNumRegions(), false);
282  visited[begin->getRegionNumber()] = true;
283 
284  // Retrieve all successors of the region and enqueue them in the worklist.
285  SmallVector<Region *> worklist;
286  auto enqueueAllSuccessors = [&](Region *region) {
287  SmallVector<RegionSuccessor> successors;
288  op.getSuccessorRegions(region, successors);
289  for (RegionSuccessor successor : successors)
290  if (!successor.isParent())
291  worklist.push_back(successor.getSuccessor());
292  };
293  enqueueAllSuccessors(begin);
294 
295  // Process all regions in the worklist via DFS.
296  while (!worklist.empty()) {
297  Region *nextRegion = worklist.pop_back_val();
298  if (stopConditionFn(nextRegion, visited))
299  return true;
300  if (visited[nextRegion->getRegionNumber()])
301  continue;
302  visited[nextRegion->getRegionNumber()] = true;
303  enqueueAllSuccessors(nextRegion);
304  }
305 
306  return false;
307 }
308 
309 /// Return `true` if region `r` is reachable from region `begin` according to
310 /// the RegionBranchOpInterface (by taking a branch).
311 static bool isRegionReachable(Region *begin, Region *r) {
312  assert(begin->getParentOp() == r->getParentOp() &&
313  "expected that both regions belong to the same op");
314  return traverseRegionGraph(begin,
315  [&](Region *nextRegion, ArrayRef<bool> visited) {
316  // Interrupt traversal if `r` was reached.
317  return nextRegion == r;
318  });
319 }
320 
321 /// Return `true` if `a` and `b` are in mutually exclusive regions.
322 ///
323 /// 1. Find the first common of `a` and `b` (ancestor) that implements
324 /// RegionBranchOpInterface.
325 /// 2. Determine the regions `regionA` and `regionB` in which `a` and `b` are
326 /// contained.
327 /// 3. Check if `regionA` and `regionB` are mutually exclusive. They are
328 /// mutually exclusive if they are not reachable from each other as per
329 /// RegionBranchOpInterface::getSuccessorRegions.
331  assert(a && "expected non-empty operation");
332  assert(b && "expected non-empty operation");
333 
334  auto branchOp = a->getParentOfType<RegionBranchOpInterface>();
335  while (branchOp) {
336  // Check if b is inside branchOp. (We already know that a is.)
337  if (!branchOp->isProperAncestor(b)) {
338  // Check next enclosing RegionBranchOpInterface.
339  branchOp = branchOp->getParentOfType<RegionBranchOpInterface>();
340  continue;
341  }
342 
343  // b is contained in branchOp. Retrieve the regions in which `a` and `b`
344  // are contained.
345  Region *regionA = nullptr, *regionB = nullptr;
346  for (Region &r : branchOp->getRegions()) {
347  if (r.findAncestorOpInRegion(*a)) {
348  assert(!regionA && "already found a region for a");
349  regionA = &r;
350  }
351  if (r.findAncestorOpInRegion(*b)) {
352  assert(!regionB && "already found a region for b");
353  regionB = &r;
354  }
355  }
356  assert(regionA && regionB && "could not find region of op");
357 
358  // `a` and `b` are in mutually exclusive regions if both regions are
359  // distinct and neither region is reachable from the other region.
360  return regionA != regionB && !isRegionReachable(regionA, regionB) &&
361  !isRegionReachable(regionB, regionA);
362  }
363 
364  // Could not find a common RegionBranchOpInterface among a's and b's
365  // ancestors.
366  return false;
367 }
368 
369 bool RegionBranchOpInterface::isRepetitiveRegion(unsigned index) {
370  Region *region = &getOperation()->getRegion(index);
371  return isRegionReachable(region, region);
372 }
373 
374 bool RegionBranchOpInterface::hasLoop() {
375  SmallVector<RegionSuccessor> entryRegions;
376  getSuccessorRegions(RegionBranchPoint::parent(), entryRegions);
377  for (RegionSuccessor successor : entryRegions)
378  if (!successor.isParent() &&
379  traverseRegionGraph(successor.getSuccessor(),
380  [](Region *nextRegion, ArrayRef<bool> visited) {
381  // Interrupt traversal if the region was already
382  // visited.
383  return visited[nextRegion->getRegionNumber()];
384  }))
385  return true;
386  return false;
387 }
388 
390  while (Region *region = op->getParentRegion()) {
391  op = region->getParentOp();
392  if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op))
393  if (branchOp.isRepetitiveRegion(region->getRegionNumber()))
394  return region;
395  }
396  return nullptr;
397 }
398 
400  Region *region = value.getParentRegion();
401  while (region) {
402  Operation *op = region->getParentOp();
403  if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op))
404  if (branchOp.isRepetitiveRegion(region->getRegionNumber()))
405  return region;
406  region = op->getParentRegion();
407  }
408  return nullptr;
409 }
static bool isRepetitiveRegion(Region *region, const BufferizationOptions &options)
static LogicalResult verifyWeights(Operation *op, llvm::ArrayRef< int32_t > weights, std::size_t expectedWeightsNum, llvm::StringRef weightAnchorName, llvm::StringRef weightRefName)
static bool traverseRegionGraph(Region *begin, StopConditionFn stopConditionFn)
Traverse the region graph starting at begin.
static InFlightDiagnostic & printRegionEdgeName(InFlightDiagnostic &diag, RegionBranchPoint sourceNo, RegionBranchPoint succRegionNo)
static LogicalResult verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint, function_ref< FailureOr< TypeRange >(RegionBranchPoint)> getInputsTypesForRegion)
Verify that types match along all region control flow edges originating from sourcePoint.
static bool isRegionReachable(Region *begin, Region *r)
Return true if region r is reachable from region begin according to the RegionBranchOpInterface (by t...
static std::string diag(const llvm::Value &value)
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:118
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
unsigned getBeginOperandIndex() const
Return the operand index of the first element of this range.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Block * getSuccessor(unsigned index)
Definition: Operation.h:708
unsigned getNumSuccessors()
Definition: Operation.h:706
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:673
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
static constexpr RegionBranchPoint parent()
Returns an instance of RegionBranchPoint representing the parent operation.
Region * getRegionOrNull() const
Returns the region if branching from a region.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
unsigned getRegionNumber()
Return the number of this region in the parent operation.
Definition: Region.cpp:62
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
This class models how operands are forwarded to block arguments in control flow.
SuccessorOperands(MutableOperandRange forwardedOperands)
Constructs a SuccessorOperands with no produced operands that simply forwards operands to the success...
unsigned getProducedOperandCount() const
Returns the amount of operands that are produced internally by the operation.
unsigned size() const
Returns the amount of operands passed to the successor.
OperandRange getForwardedOperands() const
Get the range of operands that are simply forwarded to the successor.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition: Value.cpp:41
std::optional< BlockArgument > getBranchSuccessorArgument(const SuccessorOperands &operands, unsigned operandIndex, Block *successor)
Return the BlockArgument corresponding to operand operandIndex in some successor if operandIndex is w...
LogicalResult verifyRegionBranchWeights(Operation *op)
Verify that the region weights attached to an operation implementing WeightedRegiobBranchOpInterface ...
LogicalResult verifyBranchSuccessorOperands(Operation *op, unsigned succNo, const SuccessorOperands &operands)
Verify that the given operands match those of the given successor block.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
LogicalResult verifyTypesAlongControlFlowEdges(Operation *op)
Verify that types match along control flow edges described the given op.
LogicalResult verifyBranchWeights(Operation *op)
Verify that the branch weights attached to an operation implementing WeightedBranchOpInterface are co...
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
bool insideMutuallyExclusiveRegions(Operation *a, Operation *b)
Return true if a and b are in mutually exclusive regions as per RegionBranchOpInterface.
Region * getEnclosingRepetitiveRegion(Operation *op)
Return the first enclosing region of the given op that may be executed repetitively as per RegionBran...