14 #include "llvm/ADT/SmallPtrSet.h"
22 #include "mlir/Interfaces/ControlFlowInterfaces.cpp.inc"
25 : producedOperandCount(0), forwardedOperands(std::move(forwardedOperands)) {
30 : producedOperandCount(producedOperandCount),
31 forwardedOperands(std::move(forwardedOperands)) {}
40 std::optional<BlockArgument>
42 unsigned operandIndex,
Block *successor) {
45 if (forwardedOperands.empty())
50 if (operandIndex < operandsStart ||
51 operandIndex >= (operandsStart + forwardedOperands.size()))
65 unsigned operandCount = operands.
size();
68 return op->
emitError() <<
"branch has " << operandCount
69 <<
" operands for successor #" << succNo
70 <<
", but target block has "
76 if (!cast<BranchOpInterface>(op).areTypesCompatible(
78 return op->
emitError() <<
"type mismatch for bb argument #" << i
79 <<
" of successor #" << succNo;
90 std::size_t expectedWeightsNum,
91 llvm::StringRef weightAnchorName,
92 llvm::StringRef weightRefName) {
96 if (weights.size() != expectedWeightsNum)
97 return op->
emitError() <<
"expects number of " << weightAnchorName
98 <<
" weights to match number of " << weightRefName
99 <<
": " << weights.size() <<
" vs "
100 << expectedWeightsNum;
104 return op->
emitError() <<
"weight #" << index <<
" must be non-negative";
106 if (llvm::all_of(weights, [](int32_t value) {
return value == 0; }))
107 return op->
emitError() <<
"branch weights cannot all be zero";
114 cast<WeightedBranchOpInterface>(op).getWeights();
125 cast<WeightedRegionBranchOpInterface>(op).getWeights();
138 diag <<
"Region #" << region->getRegionNumber();
140 diag <<
"parent operands";
144 diag <<
"Region #" << region->getRegionNumber();
146 diag <<
"parent results";
156 getInputsTypesForRegion) {
157 auto regionInterface = cast<RegionBranchOpInterface>(op);
160 regionInterface.getSuccessorRegions(sourcePoint, successors);
163 FailureOr<TypeRange> sourceTypes = getInputsTypesForRegion(succ);
164 if (failed(sourceTypes))
167 TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes();
168 if (sourceTypes->size() != succInputsTypes.size()) {
171 <<
": source has " << sourceTypes->size()
172 <<
" operands, but target successor needs "
173 << succInputsTypes.size();
176 for (
const auto &typesIdx :
178 Type sourceType = std::get<0>(typesIdx.value());
179 Type inputType = std::get<1>(typesIdx.value());
180 if (!regionInterface.areTypesCompatible(sourceType, inputType)) {
183 <<
": source type #" << typesIdx.index() <<
" " << sourceType
184 <<
" should match input type #" << typesIdx.index() <<
" "
194 auto regionInterface = cast<RegionBranchOpInterface>(op);
197 return regionInterface.getEntrySuccessorOperands(point).getTypes();
202 inputTypesFromParent)))
206 if (lhs.size() != rhs.size())
208 for (
auto types : llvm::zip(lhs, rhs)) {
209 if (!regionInterface.areTypesCompatible(std::get<0>(types),
210 std::get<1>(types))) {
225 for (
Block &block : region)
227 if (
auto terminator =
228 dyn_cast<RegionBranchTerminatorOpInterface>(block.back()))
229 regionReturnOps.push_back(terminator);
233 if (regionReturnOps.empty())
236 auto inputTypesForRegion =
238 std::optional<OperandRange> regionReturnOperands;
239 for (RegionBranchTerminatorOpInterface regionReturnOp : regionReturnOps) {
240 auto terminatorOperands = regionReturnOp.getSuccessorOperands(point);
242 if (!regionReturnOperands) {
243 regionReturnOperands = terminatorOperands;
249 if (!areTypesCompatible(regionReturnOperands->getTypes(),
250 terminatorOperands.getTypes())) {
253 <<
" operands mismatch between return-like terminators";
258 return TypeRange(regionReturnOperands->getTypes());
280 auto op = cast<RegionBranchOpInterface>(begin->
getParentOp());
286 auto enqueueAllSuccessors = [&](
Region *region) {
288 op.getSuccessorRegions(region, successors);
290 if (!successor.isParent())
291 worklist.push_back(successor.getSuccessor());
293 enqueueAllSuccessors(begin);
296 while (!worklist.empty()) {
297 Region *nextRegion = worklist.pop_back_val();
298 if (stopConditionFn(nextRegion, visited))
303 enqueueAllSuccessors(nextRegion);
313 "expected that both regions belong to the same op");
317 return nextRegion == r;
331 assert(a &&
"expected non-empty operation");
332 assert(b &&
"expected non-empty operation");
337 if (!branchOp->isProperAncestor(b)) {
339 branchOp = branchOp->getParentOfType<RegionBranchOpInterface>();
345 Region *regionA =
nullptr, *regionB =
nullptr;
346 for (
Region &r : branchOp->getRegions()) {
347 if (r.findAncestorOpInRegion(*a)) {
348 assert(!regionA &&
"already found a region for a");
351 if (r.findAncestorOpInRegion(*b)) {
352 assert(!regionB &&
"already found a region for b");
356 assert(regionA && regionB &&
"could not find region of op");
370 Region *region = &getOperation()->getRegion(index);
374 bool RegionBranchOpInterface::hasLoop() {
378 if (!successor.isParent() &&
383 return visited[nextRegion->getRegionNumber()];
392 if (
auto branchOp = dyn_cast<RegionBranchOpInterface>(op))
403 if (
auto branchOp = dyn_cast<RegionBranchOpInterface>(op))
static bool isRepetitiveRegion(Region *region, const BufferizationOptions &options)
static LogicalResult verifyWeights(Operation *op, llvm::ArrayRef< int32_t > weights, std::size_t expectedWeightsNum, llvm::StringRef weightAnchorName, llvm::StringRef weightRefName)
static bool traverseRegionGraph(Region *begin, StopConditionFn stopConditionFn)
Traverse the region graph starting at begin.
static InFlightDiagnostic & printRegionEdgeName(InFlightDiagnostic &diag, RegionBranchPoint sourceNo, RegionBranchPoint succRegionNo)
static LogicalResult verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint, function_ref< FailureOr< TypeRange >(RegionBranchPoint)> getInputsTypesForRegion)
Verify that types match along all region control flow edges originating from sourcePoint.
static bool isRegionReachable(Region *begin, Region *r)
Return true if region r is reachable from region begin according to the RegionBranchOpInterface (by t...
static std::string diag(const llvm::Value &value)
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
This class represents a diagnostic that is inflight and set to be reported.
This class provides a mutable adaptor for a range of operands.
This class implements the operand iterators for the Operation class.
unsigned getBeginOperandIndex() const
Return the operand index of the first element of this range.
Operation is the basic unit of execution within MLIR.
Block * getSuccessor(unsigned index)
unsigned getNumSuccessors()
unsigned getNumRegions()
Returns the number of regions held by this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Region * getParentRegion()
Returns the region to which the instruction belongs.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
static constexpr RegionBranchPoint parent()
Returns an instance of RegionBranchPoint representing the parent operation.
Region * getRegionOrNull() const
Returns the region if branching from a region.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
unsigned getRegionNumber()
Return the number of this region in the parent operation.
Operation * getParentOp()
Return the parent operation this region is attached to.
This class models how operands are forwarded to block arguments in control flow.
SuccessorOperands(MutableOperandRange forwardedOperands)
Constructs a SuccessorOperands with no produced operands that simply forwards operands to the success...
unsigned getProducedOperandCount() const
Returns the amount of operands that are produced internally by the operation.
unsigned size() const
Returns the amount of operands passed to the successor.
OperandRange getForwardedOperands() const
Get the range of operands that are simply forwarded to the successor.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Region * getParentRegion()
Return the Region in which this Value is defined.
std::optional< BlockArgument > getBranchSuccessorArgument(const SuccessorOperands &operands, unsigned operandIndex, Block *successor)
Return the BlockArgument corresponding to operand operandIndex in some successor if operandIndex is w...
LogicalResult verifyRegionBranchWeights(Operation *op)
Verify that the region weights attached to an operation implementing WeightedRegiobBranchOpInterface ...
LogicalResult verifyBranchSuccessorOperands(Operation *op, unsigned succNo, const SuccessorOperands &operands)
Verify that the given operands match those of the given successor block.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
LogicalResult verifyTypesAlongControlFlowEdges(Operation *op)
Verify that types match along control flow edges described the given op.
LogicalResult verifyBranchWeights(Operation *op)
Verify that the branch weights attached to an operation implementing WeightedBranchOpInterface are co...
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
bool insideMutuallyExclusiveRegions(Operation *a, Operation *b)
Return true if a and b are in mutually exclusive regions as per RegionBranchOpInterface.
Region * getEnclosingRepetitiveRegion(Operation *op)
Return the first enclosing region of the given op that may be executed repetitively as per RegionBran...