21#include "llvm/ADT/STLExtras.h"
22#include "llvm/Support/DebugLog.h"
29#define DEBUG_TYPE "dataflow"
61 for (
Value argument : region.front().getArguments())
65 return initializeRecursively(top);
69AbstractSparseForwardDataFlowAnalysis::initializeRecursively(
Operation *op) {
70 LDBG() <<
"Initializing recursively for operation: "
75 if (failed(visitOperation(op))) {
76 LDBG() <<
"Failed to visit operation: "
82 LDBG() <<
"Processing region with " << region.getBlocks().size()
84 for (
Block &block : region) {
85 LDBG() <<
"Processing block with " << block.getNumArguments()
88 ->blockContentSubscribe(
this);
91 LDBG() <<
"Recursively initializing nested operation: "
93 if (failed(initializeRecursively(&op))) {
94 LDBG() <<
"Failed to initialize nested operation: "
102 LDBG() <<
"Successfully completed recursive initialization for operation: "
103 << OpWithFlags(op, OpPrintingFlags().skipRegions());
110 return visitOperation(point->
getPrevOp());
116AbstractSparseForwardDataFlowAnalysis::visitOperation(
Operation *op) {
131 resultLattices.push_back(resultLattice);
135 if (
auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
147 operandLattices.push_back(operandLattice);
150 if (
auto call = dyn_cast<CallOpInterface>(op))
157void AbstractSparseForwardDataFlowAnalysis::visitBlock(
Block *block) {
167 SmallVector<AbstractSparseLattice *> argLattices;
171 argLattices.push_back(argLattice);
178 auto callable = dyn_cast<CallableOpInterface>(block->
getParentOp());
179 if (callable && callable.getCallableRegion() == block->
getParent())
183 if (
auto branch = dyn_cast<RegionBranchOpInterface>(block->
getParentOp())) {
197 Block *predecessor = *it;
201 auto *edgeExecutable =
203 edgeExecutable->blockContentSubscribe(
this);
204 if (!edgeExecutable->isLive())
210 SuccessorOperands operands =
211 branch.getSuccessorOperands(it.getSuccessorIndex());
212 for (
auto [idx, lattice] : llvm::enumerate(argLattices)) {
213 if (Value operand = operands[idx]) {
229 CallOpInterface call,
234 auto isExternalCallable = [&]() {
236 dyn_cast_if_present<CallableOpInterface>(call.resolveCallable());
237 return callable && !callable.getCallableRegion();
250 if (!predecessors->allPredecessorsKnown()) {
254 for (
Operation *predecessor : predecessors->getKnownPredecessors())
255 for (
auto &&[operand, resLattice] :
256 llvm::zip(predecessor->getOperands(), resultLattices))
263 CallableOpInterface callable,
265 Block *entryBlock = &callable.getCallableRegion()->
front();
270 if (!callsites->allPredecessorsKnown() ||
274 for (
Operation *callsite : callsites->getKnownPredecessors()) {
275 auto call = cast<CallOpInterface>(callsite);
276 for (
auto it : llvm::zip(call.getArgOperands(), argLattices))
277 join(std::get<1>(it),
283void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors(
287 assert(predecessors->allPredecessorsKnown() &&
288 "unexpected unresolved region successors");
290 for (
Operation *op : predecessors->getKnownPredecessors()) {
292 std::optional<OperandRange> operands;
296 operands = branch.getEntrySuccessorOperands(successor);
298 }
else if (
auto regionTerminator =
299 dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
300 operands = regionTerminator.getSuccessorOperands(successor);
308 ValueRange inputs = predecessors->getSuccessorInputs(op);
309 assert(inputs.size() == operands->size() &&
310 "expected the same number of successor inputs as operands");
312 unsigned firstIndex = 0;
313 if (inputs.size() != lattices.size()) {
316 firstIndex = cast<OpResult>(inputs.front()).getResultNumber();
319 branch->getResults().slice(firstIndex, inputs.size()), lattices,
323 firstIndex = cast<BlockArgument>(inputs.front()).getArgNumber();
326 branch, RegionSuccessor(region),
327 region->
getArguments().slice(firstIndex, inputs.size()), lattices,
332 for (
auto it : llvm::zip(*operands, lattices.drop_front(firstIndex)))
368 return initializeRecursively(top);
372AbstractSparseBackwardDataFlowAnalysis::initializeRecursively(
Operation *op) {
373 if (failed(visitOperation(op)))
377 for (
Block &block : region) {
379 ->blockContentSubscribe(
this);
383 for (
auto it = block.
rbegin(); it != block.
rend(); it++)
384 if (failed(initializeRecursively(&*it)))
399 return visitOperation(point->
getPrevOp());
405 resultLattices.reserve(values.size());
408 resultLattices.push_back(resultLattice);
410 return resultLattices;
414AbstractSparseBackwardDataFlowAnalysis::getLatticeElementsFor(
417 resultLattices.reserve(values.size());
420 getLatticeElementFor(point,
result);
421 resultLattices.push_back(resultLattice);
423 return resultLattices;
431AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
432 LDBG() <<
"Visiting operation: "
433 << OpWithFlags(op, OpPrintingFlags().skipRegions()) <<
" with "
441 LDBG() <<
"Operation is in dead block, bailing out";
445 LDBG() <<
"Creating lattice elements for " << op->
getNumOperands()
447 SmallVector<AbstractSparseLattice *> operandLattices =
449 SmallVector<const AbstractSparseLattice *> resultLattices =
454 if (
auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
455 LDBG() <<
"Processing RegionBranchOpInterface operation";
456 visitRegionSuccessors(branch, operandLattices);
460 if (
auto branch = dyn_cast<BranchOpInterface>(op)) {
461 LDBG() <<
"Processing BranchOpInterface operation with "
471 for (
auto [index, block] : llvm::enumerate(op->
getSuccessors())) {
472 SuccessorOperands successorOperands = branch.getSuccessorOperands(index);
474 if (!forwarded.empty()) {
475 MutableArrayRef<OpOperand> operands = op->
getOpOperands().slice(
477 for (OpOperand &operand : operands) {
478 unaccounted.reset(operand.getOperandNumber());
479 if (std::optional<BlockArgument> blockArg =
481 successorOperands, operand.getOperandNumber(), block)) {
490 for (
int index : unaccounted.set_bits()) {
499 if (
auto call = dyn_cast<CallOpInterface>(op)) {
500 LDBG() <<
"Processing CallOpInterface operation";
501 Operation *callableOp = call.resolveCallableInTable(&symbolTable);
502 if (
auto callable = dyn_cast_or_null<CallableOpInterface>(callableOp)) {
510 OperandRange argOperands = call.getArgOperands();
511 MutableArrayRef<OpOperand> argOpOperands =
513 Region *region = callable.getCallableRegion();
514 if (!region || region->
empty() ||
523 for (
auto [blockArg, argOpOperand] :
527 unaccounted.reset(argOpOperand.getOperandNumber());
532 for (
int index : unaccounted.set_bits()) {
550 if (
auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
551 LDBG() <<
"Processing RegionBranchTerminatorOpInterface operation";
552 if (
auto branch = dyn_cast<RegionBranchOpInterface>(op->
getParentOp())) {
553 visitRegionSuccessorsFromTerminator(terminator, branch);
558 if (op->
hasTrait<OpTrait::ReturnLike>()) {
559 LDBG() <<
"Processing ReturnLike operation";
562 if (
auto callable = dyn_cast<CallableOpInterface>(op->
getParentOp())) {
563 LDBG() <<
"Callable parent found, visiting callable operation";
568 LDBG() <<
"Using default visitOperationImpl for operation: "
569 << OpWithFlags(op, OpPrintingFlags().skipRegions());
574 Operation *op, CallableOpInterface callable,
582 for (
auto [op,
result] : llvm::zip(operandLattices, callResultLattices))
594void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors(
595 RegionBranchOpInterface branch,
599 BitVector unaccounted(branch->getNumOperands(),
true);
602 for (
const auto &[operand, inputs] : mapping) {
603 for (
Value input : inputs) {
606 unaccounted.reset(operand->getOperandNumber());
613 branch.getEntrySuccessorRegions(operands, successors);
620 ValueRange inputs = branch.getSuccessorInputs(successor);
624 if (!llvm::is_contained(inputs, argument)) {
625 noControlFlowArguments.push_back(argument);
633 for (
int index : unaccounted.set_bits()) {
638void AbstractSparseBackwardDataFlowAnalysis::
639 visitRegionSuccessorsFromTerminator(
640 RegionBranchTerminatorOpInterface terminator,
641 RegionBranchOpInterface branch) {
642 assert(terminator->getParentOp() == branch.getOperation() &&
643 "expected `branch` to be the parent op of `terminator`");
647 BitVector unaccounted(terminator->getNumOperands(),
true);
650 branch.getSuccessorOperandInputMapping(mapping,
651 RegionBranchPoint(terminator));
652 for (
const auto &[operand, inputs] : mapping) {
653 for (Value input : inputs) {
656 unaccounted.reset(operand->getOperandNumber());
662 for (
int index : unaccounted.set_bits()) {
668AbstractSparseBackwardDataFlowAnalysis::getLatticeElementFor(
669 ProgramPoint *point, Value value) {
static MutableArrayRef< OpOperand > operandsToOpOperands(OperandRange &operands)
virtual void onUpdate(DataFlowSolver *solver) const
This function is called by the solver when the analysis state is updated to enqueue more work items.
LatticeAnchor anchor
The lattice anchor to which the state belongs.
friend class DataFlowSolver
Allow the framework to access the dependents.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
unsigned getNumArguments()
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
pred_iterator pred_begin()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
PredecessorIterator pred_iterator
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
reverse_iterator rbegin()
Base class for all data-flow analyses.
void addDependency(AnalysisState *state, ProgramPoint *point)
Create a dependency between the given analysis state and lattice anchor on this analysis.
ProgramPoint * getProgramPointBefore(Operation *op)
Get a uniqued program point instance.
void propagateIfChanged(AnalysisState *state, ChangeResult changed)
Propagate an update to a state if it changed.
const DataFlowConfig & getSolverConfig() const
Return the configuration of the solver used for this analysis.
StateT * getOrCreate(AnchorT anchor)
Get the analysis state associated with the lattice anchor.
ProgramPoint * getProgramPointAfter(Operation *op)
DataFlowAnalysis(DataFlowSolver &solver)
Create an analysis with a reference to the parent solver.
AnchorT * getLatticeAnchor(Args &&...args)
Get or create a custom lattice anchor.
void registerAnchorKind()
Register a custom lattice anchor class.
friend class DataFlowSolver
Allow the data-flow solver to access the internals of this class.
const StateT * getOrCreateFor(ProgramPoint *dependent, AnchorT anchor)
Get a read-only analysis state for the given point and create a dependency on dependent.
void enqueue(WorkItem item)
Push a work item onto the worklist.
ProgramPoint * getProgramPointAfter(Operation *op)
Set of flags used to control the behavior of the various IR print methods (e.g.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
This class implements the operand iterators for the Operation class.
unsigned getBeginOperandIndex() const
Return the operand index of the first element of this range.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
unsigned getNumSuccessors()
Block * getBlock()
Returns the operation block that contains this operation.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
MutableArrayRef< OpOperand > getOpOperands()
unsigned getNumOperands()
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
operand_range getOperands()
Returns an iterator on the underlying Value's.
SuccessorRange getSuccessors()
result_range getResults()
OpOperand & getOpOperand(unsigned idx)
unsigned getNumResults()
Return the number of results held by this operation.
static constexpr RegionBranchPoint parent()
Returns an instance of RegionBranchPoint representing the parent operation.
This class represents a successor of a region.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
Region * getSuccessor() const
Return the given region successor.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgListType getArguments()
OperandRange getForwardedOperands() const
Get the range of operands that are simply forwarded to the successor.
This class represents a collection of SymbolTables.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
virtual void setToExitState(AbstractSparseLattice *lattice)=0
Set the given lattice element(s) at control flow exit point(s) and propagate the update if it chaned.
SmallVector< AbstractSparseLattice * > getLatticeElements(ValueRange values)
Get the lattice elements for a range of values.
AbstractSparseBackwardDataFlowAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
virtual AbstractSparseLattice * getLatticeElement(Value value)=0
Get the lattice element for a value.
virtual void visitBranchOperand(OpOperand &operand)=0
virtual void visitCallOperand(OpOperand &operand)=0
virtual void visitNonControlFlowArguments(RegionSuccessor &successor, ArrayRef< BlockArgument > arguments)=0
LogicalResult visit(ProgramPoint *point) override
Visit a program point.
void meet(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs)
Join the lattice element and propagate and update if it changed.
virtual LogicalResult visitCallableOperation(Operation *op, CallableOpInterface callable, ArrayRef< AbstractSparseLattice * > operandLattices)
Visits a callable operation.
virtual void visitExternalCallImpl(CallOpInterface call, ArrayRef< AbstractSparseLattice * > operandLattices, ArrayRef< const AbstractSparseLattice * > resultLattices)=0
The transfer function for calls to external functions.
LogicalResult initialize(Operation *top) override
Initialize the analysis by visiting the operation and everything nested under it.
void setAllToExitStates(ArrayRef< AbstractSparseLattice * > lattices)
Set the given lattice element(s) at control flow exit point(s) and propagate the update if it chaned.
virtual LogicalResult visitOperationImpl(Operation *op, ArrayRef< AbstractSparseLattice * > operandLattices, ArrayRef< const AbstractSparseLattice * > resultLattices)=0
The operation transfer function.
LogicalResult visit(ProgramPoint *point) override
Visit a program point.
LogicalResult initialize(Operation *top) override
Initialize the analysis by visiting every owner of an SSA value: all operations and blocks.
virtual void visitExternalCallImpl(CallOpInterface call, ArrayRef< const AbstractSparseLattice * > argumentLattices, ArrayRef< AbstractSparseLattice * > resultLattices)=0
The transfer function for calls to external functions.
AbstractSparseForwardDataFlowAnalysis(DataFlowSolver &solver)
void setAllToEntryStates(ArrayRef< AbstractSparseLattice * > lattices)
virtual void setToEntryState(AbstractSparseLattice *lattice)=0
Set the given lattice element(s) at control flow entry point(s).
const AbstractSparseLattice * getLatticeElementFor(ProgramPoint *point, Value value)
Get a read-only lattice element for a value and add it as a dependency to a program point.
virtual LogicalResult visitCallOperation(CallOpInterface call, ArrayRef< const AbstractSparseLattice * > operandLattices, ArrayRef< AbstractSparseLattice * > resultLattices)
Visits a call operation.
virtual void visitCallableOperation(CallableOpInterface callable, ArrayRef< AbstractSparseLattice * > argLattices)
Visits a callable operation.
virtual AbstractSparseLattice * getLatticeElement(Value value)=0
Get the lattice element of a value.
virtual void visitNonControlFlowArgumentsImpl(Operation *op, const RegionSuccessor &successor, ValueRange successorInputs, ArrayRef< AbstractSparseLattice * > argLattices, unsigned firstIndex)=0
Given an operation with region control-flow, the lattices of the operands, and a region successor,...
virtual LogicalResult visitOperationImpl(Operation *op, ArrayRef< const AbstractSparseLattice * > operandLattices, ArrayRef< AbstractSparseLattice * > resultLattices)=0
The operation transfer function.
void join(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs)
Join the lattice element and propagate and update if it changed.
This class represents an abstract lattice.
void onUpdate(DataFlowSolver *solver) const override
When the lattice gets updated, propagate an update to users of the value using its use-def chain to s...
void useDefSubscribe(DataFlowAnalysis *analysis)
Subscribe an analysis to updates of the lattice.
This analysis state represents a set of live control-flow "predecessors" of a program point (either a...
ArrayRef< Operation * > getKnownPredecessors() const
Get the known predecessors.
bool allPredecessorsKnown() const
Returns true if all predecessors are known.
std::optional< BlockArgument > getBranchSuccessorArgument(const SuccessorOperands &operands, unsigned operandIndex, Block *successor)
Return the BlockArgument corresponding to operand operandIndex in some successor if operandIndex is w...
Include the generated interface declarations.
DenseMap< OpOperand *, SmallVector< Value > > RegionBranchSuccessorMapping
A mapping from successor operands to successor inputs.
Program point represents a specific location in the execution of a program.
bool isBlockStart() const
Block * getBlock() const
Get the block contains this program point.
Operation * getPrevOp() const
Get the previous operation of this program point.