21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/Support/Casting.h"
49 registerAnchorKind<CFGEdge>();
59 for (
Value argument : region.front().getArguments())
63 return initializeRecursively(top);
67 AbstractSparseForwardDataFlowAnalysis::initializeRecursively(
Operation *op) {
70 if (failed(visitOperation(op)))
74 for (
Block &block : region) {
76 ->blockContentSubscribe(
this);
79 if (failed(initializeRecursively(&op)))
90 return visitOperation(point->
getPrevOp());
96 AbstractSparseForwardDataFlowAnalysis::visitOperation(
Operation *op) {
111 resultLattices.push_back(resultLattice);
115 if (
auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
128 operandLattices.push_back(operandLattice);
131 if (
auto call = dyn_cast<CallOpInterface>(op)) {
135 dyn_cast_if_present<CallableOpInterface>(call.resolveCallable());
137 (callable && !callable.getCallableRegion())) {
144 const auto *predecessors = getOrCreateFor<PredecessorState>(
148 if (!predecessors->allPredecessorsKnown()) {
152 for (
Operation *predecessor : predecessors->getKnownPredecessors())
153 for (
auto &&[operand, resLattice] :
154 llvm::zip(predecessor->getOperands(), resultLattices))
164 void AbstractSparseForwardDataFlowAnalysis::visitBlock(
Block *block) {
178 argLattices.push_back(argLattice);
185 auto callable = dyn_cast<CallableOpInterface>(block->
getParentOp());
186 if (callable && callable.getCallableRegion() == block->
getParent()) {
187 const auto *callsites = getOrCreateFor<PredecessorState>(
191 if (!callsites->allPredecessorsKnown() ||
195 for (
Operation *callsite : callsites->getKnownPredecessors()) {
196 auto call = cast<CallOpInterface>(callsite);
197 for (
auto it : llvm::zip(call.getArgOperands(), argLattices))
198 join(std::get<1>(it),
206 if (
auto branch = dyn_cast<RegionBranchOpInterface>(block->
getParentOp())) {
220 Block *predecessor = *it;
224 auto *edgeExecutable =
225 getOrCreate<Executable>(getLatticeAnchor<CFGEdge>(predecessor, block));
226 edgeExecutable->blockContentSubscribe(
this);
227 if (!edgeExecutable->isLive())
234 branch.getSuccessorOperands(it.getSuccessorIndex());
236 if (
Value operand = operands[idx]) {
251 void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors(
254 const auto *predecessors = getOrCreateFor<PredecessorState>(point, point);
255 assert(predecessors->allPredecessorsKnown() &&
256 "unexpected unresolved region successors");
258 for (
Operation *op : predecessors->getKnownPredecessors()) {
260 std::optional<OperandRange> operands;
264 operands = branch.getEntrySuccessorOperands(successor);
266 }
else if (
auto regionTerminator =
267 dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
268 operands = regionTerminator.getSuccessorOperands(successor);
276 ValueRange inputs = predecessors->getSuccessorInputs(op);
277 assert(inputs.size() == operands->size() &&
278 "expected the same number of successor inputs as operands");
280 unsigned firstIndex = 0;
281 if (inputs.size() != lattices.size()) {
284 firstIndex = cast<OpResult>(inputs.front()).getResultNumber();
288 branch->getResults().slice(firstIndex, inputs.size())),
289 lattices, firstIndex);
292 firstIndex = cast<BlockArgument>(inputs.front()).getArgNumber();
297 firstIndex, inputs.size())),
298 lattices, firstIndex);
302 for (
auto it : llvm::zip(*operands, lattices.drop_front(firstIndex)))
333 registerAnchorKind<CFGEdge>();
338 return initializeRecursively(top);
342 AbstractSparseBackwardDataFlowAnalysis::initializeRecursively(
Operation *op) {
343 if (failed(visitOperation(op)))
347 for (
Block &block : region) {
349 ->blockContentSubscribe(
this);
353 for (
auto it = block.
rbegin(); it != block.
rend(); it++)
354 if (failed(initializeRecursively(&*it)))
369 return visitOperation(point->
getPrevOp());
375 resultLattices.reserve(values.size());
376 for (
Value result : values) {
378 resultLattices.push_back(resultLattice);
380 return resultLattices;
384 AbstractSparseBackwardDataFlowAnalysis::getLatticeElementsFor(
387 resultLattices.reserve(values.size());
388 for (
Value result : values) {
390 getLatticeElementFor(point, result);
391 resultLattices.push_back(resultLattice);
393 return resultLattices;
401 AbstractSparseBackwardDataFlowAnalysis::visitOperation(
Operation *op) {
414 if (
auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
415 visitRegionSuccessors(branch, operandLattices);
419 if (
auto branch = dyn_cast<BranchOpInterface>(op)) {
430 if (!forwarded.empty()) {
434 unaccounted.reset(operand.getOperandNumber());
435 if (std::optional<BlockArgument> blockArg =
437 successorOperands, operand.getOperandNumber(), block)) {
446 for (
int index : unaccounted.set_bits()) {
455 if (
auto call = dyn_cast<CallOpInterface>(op)) {
456 Operation *callableOp = call.resolveCallableInTable(&symbolTable);
457 if (
auto callable = dyn_cast_or_null<CallableOpInterface>(callableOp)) {
468 Region *region = callable.getCallableRegion();
469 if (!region || region->
empty() ||
478 for (
auto [blockArg, argOpOperand] :
482 unaccounted.reset(argOpOperand.getOperandNumber());
487 for (
int index : unaccounted.set_bits()) {
505 if (
auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
506 if (
auto branch = dyn_cast<RegionBranchOpInterface>(op->
getParentOp())) {
507 visitRegionSuccessorsFromTerminator(terminator, branch);
515 if (
auto callable = dyn_cast<CallableOpInterface>(op->
getParentOp())) {
523 for (
auto [op, result] :
524 llvm::zip(operandLattices, callResultLattices))
540 void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors(
541 RegionBranchOpInterface branch,
546 branch.getEntrySuccessorRegions(operands, successors);
553 OperandRange operands = branch.getEntrySuccessorOperands(successor);
555 ValueRange inputs = successor.getSuccessorInputs();
556 for (
auto [operand, input] : llvm::zip(opoperands, inputs)) {
559 unaccounted.reset(operand.getOperandNumber());
564 for (
int index : unaccounted.set_bits()) {
569 void AbstractSparseBackwardDataFlowAnalysis::
570 visitRegionSuccessorsFromTerminator(
571 RegionBranchTerminatorOpInterface terminator,
572 RegionBranchOpInterface branch) {
573 assert(isa<RegionBranchTerminatorOpInterface>(terminator) &&
574 "expected a `RegionBranchTerminatorOpInterface` op");
575 assert(terminator->getParentOp() == branch.getOperation() &&
576 "expected `branch` to be the parent op of `terminator`");
581 terminator.getSuccessorRegions(operandAttributes, successors);
584 BitVector unaccounted(terminator->getNumOperands(),
true);
587 ValueRange inputs = successor.getSuccessorInputs();
588 OperandRange operands = terminator.getSuccessorOperands(successor);
590 for (
auto [opOperand, input] : llvm::zip(opOperands, inputs)) {
593 unaccounted.reset(
const_cast<OpOperand &
>(opOperand).getOperandNumber());
598 for (
int index : unaccounted.set_bits()) {
604 AbstractSparseBackwardDataFlowAnalysis::getLatticeElementFor(
static MutableArrayRef< OpOperand > operandsToOpOperands(OperandRange &operands)
virtual void onUpdate(DataFlowSolver *solver) const
This function is called by the solver when the analysis state is updated to enqueue more work items.
LatticeAnchor anchor
The lattice anchor to which the state belongs.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
unsigned getNumArguments()
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
pred_iterator pred_begin()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
reverse_iterator rbegin()
Base class for all data-flow analyses.
void addDependency(AnalysisState *state, ProgramPoint *point)
Create a dependency between the given analysis state and lattice anchor on this analysis.
void propagateIfChanged(AnalysisState *state, ChangeResult changed)
Propagate an update to a state if it changed.
ProgramPoint * getProgramPointAfter(Operation *op)
ProgramPoint * getProgramPointBefore(Operation *op)
Get a uniqued program point instance.
const DataFlowConfig & getSolverConfig() const
Return the configuration of the solver used for this analysis.
The general data-flow analysis solver.
void enqueue(WorkItem item)
Push a work item onto the worklist.
ProgramPoint * getProgramPointAfter(Operation *op)
This class represents an operand of an operation.
This class implements the operand iterators for the Operation class.
unsigned getBeginOperandIndex() const
Return the operand index of the first element of this range.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
OpOperand & getOpOperand(unsigned idx)
unsigned getNumOperands()
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Block * getBlock()
Returns the operation block that contains this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
MutableArrayRef< OpOperand > getOpOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
SuccessorRange getSuccessors()
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
Implement a predecessor iterator for blocks.
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
static constexpr RegionBranchPoint parent()
Returns an instance of RegionBranchPoint representing the parent operation.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgListType getArguments()
This class models how operands are forwarded to block arguments in control flow.
OperandRange getForwardedOperands() const
Get the range of operands that are simply forwarded to the successor.
This class represents a collection of SymbolTables.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
virtual void setToExitState(AbstractSparseLattice *lattice)=0
Set the given lattice element(s) at control flow exit point(s).
SmallVector< AbstractSparseLattice * > getLatticeElements(ValueRange values)
Get the lattice elements for a range of values.
AbstractSparseBackwardDataFlowAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
virtual void visitBranchOperand(OpOperand &operand)=0
virtual void visitCallOperand(OpOperand &operand)=0
LogicalResult visit(ProgramPoint *point) override
Visit a program point.
void meet(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs)
Join the lattice element and propagate and update if it changed.
virtual void visitExternalCallImpl(CallOpInterface call, ArrayRef< AbstractSparseLattice * > operandLattices, ArrayRef< const AbstractSparseLattice * > resultLattices)=0
The transfer function for calls to external functions.
virtual AbstractSparseLattice * getLatticeElement(Value value)=0
Get the lattice element for a value.
LogicalResult initialize(Operation *top) override
Initialize the analysis by visiting the operation and everything nested under it.
void setAllToExitStates(ArrayRef< AbstractSparseLattice * > lattices)
Set the given lattice element(s) at control flow exit point(s).
virtual LogicalResult visitOperationImpl(Operation *op, ArrayRef< AbstractSparseLattice * > operandLattices, ArrayRef< const AbstractSparseLattice * > resultLattices)=0
The operation transfer function.
LogicalResult visit(ProgramPoint *point) override
Visit a program point.
LogicalResult initialize(Operation *top) override
Initialize the analysis by visiting every owner of an SSA value: all operations and blocks.
virtual AbstractSparseLattice * getLatticeElement(Value value)=0
Get the lattice element of a value.
virtual void visitExternalCallImpl(CallOpInterface call, ArrayRef< const AbstractSparseLattice * > argumentLattices, ArrayRef< AbstractSparseLattice * > resultLattices)=0
The transfer function for calls to external functions.
AbstractSparseForwardDataFlowAnalysis(DataFlowSolver &solver)
void setAllToEntryStates(ArrayRef< AbstractSparseLattice * > lattices)
virtual void setToEntryState(AbstractSparseLattice *lattice)=0
Set the given lattice element(s) at control flow entry point(s).
const AbstractSparseLattice * getLatticeElementFor(ProgramPoint *point, Value value)
Get a read-only lattice element for a value and add it as a dependency to a program point.
virtual void visitNonControlFlowArgumentsImpl(Operation *op, const RegionSuccessor &successor, ArrayRef< AbstractSparseLattice * > argLattices, unsigned firstIndex)=0
Given an operation with region control-flow, the lattices of the operands, and a region successor,...
virtual LogicalResult visitOperationImpl(Operation *op, ArrayRef< const AbstractSparseLattice * > operandLattices, ArrayRef< AbstractSparseLattice * > resultLattices)=0
The operation transfer function.
void join(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs)
Join the lattice element and propagate and update if it changed.
This class represents an abstract lattice.
void onUpdate(DataFlowSolver *solver) const override
When the lattice gets updated, propagate an update to users of the value using its use-def chain to s...
virtual ChangeResult join(const AbstractSparseLattice &rhs)
Join the information contained in 'rhs' into this lattice.
virtual ChangeResult meet(const AbstractSparseLattice &rhs)
Meet (intersect) the information in this lattice with 'rhs'.
void useDefSubscribe(DataFlowAnalysis *analysis)
Subscribe an analysis to updates of the lattice.
This analysis state represents a set of live control-flow "predecessors" of a program point (either a...
bool allPredecessorsKnown() const
Returns true if all predecessors are known.
ArrayRef< Operation * > getKnownPredecessors() const
Get the known predecessors.
std::optional< BlockArgument > getBranchSuccessorArgument(const SuccessorOperands &operands, unsigned operandIndex, Block *successor)
Return the BlockArgument corresponding to operand operandIndex in some successor if operandIndex is w...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
This trait indicates that a terminator operation is "return-like".
Program point represents a specific location in the execution of a program.
bool isBlockStart() const
Operation * getPrevOp() const
Get the previous operation of this program point.
Block * getBlock() const
Get the block contains this program point.