21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/Support/DebugLog.h"
29 #define DEBUG_TYPE "dataflow"
51 registerAnchorKind<CFGEdge>();
61 for (
Value argument : region.front().getArguments())
65 return initializeRecursively(top);
69 AbstractSparseForwardDataFlowAnalysis::initializeRecursively(
Operation *op) {
70 LDBG() <<
"Initializing recursively for operation: " << op->
getName();
74 if (
failed(visitOperation(op))) {
75 LDBG() <<
"Failed to visit operation: " << op->
getName();
80 LDBG() <<
"Processing region with " << region.getBlocks().size()
82 for (
Block &block : region) {
83 LDBG() <<
"Processing block with " << block.getNumArguments()
86 ->blockContentSubscribe(
this);
89 LDBG() <<
"Recursively initializing nested operation: " << op.
getName();
90 if (
failed(initializeRecursively(&op))) {
91 LDBG() <<
"Failed to initialize nested operation: " << op.
getName();
98 LDBG() <<
"Successfully completed recursive initialization for operation: "
106 return visitOperation(point->
getPrevOp());
112 AbstractSparseForwardDataFlowAnalysis::visitOperation(
Operation *op) {
127 resultLattices.push_back(resultLattice);
131 if (
auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
144 operandLattices.push_back(operandLattice);
147 if (
auto call = dyn_cast<CallOpInterface>(op))
154 void AbstractSparseForwardDataFlowAnalysis::visitBlock(
Block *block) {
168 argLattices.push_back(argLattice);
175 auto callable = dyn_cast<CallableOpInterface>(block->
getParentOp());
176 if (callable && callable.getCallableRegion() == block->
getParent())
180 if (
auto branch = dyn_cast<RegionBranchOpInterface>(block->
getParentOp())) {
194 Block *predecessor = *it;
198 auto *edgeExecutable =
199 getOrCreate<Executable>(getLatticeAnchor<CFGEdge>(predecessor, block));
200 edgeExecutable->blockContentSubscribe(
this);
201 if (!edgeExecutable->isLive())
208 branch.getSuccessorOperands(it.getSuccessorIndex());
210 if (
Value operand = operands[idx]) {
226 CallOpInterface call,
231 auto isExternalCallable = [&]() {
233 dyn_cast_if_present<CallableOpInterface>(call.resolveCallable());
234 return callable && !callable.getCallableRegion();
243 const auto *predecessors = getOrCreateFor<PredecessorState>(
247 if (!predecessors->allPredecessorsKnown()) {
251 for (
Operation *predecessor : predecessors->getKnownPredecessors())
252 for (
auto &&[operand, resLattice] :
253 llvm::zip(predecessor->getOperands(), resultLattices))
260 CallableOpInterface callable,
262 Block *entryBlock = &callable.getCallableRegion()->
front();
263 const auto *callsites = getOrCreateFor<PredecessorState>(
267 if (!callsites->allPredecessorsKnown() ||
271 for (
Operation *callsite : callsites->getKnownPredecessors()) {
272 auto call = cast<CallOpInterface>(callsite);
273 for (
auto it : llvm::zip(call.getArgOperands(), argLattices))
274 join(std::get<1>(it),
280 void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors(
283 const auto *predecessors = getOrCreateFor<PredecessorState>(point, point);
284 assert(predecessors->allPredecessorsKnown() &&
285 "unexpected unresolved region successors");
287 for (
Operation *op : predecessors->getKnownPredecessors()) {
289 std::optional<OperandRange> operands;
293 operands = branch.getEntrySuccessorOperands(successor);
295 }
else if (
auto regionTerminator =
296 dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
297 operands = regionTerminator.getSuccessorOperands(successor);
305 ValueRange inputs = predecessors->getSuccessorInputs(op);
306 assert(inputs.size() == operands->size() &&
307 "expected the same number of successor inputs as operands");
309 unsigned firstIndex = 0;
310 if (inputs.size() != lattices.size()) {
313 firstIndex = cast<OpResult>(inputs.front()).getResultNumber();
317 branch->getResults().slice(firstIndex, inputs.size())),
318 lattices, firstIndex);
321 firstIndex = cast<BlockArgument>(inputs.front()).getArgNumber();
326 firstIndex, inputs.size())),
327 lattices, firstIndex);
331 for (
auto it : llvm::zip(*operands, lattices.drop_front(firstIndex)))
362 registerAnchorKind<CFGEdge>();
367 return initializeRecursively(top);
371 AbstractSparseBackwardDataFlowAnalysis::initializeRecursively(
Operation *op) {
372 if (
failed(visitOperation(op)))
376 for (
Block &block : region) {
378 ->blockContentSubscribe(
this);
382 for (
auto it = block.
rbegin(); it != block.
rend(); it++)
383 if (
failed(initializeRecursively(&*it)))
398 return visitOperation(point->
getPrevOp());
404 resultLattices.reserve(values.size());
405 for (
Value result : values) {
407 resultLattices.push_back(resultLattice);
409 return resultLattices;
413 AbstractSparseBackwardDataFlowAnalysis::getLatticeElementsFor(
416 resultLattices.reserve(values.size());
417 for (
Value result : values) {
419 getLatticeElementFor(point, result);
420 resultLattices.push_back(resultLattice);
422 return resultLattices;
430 AbstractSparseBackwardDataFlowAnalysis::visitOperation(
Operation *op) {
431 LDBG() <<
"Visiting operation: " << op->
getName() <<
" with "
439 LDBG() <<
"Operation is in dead block, bailing out";
443 LDBG() <<
"Creating lattice elements for " << op->
getNumOperands()
452 if (
auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
453 LDBG() <<
"Processing RegionBranchOpInterface operation";
454 visitRegionSuccessors(branch, operandLattices);
458 if (
auto branch = dyn_cast<BranchOpInterface>(op)) {
459 LDBG() <<
"Processing BranchOpInterface operation with "
472 if (!forwarded.empty()) {
476 unaccounted.reset(operand.getOperandNumber());
477 if (std::optional<BlockArgument> blockArg =
479 successorOperands, operand.getOperandNumber(), block)) {
488 for (
int index : unaccounted.set_bits()) {
497 if (
auto call = dyn_cast<CallOpInterface>(op)) {
498 LDBG() <<
"Processing CallOpInterface operation";
499 Operation *callableOp = call.resolveCallableInTable(&symbolTable);
500 if (
auto callable = dyn_cast_or_null<CallableOpInterface>(callableOp)) {
511 Region *region = callable.getCallableRegion();
512 if (!region || region->
empty() ||
521 for (
auto [blockArg, argOpOperand] :
525 unaccounted.reset(argOpOperand.getOperandNumber());
530 for (
int index : unaccounted.set_bits()) {
548 if (
auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
549 LDBG() <<
"Processing RegionBranchTerminatorOpInterface operation";
550 if (
auto branch = dyn_cast<RegionBranchOpInterface>(op->
getParentOp())) {
551 visitRegionSuccessorsFromTerminator(terminator, branch);
557 LDBG() <<
"Processing ReturnLike operation";
560 if (
auto callable = dyn_cast<CallableOpInterface>(op->
getParentOp())) {
561 LDBG() <<
"Callable parent found, visiting callable operation";
566 LDBG() <<
"Using default visitOperationImpl for operation: " << op->
getName();
571 Operation *op, CallableOpInterface callable,
579 for (
auto [op, result] : llvm::zip(operandLattices, callResultLattices))
591 void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors(
592 RegionBranchOpInterface branch,
597 branch.getEntrySuccessorRegions(operands, successors);
604 OperandRange operands = branch.getEntrySuccessorOperands(successor);
606 ValueRange inputs = successor.getSuccessorInputs();
607 for (
auto [operand, input] : llvm::zip(opoperands, inputs)) {
610 unaccounted.reset(operand.getOperandNumber());
615 for (
int index : unaccounted.set_bits()) {
620 void AbstractSparseBackwardDataFlowAnalysis::
621 visitRegionSuccessorsFromTerminator(
622 RegionBranchTerminatorOpInterface terminator,
623 RegionBranchOpInterface branch) {
624 assert(isa<RegionBranchTerminatorOpInterface>(terminator) &&
625 "expected a `RegionBranchTerminatorOpInterface` op");
626 assert(terminator->getParentOp() == branch.getOperation() &&
627 "expected `branch` to be the parent op of `terminator`");
632 terminator.getSuccessorRegions(operandAttributes, successors);
635 BitVector unaccounted(terminator->getNumOperands(),
true);
638 ValueRange inputs = successor.getSuccessorInputs();
639 OperandRange operands = terminator.getSuccessorOperands(successor);
641 for (
auto [opOperand, input] : llvm::zip(opOperands, inputs)) {
644 unaccounted.reset(
const_cast<OpOperand &
>(opOperand).getOperandNumber());
649 for (
int index : unaccounted.set_bits()) {
655 AbstractSparseBackwardDataFlowAnalysis::getLatticeElementFor(
static MutableArrayRef< OpOperand > operandsToOpOperands(OperandRange &operands)
virtual void onUpdate(DataFlowSolver *solver) const
This function is called by the solver when the analysis state is updated to enqueue more work items.
LatticeAnchor anchor
The lattice anchor to which the state belongs.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
unsigned getNumArguments()
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
pred_iterator pred_begin()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
reverse_iterator rbegin()
Base class for all data-flow analyses.
void addDependency(AnalysisState *state, ProgramPoint *point)
Create a dependency between the given analysis state and lattice anchor on this analysis.
void propagateIfChanged(AnalysisState *state, ChangeResult changed)
Propagate an update to a state if it changed.
ProgramPoint * getProgramPointAfter(Operation *op)
ProgramPoint * getProgramPointBefore(Operation *op)
Get a uniqued program point instance.
const DataFlowConfig & getSolverConfig() const
Return the configuration of the solver used for this analysis.
The general data-flow analysis solver.
void enqueue(WorkItem item)
Push a work item onto the worklist.
ProgramPoint * getProgramPointAfter(Operation *op)
This class represents an operand of an operation.
This class implements the operand iterators for the Operation class.
unsigned getBeginOperandIndex() const
Return the operand index of the first element of this range.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
OpOperand & getOpOperand(unsigned idx)
unsigned getNumSuccessors()
unsigned getNumOperands()
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Block * getBlock()
Returns the operation block that contains this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
OperationName getName()
The name of an operation is the key identifier for it.
MutableArrayRef< OpOperand > getOpOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
SuccessorRange getSuccessors()
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
Implement a predecessor iterator for blocks.
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
static constexpr RegionBranchPoint parent()
Returns an instance of RegionBranchPoint representing the parent operation.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgListType getArguments()
This class models how operands are forwarded to block arguments in control flow.
OperandRange getForwardedOperands() const
Get the range of operands that are simply forwarded to the successor.
This class represents a collection of SymbolTables.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
virtual void setToExitState(AbstractSparseLattice *lattice)=0
Set the given lattice element(s) at control flow exit point(s) and propagate the update if it chaned.
SmallVector< AbstractSparseLattice * > getLatticeElements(ValueRange values)
Get the lattice elements for a range of values.
AbstractSparseBackwardDataFlowAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
virtual void visitBranchOperand(OpOperand &operand)=0
virtual void visitCallOperand(OpOperand &operand)=0
LogicalResult visit(ProgramPoint *point) override
Visit a program point.
void meet(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs)
Join the lattice element and propagate and update if it changed.
virtual LogicalResult visitCallableOperation(Operation *op, CallableOpInterface callable, ArrayRef< AbstractSparseLattice * > operandLattices)
Visits a callable operation.
virtual void visitExternalCallImpl(CallOpInterface call, ArrayRef< AbstractSparseLattice * > operandLattices, ArrayRef< const AbstractSparseLattice * > resultLattices)=0
The transfer function for calls to external functions.
virtual AbstractSparseLattice * getLatticeElement(Value value)=0
Get the lattice element for a value.
LogicalResult initialize(Operation *top) override
Initialize the analysis by visiting the operation and everything nested under it.
void setAllToExitStates(ArrayRef< AbstractSparseLattice * > lattices)
Set the given lattice element(s) at control flow exit point(s) and propagate the update if it chaned.
virtual LogicalResult visitOperationImpl(Operation *op, ArrayRef< AbstractSparseLattice * > operandLattices, ArrayRef< const AbstractSparseLattice * > resultLattices)=0
The operation transfer function.
LogicalResult visit(ProgramPoint *point) override
Visit a program point.
LogicalResult initialize(Operation *top) override
Initialize the analysis by visiting every owner of an SSA value: all operations and blocks.
virtual AbstractSparseLattice * getLatticeElement(Value value)=0
Get the lattice element of a value.
virtual void visitExternalCallImpl(CallOpInterface call, ArrayRef< const AbstractSparseLattice * > argumentLattices, ArrayRef< AbstractSparseLattice * > resultLattices)=0
The transfer function for calls to external functions.
AbstractSparseForwardDataFlowAnalysis(DataFlowSolver &solver)
void setAllToEntryStates(ArrayRef< AbstractSparseLattice * > lattices)
virtual void setToEntryState(AbstractSparseLattice *lattice)=0
Set the given lattice element(s) at control flow entry point(s).
const AbstractSparseLattice * getLatticeElementFor(ProgramPoint *point, Value value)
Get a read-only lattice element for a value and add it as a dependency to a program point.
virtual void visitNonControlFlowArgumentsImpl(Operation *op, const RegionSuccessor &successor, ArrayRef< AbstractSparseLattice * > argLattices, unsigned firstIndex)=0
Given an operation with region control-flow, the lattices of the operands, and a region successor,...
virtual LogicalResult visitCallOperation(CallOpInterface call, ArrayRef< const AbstractSparseLattice * > operandLattices, ArrayRef< AbstractSparseLattice * > resultLattices)
Visits a call operation.
virtual void visitCallableOperation(CallableOpInterface callable, ArrayRef< AbstractSparseLattice * > argLattices)
Visits a callable operation.
virtual LogicalResult visitOperationImpl(Operation *op, ArrayRef< const AbstractSparseLattice * > operandLattices, ArrayRef< AbstractSparseLattice * > resultLattices)=0
The operation transfer function.
void join(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs)
Join the lattice element and propagate and update if it changed.
This class represents an abstract lattice.
void onUpdate(DataFlowSolver *solver) const override
When the lattice gets updated, propagate an update to users of the value using its use-def chain to s...
virtual ChangeResult join(const AbstractSparseLattice &rhs)
Join the information contained in 'rhs' into this lattice.
virtual ChangeResult meet(const AbstractSparseLattice &rhs)
Meet (intersect) the information in this lattice with 'rhs'.
void useDefSubscribe(DataFlowAnalysis *analysis)
Subscribe an analysis to updates of the lattice.
This analysis state represents a set of live control-flow "predecessors" of a program point (either a...
bool allPredecessorsKnown() const
Returns true if all predecessors are known.
ArrayRef< Operation * > getKnownPredecessors() const
Get the known predecessors.
std::optional< BlockArgument > getBranchSuccessorArgument(const SuccessorOperands &operands, unsigned operandIndex, Block *successor)
Return the BlockArgument corresponding to operand operandIndex in some successor if operandIndex is w...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
This trait indicates that a terminator operation is "return-like".
Program point represents a specific location in the execution of a program.
bool isBlockStart() const
Operation * getPrevOp() const
Get the previous operation of this program point.
Block * getBlock() const
Get the block contains this program point.