21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/Support/DebugLog.h"
29 #define DEBUG_TYPE "dataflow"
51 registerAnchorKind<CFGEdge>();
61 for (
Value argument : region.front().getArguments())
65 return initializeRecursively(top);
69 AbstractSparseForwardDataFlowAnalysis::initializeRecursively(
Operation *op) {
70 LDBG() <<
"Initializing recursively for operation: " << op->
getName();
74 if (
failed(visitOperation(op))) {
75 LDBG() <<
"Failed to visit operation: " << op->
getName();
80 LDBG() <<
"Processing region with " << region.getBlocks().size()
82 for (
Block &block : region) {
83 LDBG() <<
"Processing block with " << block.getNumArguments()
86 ->blockContentSubscribe(
this);
89 LDBG() <<
"Recursively initializing nested operation: " << op.
getName();
90 if (
failed(initializeRecursively(&op))) {
91 LDBG() <<
"Failed to initialize nested operation: " << op.
getName();
98 LDBG() <<
"Successfully completed recursive initialization for operation: "
106 return visitOperation(point->
getPrevOp());
112 AbstractSparseForwardDataFlowAnalysis::visitOperation(
Operation *op) {
127 resultLattices.push_back(resultLattice);
131 if (
auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
144 operandLattices.push_back(operandLattice);
147 if (
auto call = dyn_cast<CallOpInterface>(op))
154 void AbstractSparseForwardDataFlowAnalysis::visitBlock(
Block *block) {
168 argLattices.push_back(argLattice);
175 auto callable = dyn_cast<CallableOpInterface>(block->
getParentOp());
176 if (callable && callable.getCallableRegion() == block->
getParent())
180 if (
auto branch = dyn_cast<RegionBranchOpInterface>(block->
getParentOp())) {
194 Block *predecessor = *it;
198 auto *edgeExecutable =
199 getOrCreate<Executable>(getLatticeAnchor<CFGEdge>(predecessor, block));
200 edgeExecutable->blockContentSubscribe(
this);
201 if (!edgeExecutable->isLive())
208 branch.getSuccessorOperands(it.getSuccessorIndex());
210 if (
Value operand = operands[idx]) {
226 CallOpInterface call,
232 dyn_cast_if_present<CallableOpInterface>(call.resolveCallable());
234 (callable && !callable.getCallableRegion())) {
241 const auto *predecessors = getOrCreateFor<PredecessorState>(
245 if (!predecessors->allPredecessorsKnown()) {
249 for (
Operation *predecessor : predecessors->getKnownPredecessors())
250 for (
auto &&[operand, resLattice] :
251 llvm::zip(predecessor->getOperands(), resultLattices))
258 CallableOpInterface callable,
260 Block *entryBlock = &callable.getCallableRegion()->
front();
261 const auto *callsites = getOrCreateFor<PredecessorState>(
265 if (!callsites->allPredecessorsKnown() ||
269 for (
Operation *callsite : callsites->getKnownPredecessors()) {
270 auto call = cast<CallOpInterface>(callsite);
271 for (
auto it : llvm::zip(call.getArgOperands(), argLattices))
272 join(std::get<1>(it),
278 void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors(
281 const auto *predecessors = getOrCreateFor<PredecessorState>(point, point);
282 assert(predecessors->allPredecessorsKnown() &&
283 "unexpected unresolved region successors");
285 for (
Operation *op : predecessors->getKnownPredecessors()) {
287 std::optional<OperandRange> operands;
291 operands = branch.getEntrySuccessorOperands(successor);
293 }
else if (
auto regionTerminator =
294 dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
295 operands = regionTerminator.getSuccessorOperands(successor);
303 ValueRange inputs = predecessors->getSuccessorInputs(op);
304 assert(inputs.size() == operands->size() &&
305 "expected the same number of successor inputs as operands");
307 unsigned firstIndex = 0;
308 if (inputs.size() != lattices.size()) {
311 firstIndex = cast<OpResult>(inputs.front()).getResultNumber();
315 branch->getResults().slice(firstIndex, inputs.size())),
316 lattices, firstIndex);
319 firstIndex = cast<BlockArgument>(inputs.front()).getArgNumber();
324 firstIndex, inputs.size())),
325 lattices, firstIndex);
329 for (
auto it : llvm::zip(*operands, lattices.drop_front(firstIndex)))
360 registerAnchorKind<CFGEdge>();
365 return initializeRecursively(top);
369 AbstractSparseBackwardDataFlowAnalysis::initializeRecursively(
Operation *op) {
370 if (
failed(visitOperation(op)))
374 for (
Block &block : region) {
376 ->blockContentSubscribe(
this);
380 for (
auto it = block.
rbegin(); it != block.
rend(); it++)
381 if (
failed(initializeRecursively(&*it)))
396 return visitOperation(point->
getPrevOp());
402 resultLattices.reserve(values.size());
403 for (
Value result : values) {
405 resultLattices.push_back(resultLattice);
407 return resultLattices;
411 AbstractSparseBackwardDataFlowAnalysis::getLatticeElementsFor(
414 resultLattices.reserve(values.size());
415 for (
Value result : values) {
417 getLatticeElementFor(point, result);
418 resultLattices.push_back(resultLattice);
420 return resultLattices;
428 AbstractSparseBackwardDataFlowAnalysis::visitOperation(
Operation *op) {
429 LDBG() <<
"Visiting operation: " << op->
getName() <<
" with "
437 LDBG() <<
"Operation is in dead block, bailing out";
441 LDBG() <<
"Creating lattice elements for " << op->
getNumOperands()
450 if (
auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
451 LDBG() <<
"Processing RegionBranchOpInterface operation";
452 visitRegionSuccessors(branch, operandLattices);
456 if (
auto branch = dyn_cast<BranchOpInterface>(op)) {
457 LDBG() <<
"Processing BranchOpInterface operation with "
470 if (!forwarded.empty()) {
474 unaccounted.reset(operand.getOperandNumber());
475 if (std::optional<BlockArgument> blockArg =
477 successorOperands, operand.getOperandNumber(), block)) {
486 for (
int index : unaccounted.set_bits()) {
495 if (
auto call = dyn_cast<CallOpInterface>(op)) {
496 LDBG() <<
"Processing CallOpInterface operation";
497 Operation *callableOp = call.resolveCallableInTable(&symbolTable);
498 if (
auto callable = dyn_cast_or_null<CallableOpInterface>(callableOp)) {
509 Region *region = callable.getCallableRegion();
510 if (!region || region->
empty() ||
519 for (
auto [blockArg, argOpOperand] :
523 unaccounted.reset(argOpOperand.getOperandNumber());
528 for (
int index : unaccounted.set_bits()) {
546 if (
auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
547 LDBG() <<
"Processing RegionBranchTerminatorOpInterface operation";
548 if (
auto branch = dyn_cast<RegionBranchOpInterface>(op->
getParentOp())) {
549 visitRegionSuccessorsFromTerminator(terminator, branch);
555 LDBG() <<
"Processing ReturnLike operation";
558 if (
auto callable = dyn_cast<CallableOpInterface>(op->
getParentOp())) {
559 LDBG() <<
"Callable parent found, visiting callable operation";
564 LDBG() <<
"Using default visitOperationImpl for operation: " << op->
getName();
569 Operation *op, CallableOpInterface callable,
577 for (
auto [op, result] : llvm::zip(operandLattices, callResultLattices))
589 void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors(
590 RegionBranchOpInterface branch,
595 branch.getEntrySuccessorRegions(operands, successors);
602 OperandRange operands = branch.getEntrySuccessorOperands(successor);
604 ValueRange inputs = successor.getSuccessorInputs();
605 for (
auto [operand, input] : llvm::zip(opoperands, inputs)) {
608 unaccounted.reset(operand.getOperandNumber());
613 for (
int index : unaccounted.set_bits()) {
618 void AbstractSparseBackwardDataFlowAnalysis::
619 visitRegionSuccessorsFromTerminator(
620 RegionBranchTerminatorOpInterface terminator,
621 RegionBranchOpInterface branch) {
622 assert(isa<RegionBranchTerminatorOpInterface>(terminator) &&
623 "expected a `RegionBranchTerminatorOpInterface` op");
624 assert(terminator->getParentOp() == branch.getOperation() &&
625 "expected `branch` to be the parent op of `terminator`");
630 terminator.getSuccessorRegions(operandAttributes, successors);
633 BitVector unaccounted(terminator->getNumOperands(),
true);
636 ValueRange inputs = successor.getSuccessorInputs();
637 OperandRange operands = terminator.getSuccessorOperands(successor);
639 for (
auto [opOperand, input] : llvm::zip(opOperands, inputs)) {
642 unaccounted.reset(
const_cast<OpOperand &
>(opOperand).getOperandNumber());
647 for (
int index : unaccounted.set_bits()) {
653 AbstractSparseBackwardDataFlowAnalysis::getLatticeElementFor(
static MutableArrayRef< OpOperand > operandsToOpOperands(OperandRange &operands)
virtual void onUpdate(DataFlowSolver *solver) const
This function is called by the solver when the analysis state is updated to enqueue more work items.
LatticeAnchor anchor
The lattice anchor to which the state belongs.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
unsigned getNumArguments()
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
pred_iterator pred_begin()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
reverse_iterator rbegin()
Base class for all data-flow analyses.
void addDependency(AnalysisState *state, ProgramPoint *point)
Create a dependency between the given analysis state and lattice anchor on this analysis.
void propagateIfChanged(AnalysisState *state, ChangeResult changed)
Propagate an update to a state if it changed.
ProgramPoint * getProgramPointAfter(Operation *op)
ProgramPoint * getProgramPointBefore(Operation *op)
Get a uniqued program point instance.
const DataFlowConfig & getSolverConfig() const
Return the configuration of the solver used for this analysis.
The general data-flow analysis solver.
void enqueue(WorkItem item)
Push a work item onto the worklist.
ProgramPoint * getProgramPointAfter(Operation *op)
This class represents an operand of an operation.
This class implements the operand iterators for the Operation class.
unsigned getBeginOperandIndex() const
Return the operand index of the first element of this range.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
OpOperand & getOpOperand(unsigned idx)
unsigned getNumSuccessors()
unsigned getNumOperands()
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Block * getBlock()
Returns the operation block that contains this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
OperationName getName()
The name of an operation is the key identifier for it.
MutableArrayRef< OpOperand > getOpOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
SuccessorRange getSuccessors()
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
Implement a predecessor iterator for blocks.
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
static constexpr RegionBranchPoint parent()
Returns an instance of RegionBranchPoint representing the parent operation.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgListType getArguments()
This class models how operands are forwarded to block arguments in control flow.
OperandRange getForwardedOperands() const
Get the range of operands that are simply forwarded to the successor.
This class represents a collection of SymbolTables.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
virtual void setToExitState(AbstractSparseLattice *lattice)=0
Set the given lattice element(s) at control flow exit point(s) and propagate the update if it chaned.
SmallVector< AbstractSparseLattice * > getLatticeElements(ValueRange values)
Get the lattice elements for a range of values.
AbstractSparseBackwardDataFlowAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
virtual void visitBranchOperand(OpOperand &operand)=0
virtual void visitCallOperand(OpOperand &operand)=0
LogicalResult visit(ProgramPoint *point) override
Visit a program point.
void meet(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs)
Join the lattice element and propagate and update if it changed.
virtual LogicalResult visitCallableOperation(Operation *op, CallableOpInterface callable, ArrayRef< AbstractSparseLattice * > operandLattices)
Visits a callable operation.
virtual void visitExternalCallImpl(CallOpInterface call, ArrayRef< AbstractSparseLattice * > operandLattices, ArrayRef< const AbstractSparseLattice * > resultLattices)=0
The transfer function for calls to external functions.
virtual AbstractSparseLattice * getLatticeElement(Value value)=0
Get the lattice element for a value.
LogicalResult initialize(Operation *top) override
Initialize the analysis by visiting the operation and everything nested under it.
void setAllToExitStates(ArrayRef< AbstractSparseLattice * > lattices)
Set the given lattice element(s) at control flow exit point(s) and propagate the update if it chaned.
virtual LogicalResult visitOperationImpl(Operation *op, ArrayRef< AbstractSparseLattice * > operandLattices, ArrayRef< const AbstractSparseLattice * > resultLattices)=0
The operation transfer function.
LogicalResult visit(ProgramPoint *point) override
Visit a program point.
LogicalResult initialize(Operation *top) override
Initialize the analysis by visiting every owner of an SSA value: all operations and blocks.
virtual AbstractSparseLattice * getLatticeElement(Value value)=0
Get the lattice element of a value.
virtual void visitExternalCallImpl(CallOpInterface call, ArrayRef< const AbstractSparseLattice * > argumentLattices, ArrayRef< AbstractSparseLattice * > resultLattices)=0
The transfer function for calls to external functions.
AbstractSparseForwardDataFlowAnalysis(DataFlowSolver &solver)
void setAllToEntryStates(ArrayRef< AbstractSparseLattice * > lattices)
virtual void setToEntryState(AbstractSparseLattice *lattice)=0
Set the given lattice element(s) at control flow entry point(s).
const AbstractSparseLattice * getLatticeElementFor(ProgramPoint *point, Value value)
Get a read-only lattice element for a value and add it as a dependency to a program point.
virtual void visitNonControlFlowArgumentsImpl(Operation *op, const RegionSuccessor &successor, ArrayRef< AbstractSparseLattice * > argLattices, unsigned firstIndex)=0
Given an operation with region control-flow, the lattices of the operands, and a region successor,...
virtual LogicalResult visitCallOperation(CallOpInterface call, ArrayRef< const AbstractSparseLattice * > operandLattices, ArrayRef< AbstractSparseLattice * > resultLattices)
Visits a call operation.
virtual void visitCallableOperation(CallableOpInterface callable, ArrayRef< AbstractSparseLattice * > argLattices)
Visits a callable operation.
virtual LogicalResult visitOperationImpl(Operation *op, ArrayRef< const AbstractSparseLattice * > operandLattices, ArrayRef< AbstractSparseLattice * > resultLattices)=0
The operation transfer function.
void join(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs)
Join the lattice element and propagate and update if it changed.
This class represents an abstract lattice.
void onUpdate(DataFlowSolver *solver) const override
When the lattice gets updated, propagate an update to users of the value using its use-def chain to s...
virtual ChangeResult join(const AbstractSparseLattice &rhs)
Join the information contained in 'rhs' into this lattice.
virtual ChangeResult meet(const AbstractSparseLattice &rhs)
Meet (intersect) the information in this lattice with 'rhs'.
void useDefSubscribe(DataFlowAnalysis *analysis)
Subscribe an analysis to updates of the lattice.
This analysis state represents a set of live control-flow "predecessors" of a program point (either a...
bool allPredecessorsKnown() const
Returns true if all predecessors are known.
ArrayRef< Operation * > getKnownPredecessors() const
Get the known predecessors.
std::optional< BlockArgument > getBranchSuccessorArgument(const SuccessorOperands &operands, unsigned operandIndex, Block *successor)
Return the BlockArgument corresponding to operand operandIndex in some successor if operandIndex is w...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
This trait indicates that a terminator operation is "return-like".
Program point represents a specific location in the execution of a program.
bool isBlockStart() const
Operation * getPrevOp() const
Get the previous operation of this program point.
Block * getBlock() const
Get the block contains this program point.