14#include "llvm/Support/DebugLog.h"
22#include "mlir/Interfaces/ControlFlowInterfaces.cpp.inc"
25 : producedOperandCount(0), forwardedOperands(std::move(forwardedOperands)) {
30 : producedOperandCount(producedOperandCount),
31 forwardedOperands(std::move(forwardedOperands)) {}
40std::optional<BlockArgument>
42 unsigned operandIndex,
Block *successor) {
43 LDBG() <<
"Getting branch successor argument for operand index "
44 << operandIndex <<
" in successor block";
48 if (forwardedOperands.empty()) {
49 LDBG() <<
"No forwarded operands, returning nullopt";
55 if (operandIndex < operandsStart ||
56 operandIndex >= (operandsStart + forwardedOperands.size())) {
57 LDBG() <<
"Operand index " << operandIndex <<
" out of range ["
58 << operandsStart <<
", "
59 << (operandsStart + forwardedOperands.size())
60 <<
"), returning nullopt";
67 LDBG() <<
"Computed argument index " << argIndex <<
" for successor block";
75 LDBG() <<
"Verifying branch successor operands for successor #" << succNo
76 <<
" in operation " << op->
getName();
79 unsigned operandCount = operands.
size();
81 LDBG() <<
"Branch has " << operandCount <<
" operands, target block has "
85 return op->
emitError() <<
"branch has " << operandCount
86 <<
" operands for successor #" << succNo
87 <<
", but target block has "
91 LDBG() <<
"Checking type compatibility for "
93 <<
" forwarded operands";
96 Type operandType = operands[i].getType();
98 LDBG() <<
"Checking type compatibility: operand type " << operandType
99 <<
" vs argument type " << argType;
101 if (!cast<BranchOpInterface>(op).areTypesCompatible(operandType, argType))
102 return op->
emitError() <<
"type mismatch for bb argument #" << i
103 <<
" of successor #" << succNo;
106 LDBG() <<
"Branch successor operand verification successful";
116 std::size_t expectedWeightsNum,
117 llvm::StringRef weightAnchorName,
118 llvm::StringRef weightRefName) {
122 if (weights.size() != expectedWeightsNum)
123 return op->
emitError() <<
"expects number of " << weightAnchorName
124 <<
" weights to match number of " << weightRefName
125 <<
": " << weights.size() <<
" vs "
126 << expectedWeightsNum;
128 if (llvm::all_of(weights, [](int32_t value) {
return value == 0; }))
129 return op->
emitError() <<
"branch weights cannot all be zero";
136 cast<WeightedBranchOpInterface>(op).getWeights();
147 cast<WeightedRegionBranchOpInterface>(op).getWeights();
160 diag <<
"Operation " << op->getName();
162 diag <<
"parent operands";
166 diag <<
"Region #" << region->getRegionNumber();
168 diag <<
"parent results";
179 getInputsTypesForRegion) {
181 branchOp.getSuccessorRegions(sourcePoint, successors);
184 FailureOr<TypeRange> sourceTypes = getInputsTypesForRegion(succ);
185 if (failed(sourceTypes))
188 TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes();
189 if (sourceTypes->size() != succInputsTypes.size()) {
191 branchOp->emitOpError(
"region control flow edge ");
193 llvm::raw_string_ostream os(succStr);
196 <<
": source has " << sourceTypes->size()
197 <<
" operands, but target successor " << os.str() <<
" needs "
198 << succInputsTypes.size();
201 for (
const auto &typesIdx :
202 llvm::enumerate(llvm::zip(*sourceTypes, succInputsTypes))) {
203 Type sourceType = std::get<0>(typesIdx.value());
204 Type inputType = std::get<1>(typesIdx.value());
206 if (!branchOp.areTypesCompatible(sourceType, inputType)) {
208 branchOp->emitOpError(
"along control flow edge ");
210 <<
": source type #" << typesIdx.index() <<
" " << sourceType
211 <<
" should match input type #" << typesIdx.index() <<
" "
222 auto regionInterface = cast<RegionBranchOpInterface>(op);
225 return regionInterface.getEntrySuccessorOperands(successor).getTypes();
237 for (
Block &block : region)
239 if (
auto terminator =
240 dyn_cast<RegionBranchTerminatorOpInterface>(block.back()))
241 regionReturnOps.push_back(terminator);
245 if (regionReturnOps.empty())
250 for (RegionBranchTerminatorOpInterface regionReturnOp : regionReturnOps) {
252 auto inputTypesForRegion =
255 regionReturnOp.getSuccessorOperands(successor);
259 inputTypesForRegion)))
279 auto op = cast<RegionBranchOpInterface>(begin->
getParentOp());
280 LDBG() <<
"Starting region graph traversal from region #"
285 LDBG() <<
"Initialized visited array with " << op->getNumRegions()
290 auto enqueueAllSuccessors = [&](
Region *region) {
291 LDBG() <<
"Enqueuing successors for region #" << region->getRegionNumber();
293 for (
Block &block : *region) {
297 dyn_cast<RegionBranchTerminatorOpInterface>(block.back());
301 operandAttributes.resize(terminator->getNumOperands());
302 terminator.getSuccessorRegions(operandAttributes, successors);
303 LDBG() <<
"Found " << successors.size()
304 <<
" successors from terminator in block";
306 if (!successor.isParent()) {
307 worklist.push_back(successor.getSuccessor());
308 LDBG() <<
"Added region #"
309 << successor.getSuccessor()->getRegionNumber()
312 LDBG() <<
"Skipping parent successor";
317 enqueueAllSuccessors(begin);
318 LDBG() <<
"Initial worklist size: " << worklist.size();
321 while (!worklist.empty()) {
322 Region *nextRegion = worklist.pop_back_val();
324 <<
" from worklist (remaining: " << worklist.size() <<
")";
326 if (stopConditionFn(nextRegion, visited)) {
327 LDBG() <<
"Stop condition met for region #"
332 llvm::errs() <<
"Region " << *nextRegion <<
" has no parent op\n";
337 <<
" already visited, skipping";
343 enqueueAllSuccessors(nextRegion);
346 LDBG() <<
"Traversal completed, returning false";
354 "expected that both regions belong to the same op");
358 return nextRegion == r;
372 LDBG() <<
"Checking if operations are in mutually exclusive regions: "
373 << a->
getName() <<
" and " <<
b->getName();
375 assert(a &&
"expected non-empty operation");
376 assert(
b &&
"expected non-empty operation");
380 LDBG() <<
"Checking branch operation " << branchOp->getName();
383 if (!branchOp->isProperAncestor(
b)) {
384 LDBG() <<
"Operation b is not inside branchOp, checking next ancestor";
386 branchOp = branchOp->getParentOfType<RegionBranchOpInterface>();
390 LDBG() <<
"Both operations are inside branchOp, finding their regions";
394 Region *regionA =
nullptr, *regionB =
nullptr;
395 for (
Region &r : branchOp->getRegions()) {
396 if (r.findAncestorOpInRegion(*a)) {
397 assert(!regionA &&
"already found a region for a");
399 LDBG() <<
"Found region #" << r.
getRegionNumber() <<
" for operation a";
401 if (r.findAncestorOpInRegion(*
b)) {
402 assert(!regionB &&
"already found a region for b");
404 LDBG() <<
"Found region #" << r.getRegionNumber() <<
" for operation b";
407 assert(regionA && regionB &&
"could not find region of op");
409 LDBG() <<
"Region A: #" << regionA->
getRegionNumber() <<
", Region B: #"
410 << regionB->getRegionNumber();
414 bool regionsAreDistinct = (regionA != regionB);
418 LDBG() <<
"Regions distinct: " << regionsAreDistinct
419 <<
", A not reachable from B: " << aNotReachableFromB
420 <<
", B not reachable from A: " << bNotReachableFromA;
422 bool mutuallyExclusive =
423 regionsAreDistinct && aNotReachableFromB && bNotReachableFromA;
424 LDBG() <<
"Operations are mutually exclusive: " << mutuallyExclusive;
426 return mutuallyExclusive;
431 LDBG() <<
"No common RegionBranchOpInterface found, operations are not "
432 "mutually exclusive";
436bool RegionBranchOpInterface::isRepetitiveRegion(
unsigned index) {
437 LDBG() <<
"Checking if region #" <<
index <<
" is repetitive in operation "
438 << getOperation()->getName();
440 Region *region = &getOperation()->getRegion(
index);
443 LDBG() <<
"Region #" <<
index <<
" is repetitive: " << isRepetitive;
447bool RegionBranchOpInterface::hasLoop() {
448 LDBG() <<
"Checking if operation " << getOperation()->getName()
453 LDBG() <<
"Found " << entryRegions.size() <<
" entry regions";
456 if (!successor.isParent()) {
457 LDBG() <<
"Checking entry region #"
458 << successor.getSuccessor()->getRegionNumber() <<
" for loops";
465 return visited[nextRegion->getRegionNumber()];
469 LDBG() <<
"Found loop in entry region #"
470 << successor.getSuccessor()->getRegionNumber();
474 LDBG() <<
"Skipping parent successor";
478 LDBG() <<
"No loops found in operation";
483 LDBG() <<
"Finding enclosing repetitive region for operation "
491 if (
auto branchOp = dyn_cast<RegionBranchOpInterface>(op)) {
493 <<
"Found RegionBranchOpInterface, checking if region is repetitive";
499 LDBG() <<
"Parent operation does not implement RegionBranchOpInterface";
503 LDBG() <<
"No enclosing repetitive region found";
508 LDBG() <<
"Finding enclosing repetitive region for value";
516 if (
auto branchOp = dyn_cast<RegionBranchOpInterface>(op)) {
518 <<
"Found RegionBranchOpInterface, checking if region is repetitive";
524 LDBG() <<
"Parent operation does not implement RegionBranchOpInterface";
529 LDBG() <<
"No enclosing repetitive region found for value";
static LogicalResult verifyWeights(Operation *op, llvm::ArrayRef< int32_t > weights, std::size_t expectedWeightsNum, llvm::StringRef weightAnchorName, llvm::StringRef weightRefName)
static bool traverseRegionGraph(Region *begin, StopConditionFn stopConditionFn)
Traverse the region graph starting at begin.
static InFlightDiagnostic & printRegionEdgeName(InFlightDiagnostic &diag, RegionBranchPoint sourceNo, RegionSuccessor succRegionNo)
static LogicalResult verifyTypesAlongAllEdges(RegionBranchOpInterface branchOp, RegionBranchPoint sourcePoint, function_ref< FailureOr< TypeRange >(RegionSuccessor)> getInputsTypesForRegion)
Verify that types match along all region control flow edges originating from sourcePoint.
function_ref< bool(Region *, ArrayRef< bool > visited)> StopConditionFn
Stop condition for traverseRegionGraph.
static bool isRegionReachable(Region *begin, Region *r)
Return true if region r is reachable from region begin according to the RegionBranchOpInterface (by t...
static std::string diag(const llvm::Value &value)
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
This class represents a diagnostic that is inflight and set to be reported.
This class provides a mutable adaptor for a range of operands.
This class implements the operand iterators for the Operation class.
unsigned getBeginOperandIndex() const
Return the operand index of the first element of this range.
type_range getTypes() const
Operation is the basic unit of execution within MLIR.
unsigned getNumSuccessors()
unsigned getNumRegions()
Returns the number of regions held by this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
OperationName getName()
The name of an operation is the key identifier for it.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Block * getSuccessor(unsigned index)
Region * getParentRegion()
Returns the region to which the instruction belongs.
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
Operation * getTerminatorPredecessorOrNull() const
Returns the terminator if branching from a region.
static constexpr RegionBranchPoint parent()
Returns an instance of RegionBranchPoint representing the parent operation.
This class represents a successor of a region.
Region * getSuccessor() const
Return the given region successor.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
unsigned getRegionNumber()
Return the number of this region in the parent operation.
Operation * getParentOp()
Return the parent operation this region is attached to.
This class models how operands are forwarded to block arguments in control flow.
SuccessorOperands(MutableOperandRange forwardedOperands)
Constructs a SuccessorOperands with no produced operands that simply forwards operands to the success...
unsigned getProducedOperandCount() const
Returns the amount of operands that are produced internally by the operation.
unsigned size() const
Returns the amount of operands passed to the successor.
OperandRange getForwardedOperands() const
Get the range of operands that are simply forwarded to the successor.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Region * getParentRegion()
Return the Region in which this Value is defined.
std::optional< BlockArgument > getBranchSuccessorArgument(const SuccessorOperands &operands, unsigned operandIndex, Block *successor)
Return the BlockArgument corresponding to operand operandIndex in some successor if operandIndex is w...
LogicalResult verifyRegionBranchWeights(Operation *op)
Verify that the region weights attached to an operation implementing WeightedRegiobBranchOpInterface ...
LogicalResult verifyBranchSuccessorOperands(Operation *op, unsigned succNo, const SuccessorOperands &operands)
Verify that the given operands match those of the given successor block.
LogicalResult verifyTypesAlongControlFlowEdges(Operation *op)
Verify that types match along control flow edges described the given op.
LogicalResult verifyBranchWeights(Operation *op)
Verify that the branch weights attached to an operation implementing WeightedBranchOpInterface are co...
Include the generated interface declarations.
bool insideMutuallyExclusiveRegions(Operation *a, Operation *b)
Return true if a and b are in mutually exclusive regions as per RegionBranchOpInterface.
Region * getEnclosingRepetitiveRegion(Operation *op)
Return the first enclosing region of the given op that may be executed repetitively as per RegionBran...
llvm::function_ref< Fn > function_ref