21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/Support/Casting.h"
39 solver->
enqueue({user, analysis});
49 registerAnchorKind<CFGEdge>();
59 for (
Value argument : region.front().getArguments())
63 return initializeRecursively(top);
67 AbstractSparseForwardDataFlowAnalysis::initializeRecursively(
Operation *op) {
70 if (failed(visitOperation(op)))
74 for (
Block &block : region) {
75 getOrCreate<Executable>(&block)->blockContentSubscribe(
this);
78 if (failed(initializeRecursively(&op)))
87 if (
Operation *op = llvm::dyn_cast_if_present<Operation *>(point))
88 return visitOperation(op);
89 visitBlock(point.get<
Block *>());
94 AbstractSparseForwardDataFlowAnalysis::visitOperation(
Operation *op) {
100 if (!getOrCreate<Executable>(op->
getBlock())->isLive())
108 resultLattices.push_back(resultLattice);
112 if (
auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
113 visitRegionSuccessors({branch}, branch,
125 operandLattices.push_back(operandLattice);
128 if (
auto call = dyn_cast<CallOpInterface>(op)) {
132 dyn_cast_if_present<CallableOpInterface>(call.resolveCallable());
134 (callable && !callable.getCallableRegion())) {
141 const auto *predecessors = getOrCreateFor<PredecessorState>(op, call);
144 if (!predecessors->allPredecessorsKnown()) {
148 for (
Operation *predecessor : predecessors->getKnownPredecessors())
149 for (
auto &&[operand, resLattice] :
150 llvm::zip(predecessor->getOperands(), resultLattices))
159 void AbstractSparseForwardDataFlowAnalysis::visitBlock(
Block *block) {
165 if (!getOrCreate<Executable>(block)->isLive())
173 argLattices.push_back(argLattice);
180 auto callable = dyn_cast<CallableOpInterface>(block->
getParentOp());
181 if (callable && callable.getCallableRegion() == block->
getParent()) {
182 const auto *callsites = getOrCreateFor<PredecessorState>(block, callable);
185 if (!callsites->allPredecessorsKnown() ||
189 for (
Operation *callsite : callsites->getKnownPredecessors()) {
190 auto call = cast<CallOpInterface>(callsite);
191 for (
auto it : llvm::zip(call.getArgOperands(), argLattices))
198 if (
auto branch = dyn_cast<RegionBranchOpInterface>(block->
getParentOp())) {
199 return visitRegionSuccessors(block, branch, block->
getParent(),
212 Block *predecessor = *it;
216 auto *edgeExecutable =
217 getOrCreate<Executable>(getLatticeAnchor<CFGEdge>(predecessor, block));
218 edgeExecutable->blockContentSubscribe(
this);
219 if (!edgeExecutable->isLive())
226 branch.getSuccessorOperands(it.getSuccessorIndex());
228 if (
Value operand = operands[idx]) {
242 void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors(
245 const auto *predecessors = getOrCreateFor<PredecessorState>(point, point);
246 assert(predecessors->allPredecessorsKnown() &&
247 "unexpected unresolved region successors");
249 for (
Operation *op : predecessors->getKnownPredecessors()) {
251 std::optional<OperandRange> operands;
255 operands = branch.getEntrySuccessorOperands(successor);
257 }
else if (
auto regionTerminator =
258 dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
259 operands = regionTerminator.getSuccessorOperands(successor);
267 ValueRange inputs = predecessors->getSuccessorInputs(op);
268 assert(inputs.size() == operands->size() &&
269 "expected the same number of successor inputs as operands");
271 unsigned firstIndex = 0;
272 if (inputs.size() != lattices.size()) {
273 if (llvm::dyn_cast_if_present<Operation *>(point)) {
275 firstIndex = cast<OpResult>(inputs.front()).getResultNumber();
279 branch->getResults().slice(firstIndex, inputs.size())),
280 lattices, firstIndex);
283 firstIndex = cast<BlockArgument>(inputs.front()).getArgNumber();
284 Region *region = point.get<
Block *>()->getParent();
288 firstIndex, inputs.size())),
289 lattices, firstIndex);
293 for (
auto it : llvm::zip(*operands, lattices.drop_front(firstIndex)))
324 registerAnchorKind<CFGEdge>();
329 return initializeRecursively(top);
333 AbstractSparseBackwardDataFlowAnalysis::initializeRecursively(
Operation *op) {
334 if (failed(visitOperation(op)))
338 for (
Block &block : region) {
339 getOrCreate<Executable>(&block)->blockContentSubscribe(
this);
343 for (
auto it = block.
rbegin(); it != block.
rend(); it++)
344 if (failed(initializeRecursively(&*it)))
353 if (
Operation *op = llvm::dyn_cast_if_present<Operation *>(point))
354 return visitOperation(op);
365 resultLattices.reserve(values.size());
366 for (
Value result : values) {
368 resultLattices.push_back(resultLattice);
370 return resultLattices;
374 AbstractSparseBackwardDataFlowAnalysis::getLatticeElementsFor(
377 resultLattices.reserve(values.size());
378 for (
Value result : values) {
380 getLatticeElementFor(point, result);
381 resultLattices.push_back(resultLattice);
383 return resultLattices;
391 AbstractSparseBackwardDataFlowAnalysis::visitOperation(
Operation *op) {
393 if (!getOrCreate<Executable>(op->
getBlock())->isLive())
403 if (
auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
404 visitRegionSuccessors(branch, operandLattices);
408 if (
auto branch = dyn_cast<BranchOpInterface>(op)) {
419 if (!forwarded.empty()) {
423 unaccounted.reset(operand.getOperandNumber());
424 if (std::optional<BlockArgument> blockArg =
426 successorOperands, operand.getOperandNumber(), block)) {
428 *getLatticeElementFor(op, *blockArg));
435 for (
int index : unaccounted.set_bits()) {
444 if (
auto call = dyn_cast<CallOpInterface>(op)) {
445 Operation *callableOp = call.resolveCallableInTable(&symbolTable);
446 if (
auto callable = dyn_cast_or_null<CallableOpInterface>(callableOp)) {
457 Region *region = callable.getCallableRegion();
458 if (!region || region->
empty() ||
467 for (
auto [blockArg, argOpOperand] :
470 *getLatticeElementFor(op, blockArg));
471 unaccounted.reset(argOpOperand.getOperandNumber());
476 for (
int index : unaccounted.set_bits()) {
494 if (
auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
495 if (
auto branch = dyn_cast<RegionBranchOpInterface>(op->
getParentOp())) {
496 visitRegionSuccessorsFromTerminator(terminator, branch);
504 if (
auto callable = dyn_cast<CallableOpInterface>(op->
getParentOp())) {
506 getOrCreateFor<PredecessorState>(op, callable);
510 getLatticeElementsFor(op, call->getResults());
511 for (
auto [op, result] :
512 llvm::zip(operandLattices, callResultLattices))
528 void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors(
529 RegionBranchOpInterface branch,
534 branch.getEntrySuccessorRegions(operands, successors);
541 OperandRange operands = branch.getEntrySuccessorOperands(successor);
543 ValueRange inputs = successor.getSuccessorInputs();
544 for (
auto [operand, input] : llvm::zip(opoperands, inputs)) {
546 unaccounted.reset(operand.getOperandNumber());
551 for (
int index : unaccounted.set_bits()) {
556 void AbstractSparseBackwardDataFlowAnalysis::
557 visitRegionSuccessorsFromTerminator(
558 RegionBranchTerminatorOpInterface terminator,
559 RegionBranchOpInterface branch) {
560 assert(isa<RegionBranchTerminatorOpInterface>(terminator) &&
561 "expected a `RegionBranchTerminatorOpInterface` op");
562 assert(terminator->getParentOp() == branch.getOperation() &&
563 "expected `branch` to be the parent op of `terminator`");
568 terminator.getSuccessorRegions(operandAttributes, successors);
571 BitVector unaccounted(terminator->getNumOperands(),
true);
574 ValueRange inputs = successor.getSuccessorInputs();
575 OperandRange operands = terminator.getSuccessorOperands(successor);
577 for (
auto [opOperand, input] : llvm::zip(opOperands, inputs)) {
579 *getLatticeElementFor(terminator, input));
580 unaccounted.reset(
const_cast<OpOperand &
>(opOperand).getOperandNumber());
585 for (
int index : unaccounted.set_bits()) {
591 AbstractSparseBackwardDataFlowAnalysis::getLatticeElementFor(
ProgramPoint point,
static MutableArrayRef< OpOperand > operandsToOpOperands(OperandRange &operands)
virtual void onUpdate(DataFlowSolver *solver) const
This function is called by the solver when the analysis state is updated to enqueue more work items.
LatticeAnchor anchor
The lattice anchor to which the state belongs.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
unsigned getNumArguments()
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
pred_iterator pred_begin()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
reverse_iterator rbegin()
Base class for all data-flow analyses.
void propagateIfChanged(AnalysisState *state, ChangeResult changed)
Propagate an update to a state if it changed.
const DataFlowConfig & getSolverConfig() const
Return the configuration of the solver used for this analysis.
void addDependency(AnalysisState *state, ProgramPoint point)
Create a dependency between the given analysis state and lattice anchor on this analysis.
The general data-flow analysis solver.
void enqueue(WorkItem item)
Push a work item onto the worklist.
This class represents an operand of an operation.
This class implements the operand iterators for the Operation class.
unsigned getBeginOperandIndex() const
Return the operand index of the first element of this range.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
OpOperand & getOpOperand(unsigned idx)
unsigned getNumOperands()
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Block * getBlock()
Returns the operation block that contains this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
MutableArrayRef< OpOperand > getOpOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
SuccessorRange getSuccessors()
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
Implement a predecessor iterator for blocks.
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
static constexpr RegionBranchPoint parent()
Returns an instance of RegionBranchPoint representing the parent operation.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgListType getArguments()
This class models how operands are forwarded to block arguments in control flow.
OperandRange getForwardedOperands() const
Get the range of operands that are simply forwarded to the successor.
This class represents a collection of SymbolTables.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
user_range getUsers() const
virtual void setToExitState(AbstractSparseLattice *lattice)=0
Set the given lattice element(s) at control flow exit point(s).
SmallVector< AbstractSparseLattice * > getLatticeElements(ValueRange values)
Get the lattice elements for a range of values.
AbstractSparseBackwardDataFlowAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
virtual void visitBranchOperand(OpOperand &operand)=0
virtual void visitCallOperand(OpOperand &operand)=0
LogicalResult visit(ProgramPoint point) override
Visit a program point.
void meet(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs)
Join the lattice element and propagate and update if it changed.
virtual void visitExternalCallImpl(CallOpInterface call, ArrayRef< AbstractSparseLattice * > operandLattices, ArrayRef< const AbstractSparseLattice * > resultLattices)=0
The transfer function for calls to external functions.
virtual AbstractSparseLattice * getLatticeElement(Value value)=0
Get the lattice element for a value.
LogicalResult initialize(Operation *top) override
Initialize the analysis by visiting the operation and everything nested under it.
void setAllToExitStates(ArrayRef< AbstractSparseLattice * > lattices)
Set the given lattice element(s) at control flow exit point(s).
virtual LogicalResult visitOperationImpl(Operation *op, ArrayRef< AbstractSparseLattice * > operandLattices, ArrayRef< const AbstractSparseLattice * > resultLattices)=0
The operation transfer function.
LogicalResult visit(ProgramPoint point) override
Visit a program point.
LogicalResult initialize(Operation *top) override
Initialize the analysis by visiting every owner of an SSA value: all operations and blocks.
virtual AbstractSparseLattice * getLatticeElement(Value value)=0
Get the lattice element of a value.
virtual void visitExternalCallImpl(CallOpInterface call, ArrayRef< const AbstractSparseLattice * > argumentLattices, ArrayRef< AbstractSparseLattice * > resultLattices)=0
The transfer function for calls to external functions.
AbstractSparseForwardDataFlowAnalysis(DataFlowSolver &solver)
void setAllToEntryStates(ArrayRef< AbstractSparseLattice * > lattices)
virtual void setToEntryState(AbstractSparseLattice *lattice)=0
Set the given lattice element(s) at control flow entry point(s).
virtual void visitNonControlFlowArgumentsImpl(Operation *op, const RegionSuccessor &successor, ArrayRef< AbstractSparseLattice * > argLattices, unsigned firstIndex)=0
Given an operation with region control-flow, the lattices of the operands, and a region successor,...
virtual LogicalResult visitOperationImpl(Operation *op, ArrayRef< const AbstractSparseLattice * > operandLattices, ArrayRef< AbstractSparseLattice * > resultLattices)=0
The operation transfer function.
const AbstractSparseLattice * getLatticeElementFor(ProgramPoint point, Value value)
Get a read-only lattice element for a value and add it as a dependency to a program point.
void join(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs)
Join the lattice element and propagate and update if it changed.
This class represents an abstract lattice.
void onUpdate(DataFlowSolver *solver) const override
When the lattice gets updated, propagate an update to users of the value using its use-def chain to s...
virtual ChangeResult join(const AbstractSparseLattice &rhs)
Join the information contained in 'rhs' into this lattice.
virtual ChangeResult meet(const AbstractSparseLattice &rhs)
Meet (intersect) the information in this lattice with 'rhs'.
void useDefSubscribe(DataFlowAnalysis *analysis)
Subscribe an analysis to updates of the lattice.
This analysis state represents a set of live control-flow "predecessors" of a program point (either a...
bool allPredecessorsKnown() const
Returns true if all predecessors are known.
ArrayRef< Operation * > getKnownPredecessors() const
Get the known predecessors.
std::optional< BlockArgument > getBranchSuccessorArgument(const SuccessorOperands &operands, unsigned operandIndex, Block *successor)
Return the BlockArgument corresponding to operand operandIndex in some successor if operandIndex is w...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
This trait indicates that a terminator operation is "return-like".
Program point represents a specific location in the execution of a program.