MLIR  18.0.0git
ControlFlowInterfaces.cpp
Go to the documentation of this file.
1 //===- ControlFlowInterfaces.cpp - ControlFlow Interfaces -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "mlir/IR/BuiltinTypes.h"
13 #include "llvm/ADT/SmallPtrSet.h"
14 
15 using namespace mlir;
16 
17 //===----------------------------------------------------------------------===//
18 // ControlFlowInterfaces
19 //===----------------------------------------------------------------------===//
20 
21 #include "mlir/Interfaces/ControlFlowInterfaces.cpp.inc"
22 
24  : producedOperandCount(0), forwardedOperands(std::move(forwardedOperands)) {
25 }
26 
27 SuccessorOperands::SuccessorOperands(unsigned int producedOperandCount,
28  MutableOperandRange forwardedOperands)
29  : producedOperandCount(producedOperandCount),
30  forwardedOperands(std::move(forwardedOperands)) {}
31 
32 //===----------------------------------------------------------------------===//
33 // BranchOpInterface
34 //===----------------------------------------------------------------------===//
35 
36 /// Returns the `BlockArgument` corresponding to operand `operandIndex` in some
37 /// successor if 'operandIndex' is within the range of 'operands', or
38 /// std::nullopt if `operandIndex` isn't a successor operand index.
39 std::optional<BlockArgument>
41  unsigned operandIndex, Block *successor) {
42  OperandRange forwardedOperands = operands.getForwardedOperands();
43  // Check that the operands are valid.
44  if (forwardedOperands.empty())
45  return std::nullopt;
46 
47  // Check to ensure that this operand is within the range.
48  unsigned operandsStart = forwardedOperands.getBeginOperandIndex();
49  if (operandIndex < operandsStart ||
50  operandIndex >= (operandsStart + forwardedOperands.size()))
51  return std::nullopt;
52 
53  // Index the successor.
54  unsigned argIndex =
55  operands.getProducedOperandCount() + operandIndex - operandsStart;
56  return successor->getArgument(argIndex);
57 }
58 
59 /// Verify that the given operands match those of the given successor block.
62  const SuccessorOperands &operands) {
63  // Check the count.
64  unsigned operandCount = operands.size();
65  Block *destBB = op->getSuccessor(succNo);
66  if (operandCount != destBB->getNumArguments())
67  return op->emitError() << "branch has " << operandCount
68  << " operands for successor #" << succNo
69  << ", but target block has "
70  << destBB->getNumArguments();
71 
72  // Check the types.
73  for (unsigned i = operands.getProducedOperandCount(); i != operandCount;
74  ++i) {
75  if (!cast<BranchOpInterface>(op).areTypesCompatible(
76  operands[i].getType(), destBB->getArgument(i).getType()))
77  return op->emitError() << "type mismatch for bb argument #" << i
78  << " of successor #" << succNo;
79  }
80  return success();
81 }
82 
83 //===----------------------------------------------------------------------===//
84 // RegionBranchOpInterface
85 //===----------------------------------------------------------------------===//
86 
88  RegionBranchPoint sourceNo,
89  RegionBranchPoint succRegionNo) {
90  diag << "from ";
91  if (Region *region = sourceNo.getRegionOrNull())
92  diag << "Region #" << region->getRegionNumber();
93  else
94  diag << "parent operands";
95 
96  diag << " to ";
97  if (Region *region = succRegionNo.getRegionOrNull())
98  diag << "Region #" << region->getRegionNumber();
99  else
100  diag << "parent results";
101  return diag;
102 }
103 
104 /// Verify that types match along all region control flow edges originating from
105 /// `sourcePoint`. `getInputsTypesForRegion` is a function that returns the
106 /// types of the inputs that flow to a successor region.
107 static LogicalResult
110  getInputsTypesForRegion) {
111  auto regionInterface = cast<RegionBranchOpInterface>(op);
112 
114  regionInterface.getSuccessorRegions(sourcePoint, successors);
115 
116  for (RegionSuccessor &succ : successors) {
117  FailureOr<TypeRange> sourceTypes = getInputsTypesForRegion(succ);
118  if (failed(sourceTypes))
119  return failure();
120 
121  TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes();
122  if (sourceTypes->size() != succInputsTypes.size()) {
123  InFlightDiagnostic diag = op->emitOpError(" region control flow edge ");
124  return printRegionEdgeName(diag, sourcePoint, succ)
125  << ": source has " << sourceTypes->size()
126  << " operands, but target successor needs "
127  << succInputsTypes.size();
128  }
129 
130  for (const auto &typesIdx :
131  llvm::enumerate(llvm::zip(*sourceTypes, succInputsTypes))) {
132  Type sourceType = std::get<0>(typesIdx.value());
133  Type inputType = std::get<1>(typesIdx.value());
134  if (!regionInterface.areTypesCompatible(sourceType, inputType)) {
135  InFlightDiagnostic diag = op->emitOpError(" along control flow edge ");
136  return printRegionEdgeName(diag, sourcePoint, succ)
137  << ": source type #" << typesIdx.index() << " " << sourceType
138  << " should match input type #" << typesIdx.index() << " "
139  << inputType;
140  }
141  }
142  }
143  return success();
144 }
145 
146 /// Verify that types match along control flow edges described the given op.
148  auto regionInterface = cast<RegionBranchOpInterface>(op);
149 
150  auto inputTypesFromParent = [&](RegionBranchPoint point) -> TypeRange {
151  return regionInterface.getEntrySuccessorOperands(point).getTypes();
152  };
153 
154  // Verify types along control flow edges originating from the parent.
156  inputTypesFromParent)))
157  return failure();
158 
159  auto areTypesCompatible = [&](TypeRange lhs, TypeRange rhs) {
160  if (lhs.size() != rhs.size())
161  return false;
162  for (auto types : llvm::zip(lhs, rhs)) {
163  if (!regionInterface.areTypesCompatible(std::get<0>(types),
164  std::get<1>(types))) {
165  return false;
166  }
167  }
168  return true;
169  };
170 
171  // Verify types along control flow edges originating from each region.
172  for (Region &region : op->getRegions()) {
173 
174  // Since there can be multiple terminators implementing the
175  // `RegionBranchTerminatorOpInterface`, all should have the same operand
176  // types when passing them to the same region.
177 
179  for (Block &block : region)
180  if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(
181  block.getTerminator()))
182  regionReturnOps.push_back(terminator);
183 
184  // If there is no return-like terminator, the op itself should verify
185  // type consistency.
186  if (regionReturnOps.empty())
187  continue;
188 
189  auto inputTypesForRegion =
191  std::optional<OperandRange> regionReturnOperands;
192  for (RegionBranchTerminatorOpInterface regionReturnOp : regionReturnOps) {
193  auto terminatorOperands = regionReturnOp.getSuccessorOperands(point);
194 
195  if (!regionReturnOperands) {
196  regionReturnOperands = terminatorOperands;
197  continue;
198  }
199 
200  // Found more than one ReturnLike terminator. Make sure the operand
201  // types match with the first one.
202  if (!areTypesCompatible(regionReturnOperands->getTypes(),
203  terminatorOperands.getTypes())) {
204  InFlightDiagnostic diag = op->emitOpError(" along control flow edge");
205  return printRegionEdgeName(diag, region, point)
206  << " operands mismatch between return-like terminators";
207  }
208  }
209 
210  // All successors get the same set of operand types.
211  return TypeRange(regionReturnOperands->getTypes());
212  };
213 
214  if (failed(verifyTypesAlongAllEdges(op, region, inputTypesForRegion)))
215  return failure();
216  }
217 
218  return success();
219 }
220 
221 /// Return `true` if region `r` is reachable from region `begin` according to
222 /// the RegionBranchOpInterface (by taking a branch).
223 static bool isRegionReachable(Region *begin, Region *r) {
224  assert(begin->getParentOp() == r->getParentOp() &&
225  "expected that both regions belong to the same op");
226  auto op = cast<RegionBranchOpInterface>(begin->getParentOp());
227  SmallVector<bool> visited(op->getNumRegions(), false);
228  visited[begin->getRegionNumber()] = true;
229 
230  // Retrieve all successors of the region and enqueue them in the worklist.
231  SmallVector<Region *> worklist;
232  auto enqueueAllSuccessors = [&](Region *region) {
233  SmallVector<RegionSuccessor> successors;
234  op.getSuccessorRegions(region, successors);
235  for (RegionSuccessor successor : successors)
236  if (!successor.isParent())
237  worklist.push_back(successor.getSuccessor());
238  };
239  enqueueAllSuccessors(begin);
240 
241  // Process all regions in the worklist via DFS.
242  while (!worklist.empty()) {
243  Region *nextRegion = worklist.pop_back_val();
244  if (nextRegion == r)
245  return true;
246  if (visited[nextRegion->getRegionNumber()])
247  continue;
248  visited[nextRegion->getRegionNumber()] = true;
249  enqueueAllSuccessors(nextRegion);
250  }
251 
252  return false;
253 }
254 
255 /// Return `true` if `a` and `b` are in mutually exclusive regions.
256 ///
257 /// 1. Find the first common of `a` and `b` (ancestor) that implements
258 /// RegionBranchOpInterface.
259 /// 2. Determine the regions `regionA` and `regionB` in which `a` and `b` are
260 /// contained.
261 /// 3. Check if `regionA` and `regionB` are mutually exclusive. They are
262 /// mutually exclusive if they are not reachable from each other as per
263 /// RegionBranchOpInterface::getSuccessorRegions.
265  assert(a && "expected non-empty operation");
266  assert(b && "expected non-empty operation");
267 
268  auto branchOp = a->getParentOfType<RegionBranchOpInterface>();
269  while (branchOp) {
270  // Check if b is inside branchOp. (We already know that a is.)
271  if (!branchOp->isProperAncestor(b)) {
272  // Check next enclosing RegionBranchOpInterface.
273  branchOp = branchOp->getParentOfType<RegionBranchOpInterface>();
274  continue;
275  }
276 
277  // b is contained in branchOp. Retrieve the regions in which `a` and `b`
278  // are contained.
279  Region *regionA = nullptr, *regionB = nullptr;
280  for (Region &r : branchOp->getRegions()) {
281  if (r.findAncestorOpInRegion(*a)) {
282  assert(!regionA && "already found a region for a");
283  regionA = &r;
284  }
285  if (r.findAncestorOpInRegion(*b)) {
286  assert(!regionB && "already found a region for b");
287  regionB = &r;
288  }
289  }
290  assert(regionA && regionB && "could not find region of op");
291 
292  // `a` and `b` are in mutually exclusive regions if both regions are
293  // distinct and neither region is reachable from the other region.
294  return regionA != regionB && !isRegionReachable(regionA, regionB) &&
295  !isRegionReachable(regionB, regionA);
296  }
297 
298  // Could not find a common RegionBranchOpInterface among a's and b's
299  // ancestors.
300  return false;
301 }
302 
303 bool RegionBranchOpInterface::isRepetitiveRegion(unsigned index) {
304  Region *region = &getOperation()->getRegion(index);
305  return isRegionReachable(region, region);
306 }
307 
309  while (Region *region = op->getParentRegion()) {
310  op = region->getParentOp();
311  if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op))
312  if (branchOp.isRepetitiveRegion(region->getRegionNumber()))
313  return region;
314  }
315  return nullptr;
316 }
317 
319  Region *region = value.getParentRegion();
320  while (region) {
321  Operation *op = region->getParentOp();
322  if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op))
323  if (branchOp.isRepetitiveRegion(region->getRegionNumber()))
324  return region;
325  region = op->getParentRegion();
326  }
327  return nullptr;
328 }
static bool isRepetitiveRegion(Region *region, const BufferizationOptions &options)
static InFlightDiagnostic & printRegionEdgeName(InFlightDiagnostic &diag, RegionBranchPoint sourceNo, RegionBranchPoint succRegionNo)
static LogicalResult verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint, function_ref< FailureOr< TypeRange >(RegionBranchPoint)> getInputsTypesForRegion)
Verify that types match along all region control flow edges originating from sourcePoint.
static bool isRegionReachable(Region *begin, Region *r)
Return true if region r is reachable from region begin according to the RegionBranchOpInterface (by t...
static std::string diag(const llvm::Value &value)
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgument getArgument(unsigned i)
Definition: Block.h:122
unsigned getNumArguments()
Definition: Block.h:121
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:308
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:115
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
unsigned getBeginOperandIndex() const
Return the operand index of the first element of this range.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Block * getSuccessor(unsigned index)
Definition: Operation.h:687
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:652
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:655
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:640
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
static constexpr RegionBranchPoint parent()
Returns an instance of RegionBranchPoint representing the parent operation.
Region * getRegionOrNull() const
Returns the region if branching from a region.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
unsigned getRegionNumber()
Return the number of this region in the parent operation.
Definition: Region.cpp:62
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
This class models how operands are forwarded to block arguments in control flow.
SuccessorOperands(MutableOperandRange forwardedOperands)
Constructs a SuccessorOperands with no produced operands that simply forwards operands to the success...
unsigned getProducedOperandCount() const
Returns the amount of operands that are produced internally by the operation.
unsigned size() const
Returns the amount of operands passed to the successor.
OperandRange getForwardedOperands() const
Get the range of operands that are simply forwarded to the successor.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
Type getType() const
Return the type of this value.
Definition: Value.h:122
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition: Value.cpp:41
std::optional< BlockArgument > getBranchSuccessorArgument(const SuccessorOperands &operands, unsigned operandIndex, Block *successor)
Return the BlockArgument corresponding to operand operandIndex in some successor if operandIndex is w...
LogicalResult verifyBranchSuccessorOperands(Operation *op, unsigned succNo, const SuccessorOperands &operands)
Verify that the given operands match those of the given successor block.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
LogicalResult verifyTypesAlongControlFlowEdges(Operation *op)
Verify that types match along control flow edges described the given op.
This header declares functions that assist transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
bool insideMutuallyExclusiveRegions(Operation *a, Operation *b)
Return true if a and b are in mutually exclusive regions as per RegionBranchOpInterface.
Region * getEnclosingRepetitiveRegion(Operation *op)
Return the first enclosing region of the given op that may be executed repetitively as per RegionBran...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26