14#include "llvm/Support/DebugLog.h"
22#include "mlir/Interfaces/ControlFlowInterfaces.cpp.inc"
25 : producedOperandCount(0), forwardedOperands(std::move(forwardedOperands)) {
30 : producedOperandCount(producedOperandCount),
31 forwardedOperands(std::move(forwardedOperands)) {}
40std::optional<BlockArgument>
42 unsigned operandIndex,
Block *successor) {
43 LDBG() <<
"Getting branch successor argument for operand index "
44 << operandIndex <<
" in successor block";
48 if (forwardedOperands.empty()) {
49 LDBG() <<
"No forwarded operands, returning nullopt";
55 if (operandIndex < operandsStart ||
56 operandIndex >= (operandsStart + forwardedOperands.size())) {
57 LDBG() <<
"Operand index " << operandIndex <<
" out of range ["
58 << operandsStart <<
", "
59 << (operandsStart + forwardedOperands.size())
60 <<
"), returning nullopt";
67 LDBG() <<
"Computed argument index " << argIndex <<
" for successor block";
75 LDBG() <<
"Verifying branch successor operands for successor #" << succNo
76 <<
" in operation " << op->
getName();
79 unsigned operandCount = operands.
size();
81 LDBG() <<
"Branch has " << operandCount <<
" operands, target block has "
85 return op->
emitError() <<
"branch has " << operandCount
86 <<
" operands for successor #" << succNo
87 <<
", but target block has "
91 LDBG() <<
"Checking type compatibility for "
93 <<
" forwarded operands";
96 Type operandType = operands[i].getType();
98 LDBG() <<
"Checking type compatibility: operand type " << operandType
99 <<
" vs argument type " << argType;
101 if (!cast<BranchOpInterface>(op).areTypesCompatible(operandType, argType))
102 return op->
emitError() <<
"type mismatch for bb argument #" << i
103 <<
" of successor #" << succNo;
106 LDBG() <<
"Branch successor operand verification successful";
116 std::size_t expectedWeightsNum,
117 llvm::StringRef weightAnchorName,
118 llvm::StringRef weightRefName) {
122 if (weights.size() != expectedWeightsNum)
123 return op->
emitError() <<
"expects number of " << weightAnchorName
124 <<
" weights to match number of " << weightRefName
125 <<
": " << weights.size() <<
" vs "
126 << expectedWeightsNum;
128 if (llvm::all_of(weights, [](int32_t value) {
return value == 0; }))
129 return op->
emitError() <<
"branch weights cannot all be zero";
136 cast<WeightedBranchOpInterface>(op).getWeights();
147 cast<WeightedRegionBranchOpInterface>(op).getWeights();
157 auto regionInterface = cast<RegionBranchOpInterface>(op);
162 regionInterface.getAllRegionBranchPoints();
165 regionInterface.getSuccessorRegions(branchPoint, successors);
169 auto emitRegionEdgeError = [&]() {
171 regionInterface->emitOpError(
"along control flow edge from ");
172 if (branchPoint.isParent()) {
174 diag.attachNote(op->
getLoc()) <<
"region branch point";
177 << branchPoint.getTerminatorPredecessorOrNull()->getName();
179 branchPoint.getTerminatorPredecessorOrNull()->getLoc())
180 <<
"region branch point";
183 if (
Region *region = successor.getSuccessor()) {
184 diag <<
"Region #" << region->getRegionNumber();
193 regionInterface.getSuccessorOperands(branchPoint, successor);
194 ValueRange succInputs = successor.getSuccessorInputs();
195 if (succOperands.size() != succInputs.size()) {
196 return emitRegionEdgeError()
197 <<
": region branch point has " << succOperands.size()
198 <<
" operands, but region successor needs " << succInputs.size()
205 for (
const auto &typesIdx :
206 llvm::enumerate(llvm::zip(succOperandTypes, succInputTypes))) {
207 Type succOperandType = std::get<0>(typesIdx.value());
208 Type succInputType = std::get<1>(typesIdx.value());
209 if (!regionInterface.areTypesCompatible(succOperandType, succInputType))
210 return emitRegionEdgeError()
211 <<
": successor operand type #" << typesIdx.index() <<
" "
212 << succOperandType <<
" should match successor input type #"
213 << typesIdx.index() <<
" " << succInputType;
232 auto op = cast<RegionBranchOpInterface>(begin->
getParentOp());
233 LDBG() <<
"Starting region graph traversal from region #"
238 LDBG() <<
"Initialized visited array with " << op->getNumRegions()
243 auto enqueueAllSuccessors = [&](
Region *region) {
244 LDBG() <<
"Enqueuing successors for region #" << region->getRegionNumber();
246 for (
Block &block : *region) {
250 dyn_cast<RegionBranchTerminatorOpInterface>(block.back());
254 operandAttributes.resize(terminator->getNumOperands());
255 terminator.getSuccessorRegions(operandAttributes, successors);
256 LDBG() <<
"Found " << successors.size()
257 <<
" successors from terminator in block";
259 if (!successor.isParent()) {
260 worklist.push_back(successor.getSuccessor());
261 LDBG() <<
"Added region #"
262 << successor.getSuccessor()->getRegionNumber()
265 LDBG() <<
"Skipping parent successor";
270 enqueueAllSuccessors(begin);
271 LDBG() <<
"Initial worklist size: " << worklist.size();
274 while (!worklist.empty()) {
275 Region *nextRegion = worklist.pop_back_val();
277 <<
" from worklist (remaining: " << worklist.size() <<
")";
279 if (stopConditionFn(nextRegion, visited)) {
280 LDBG() <<
"Stop condition met for region #"
285 llvm::errs() <<
"Region " << *nextRegion <<
" has no parent op\n";
290 <<
" already visited, skipping";
296 enqueueAllSuccessors(nextRegion);
299 LDBG() <<
"Traversal completed, returning false";
307 "expected that both regions belong to the same op");
311 return nextRegion == r;
325 LDBG() <<
"Checking if operations are in mutually exclusive regions: "
326 << a->
getName() <<
" and " <<
b->getName();
328 assert(a &&
"expected non-empty operation");
329 assert(
b &&
"expected non-empty operation");
333 LDBG() <<
"Checking branch operation " << branchOp->getName();
336 if (!branchOp->isProperAncestor(
b)) {
337 LDBG() <<
"Operation b is not inside branchOp, checking next ancestor";
339 branchOp = branchOp->getParentOfType<RegionBranchOpInterface>();
343 LDBG() <<
"Both operations are inside branchOp, finding their regions";
347 Region *regionA =
nullptr, *regionB =
nullptr;
348 for (
Region &r : branchOp->getRegions()) {
349 if (r.findAncestorOpInRegion(*a)) {
350 assert(!regionA &&
"already found a region for a");
352 LDBG() <<
"Found region #" << r.
getRegionNumber() <<
" for operation a";
354 if (r.findAncestorOpInRegion(*
b)) {
355 assert(!regionB &&
"already found a region for b");
357 LDBG() <<
"Found region #" << r.getRegionNumber() <<
" for operation b";
360 assert(regionA && regionB &&
"could not find region of op");
362 LDBG() <<
"Region A: #" << regionA->
getRegionNumber() <<
", Region B: #"
363 << regionB->getRegionNumber();
367 bool regionsAreDistinct = (regionA != regionB);
371 LDBG() <<
"Regions distinct: " << regionsAreDistinct
372 <<
", A not reachable from B: " << aNotReachableFromB
373 <<
", B not reachable from A: " << bNotReachableFromA;
375 bool mutuallyExclusive =
376 regionsAreDistinct && aNotReachableFromB && bNotReachableFromA;
377 LDBG() <<
"Operations are mutually exclusive: " << mutuallyExclusive;
379 return mutuallyExclusive;
384 LDBG() <<
"No common RegionBranchOpInterface found, operations are not "
385 "mutually exclusive";
389bool RegionBranchOpInterface::isRepetitiveRegion(
unsigned index) {
390 LDBG() <<
"Checking if region #" <<
index <<
" is repetitive in operation "
391 << getOperation()->getName();
393 Region *region = &getOperation()->getRegion(
index);
396 LDBG() <<
"Region #" <<
index <<
" is repetitive: " << isRepetitive;
400bool RegionBranchOpInterface::hasLoop() {
401 LDBG() <<
"Checking if operation " << getOperation()->getName()
406 LDBG() <<
"Found " << entryRegions.size() <<
" entry regions";
409 if (!successor.isParent()) {
410 LDBG() <<
"Checking entry region #"
411 << successor.getSuccessor()->getRegionNumber() <<
" for loops";
418 return visited[nextRegion->getRegionNumber()];
422 LDBG() <<
"Found loop in entry region #"
423 << successor.getSuccessor()->getRegionNumber();
427 LDBG() <<
"Skipping parent successor";
431 LDBG() <<
"No loops found in operation";
439 return getEntrySuccessorOperands(dest);
440 auto terminator = cast<RegionBranchTerminatorOpInterface>(
442 return terminator.getSuccessorOperands(dest);
454 branchOp.getSuccessorRegions(src, successors);
456 OperandRange operands = branchOp.getSuccessorOperands(src, dst);
457 assert(operands.size() == dst.getSuccessorInputs().size() &&
458 "expected the same number of operands and inputs");
459 for (
const auto &[operand, input] : llvm::zip_equal(
461 mapping[&operand].push_back(input);
464void RegionBranchOpInterface::getSuccessorOperandInputMapping(
466 std::optional<RegionBranchPoint> src) {
467 if (src.has_value()) {
478RegionBranchOpInterface::getAllRegionBranchPoints() {
481 for (
Region ®ion : getOperation()->getRegions()) {
482 for (
Block &block : region) {
485 if (
auto terminator =
486 dyn_cast<RegionBranchTerminatorOpInterface>(block.back()))
494 LDBG() <<
"Finding enclosing repetitive region for operation "
498 LDBG() <<
"Checking region #" << region->getRegionNumber()
499 <<
" in operation " << region->getParentOp()->getName();
502 if (
auto branchOp = dyn_cast<RegionBranchOpInterface>(op)) {
504 <<
"Found RegionBranchOpInterface, checking if region is repetitive";
505 if (branchOp.isRepetitiveRegion(region->getRegionNumber())) {
506 LDBG() <<
"Found repetitive region #" << region->getRegionNumber();
510 LDBG() <<
"Parent operation does not implement RegionBranchOpInterface";
514 LDBG() <<
"No enclosing repetitive region found";
519 LDBG() <<
"Finding enclosing repetitive region for value";
527 if (
auto branchOp = dyn_cast<RegionBranchOpInterface>(op)) {
529 <<
"Found RegionBranchOpInterface, checking if region is repetitive";
535 LDBG() <<
"Parent operation does not implement RegionBranchOpInterface";
540 LDBG() <<
"No enclosing repetitive region found for value";
static LogicalResult verifyWeights(Operation *op, llvm::ArrayRef< int32_t > weights, std::size_t expectedWeightsNum, llvm::StringRef weightAnchorName, llvm::StringRef weightRefName)
static void getSuccessorOperandInputMapping(RegionBranchOpInterface branchOp, RegionBranchSuccessorMapping &mapping, RegionBranchPoint src)
static bool traverseRegionGraph(Region *begin, StopConditionFn stopConditionFn)
Traverse the region graph starting at begin.
function_ref< bool(Region *, ArrayRef< bool > visited)> StopConditionFn
Stop condition for traverseRegionGraph.
static bool isRegionReachable(Region *begin, Region *r)
Return true if region r is reachable from region begin according to the RegionBranchOpInterface (by t...
static std::string diag(const llvm::Value &value)
static MutableArrayRef< OpOperand > operandsToOpOperands(OperandRange &operands)
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
This class represents a diagnostic that is inflight and set to be reported.
This class provides a mutable adaptor for a range of operands.
This class implements the operand iterators for the Operation class.
unsigned getBeginOperandIndex() const
Return the operand index of the first element of this range.
type_range getTypes() const
Operation is the basic unit of execution within MLIR.
unsigned getNumSuccessors()
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
OperationName getName()
The name of an operation is the key identifier for it.
Block * getSuccessor(unsigned index)
Region * getParentRegion()
Returns the region to which the instruction belongs.
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
Operation * getTerminatorPredecessorOrNull() const
Returns the terminator if branching from a region.
static constexpr RegionBranchPoint parent()
Returns an instance of RegionBranchPoint representing the parent operation.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
unsigned getRegionNumber()
Return the number of this region in the parent operation.
Operation * getParentOp()
Return the parent operation this region is attached to.
This class models how operands are forwarded to block arguments in control flow.
SuccessorOperands(MutableOperandRange forwardedOperands)
Constructs a SuccessorOperands with no produced operands that simply forwards operands to the success...
unsigned getProducedOperandCount() const
Returns the amount of operands that are produced internally by the operation.
unsigned size() const
Returns the amount of operands passed to the successor.
OperandRange getForwardedOperands() const
Get the range of operands that are simply forwarded to the successor.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Region * getParentRegion()
Return the Region in which this Value is defined.
std::optional< BlockArgument > getBranchSuccessorArgument(const SuccessorOperands &operands, unsigned operandIndex, Block *successor)
Return the BlockArgument corresponding to operand operandIndex in some successor if operandIndex is w...
LogicalResult verifyRegionBranchWeights(Operation *op)
Verify that the region weights attached to an operation implementing WeightedRegiobBranchOpInterface ...
LogicalResult verifyBranchSuccessorOperands(Operation *op, unsigned succNo, const SuccessorOperands &operands)
Verify that the given operands match those of the given successor block.
LogicalResult verifyRegionBranchOpInterface(Operation *op)
Verify that types match along control flow edges described the given op.
LogicalResult verifyBranchWeights(Operation *op)
Verify that the branch weights attached to an operation implementing WeightedBranchOpInterface are co...
Include the generated interface declarations.
DenseMap< OpOperand *, SmallVector< Value > > RegionBranchSuccessorMapping
A mapping from successor operands to successor inputs.
bool insideMutuallyExclusiveRegions(Operation *a, Operation *b)
Return true if a and b are in mutually exclusive regions as per RegionBranchOpInterface.
Region * getEnclosingRepetitiveRegion(Operation *op)
Return the first enclosing region of the given op that may be executed repetitively as per RegionBran...
llvm::function_ref< Fn > function_ref