MLIR  16.0.0git
ControlFlowInterfaces.cpp
Go to the documentation of this file.
1 //===- ControlFlowInterfaces.cpp - ControlFlow Interfaces -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "mlir/IR/BuiltinTypes.h"
13 #include "llvm/ADT/SmallPtrSet.h"
14 
15 using namespace mlir;
16 
17 //===----------------------------------------------------------------------===//
18 // ControlFlowInterfaces
19 //===----------------------------------------------------------------------===//
20 
21 #include "mlir/Interfaces/ControlFlowInterfaces.cpp.inc"
22 
24  : producedOperandCount(0), forwardedOperands(std::move(forwardedOperands)) {
25 }
26 
27 SuccessorOperands::SuccessorOperands(unsigned int producedOperandCount,
28  MutableOperandRange forwardedOperands)
29  : producedOperandCount(producedOperandCount),
30  forwardedOperands(std::move(forwardedOperands)) {}
31 
32 //===----------------------------------------------------------------------===//
33 // BranchOpInterface
34 //===----------------------------------------------------------------------===//
35 
36 /// Returns the `BlockArgument` corresponding to operand `operandIndex` in some
37 /// successor if 'operandIndex' is within the range of 'operands', or
38 /// std::nullopt if `operandIndex` isn't a successor operand index.
41  unsigned operandIndex, Block *successor) {
42  OperandRange forwardedOperands = operands.getForwardedOperands();
43  // Check that the operands are valid.
44  if (forwardedOperands.empty())
45  return std::nullopt;
46 
47  // Check to ensure that this operand is within the range.
48  unsigned operandsStart = forwardedOperands.getBeginOperandIndex();
49  if (operandIndex < operandsStart ||
50  operandIndex >= (operandsStart + forwardedOperands.size()))
51  return std::nullopt;
52 
53  // Index the successor.
54  unsigned argIndex =
55  operands.getProducedOperandCount() + operandIndex - operandsStart;
56  return successor->getArgument(argIndex);
57 }
58 
59 /// Verify that the given operands match those of the given successor block.
62  const SuccessorOperands &operands) {
63  // Check the count.
64  unsigned operandCount = operands.size();
65  Block *destBB = op->getSuccessor(succNo);
66  if (operandCount != destBB->getNumArguments())
67  return op->emitError() << "branch has " << operandCount
68  << " operands for successor #" << succNo
69  << ", but target block has "
70  << destBB->getNumArguments();
71 
72  // Check the types.
73  for (unsigned i = operands.getProducedOperandCount(); i != operandCount;
74  ++i) {
75  if (!cast<BranchOpInterface>(op).areTypesCompatible(
76  operands[i].getType(), destBB->getArgument(i).getType()))
77  return op->emitError() << "type mismatch for bb argument #" << i
78  << " of successor #" << succNo;
79  }
80  return success();
81 }
82 
83 //===----------------------------------------------------------------------===//
84 // RegionBranchOpInterface
85 //===----------------------------------------------------------------------===//
86 
87 /// Verify that types match along all region control flow edges originating from
88 /// `sourceNo` (region # if source is a region, std::nullopt if source is parent
89 /// op). `getInputsTypesForRegion` is a function that returns the types of the
90 /// inputs that flow from `sourceIndex' to the given region, or std::nullopt if
91 /// the exact type match verification is not necessary (e.g., if the Op verifies
92 /// the match itself).
93 static LogicalResult
96  getInputsTypesForRegion) {
97  auto regionInterface = cast<RegionBranchOpInterface>(op);
98 
100  regionInterface.getSuccessorRegions(sourceNo, successors);
101 
102  for (RegionSuccessor &succ : successors) {
103  Optional<unsigned> succRegionNo;
104  if (!succ.isParent())
105  succRegionNo = succ.getSuccessor()->getRegionNumber();
106 
107  auto printEdgeName = [&](InFlightDiagnostic &diag) -> InFlightDiagnostic & {
108  diag << "from ";
109  if (sourceNo)
110  diag << "Region #" << sourceNo.value();
111  else
112  diag << "parent operands";
113 
114  diag << " to ";
115  if (succRegionNo)
116  diag << "Region #" << succRegionNo.value();
117  else
118  diag << "parent results";
119  return diag;
120  };
121 
122  Optional<TypeRange> sourceTypes = getInputsTypesForRegion(succRegionNo);
123  if (!sourceTypes.has_value())
124  continue;
125 
126  TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes();
127  if (sourceTypes->size() != succInputsTypes.size()) {
128  InFlightDiagnostic diag = op->emitOpError(" region control flow edge ");
129  return printEdgeName(diag) << ": source has " << sourceTypes->size()
130  << " operands, but target successor needs "
131  << succInputsTypes.size();
132  }
133 
134  for (const auto &typesIdx :
135  llvm::enumerate(llvm::zip(*sourceTypes, succInputsTypes))) {
136  Type sourceType = std::get<0>(typesIdx.value());
137  Type inputType = std::get<1>(typesIdx.value());
138  if (!regionInterface.areTypesCompatible(sourceType, inputType)) {
139  InFlightDiagnostic diag = op->emitOpError(" along control flow edge ");
140  return printEdgeName(diag)
141  << ": source type #" << typesIdx.index() << " " << sourceType
142  << " should match input type #" << typesIdx.index() << " "
143  << inputType;
144  }
145  }
146  }
147  return success();
148 }
149 
150 /// Verify that types match along control flow edges described the given op.
152  auto regionInterface = cast<RegionBranchOpInterface>(op);
153 
154  auto inputTypesFromParent = [&](Optional<unsigned> regionNo) -> TypeRange {
155  return regionInterface.getSuccessorEntryOperands(regionNo).getTypes();
156  };
157 
158  // Verify types along control flow edges originating from the parent.
159  if (failed(verifyTypesAlongAllEdges(op, std::nullopt, inputTypesFromParent)))
160  return failure();
161 
162  auto areTypesCompatible = [&](TypeRange lhs, TypeRange rhs) {
163  if (lhs.size() != rhs.size())
164  return false;
165  for (auto types : llvm::zip(lhs, rhs)) {
166  if (!regionInterface.areTypesCompatible(std::get<0>(types),
167  std::get<1>(types))) {
168  return false;
169  }
170  }
171  return true;
172  };
173 
174  // Verify types along control flow edges originating from each region.
175  for (unsigned regionNo : llvm::seq(0U, op->getNumRegions())) {
176  Region &region = op->getRegion(regionNo);
177 
178  // Since there can be multiple `ReturnLike` terminators or others
179  // implementing the `RegionBranchTerminatorOpInterface`, all should have the
180  // same operand types when passing them to the same region.
181 
182  Optional<OperandRange> regionReturnOperands;
183  for (Block &block : region) {
184  Operation *terminator = block.getTerminator();
185  auto terminatorOperands =
186  getRegionBranchSuccessorOperands(terminator, regionNo);
187  if (!terminatorOperands)
188  continue;
189 
190  if (!regionReturnOperands) {
191  regionReturnOperands = terminatorOperands;
192  continue;
193  }
194 
195  // Found more than one ReturnLike terminator. Make sure the operand types
196  // match with the first one.
197  if (!areTypesCompatible(regionReturnOperands->getTypes(),
198  terminatorOperands->getTypes()))
199  return op->emitOpError("Region #")
200  << regionNo
201  << " operands mismatch between return-like terminators";
202  }
203 
204  auto inputTypesFromRegion =
205  [&](Optional<unsigned> regionNo) -> Optional<TypeRange> {
206  // If there is no return-like terminator, the op itself should verify
207  // type consistency.
208  if (!regionReturnOperands)
209  return std::nullopt;
210 
211  // All successors get the same set of operand types.
212  return TypeRange(regionReturnOperands->getTypes());
213  };
214 
215  if (failed(verifyTypesAlongAllEdges(op, regionNo, inputTypesFromRegion)))
216  return failure();
217  }
218 
219  return success();
220 }
221 
222 /// Return `true` if region `r` is reachable from region `begin` according to
223 /// the RegionBranchOpInterface (by taking a branch).
224 static bool isRegionReachable(Region *begin, Region *r) {
225  assert(begin->getParentOp() == r->getParentOp() &&
226  "expected that both regions belong to the same op");
227  auto op = cast<RegionBranchOpInterface>(begin->getParentOp());
228  SmallVector<bool> visited(op->getNumRegions(), false);
229  visited[begin->getRegionNumber()] = true;
230 
231  // Retrieve all successors of the region and enqueue them in the worklist.
232  SmallVector<unsigned> worklist;
233  auto enqueueAllSuccessors = [&](unsigned index) {
234  SmallVector<RegionSuccessor> successors;
235  op.getSuccessorRegions(index, successors);
236  for (RegionSuccessor successor : successors)
237  if (!successor.isParent())
238  worklist.push_back(successor.getSuccessor()->getRegionNumber());
239  };
240  enqueueAllSuccessors(begin->getRegionNumber());
241 
242  // Process all regions in the worklist via DFS.
243  while (!worklist.empty()) {
244  unsigned nextRegion = worklist.pop_back_val();
245  if (nextRegion == r->getRegionNumber())
246  return true;
247  if (visited[nextRegion])
248  continue;
249  visited[nextRegion] = true;
250  enqueueAllSuccessors(nextRegion);
251  }
252 
253  return false;
254 }
255 
256 /// Return `true` if `a` and `b` are in mutually exclusive regions.
257 ///
258 /// 1. Find the first common of `a` and `b` (ancestor) that implements
259 /// RegionBranchOpInterface.
260 /// 2. Determine the regions `regionA` and `regionB` in which `a` and `b` are
261 /// contained.
262 /// 3. Check if `regionA` and `regionB` are mutually exclusive. They are
263 /// mutually exclusive if they are not reachable from each other as per
264 /// RegionBranchOpInterface::getSuccessorRegions.
266  assert(a && "expected non-empty operation");
267  assert(b && "expected non-empty operation");
268 
269  auto branchOp = a->getParentOfType<RegionBranchOpInterface>();
270  while (branchOp) {
271  // Check if b is inside branchOp. (We already know that a is.)
272  if (!branchOp->isProperAncestor(b)) {
273  // Check next enclosing RegionBranchOpInterface.
274  branchOp = branchOp->getParentOfType<RegionBranchOpInterface>();
275  continue;
276  }
277 
278  // b is contained in branchOp. Retrieve the regions in which `a` and `b`
279  // are contained.
280  Region *regionA = nullptr, *regionB = nullptr;
281  for (Region &r : branchOp->getRegions()) {
282  if (r.findAncestorOpInRegion(*a)) {
283  assert(!regionA && "already found a region for a");
284  regionA = &r;
285  }
286  if (r.findAncestorOpInRegion(*b)) {
287  assert(!regionB && "already found a region for b");
288  regionB = &r;
289  }
290  }
291  assert(regionA && regionB && "could not find region of op");
292 
293  // `a` and `b` are in mutually exclusive regions if both regions are
294  // distinct and neither region is reachable from the other region.
295  return regionA != regionB && !isRegionReachable(regionA, regionB) &&
296  !isRegionReachable(regionB, regionA);
297  }
298 
299  // Could not find a common RegionBranchOpInterface among a's and b's
300  // ancestors.
301  return false;
302 }
303 
304 bool RegionBranchOpInterface::isRepetitiveRegion(unsigned index) {
305  Region *region = &getOperation()->getRegion(index);
306  return isRegionReachable(region, region);
307 }
308 
309 void RegionBranchOpInterface::getSuccessorRegions(
311  unsigned numInputs = 0;
312  if (index) {
313  // If the predecessor is a region, get the number of operands from an
314  // exiting terminator in the region.
315  for (Block &block : getOperation()->getRegion(*index)) {
316  Operation *terminator = block.getTerminator();
317  if (getRegionBranchSuccessorOperands(terminator, *index)) {
318  numInputs = terminator->getNumOperands();
319  break;
320  }
321  }
322  } else {
323  // Otherwise, use the number of parent operation operands.
324  numInputs = getOperation()->getNumOperands();
325  }
326  SmallVector<Attribute, 2> operands(numInputs, nullptr);
327  getSuccessorRegions(index, operands, regions);
328 }
329 
331  while (Region *region = op->getParentRegion()) {
332  op = region->getParentOp();
333  if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op))
334  if (branchOp.isRepetitiveRegion(region->getRegionNumber()))
335  return region;
336  }
337  return nullptr;
338 }
339 
341  Region *region = value.getParentRegion();
342  while (region) {
343  Operation *op = region->getParentOp();
344  if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op))
345  if (branchOp.isRepetitiveRegion(region->getRegionNumber()))
346  return region;
347  region = op->getParentRegion();
348  }
349  return nullptr;
350 }
351 
352 //===----------------------------------------------------------------------===//
353 // RegionBranchTerminatorOpInterface
354 //===----------------------------------------------------------------------===//
355 
356 /// Returns true if the given operation is either annotated with the
357 /// `ReturnLike` trait or implements the `RegionBranchTerminatorOpInterface`.
359  return dyn_cast<RegionBranchTerminatorOpInterface>(operation) ||
360  operation->hasTrait<OpTrait::ReturnLike>();
361 }
362 
363 /// Returns the mutable operands that are passed to the region with the given
364 /// `regionIndex`. If the operation does not implement the
365 /// `RegionBranchTerminatorOpInterface` and is not marked as `ReturnLike`, the
366 /// result will be `std::nullopt`. In all other cases, the resulting
367 /// `OperandRange` represents all operands that are passed to the specified
368 /// successor region. If `regionIndex` is `std::nullopt`, all operands that are
369 /// passed to the parent operation will be returned.
372  Optional<unsigned> regionIndex) {
373  // Try to query a RegionBranchTerminatorOpInterface to determine
374  // all successor operands that will be passed to the successor
375  // input arguments.
376  if (auto regionTerminatorInterface =
377  dyn_cast<RegionBranchTerminatorOpInterface>(operation))
378  return regionTerminatorInterface.getMutableSuccessorOperands(regionIndex);
379 
380  // TODO: The ReturnLike trait should imply a default implementation of the
381  // RegionBranchTerminatorOpInterface. This would make this code significantly
382  // easier. Furthermore, this may even make this function obsolete.
383  if (operation->hasTrait<OpTrait::ReturnLike>())
384  return MutableOperandRange(operation);
385  return std::nullopt;
386 }
387 
388 /// Returns the read only operands that are passed to the region with the given
389 /// `regionIndex`. See `getMutableRegionBranchSuccessorOperands` for more
390 /// information.
393  Optional<unsigned> regionIndex) {
394  auto range = getMutableRegionBranchSuccessorOperands(operation, regionIndex);
395  return range ? Optional<OperandRange>(*range) : std::nullopt;
396 }
static LogicalResult verifyTypesAlongAllEdges(Operation *op, Optional< unsigned > sourceNo, function_ref< Optional< TypeRange >(Optional< unsigned >)> getInputsTypesForRegion)
Verify that types match along all region control flow edges originating from sourceNo (region # if so...
static bool isRegionReachable(Region *begin, Region *r)
Return true if region r is reachable from region begin according to the RegionBranchOpInterface (by t...
static std::string diag(llvm::Value &value)
static constexpr const bool value
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgument getArgument(unsigned i)
Definition: Block.h:118
unsigned getNumArguments()
Definition: Block.h:117
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:307
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:114
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:41
unsigned getBeginOperandIndex() const
Return the operand index of the first element of this range.
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:31
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:532
Block * getSuccessor(unsigned index)
Definition: Operation.h:512
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:477
unsigned getNumOperands()
Definition: Operation.h:263
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:225
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:169
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:490
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:161
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:512
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
unsigned getRegionNumber()
Return the number of this region in the parent operation.
Definition: Region.cpp:62
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
This class models how operands are forwarded to block arguments in control flow.
SuccessorOperands(MutableOperandRange forwardedOperands)
Constructs a SuccessorOperands with no produced operands that simply forwards operands to the success...
unsigned getProducedOperandCount() const
Returns the amount of operands that are produced internally by the operation.
unsigned size() const
Returns the amount of operands passed to the successor.
OperandRange getForwardedOperands() const
Get the range of operands that are simply forwarded to the successor.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
Type getType() const
Return the type of this value.
Definition: Value.h:114
LogicalResult verifyBranchSuccessorOperands(Operation *op, unsigned succNo, const SuccessorOperands &operands)
Verify that the given operands match those of the given successor block.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:230
Optional< BlockArgument > getBranchSuccessorArgument(const SuccessorOperands &operands, unsigned operandIndex, Block *successor)
Return the BlockArgument corresponding to operand operandIndex in some successor if operandIndex is w...
LogicalResult verifyTypesAlongControlFlowEdges(Operation *op)
Verify that types match along control flow edges described the given op.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
bool insideMutuallyExclusiveRegions(Operation *a, Operation *b)
Return true if a and b are in mutually exclusive regions as per RegionBranchOpInterface.
Region * getEnclosingRepetitiveRegion(Operation *op)
Return the first enclosing region of the given op that may be executed repetitively as per RegionBran...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Optional< OperandRange > getRegionBranchSuccessorOperands(Operation *operation, Optional< unsigned > regionIndex)
Returns the read only operands that are passed to the region with the given regionIndex.
bool isRegionReturnLike(Operation *operation)
Returns true if the given operation is either annotated with the ReturnLike trait or implements the R...
Optional< MutableOperandRange > getMutableRegionBranchSuccessorOperands(Operation *operation, Optional< unsigned > regionIndex)
Returns the mutable operands that are passed to the region with the given regionIndex.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This trait indicates that a terminator operation is "return-like".