MLIR  22.0.0git
ControlFlowInterfaces.cpp
Go to the documentation of this file.
1 //===- ControlFlowInterfaces.cpp - ControlFlow Interfaces -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "mlir/IR/BuiltinTypes.h"
13 
14 using namespace mlir;
15 
16 //===----------------------------------------------------------------------===//
17 // ControlFlowInterfaces
18 //===----------------------------------------------------------------------===//
19 
20 #include "mlir/Interfaces/ControlFlowInterfaces.cpp.inc"
21 
23  : producedOperandCount(0), forwardedOperands(std::move(forwardedOperands)) {
24 }
25 
26 SuccessorOperands::SuccessorOperands(unsigned int producedOperandCount,
27  MutableOperandRange forwardedOperands)
28  : producedOperandCount(producedOperandCount),
29  forwardedOperands(std::move(forwardedOperands)) {}
30 
31 //===----------------------------------------------------------------------===//
32 // BranchOpInterface
33 //===----------------------------------------------------------------------===//
34 
35 /// Returns the `BlockArgument` corresponding to operand `operandIndex` in some
36 /// successor if 'operandIndex' is within the range of 'operands', or
37 /// std::nullopt if `operandIndex` isn't a successor operand index.
38 std::optional<BlockArgument>
40  unsigned operandIndex, Block *successor) {
41  OperandRange forwardedOperands = operands.getForwardedOperands();
42  // Check that the operands are valid.
43  if (forwardedOperands.empty())
44  return std::nullopt;
45 
46  // Check to ensure that this operand is within the range.
47  unsigned operandsStart = forwardedOperands.getBeginOperandIndex();
48  if (operandIndex < operandsStart ||
49  operandIndex >= (operandsStart + forwardedOperands.size()))
50  return std::nullopt;
51 
52  // Index the successor.
53  unsigned argIndex =
54  operands.getProducedOperandCount() + operandIndex - operandsStart;
55  return successor->getArgument(argIndex);
56 }
57 
58 /// Verify that the given operands match those of the given successor block.
59 LogicalResult
61  const SuccessorOperands &operands) {
62  // Check the count.
63  unsigned operandCount = operands.size();
64  Block *destBB = op->getSuccessor(succNo);
65  if (operandCount != destBB->getNumArguments())
66  return op->emitError() << "branch has " << operandCount
67  << " operands for successor #" << succNo
68  << ", but target block has "
69  << destBB->getNumArguments();
70 
71  // Check the types.
72  for (unsigned i = operands.getProducedOperandCount(); i != operandCount;
73  ++i) {
74  if (!cast<BranchOpInterface>(op).areTypesCompatible(
75  operands[i].getType(), destBB->getArgument(i).getType()))
76  return op->emitError() << "type mismatch for bb argument #" << i
77  << " of successor #" << succNo;
78  }
79  return success();
80 }
81 
82 //===----------------------------------------------------------------------===//
83 // WeightedBranchOpInterface
84 //===----------------------------------------------------------------------===//
85 
86 static LogicalResult verifyWeights(Operation *op,
88  std::size_t expectedWeightsNum,
89  llvm::StringRef weightAnchorName,
90  llvm::StringRef weightRefName) {
91  if (weights.empty())
92  return success();
93 
94  if (weights.size() != expectedWeightsNum)
95  return op->emitError() << "expects number of " << weightAnchorName
96  << " weights to match number of " << weightRefName
97  << ": " << weights.size() << " vs "
98  << expectedWeightsNum;
99 
100  if (llvm::all_of(weights, [](int32_t value) { return value == 0; }))
101  return op->emitError() << "branch weights cannot all be zero";
102 
103  return success();
104 }
105 
107  llvm::ArrayRef<int32_t> weights =
108  cast<WeightedBranchOpInterface>(op).getWeights();
109  return verifyWeights(op, weights, op->getNumSuccessors(), "branch",
110  "successors");
111 }
112 
113 //===----------------------------------------------------------------------===//
114 // WeightedRegionBranchOpInterface
115 //===----------------------------------------------------------------------===//
116 
118  llvm::ArrayRef<int32_t> weights =
119  cast<WeightedRegionBranchOpInterface>(op).getWeights();
120  return verifyWeights(op, weights, op->getNumRegions(), "region", "regions");
121 }
122 
123 //===----------------------------------------------------------------------===//
124 // RegionBranchOpInterface
125 //===----------------------------------------------------------------------===//
126 
128  RegionBranchPoint sourceNo,
129  RegionBranchPoint succRegionNo) {
130  diag << "from ";
131  if (Region *region = sourceNo.getRegionOrNull())
132  diag << "Region #" << region->getRegionNumber();
133  else
134  diag << "parent operands";
135 
136  diag << " to ";
137  if (Region *region = succRegionNo.getRegionOrNull())
138  diag << "Region #" << region->getRegionNumber();
139  else
140  diag << "parent results";
141  return diag;
142 }
143 
144 /// Verify that types match along all region control flow edges originating from
145 /// `sourcePoint`. `getInputsTypesForRegion` is a function that returns the
146 /// types of the inputs that flow to a successor region.
147 static LogicalResult
149  function_ref<FailureOr<TypeRange>(RegionBranchPoint)>
150  getInputsTypesForRegion) {
151  auto regionInterface = cast<RegionBranchOpInterface>(op);
152 
154  regionInterface.getSuccessorRegions(sourcePoint, successors);
155 
156  for (RegionSuccessor &succ : successors) {
157  FailureOr<TypeRange> sourceTypes = getInputsTypesForRegion(succ);
158  if (failed(sourceTypes))
159  return failure();
160 
161  TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes();
162  if (sourceTypes->size() != succInputsTypes.size()) {
163  InFlightDiagnostic diag = op->emitOpError("region control flow edge ");
164  return printRegionEdgeName(diag, sourcePoint, succ)
165  << ": source has " << sourceTypes->size()
166  << " operands, but target successor needs "
167  << succInputsTypes.size();
168  }
169 
170  for (const auto &typesIdx :
171  llvm::enumerate(llvm::zip(*sourceTypes, succInputsTypes))) {
172  Type sourceType = std::get<0>(typesIdx.value());
173  Type inputType = std::get<1>(typesIdx.value());
174  if (!regionInterface.areTypesCompatible(sourceType, inputType)) {
175  InFlightDiagnostic diag = op->emitOpError("along control flow edge ");
176  return printRegionEdgeName(diag, sourcePoint, succ)
177  << ": source type #" << typesIdx.index() << " " << sourceType
178  << " should match input type #" << typesIdx.index() << " "
179  << inputType;
180  }
181  }
182  }
183  return success();
184 }
185 
186 /// Verify that types match along control flow edges described the given op.
188  auto regionInterface = cast<RegionBranchOpInterface>(op);
189 
190  auto inputTypesFromParent = [&](RegionBranchPoint point) -> TypeRange {
191  return regionInterface.getEntrySuccessorOperands(point).getTypes();
192  };
193 
194  // Verify types along control flow edges originating from the parent.
196  inputTypesFromParent)))
197  return failure();
198 
199  auto areTypesCompatible = [&](TypeRange lhs, TypeRange rhs) {
200  if (lhs.size() != rhs.size())
201  return false;
202  for (auto types : llvm::zip(lhs, rhs)) {
203  if (!regionInterface.areTypesCompatible(std::get<0>(types),
204  std::get<1>(types))) {
205  return false;
206  }
207  }
208  return true;
209  };
210 
211  // Verify types along control flow edges originating from each region.
212  for (Region &region : op->getRegions()) {
213 
214  // Since there can be multiple terminators implementing the
215  // `RegionBranchTerminatorOpInterface`, all should have the same operand
216  // types when passing them to the same region.
217 
219  for (Block &block : region)
220  if (!block.empty())
221  if (auto terminator =
222  dyn_cast<RegionBranchTerminatorOpInterface>(block.back()))
223  regionReturnOps.push_back(terminator);
224 
225  // If there is no return-like terminator, the op itself should verify
226  // type consistency.
227  if (regionReturnOps.empty())
228  continue;
229 
230  auto inputTypesForRegion =
231  [&](RegionBranchPoint point) -> FailureOr<TypeRange> {
232  std::optional<OperandRange> regionReturnOperands;
233  for (RegionBranchTerminatorOpInterface regionReturnOp : regionReturnOps) {
234  auto terminatorOperands = regionReturnOp.getSuccessorOperands(point);
235 
236  if (!regionReturnOperands) {
237  regionReturnOperands = terminatorOperands;
238  continue;
239  }
240 
241  // Found more than one ReturnLike terminator. Make sure the operand
242  // types match with the first one.
243  if (!areTypesCompatible(regionReturnOperands->getTypes(),
244  terminatorOperands.getTypes())) {
245  InFlightDiagnostic diag = op->emitOpError("along control flow edge");
246  return printRegionEdgeName(diag, region, point)
247  << " operands mismatch between return-like terminators";
248  }
249  }
250 
251  // All successors get the same set of operand types.
252  return TypeRange(regionReturnOperands->getTypes());
253  };
254 
255  if (failed(verifyTypesAlongAllEdges(op, region, inputTypesForRegion)))
256  return failure();
257  }
258 
259  return success();
260 }
261 
262 /// Stop condition for `traverseRegionGraph`. The traversal is interrupted if
263 /// this function returns "true" for a successor region. The first parameter is
264 /// the successor region. The second parameter indicates all already visited
265 /// regions.
267 
268 /// Traverse the region graph starting at `begin`. The traversal is interrupted
269 /// if `stopCondition` evaluates to "true" for a successor region. In that case,
270 /// this function returns "true". Otherwise, if the traversal was not
271 /// interrupted, this function returns "false".
272 static bool traverseRegionGraph(Region *begin,
273  StopConditionFn stopConditionFn) {
274  auto op = cast<RegionBranchOpInterface>(begin->getParentOp());
275  SmallVector<bool> visited(op->getNumRegions(), false);
276  visited[begin->getRegionNumber()] = true;
277 
278  // Retrieve all successors of the region and enqueue them in the worklist.
279  SmallVector<Region *> worklist;
280  auto enqueueAllSuccessors = [&](Region *region) {
281  SmallVector<RegionSuccessor> successors;
282  op.getSuccessorRegions(region, successors);
283  for (RegionSuccessor successor : successors)
284  if (!successor.isParent())
285  worklist.push_back(successor.getSuccessor());
286  };
287  enqueueAllSuccessors(begin);
288 
289  // Process all regions in the worklist via DFS.
290  while (!worklist.empty()) {
291  Region *nextRegion = worklist.pop_back_val();
292  if (stopConditionFn(nextRegion, visited))
293  return true;
294  if (visited[nextRegion->getRegionNumber()])
295  continue;
296  visited[nextRegion->getRegionNumber()] = true;
297  enqueueAllSuccessors(nextRegion);
298  }
299 
300  return false;
301 }
302 
303 /// Return `true` if region `r` is reachable from region `begin` according to
304 /// the RegionBranchOpInterface (by taking a branch).
305 static bool isRegionReachable(Region *begin, Region *r) {
306  assert(begin->getParentOp() == r->getParentOp() &&
307  "expected that both regions belong to the same op");
308  return traverseRegionGraph(begin,
309  [&](Region *nextRegion, ArrayRef<bool> visited) {
310  // Interrupt traversal if `r` was reached.
311  return nextRegion == r;
312  });
313 }
314 
315 /// Return `true` if `a` and `b` are in mutually exclusive regions.
316 ///
317 /// 1. Find the first common of `a` and `b` (ancestor) that implements
318 /// RegionBranchOpInterface.
319 /// 2. Determine the regions `regionA` and `regionB` in which `a` and `b` are
320 /// contained.
321 /// 3. Check if `regionA` and `regionB` are mutually exclusive. They are
322 /// mutually exclusive if they are not reachable from each other as per
323 /// RegionBranchOpInterface::getSuccessorRegions.
325  assert(a && "expected non-empty operation");
326  assert(b && "expected non-empty operation");
327 
328  auto branchOp = a->getParentOfType<RegionBranchOpInterface>();
329  while (branchOp) {
330  // Check if b is inside branchOp. (We already know that a is.)
331  if (!branchOp->isProperAncestor(b)) {
332  // Check next enclosing RegionBranchOpInterface.
333  branchOp = branchOp->getParentOfType<RegionBranchOpInterface>();
334  continue;
335  }
336 
337  // b is contained in branchOp. Retrieve the regions in which `a` and `b`
338  // are contained.
339  Region *regionA = nullptr, *regionB = nullptr;
340  for (Region &r : branchOp->getRegions()) {
341  if (r.findAncestorOpInRegion(*a)) {
342  assert(!regionA && "already found a region for a");
343  regionA = &r;
344  }
345  if (r.findAncestorOpInRegion(*b)) {
346  assert(!regionB && "already found a region for b");
347  regionB = &r;
348  }
349  }
350  assert(regionA && regionB && "could not find region of op");
351 
352  // `a` and `b` are in mutually exclusive regions if both regions are
353  // distinct and neither region is reachable from the other region.
354  return regionA != regionB && !isRegionReachable(regionA, regionB) &&
355  !isRegionReachable(regionB, regionA);
356  }
357 
358  // Could not find a common RegionBranchOpInterface among a's and b's
359  // ancestors.
360  return false;
361 }
362 
363 bool RegionBranchOpInterface::isRepetitiveRegion(unsigned index) {
364  Region *region = &getOperation()->getRegion(index);
365  return isRegionReachable(region, region);
366 }
367 
368 bool RegionBranchOpInterface::hasLoop() {
369  SmallVector<RegionSuccessor> entryRegions;
370  getSuccessorRegions(RegionBranchPoint::parent(), entryRegions);
371  for (RegionSuccessor successor : entryRegions)
372  if (!successor.isParent() &&
373  traverseRegionGraph(successor.getSuccessor(),
374  [](Region *nextRegion, ArrayRef<bool> visited) {
375  // Interrupt traversal if the region was already
376  // visited.
377  return visited[nextRegion->getRegionNumber()];
378  }))
379  return true;
380  return false;
381 }
382 
384  while (Region *region = op->getParentRegion()) {
385  op = region->getParentOp();
386  if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op))
387  if (branchOp.isRepetitiveRegion(region->getRegionNumber()))
388  return region;
389  }
390  return nullptr;
391 }
392 
394  Region *region = value.getParentRegion();
395  while (region) {
396  Operation *op = region->getParentOp();
397  if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op))
398  if (branchOp.isRepetitiveRegion(region->getRegionNumber()))
399  return region;
400  region = op->getParentRegion();
401  }
402  return nullptr;
403 }
static bool isRepetitiveRegion(Region *region, const BufferizationOptions &options)
static LogicalResult verifyWeights(Operation *op, llvm::ArrayRef< int32_t > weights, std::size_t expectedWeightsNum, llvm::StringRef weightAnchorName, llvm::StringRef weightRefName)
static bool traverseRegionGraph(Region *begin, StopConditionFn stopConditionFn)
Traverse the region graph starting at begin.
static InFlightDiagnostic & printRegionEdgeName(InFlightDiagnostic &diag, RegionBranchPoint sourceNo, RegionBranchPoint succRegionNo)
static LogicalResult verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint, function_ref< FailureOr< TypeRange >(RegionBranchPoint)> getInputsTypesForRegion)
Verify that types match along all region control flow edges originating from sourcePoint.
static bool isRegionReachable(Region *begin, Region *r)
Return true if region r is reachable from region begin according to the RegionBranchOpInterface (by t...
static std::string diag(const llvm::Value &value)
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:118
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
unsigned getBeginOperandIndex() const
Return the operand index of the first element of this range.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Block * getSuccessor(unsigned index)
Definition: Operation.h:708
unsigned getNumSuccessors()
Definition: Operation.h:706
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:672
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
static constexpr RegionBranchPoint parent()
Returns an instance of RegionBranchPoint representing the parent operation.
Region * getRegionOrNull() const
Returns the region if branching from a region.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
unsigned getRegionNumber()
Return the number of this region in the parent operation.
Definition: Region.cpp:62
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
This class models how operands are forwarded to block arguments in control flow.
SuccessorOperands(MutableOperandRange forwardedOperands)
Constructs a SuccessorOperands with no produced operands that simply forwards operands to the success...
unsigned getProducedOperandCount() const
Returns the amount of operands that are produced internally by the operation.
unsigned size() const
Returns the amount of operands passed to the successor.
OperandRange getForwardedOperands() const
Get the range of operands that are simply forwarded to the successor.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition: Value.cpp:39
std::optional< BlockArgument > getBranchSuccessorArgument(const SuccessorOperands &operands, unsigned operandIndex, Block *successor)
Return the BlockArgument corresponding to operand operandIndex in some successor if operandIndex is w...
LogicalResult verifyRegionBranchWeights(Operation *op)
Verify that the region weights attached to an operation implementing WeightedRegiobBranchOpInterface ...
LogicalResult verifyBranchSuccessorOperands(Operation *op, unsigned succNo, const SuccessorOperands &operands)
Verify that the given operands match those of the given successor block.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
LogicalResult verifyTypesAlongControlFlowEdges(Operation *op)
Verify that types match along control flow edges described the given op.
LogicalResult verifyBranchWeights(Operation *op)
Verify that the branch weights attached to an operation implementing WeightedBranchOpInterface are co...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
bool insideMutuallyExclusiveRegions(Operation *a, Operation *b)
Return true if a and b are in mutually exclusive regions as per RegionBranchOpInterface.
Region * getEnclosingRepetitiveRegion(Operation *op)
Return the first enclosing region of the given op that may be executed repetitively as per RegionBran...