21#include "llvm/ADT/STLExtras.h"
22#include "llvm/Support/DebugLog.h"
29#define DEBUG_TYPE "dataflow"
61 for (
Value argument : region.front().getArguments())
65 return initializeRecursively(top);
69AbstractSparseForwardDataFlowAnalysis::initializeRecursively(
Operation *op) {
70 LDBG() <<
"Initializing recursively for operation: " << op->
getName();
74 if (failed(visitOperation(op))) {
75 LDBG() <<
"Failed to visit operation: " << op->
getName();
80 LDBG() <<
"Processing region with " << region.getBlocks().size()
82 for (
Block &block : region) {
83 LDBG() <<
"Processing block with " << block.getNumArguments()
86 ->blockContentSubscribe(
this);
89 LDBG() <<
"Recursively initializing nested operation: " << op.
getName();
90 if (failed(initializeRecursively(&op))) {
91 LDBG() <<
"Failed to initialize nested operation: " << op.
getName();
98 LDBG() <<
"Successfully completed recursive initialization for operation: "
106 return visitOperation(point->
getPrevOp());
112AbstractSparseForwardDataFlowAnalysis::visitOperation(
Operation *op) {
127 resultLattices.push_back(resultLattice);
131 if (
auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
133 {branch, branch->getResults()},
139 SmallVector<const AbstractSparseLattice *> operandLattices;
144 operandLattices.push_back(operandLattice);
147 if (
auto call = dyn_cast<CallOpInterface>(op))
154void AbstractSparseForwardDataFlowAnalysis::visitBlock(
Block *block) {
164 SmallVector<AbstractSparseLattice *> argLattices;
168 argLattices.push_back(argLattice);
175 auto callable = dyn_cast<CallableOpInterface>(block->
getParentOp());
176 if (callable && callable.getCallableRegion() == block->
getParent())
180 if (
auto branch = dyn_cast<RegionBranchOpInterface>(block->
getParentOp())) {
194 Block *predecessor = *it;
198 auto *edgeExecutable =
200 edgeExecutable->blockContentSubscribe(
this);
201 if (!edgeExecutable->isLive())
207 SuccessorOperands operands =
208 branch.getSuccessorOperands(it.getSuccessorIndex());
209 for (
auto [idx, lattice] : llvm::enumerate(argLattices)) {
210 if (Value operand = operands[idx]) {
226 CallOpInterface call,
231 auto isExternalCallable = [&]() {
233 dyn_cast_if_present<CallableOpInterface>(call.resolveCallable());
234 return callable && !callable.getCallableRegion();
247 if (!predecessors->allPredecessorsKnown()) {
251 for (
Operation *predecessor : predecessors->getKnownPredecessors())
252 for (
auto &&[operand, resLattice] :
253 llvm::zip(predecessor->getOperands(), resultLattices))
260 CallableOpInterface callable,
262 Block *entryBlock = &callable.getCallableRegion()->
front();
267 if (!callsites->allPredecessorsKnown() ||
271 for (
Operation *callsite : callsites->getKnownPredecessors()) {
272 auto call = cast<CallOpInterface>(callsite);
273 for (
auto it : llvm::zip(call.getArgOperands(), argLattices))
274 join(std::get<1>(it),
280void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors(
284 assert(predecessors->allPredecessorsKnown() &&
285 "unexpected unresolved region successors");
287 for (
Operation *op : predecessors->getKnownPredecessors()) {
289 std::optional<OperandRange> operands;
293 operands = branch.getEntrySuccessorOperands(successor);
295 }
else if (
auto regionTerminator =
296 dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
297 operands = regionTerminator.getSuccessorOperands(successor);
305 ValueRange inputs = predecessors->getSuccessorInputs(op);
306 assert(inputs.size() == operands->size() &&
307 "expected the same number of successor inputs as operands");
309 unsigned firstIndex = 0;
310 if (inputs.size() != lattices.size()) {
313 firstIndex = cast<OpResult>(inputs.front()).getResultNumber();
317 branch, branch->getResults().slice(firstIndex, inputs.size())),
318 lattices, firstIndex);
321 firstIndex = cast<BlockArgument>(inputs.front()).getArgNumber();
326 firstIndex, inputs.size())),
327 lattices, firstIndex);
331 for (
auto it : llvm::zip(*operands, lattices.drop_front(firstIndex)))
367 return initializeRecursively(top);
371AbstractSparseBackwardDataFlowAnalysis::initializeRecursively(
Operation *op) {
372 if (failed(visitOperation(op)))
376 for (
Block &block : region) {
378 ->blockContentSubscribe(
this);
382 for (
auto it = block.
rbegin(); it != block.
rend(); it++)
383 if (failed(initializeRecursively(&*it)))
398 return visitOperation(point->
getPrevOp());
404 resultLattices.reserve(values.size());
407 resultLattices.push_back(resultLattice);
409 return resultLattices;
413AbstractSparseBackwardDataFlowAnalysis::getLatticeElementsFor(
416 resultLattices.reserve(values.size());
419 getLatticeElementFor(point,
result);
420 resultLattices.push_back(resultLattice);
422 return resultLattices;
430AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
431 LDBG() <<
"Visiting operation: " << op->
getName() <<
" with "
439 LDBG() <<
"Operation is in dead block, bailing out";
443 LDBG() <<
"Creating lattice elements for " << op->
getNumOperands()
445 SmallVector<AbstractSparseLattice *> operandLattices =
447 SmallVector<const AbstractSparseLattice *> resultLattices =
452 if (
auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
453 LDBG() <<
"Processing RegionBranchOpInterface operation";
454 visitRegionSuccessors(branch, operandLattices);
458 if (
auto branch = dyn_cast<BranchOpInterface>(op)) {
459 LDBG() <<
"Processing BranchOpInterface operation with "
469 for (
auto [index, block] : llvm::enumerate(op->
getSuccessors())) {
470 SuccessorOperands successorOperands = branch.getSuccessorOperands(index);
472 if (!forwarded.empty()) {
473 MutableArrayRef<OpOperand> operands = op->
getOpOperands().slice(
475 for (OpOperand &operand : operands) {
476 unaccounted.reset(operand.getOperandNumber());
477 if (std::optional<BlockArgument> blockArg =
479 successorOperands, operand.getOperandNumber(), block)) {
488 for (
int index : unaccounted.set_bits()) {
497 if (
auto call = dyn_cast<CallOpInterface>(op)) {
498 LDBG() <<
"Processing CallOpInterface operation";
499 Operation *callableOp = call.resolveCallableInTable(&symbolTable);
500 if (
auto callable = dyn_cast_or_null<CallableOpInterface>(callableOp)) {
508 OperandRange argOperands = call.getArgOperands();
509 MutableArrayRef<OpOperand> argOpOperands =
511 Region *region = callable.getCallableRegion();
512 if (!region || region->
empty() ||
521 for (
auto [blockArg, argOpOperand] :
525 unaccounted.reset(argOpOperand.getOperandNumber());
530 for (
int index : unaccounted.set_bits()) {
548 if (
auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
549 LDBG() <<
"Processing RegionBranchTerminatorOpInterface operation";
550 if (
auto branch = dyn_cast<RegionBranchOpInterface>(op->
getParentOp())) {
551 visitRegionSuccessorsFromTerminator(terminator, branch);
556 if (op->
hasTrait<OpTrait::ReturnLike>()) {
557 LDBG() <<
"Processing ReturnLike operation";
560 if (
auto callable = dyn_cast<CallableOpInterface>(op->
getParentOp())) {
561 LDBG() <<
"Callable parent found, visiting callable operation";
566 LDBG() <<
"Using default visitOperationImpl for operation: " << op->
getName();
571 Operation *op, CallableOpInterface callable,
579 for (
auto [op,
result] : llvm::zip(operandLattices, callResultLattices))
591void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors(
592 RegionBranchOpInterface branch,
597 branch.getEntrySuccessorRegions(operands, successors);
604 OperandRange operands = branch.getEntrySuccessorOperands(successor);
607 for (
auto [operand, input] : llvm::zip(opoperands, inputs)) {
610 unaccounted.reset(operand.getOperandNumber());
615 for (
int index : unaccounted.set_bits()) {
620void AbstractSparseBackwardDataFlowAnalysis::
621 visitRegionSuccessorsFromTerminator(
622 RegionBranchTerminatorOpInterface terminator,
623 RegionBranchOpInterface branch) {
624 assert(isa<RegionBranchTerminatorOpInterface>(terminator) &&
625 "expected a `RegionBranchTerminatorOpInterface` op");
626 assert(terminator->getParentOp() == branch.getOperation() &&
627 "expected `branch` to be the parent op of `terminator`");
629 SmallVector<Attribute> operandAttributes(terminator->getNumOperands(),
631 SmallVector<RegionSuccessor> successors;
632 terminator.getSuccessorRegions(operandAttributes, successors);
635 BitVector unaccounted(terminator->getNumOperands(),
true);
637 for (
const RegionSuccessor &successor : successors) {
639 OperandRange operands = terminator.getSuccessorOperands(successor);
641 for (
auto [opOperand, input] : llvm::zip(opOperands, inputs)) {
644 unaccounted.reset(
const_cast<OpOperand &
>(opOperand).getOperandNumber());
649 for (
int index : unaccounted.set_bits()) {
655AbstractSparseBackwardDataFlowAnalysis::getLatticeElementFor(
656 ProgramPoint *point, Value value) {
static MutableArrayRef< OpOperand > operandsToOpOperands(OperandRange &operands)
virtual void onUpdate(DataFlowSolver *solver) const
This function is called by the solver when the analysis state is updated to enqueue more work items.
LatticeAnchor anchor
The lattice anchor to which the state belongs.
friend class DataFlowSolver
Allow the framework to access the dependents.
Block represents an ordered list of Operations.
unsigned getNumArguments()
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
pred_iterator pred_begin()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
PredecessorIterator pred_iterator
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
reverse_iterator rbegin()
Base class for all data-flow analyses.
void addDependency(AnalysisState *state, ProgramPoint *point)
Create a dependency between the given analysis state and lattice anchor on this analysis.
ProgramPoint * getProgramPointBefore(Operation *op)
Get a uniqued program point instance.
void propagateIfChanged(AnalysisState *state, ChangeResult changed)
Propagate an update to a state if it changed.
const DataFlowConfig & getSolverConfig() const
Return the configuration of the solver used for this analysis.
StateT * getOrCreate(AnchorT anchor)
Get the analysis state associated with the lattice anchor.
ProgramPoint * getProgramPointAfter(Operation *op)
DataFlowAnalysis(DataFlowSolver &solver)
Create an analysis with a reference to the parent solver.
AnchorT * getLatticeAnchor(Args &&...args)
Get or create a custom lattice anchor.
void registerAnchorKind()
Register a custom lattice anchor class.
friend class DataFlowSolver
Allow the data-flow solver to access the internals of this class.
const StateT * getOrCreateFor(ProgramPoint *dependent, AnchorT anchor)
Get a read-only analysis state for the given point and create a dependency on dependent.
void enqueue(WorkItem item)
Push a work item onto the worklist.
ProgramPoint * getProgramPointAfter(Operation *op)
This class implements the operand iterators for the Operation class.
unsigned getBeginOperandIndex() const
Return the operand index of the first element of this range.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
unsigned getNumSuccessors()
Block * getBlock()
Returns the operation block that contains this operation.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
MutableArrayRef< OpOperand > getOpOperands()
unsigned getNumOperands()
OperationName getName()
The name of an operation is the key identifier for it.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
operand_range getOperands()
Returns an iterator on the underlying Value's.
SuccessorRange getSuccessors()
result_range getResults()
OpOperand & getOpOperand(unsigned idx)
unsigned getNumResults()
Return the number of results held by this operation.
This class represents a successor of a region.
ValueRange getSuccessorInputs() const
Return the inputs to the successor that are remapped by the exit values of the current region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgListType getArguments()
OperandRange getForwardedOperands() const
Get the range of operands that are simply forwarded to the successor.
This class represents a collection of SymbolTables.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
virtual void setToExitState(AbstractSparseLattice *lattice)=0
Set the given lattice element(s) at control flow exit point(s) and propagate the update if it chaned.
SmallVector< AbstractSparseLattice * > getLatticeElements(ValueRange values)
Get the lattice elements for a range of values.
AbstractSparseBackwardDataFlowAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
virtual AbstractSparseLattice * getLatticeElement(Value value)=0
Get the lattice element for a value.
virtual void visitBranchOperand(OpOperand &operand)=0
virtual void visitCallOperand(OpOperand &operand)=0
LogicalResult visit(ProgramPoint *point) override
Visit a program point.
void meet(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs)
Join the lattice element and propagate and update if it changed.
virtual LogicalResult visitCallableOperation(Operation *op, CallableOpInterface callable, ArrayRef< AbstractSparseLattice * > operandLattices)
Visits a callable operation.
virtual void visitExternalCallImpl(CallOpInterface call, ArrayRef< AbstractSparseLattice * > operandLattices, ArrayRef< const AbstractSparseLattice * > resultLattices)=0
The transfer function for calls to external functions.
LogicalResult initialize(Operation *top) override
Initialize the analysis by visiting the operation and everything nested under it.
void setAllToExitStates(ArrayRef< AbstractSparseLattice * > lattices)
Set the given lattice element(s) at control flow exit point(s) and propagate the update if it chaned.
virtual LogicalResult visitOperationImpl(Operation *op, ArrayRef< AbstractSparseLattice * > operandLattices, ArrayRef< const AbstractSparseLattice * > resultLattices)=0
The operation transfer function.
LogicalResult visit(ProgramPoint *point) override
Visit a program point.
LogicalResult initialize(Operation *top) override
Initialize the analysis by visiting every owner of an SSA value: all operations and blocks.
virtual void visitExternalCallImpl(CallOpInterface call, ArrayRef< const AbstractSparseLattice * > argumentLattices, ArrayRef< AbstractSparseLattice * > resultLattices)=0
The transfer function for calls to external functions.
AbstractSparseForwardDataFlowAnalysis(DataFlowSolver &solver)
void setAllToEntryStates(ArrayRef< AbstractSparseLattice * > lattices)
virtual void setToEntryState(AbstractSparseLattice *lattice)=0
Set the given lattice element(s) at control flow entry point(s).
const AbstractSparseLattice * getLatticeElementFor(ProgramPoint *point, Value value)
Get a read-only lattice element for a value and add it as a dependency to a program point.
virtual void visitNonControlFlowArgumentsImpl(Operation *op, const RegionSuccessor &successor, ArrayRef< AbstractSparseLattice * > argLattices, unsigned firstIndex)=0
Given an operation with region control-flow, the lattices of the operands, and a region successor,...
virtual LogicalResult visitCallOperation(CallOpInterface call, ArrayRef< const AbstractSparseLattice * > operandLattices, ArrayRef< AbstractSparseLattice * > resultLattices)
Visits a call operation.
virtual void visitCallableOperation(CallableOpInterface callable, ArrayRef< AbstractSparseLattice * > argLattices)
Visits a callable operation.
virtual AbstractSparseLattice * getLatticeElement(Value value)=0
Get the lattice element of a value.
virtual LogicalResult visitOperationImpl(Operation *op, ArrayRef< const AbstractSparseLattice * > operandLattices, ArrayRef< AbstractSparseLattice * > resultLattices)=0
The operation transfer function.
void join(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs)
Join the lattice element and propagate and update if it changed.
This class represents an abstract lattice.
void onUpdate(DataFlowSolver *solver) const override
When the lattice gets updated, propagate an update to users of the value using its use-def chain to s...
void useDefSubscribe(DataFlowAnalysis *analysis)
Subscribe an analysis to updates of the lattice.
This analysis state represents a set of live control-flow "predecessors" of a program point (either a...
ArrayRef< Operation * > getKnownPredecessors() const
Get the known predecessors.
bool allPredecessorsKnown() const
Returns true if all predecessors are known.
std::optional< BlockArgument > getBranchSuccessorArgument(const SuccessorOperands &operands, unsigned operandIndex, Block *successor)
Return the BlockArgument corresponding to operand operandIndex in some successor if operandIndex is w...
Include the generated interface declarations.
Program point represents a specific location in the execution of a program.
bool isBlockStart() const
Block * getBlock() const
Get the block contains this program point.
Operation * getPrevOp() const
Get the previous operation of this program point.