MLIR  16.0.0git
ControlFlowInterfaces.cpp
Go to the documentation of this file.
1 //===- ControlFlowInterfaces.cpp - ControlFlow Interfaces -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "mlir/IR/BuiltinTypes.h"
13 #include "llvm/ADT/SmallPtrSet.h"
14 
15 using namespace mlir;
16 
17 //===----------------------------------------------------------------------===//
18 // ControlFlowInterfaces
19 //===----------------------------------------------------------------------===//
20 
21 #include "mlir/Interfaces/ControlFlowInterfaces.cpp.inc"
22 
24  : producedOperandCount(0), forwardedOperands(std::move(forwardedOperands)) {
25 }
26 
27 SuccessorOperands::SuccessorOperands(unsigned int producedOperandCount,
28  MutableOperandRange forwardedOperands)
29  : producedOperandCount(producedOperandCount),
30  forwardedOperands(std::move(forwardedOperands)) {}
31 
32 //===----------------------------------------------------------------------===//
33 // BranchOpInterface
34 //===----------------------------------------------------------------------===//
35 
36 /// Returns the `BlockArgument` corresponding to operand `operandIndex` in some
37 /// successor if 'operandIndex' is within the range of 'operands', or None if
38 /// `operandIndex` isn't a successor operand index.
41  unsigned operandIndex, Block *successor) {
42  OperandRange forwardedOperands = operands.getForwardedOperands();
43  // Check that the operands are valid.
44  if (forwardedOperands.empty())
45  return llvm::None;
46 
47  // Check to ensure that this operand is within the range.
48  unsigned operandsStart = forwardedOperands.getBeginOperandIndex();
49  if (operandIndex < operandsStart ||
50  operandIndex >= (operandsStart + forwardedOperands.size()))
51  return llvm::None;
52 
53  // Index the successor.
54  unsigned argIndex =
55  operands.getProducedOperandCount() + operandIndex - operandsStart;
56  return successor->getArgument(argIndex);
57 }
58 
59 /// Verify that the given operands match those of the given successor block.
62  const SuccessorOperands &operands) {
63  // Check the count.
64  unsigned operandCount = operands.size();
65  Block *destBB = op->getSuccessor(succNo);
66  if (operandCount != destBB->getNumArguments())
67  return op->emitError() << "branch has " << operandCount
68  << " operands for successor #" << succNo
69  << ", but target block has "
70  << destBB->getNumArguments();
71 
72  // Check the types.
73  for (unsigned i = operands.getProducedOperandCount(); i != operandCount;
74  ++i) {
75  if (!cast<BranchOpInterface>(op).areTypesCompatible(
76  operands[i].getType(), destBB->getArgument(i).getType()))
77  return op->emitError() << "type mismatch for bb argument #" << i
78  << " of successor #" << succNo;
79  }
80  return success();
81 }
82 
83 //===----------------------------------------------------------------------===//
84 // RegionBranchOpInterface
85 //===----------------------------------------------------------------------===//
86 
87 /// Verify that types match along all region control flow edges originating from
88 /// `sourceNo` (region # if source is a region, llvm::None if source is parent
89 /// op). `getInputsTypesForRegion` is a function that returns the types of the
90 /// inputs that flow from `sourceIndex' to the given region, or llvm::None if
91 /// the exact type match verification is not necessary (e.g., if the Op verifies
92 /// the match itself).
93 static LogicalResult
96  getInputsTypesForRegion) {
97  auto regionInterface = cast<RegionBranchOpInterface>(op);
98 
100  regionInterface.getSuccessorRegions(sourceNo, successors);
101 
102  for (RegionSuccessor &succ : successors) {
103  Optional<unsigned> succRegionNo;
104  if (!succ.isParent())
105  succRegionNo = succ.getSuccessor()->getRegionNumber();
106 
107  auto printEdgeName = [&](InFlightDiagnostic &diag) -> InFlightDiagnostic & {
108  diag << "from ";
109  if (sourceNo)
110  diag << "Region #" << sourceNo.value();
111  else
112  diag << "parent operands";
113 
114  diag << " to ";
115  if (succRegionNo)
116  diag << "Region #" << succRegionNo.value();
117  else
118  diag << "parent results";
119  return diag;
120  };
121 
122  Optional<TypeRange> sourceTypes = getInputsTypesForRegion(succRegionNo);
123  if (!sourceTypes.has_value())
124  continue;
125 
126  TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes();
127  if (sourceTypes->size() != succInputsTypes.size()) {
128  InFlightDiagnostic diag = op->emitOpError(" region control flow edge ");
129  return printEdgeName(diag) << ": source has " << sourceTypes->size()
130  << " operands, but target successor needs "
131  << succInputsTypes.size();
132  }
133 
134  for (const auto &typesIdx :
135  llvm::enumerate(llvm::zip(*sourceTypes, succInputsTypes))) {
136  Type sourceType = std::get<0>(typesIdx.value());
137  Type inputType = std::get<1>(typesIdx.value());
138  if (!regionInterface.areTypesCompatible(sourceType, inputType)) {
139  InFlightDiagnostic diag = op->emitOpError(" along control flow edge ");
140  return printEdgeName(diag)
141  << ": source type #" << typesIdx.index() << " " << sourceType
142  << " should match input type #" << typesIdx.index() << " "
143  << inputType;
144  }
145  }
146  }
147  return success();
148 }
149 
150 /// Verify that types match along control flow edges described the given op.
152  auto regionInterface = cast<RegionBranchOpInterface>(op);
153 
154  auto inputTypesFromParent = [&](Optional<unsigned> regionNo) -> TypeRange {
155  return regionInterface.getSuccessorEntryOperands(regionNo).getTypes();
156  };
157 
158  // Verify types along control flow edges originating from the parent.
159  if (failed(verifyTypesAlongAllEdges(op, llvm::None, inputTypesFromParent)))
160  return failure();
161 
162  // RegionBranchOpInterface should not be implemented by Ops that do not have
163  // attached regions.
164  assert(op->getNumRegions() != 0);
165 
166  auto areTypesCompatible = [&](TypeRange lhs, TypeRange rhs) {
167  if (lhs.size() != rhs.size())
168  return false;
169  for (auto types : llvm::zip(lhs, rhs)) {
170  if (!regionInterface.areTypesCompatible(std::get<0>(types),
171  std::get<1>(types))) {
172  return false;
173  }
174  }
175  return true;
176  };
177 
178  // Verify types along control flow edges originating from each region.
179  for (unsigned regionNo : llvm::seq(0U, op->getNumRegions())) {
180  Region &region = op->getRegion(regionNo);
181 
182  // Since there can be multiple `ReturnLike` terminators or others
183  // implementing the `RegionBranchTerminatorOpInterface`, all should have the
184  // same operand types when passing them to the same region.
185 
186  Optional<OperandRange> regionReturnOperands;
187  for (Block &block : region) {
188  Operation *terminator = block.getTerminator();
189  auto terminatorOperands =
190  getRegionBranchSuccessorOperands(terminator, regionNo);
191  if (!terminatorOperands)
192  continue;
193 
194  if (!regionReturnOperands) {
195  regionReturnOperands = terminatorOperands;
196  continue;
197  }
198 
199  // Found more than one ReturnLike terminator. Make sure the operand types
200  // match with the first one.
201  if (!areTypesCompatible(regionReturnOperands->getTypes(),
202  terminatorOperands->getTypes()))
203  return op->emitOpError("Region #")
204  << regionNo
205  << " operands mismatch between return-like terminators";
206  }
207 
208  auto inputTypesFromRegion =
209  [&](Optional<unsigned> regionNo) -> Optional<TypeRange> {
210  // If there is no return-like terminator, the op itself should verify
211  // type consistency.
212  if (!regionReturnOperands)
213  return llvm::None;
214 
215  // All successors get the same set of operand types.
216  return TypeRange(regionReturnOperands->getTypes());
217  };
218 
219  if (failed(verifyTypesAlongAllEdges(op, regionNo, inputTypesFromRegion)))
220  return failure();
221  }
222 
223  return success();
224 }
225 
226 /// Return `true` if region `r` is reachable from region `begin` according to
227 /// the RegionBranchOpInterface (by taking a branch).
228 static bool isRegionReachable(Region *begin, Region *r) {
229  assert(begin->getParentOp() == r->getParentOp() &&
230  "expected that both regions belong to the same op");
231  auto op = cast<RegionBranchOpInterface>(begin->getParentOp());
232  SmallVector<bool> visited(op->getNumRegions(), false);
233  visited[begin->getRegionNumber()] = true;
234 
235  // Retrieve all successors of the region and enqueue them in the worklist.
236  SmallVector<unsigned> worklist;
237  auto enqueueAllSuccessors = [&](unsigned index) {
238  SmallVector<RegionSuccessor> successors;
239  op.getSuccessorRegions(index, successors);
240  for (RegionSuccessor successor : successors)
241  if (!successor.isParent())
242  worklist.push_back(successor.getSuccessor()->getRegionNumber());
243  };
244  enqueueAllSuccessors(begin->getRegionNumber());
245 
246  // Process all regions in the worklist via DFS.
247  while (!worklist.empty()) {
248  unsigned nextRegion = worklist.pop_back_val();
249  if (nextRegion == r->getRegionNumber())
250  return true;
251  if (visited[nextRegion])
252  continue;
253  visited[nextRegion] = true;
254  enqueueAllSuccessors(nextRegion);
255  }
256 
257  return false;
258 }
259 
260 /// Return `true` if `a` and `b` are in mutually exclusive regions.
261 ///
262 /// 1. Find the first common of `a` and `b` (ancestor) that implements
263 /// RegionBranchOpInterface.
264 /// 2. Determine the regions `regionA` and `regionB` in which `a` and `b` are
265 /// contained.
266 /// 3. Check if `regionA` and `regionB` are mutually exclusive. They are
267 /// mutually exclusive if they are not reachable from each other as per
268 /// RegionBranchOpInterface::getSuccessorRegions.
270  assert(a && "expected non-empty operation");
271  assert(b && "expected non-empty operation");
272 
273  auto branchOp = a->getParentOfType<RegionBranchOpInterface>();
274  while (branchOp) {
275  // Check if b is inside branchOp. (We already know that a is.)
276  if (!branchOp->isProperAncestor(b)) {
277  // Check next enclosing RegionBranchOpInterface.
278  branchOp = branchOp->getParentOfType<RegionBranchOpInterface>();
279  continue;
280  }
281 
282  // b is contained in branchOp. Retrieve the regions in which `a` and `b`
283  // are contained.
284  Region *regionA = nullptr, *regionB = nullptr;
285  for (Region &r : branchOp->getRegions()) {
286  if (r.findAncestorOpInRegion(*a)) {
287  assert(!regionA && "already found a region for a");
288  regionA = &r;
289  }
290  if (r.findAncestorOpInRegion(*b)) {
291  assert(!regionB && "already found a region for b");
292  regionB = &r;
293  }
294  }
295  assert(regionA && regionB && "could not find region of op");
296 
297  // `a` and `b` are in mutually exclusive regions if both regions are
298  // distinct and neither region is reachable from the other region.
299  return regionA != regionB && !isRegionReachable(regionA, regionB) &&
300  !isRegionReachable(regionB, regionA);
301  }
302 
303  // Could not find a common RegionBranchOpInterface among a's and b's
304  // ancestors.
305  return false;
306 }
307 
308 bool RegionBranchOpInterface::isRepetitiveRegion(unsigned index) {
309  Region *region = &getOperation()->getRegion(index);
310  return isRegionReachable(region, region);
311 }
312 
313 void RegionBranchOpInterface::getSuccessorRegions(
315  unsigned numInputs = 0;
316  if (index) {
317  // If the predecessor is a region, get the number of operands from an
318  // exiting terminator in the region.
319  for (Block &block : getOperation()->getRegion(*index)) {
320  Operation *terminator = block.getTerminator();
321  if (getRegionBranchSuccessorOperands(terminator, *index)) {
322  numInputs = terminator->getNumOperands();
323  break;
324  }
325  }
326  } else {
327  // Otherwise, use the number of parent operation operands.
328  numInputs = getOperation()->getNumOperands();
329  }
330  SmallVector<Attribute, 2> operands(numInputs, nullptr);
331  getSuccessorRegions(index, operands, regions);
332 }
333 
335  while (Region *region = op->getParentRegion()) {
336  op = region->getParentOp();
337  if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op))
338  if (branchOp.isRepetitiveRegion(region->getRegionNumber()))
339  return region;
340  }
341  return nullptr;
342 }
343 
345  Region *region = value.getParentRegion();
346  while (region) {
347  Operation *op = region->getParentOp();
348  if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op))
349  if (branchOp.isRepetitiveRegion(region->getRegionNumber()))
350  return region;
351  region = op->getParentRegion();
352  }
353  return nullptr;
354 }
355 
356 //===----------------------------------------------------------------------===//
357 // RegionBranchTerminatorOpInterface
358 //===----------------------------------------------------------------------===//
359 
360 /// Returns true if the given operation is either annotated with the
361 /// `ReturnLike` trait or implements the `RegionBranchTerminatorOpInterface`.
363  return dyn_cast<RegionBranchTerminatorOpInterface>(operation) ||
364  operation->hasTrait<OpTrait::ReturnLike>();
365 }
366 
367 /// Returns the mutable operands that are passed to the region with the given
368 /// `regionIndex`. If the operation does not implement the
369 /// `RegionBranchTerminatorOpInterface` and is not marked as `ReturnLike`, the
370 /// result will be `llvm::None`. In all other cases, the resulting
371 /// `OperandRange` represents all operands that are passed to the specified
372 /// successor region. If `regionIndex` is `llvm::None`, all operands that are
373 /// passed to the parent operation will be returned.
376  Optional<unsigned> regionIndex) {
377  // Try to query a RegionBranchTerminatorOpInterface to determine
378  // all successor operands that will be passed to the successor
379  // input arguments.
380  if (auto regionTerminatorInterface =
381  dyn_cast<RegionBranchTerminatorOpInterface>(operation))
382  return regionTerminatorInterface.getMutableSuccessorOperands(regionIndex);
383 
384  // TODO: The ReturnLike trait should imply a default implementation of the
385  // RegionBranchTerminatorOpInterface. This would make this code significantly
386  // easier. Furthermore, this may even make this function obsolete.
387  if (operation->hasTrait<OpTrait::ReturnLike>())
388  return MutableOperandRange(operation);
389  return llvm::None;
390 }
391 
392 /// Returns the read only operands that are passed to the region with the given
393 /// `regionIndex`. See `getMutableRegionBranchSuccessorOperands` for more
394 /// information.
397  Optional<unsigned> regionIndex) {
398  auto range = getMutableRegionBranchSuccessorOperands(operation, regionIndex);
399  return range ? Optional<OperandRange>(*range) : llvm::None;
400 }
Include the generated interface declarations.
This class contains a list of basic blocks and a link to the parent operation it is attached to...
Definition: Region.h:26
static std::string diag(llvm::Value &v)
Optional< OperandRange > getRegionBranchSuccessorOperands(Operation *operation, Optional< unsigned > regionIndex)
Returns the read only operands that are passed to the region with the given regionIndex.
SuccessorOperands(MutableOperandRange forwardedOperands)
Constructs a SuccessorOperands with no produced operands that simply forwards operands to the success...
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:477
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:310
Block represents an ordered list of Operations.
Definition: Block.h:29
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
unsigned getNumOperands()
Definition: Operation.h:263
bool isRegionReturnLike(Operation *operation)
Returns true if the given operation is either annotated with the ReturnLike trait or implements the R...
BlockArgument getArgument(unsigned i)
Definition: Block.h:120
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type &#39;OpTy&#39;.
Definition: Operation.h:169
static constexpr const bool value
unsigned getBeginOperandIndex() const
Return the operand index of the first element of this range.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
unsigned size() const
Returns the amount of operands passed to the successor.
Block * getSuccessor(unsigned index)
Definition: Operation.h:508
unsigned getRegionNumber()
Return the number of this region in the parent operation.
Definition: Region.cpp:62
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition: Value.cpp:41
LogicalResult verifyTypesAlongControlFlowEdges(Operation *op)
Verify that types match along control flow edges described the given op.
unsigned getNumArguments()
Definition: Block.h:119
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:528
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:165
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:233
static LogicalResult verifyTypesAlongAllEdges(Operation *op, Optional< unsigned > sourceNo, function_ref< Optional< TypeRange >(Optional< unsigned >)> getInputsTypesForRegion)
Verify that types match along all region control flow edges originating from sourceNo (region # if so...
Optional< MutableOperandRange > getMutableRegionBranchSuccessorOperands(Operation *operation, Optional< unsigned > regionIndex)
Returns the mutable operands that are passed to the region with the given regionIndex.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:32
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:114
This class models how operands are forwarded to block arguments in control flow.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
Optional< BlockArgument > getBranchSuccessorArgument(const SuccessorOperands &operands, unsigned operandIndex, Block *successor)
Return the BlockArgument corresponding to operand operandIndex in some successor if operandIndex is w...
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
OperandRange getForwardedOperands() const
Get the range of operands that are simply forwarded to the successor.
This class represents a successor of a region.
Type getType() const
Return the type of this value.
Definition: Value.h:118
LogicalResult verifyBranchSuccessorOperands(Operation *op, unsigned succNo, const SuccessorOperands &operands)
Verify that the given operands match those of the given successor block.
unsigned getProducedOperandCount() const
Returns the amount of operands that are produced internally by the operation.
Region * getEnclosingRepetitiveRegion(Operation *op)
Return the first enclosing region of the given op that may be executed repetitively as per RegionBran...
static bool isRegionReachable(Region *begin, Region *r)
Return true if region r is reachable from region begin according to the RegionBranchOpInterface (by t...
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:40
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:161
bool insideMutuallyExclusiveRegions(Operation *a, Operation *b)
Return true if a and b are in mutually exclusive regions as per RegionBranchOpInterface.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "&#39;dim&#39; op " which is convenient for verifiers...
Definition: Operation.cpp:508
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:221
Region & getRegion(unsigned index)
Returns the region held by this operation at position &#39;index&#39;.
Definition: Operation.h:486
This trait indicates that a terminator operation is "return-like".