14 #include "llvm/Support/DebugLog.h"
22 #include "mlir/Interfaces/ControlFlowInterfaces.cpp.inc"
25 : producedOperandCount(0), forwardedOperands(std::move(forwardedOperands)) {
30 : producedOperandCount(producedOperandCount),
31 forwardedOperands(std::move(forwardedOperands)) {}
40 std::optional<BlockArgument>
42 unsigned operandIndex,
Block *successor) {
43 LDBG() <<
"Getting branch successor argument for operand index "
44 << operandIndex <<
" in successor block";
48 if (forwardedOperands.empty()) {
49 LDBG() <<
"No forwarded operands, returning nullopt";
55 if (operandIndex < operandsStart ||
56 operandIndex >= (operandsStart + forwardedOperands.size())) {
57 LDBG() <<
"Operand index " << operandIndex <<
" out of range ["
58 << operandsStart <<
", "
59 << (operandsStart + forwardedOperands.size())
60 <<
"), returning nullopt";
67 LDBG() <<
"Computed argument index " << argIndex <<
" for successor block";
75 LDBG() <<
"Verifying branch successor operands for successor #" << succNo
76 <<
" in operation " << op->
getName();
79 unsigned operandCount = operands.
size();
81 LDBG() <<
"Branch has " << operandCount <<
" operands, target block has "
85 return op->
emitError() <<
"branch has " << operandCount
86 <<
" operands for successor #" << succNo
87 <<
", but target block has "
91 LDBG() <<
"Checking type compatibility for "
93 <<
" forwarded operands";
96 Type operandType = operands[i].getType();
98 LDBG() <<
"Checking type compatibility: operand type " << operandType
99 <<
" vs argument type " << argType;
101 if (!cast<BranchOpInterface>(op).areTypesCompatible(operandType, argType))
102 return op->
emitError() <<
"type mismatch for bb argument #" << i
103 <<
" of successor #" << succNo;
106 LDBG() <<
"Branch successor operand verification successful";
116 std::size_t expectedWeightsNum,
117 llvm::StringRef weightAnchorName,
118 llvm::StringRef weightRefName) {
122 if (weights.size() != expectedWeightsNum)
123 return op->
emitError() <<
"expects number of " << weightAnchorName
124 <<
" weights to match number of " << weightRefName
125 <<
": " << weights.size() <<
" vs "
126 << expectedWeightsNum;
128 if (llvm::all_of(weights, [](int32_t value) {
return value == 0; }))
129 return op->
emitError() <<
"branch weights cannot all be zero";
136 cast<WeightedBranchOpInterface>(op).getWeights();
147 cast<WeightedRegionBranchOpInterface>(op).getWeights();
160 diag <<
"Operation " << op->getName();
162 diag <<
"parent operands";
166 diag <<
"Region #" << region->getRegionNumber();
168 diag <<
"parent results";
179 getInputsTypesForRegion) {
181 branchOp.getSuccessorRegions(sourcePoint, successors);
184 FailureOr<TypeRange> sourceTypes = getInputsTypesForRegion(succ);
188 TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes();
189 if (sourceTypes->size() != succInputsTypes.size()) {
191 branchOp->emitOpError(
"region control flow edge ");
193 llvm::raw_string_ostream os(succStr);
196 <<
": source has " << sourceTypes->size()
197 <<
" operands, but target successor " << os.str() <<
" needs "
198 << succInputsTypes.size();
201 for (
const auto &typesIdx :
203 Type sourceType = std::get<0>(typesIdx.value());
204 Type inputType = std::get<1>(typesIdx.value());
206 if (!branchOp.areTypesCompatible(sourceType, inputType)) {
208 branchOp->emitOpError(
"along control flow edge ");
210 <<
": source type #" << typesIdx.index() <<
" " << sourceType
211 <<
" should match input type #" << typesIdx.index() <<
" "
222 auto regionInterface = cast<RegionBranchOpInterface>(op);
225 return regionInterface.getEntrySuccessorOperands(successor).getTypes();
237 for (
Block &block : region)
239 if (
auto terminator =
240 dyn_cast<RegionBranchTerminatorOpInterface>(block.back()))
241 regionReturnOps.push_back(terminator);
245 if (regionReturnOps.empty())
250 for (RegionBranchTerminatorOpInterface regionReturnOp : regionReturnOps) {
252 auto inputTypesForRegion =
255 regionReturnOp.getSuccessorOperands(successor);
259 inputTypesForRegion)))
279 auto op = cast<RegionBranchOpInterface>(begin->
getParentOp());
280 LDBG() <<
"Starting region graph traversal from region #"
285 LDBG() <<
"Initialized visited array with " << op->getNumRegions()
290 auto enqueueAllSuccessors = [&](
Region *region) {
291 LDBG() <<
"Enqueuing successors for region #" << region->getRegionNumber();
293 for (
Block &block : *region) {
297 dyn_cast<RegionBranchTerminatorOpInterface>(block.back());
301 operandAttributes.resize(terminator->getNumOperands());
302 terminator.getSuccessorRegions(operandAttributes, successors);
303 LDBG() <<
"Found " << successors.size()
304 <<
" successors from terminator in block";
306 if (!successor.isParent()) {
307 worklist.push_back(successor.getSuccessor());
308 LDBG() <<
"Added region #"
309 << successor.getSuccessor()->getRegionNumber()
312 LDBG() <<
"Skipping parent successor";
317 enqueueAllSuccessors(begin);
318 LDBG() <<
"Initial worklist size: " << worklist.size();
321 while (!worklist.empty()) {
322 Region *nextRegion = worklist.pop_back_val();
324 <<
" from worklist (remaining: " << worklist.size() <<
")";
326 if (stopConditionFn(nextRegion, visited)) {
327 LDBG() <<
"Stop condition met for region #"
331 llvm::dbgs() <<
"Region: " << nextRegion <<
"\n";
333 llvm::errs() <<
"Region " << *nextRegion <<
" has no parent op\n";
338 <<
" already visited, skipping";
344 enqueueAllSuccessors(nextRegion);
347 LDBG() <<
"Traversal completed, returning false";
355 "expected that both regions belong to the same op");
359 return nextRegion == r;
373 LDBG() <<
"Checking if operations are in mutually exclusive regions: "
376 assert(a &&
"expected non-empty operation");
377 assert(b &&
"expected non-empty operation");
381 LDBG() <<
"Checking branch operation " << branchOp->getName();
384 if (!branchOp->isProperAncestor(b)) {
385 LDBG() <<
"Operation b is not inside branchOp, checking next ancestor";
387 branchOp = branchOp->getParentOfType<RegionBranchOpInterface>();
391 LDBG() <<
"Both operations are inside branchOp, finding their regions";
395 Region *regionA =
nullptr, *regionB =
nullptr;
396 for (
Region &r : branchOp->getRegions()) {
397 if (r.findAncestorOpInRegion(*a)) {
398 assert(!regionA &&
"already found a region for a");
400 LDBG() <<
"Found region #" << r.
getRegionNumber() <<
" for operation a";
402 if (r.findAncestorOpInRegion(*b)) {
403 assert(!regionB &&
"already found a region for b");
405 LDBG() <<
"Found region #" << r.getRegionNumber() <<
" for operation b";
408 assert(regionA && regionB &&
"could not find region of op");
410 LDBG() <<
"Region A: #" << regionA->
getRegionNumber() <<
", Region B: #"
411 << regionB->getRegionNumber();
415 bool regionsAreDistinct = (regionA != regionB);
419 LDBG() <<
"Regions distinct: " << regionsAreDistinct
420 <<
", A not reachable from B: " << aNotReachableFromB
421 <<
", B not reachable from A: " << bNotReachableFromA;
423 bool mutuallyExclusive =
424 regionsAreDistinct && aNotReachableFromB && bNotReachableFromA;
425 LDBG() <<
"Operations are mutually exclusive: " << mutuallyExclusive;
427 return mutuallyExclusive;
432 LDBG() <<
"No common RegionBranchOpInterface found, operations are not "
433 "mutually exclusive";
438 LDBG() <<
"Checking if region #" << index <<
" is repetitive in operation "
439 << getOperation()->getName();
441 Region *region = &getOperation()->getRegion(index);
444 LDBG() <<
"Region #" << index <<
" is repetitive: " << isRepetitive;
448 bool RegionBranchOpInterface::hasLoop() {
449 LDBG() <<
"Checking if operation " << getOperation()->getName()
454 LDBG() <<
"Found " << entryRegions.size() <<
" entry regions";
457 if (!successor.isParent()) {
458 LDBG() <<
"Checking entry region #"
459 << successor.getSuccessor()->getRegionNumber() <<
" for loops";
466 return visited[nextRegion->getRegionNumber()];
470 LDBG() <<
"Found loop in entry region #"
471 << successor.getSuccessor()->getRegionNumber();
475 LDBG() <<
"Skipping parent successor";
479 LDBG() <<
"No loops found in operation";
484 LDBG() <<
"Finding enclosing repetitive region for operation "
492 if (
auto branchOp = dyn_cast<RegionBranchOpInterface>(op)) {
494 <<
"Found RegionBranchOpInterface, checking if region is repetitive";
500 LDBG() <<
"Parent operation does not implement RegionBranchOpInterface";
504 LDBG() <<
"No enclosing repetitive region found";
509 LDBG() <<
"Finding enclosing repetitive region for value";
517 if (
auto branchOp = dyn_cast<RegionBranchOpInterface>(op)) {
519 <<
"Found RegionBranchOpInterface, checking if region is repetitive";
525 LDBG() <<
"Parent operation does not implement RegionBranchOpInterface";
530 LDBG() <<
"No enclosing repetitive region found for value";
static bool isRepetitiveRegion(Region *region, const BufferizationOptions &options)
static LogicalResult verifyWeights(Operation *op, llvm::ArrayRef< int32_t > weights, std::size_t expectedWeightsNum, llvm::StringRef weightAnchorName, llvm::StringRef weightRefName)
static bool traverseRegionGraph(Region *begin, StopConditionFn stopConditionFn)
Traverse the region graph starting at begin.
static LogicalResult verifyTypesAlongAllEdges(RegionBranchOpInterface branchOp, RegionBranchPoint sourcePoint, function_ref< FailureOr< TypeRange >(RegionSuccessor)> getInputsTypesForRegion)
Verify that types match along all region control flow edges originating from sourcePoint.
static InFlightDiagnostic & printRegionEdgeName(InFlightDiagnostic &diag, RegionBranchPoint sourceNo, RegionSuccessor succRegionNo)
static bool isRegionReachable(Region *begin, Region *r)
Return true if region r is reachable from region begin according to the RegionBranchOpInterface (by t...
static std::string diag(const llvm::Value &value)
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
This class represents a diagnostic that is inflight and set to be reported.
This class provides a mutable adaptor for a range of operands.
This class implements the operand iterators for the Operation class.
unsigned getBeginOperandIndex() const
Return the operand index of the first element of this range.
type_range getTypes() const
Operation is the basic unit of execution within MLIR.
Block * getSuccessor(unsigned index)
unsigned getNumSuccessors()
unsigned getNumRegions()
Returns the number of regions held by this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
OperationName getName()
The name of an operation is the key identifier for it.
Region * getParentRegion()
Returns the region to which the instruction belongs.
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
static constexpr RegionBranchPoint parent()
Returns an instance of RegionBranchPoint representing the parent operation.
Operation * getTerminatorPredecessorOrNull() const
Returns the terminator if branching from a region.
This class represents a successor of a region.
Region * getSuccessor() const
Return the given region successor.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
unsigned getRegionNumber()
Return the number of this region in the parent operation.
Operation * getParentOp()
Return the parent operation this region is attached to.
This class models how operands are forwarded to block arguments in control flow.
SuccessorOperands(MutableOperandRange forwardedOperands)
Constructs a SuccessorOperands with no produced operands that simply forwards operands to the success...
unsigned getProducedOperandCount() const
Returns the amount of operands that are produced internally by the operation.
unsigned size() const
Returns the amount of operands passed to the successor.
OperandRange getForwardedOperands() const
Get the range of operands that are simply forwarded to the successor.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Region * getParentRegion()
Return the Region in which this Value is defined.
std::optional< BlockArgument > getBranchSuccessorArgument(const SuccessorOperands &operands, unsigned operandIndex, Block *successor)
Return the BlockArgument corresponding to operand operandIndex in some successor if operandIndex is w...
LogicalResult verifyRegionBranchWeights(Operation *op)
Verify that the region weights attached to an operation implementing WeightedRegiobBranchOpInterface ...
LogicalResult verifyBranchSuccessorOperands(Operation *op, unsigned succNo, const SuccessorOperands &operands)
Verify that the given operands match those of the given successor block.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
LogicalResult verifyTypesAlongControlFlowEdges(Operation *op)
Verify that types match along control flow edges described the given op.
LogicalResult verifyBranchWeights(Operation *op)
Verify that the branch weights attached to an operation implementing WeightedBranchOpInterface are co...
Include the generated interface declarations.
bool insideMutuallyExclusiveRegions(Operation *a, Operation *b)
Return true if a and b are in mutually exclusive regions as per RegionBranchOpInterface.
Region * getEnclosingRepetitiveRegion(Operation *op)
Return the first enclosing region of the given op that may be executed repetitively as per RegionBran...