22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/Support/Casting.h"
40 solver->
enqueue({user, analysis});
50 registerPointKind<CFGEdge>();
60 for (
Value argument : region.front().getArguments())
64 return initializeRecursively(top);
68 AbstractSparseForwardDataFlowAnalysis::initializeRecursively(
Operation *op) {
73 for (
Block &block : region) {
74 getOrCreate<Executable>(&block)->blockContentSubscribe(
this);
77 if (
failed(initializeRecursively(&op)))
86 if (
Operation *op = llvm::dyn_cast_if_present<Operation *>(point))
88 else if (
Block *block = llvm::dyn_cast_if_present<Block *>(point))
95 void AbstractSparseForwardDataFlowAnalysis::visitOperation(
Operation *op) {
101 if (!getOrCreate<Executable>(op->
getBlock())->isLive())
109 resultLattices.push_back(resultLattice);
113 if (
auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
114 return visitRegionSuccessors({branch}, branch,
120 if (
auto call = dyn_cast<CallOpInterface>(op)) {
121 const auto *predecessors = getOrCreateFor<PredecessorState>(op, call);
124 if (!predecessors->allPredecessorsKnown())
126 for (
Operation *predecessor : predecessors->getKnownPredecessors())
127 for (
auto it : llvm::zip(predecessor->getOperands(), resultLattices))
138 operandLattices.push_back(operandLattice);
145 void AbstractSparseForwardDataFlowAnalysis::visitBlock(
Block *block) {
151 if (!getOrCreate<Executable>(block)->isLive())
159 argLattices.push_back(argLattice);
166 auto callable = dyn_cast<CallableOpInterface>(block->
getParentOp());
167 if (callable && callable.getCallableRegion() == block->
getParent()) {
168 const auto *callsites = getOrCreateFor<PredecessorState>(block, callable);
171 if (!callsites->allPredecessorsKnown())
173 for (
Operation *callsite : callsites->getKnownPredecessors()) {
174 auto call = cast<CallOpInterface>(callsite);
175 for (
auto it : llvm::zip(call.getArgOperands(), argLattices))
182 if (
auto branch = dyn_cast<RegionBranchOpInterface>(block->
getParentOp())) {
183 return visitRegionSuccessors(block, branch, block->
getParent(),
196 Block *predecessor = *it;
200 auto *edgeExecutable =
201 getOrCreate<Executable>(getProgramPoint<CFGEdge>(predecessor, block));
202 edgeExecutable->blockContentSubscribe(
this);
203 if (!edgeExecutable->isLive())
210 branch.getSuccessorOperands(it.getSuccessorIndex());
212 if (
Value operand = operands[idx]) {
226 void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors(
229 const auto *predecessors = getOrCreateFor<PredecessorState>(point, point);
230 assert(predecessors->allPredecessorsKnown() &&
231 "unexpected unresolved region successors");
233 for (
Operation *op : predecessors->getKnownPredecessors()) {
235 std::optional<OperandRange> operands;
239 operands = branch.getEntrySuccessorOperands(successor);
241 }
else if (
auto regionTerminator =
242 dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
243 operands = regionTerminator.getSuccessorOperands(successor);
251 ValueRange inputs = predecessors->getSuccessorInputs(op);
252 assert(inputs.size() == operands->size() &&
253 "expected the same number of successor inputs as operands");
255 unsigned firstIndex = 0;
256 if (inputs.size() != lattices.size()) {
257 if (llvm::dyn_cast_if_present<Operation *>(point)) {
259 firstIndex = cast<OpResult>(inputs.front()).getResultNumber();
263 branch->getResults().slice(firstIndex, inputs.size())),
264 lattices, firstIndex);
267 firstIndex = cast<BlockArgument>(inputs.front()).getArgNumber();
268 Region *region = point.get<
Block *>()->getParent();
272 firstIndex, inputs.size())),
273 lattices, firstIndex);
277 for (
auto it : llvm::zip(*operands, lattices.drop_front(firstIndex)))
308 registerPointKind<CFGEdge>();
313 return initializeRecursively(top);
317 AbstractSparseBackwardDataFlowAnalysis::initializeRecursively(
Operation *op) {
320 for (
Block &block : region) {
321 getOrCreate<Executable>(&block)->blockContentSubscribe(
this);
325 for (
auto it = block.
rbegin(); it != block.
rend(); it++)
326 if (
failed(initializeRecursively(&*it)))
335 if (
Operation *op = llvm::dyn_cast_if_present<Operation *>(point))
337 else if (llvm::dyn_cast_if_present<Block *>(point))
351 resultLattices.reserve(values.size());
352 for (
Value result : values) {
354 resultLattices.push_back(resultLattice);
356 return resultLattices;
360 AbstractSparseBackwardDataFlowAnalysis::getLatticeElementsFor(
363 resultLattices.reserve(values.size());
364 for (
Value result : values) {
366 getLatticeElementFor(point, result);
367 resultLattices.push_back(resultLattice);
369 return resultLattices;
376 void AbstractSparseBackwardDataFlowAnalysis::visitOperation(
Operation *op) {
378 if (!getOrCreate<Executable>(op->
getBlock())->isLive())
388 if (
auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
389 visitRegionSuccessors(branch, operandLattices);
393 if (
auto branch = dyn_cast<BranchOpInterface>(op)) {
404 if (!forwarded.empty()) {
408 unaccounted.reset(operand.getOperandNumber());
409 if (std::optional<BlockArgument> blockArg =
411 successorOperands, operand.getOperandNumber(), block)) {
413 *getLatticeElementFor(op, *blockArg));
420 for (
int index : unaccounted.set_bits()) {
429 if (
auto call = dyn_cast<CallOpInterface>(op)) {
430 Operation *callableOp = call.resolveCallable(&symbolTable);
431 if (
auto callable = dyn_cast_or_null<CallableOpInterface>(callableOp)) {
439 Region *region = callable.getCallableRegion();
440 if (region && !region->
empty()) {
442 for (
auto [blockArg, argOpOperand] :
445 *getLatticeElementFor(op, blockArg));
446 unaccounted.reset(argOpOperand.getOperandNumber());
451 for (
int index : unaccounted.set_bits()) {
469 if (
auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
470 if (
auto branch = dyn_cast<RegionBranchOpInterface>(op->
getParentOp())) {
471 visitRegionSuccessorsFromTerminator(terminator, branch);
479 if (
auto callable = dyn_cast<CallableOpInterface>(op->
getParentOp())) {
481 getOrCreateFor<PredecessorState>(op, callable);
485 getLatticeElementsFor(op, call->getResults());
486 for (
auto [op, result] :
487 llvm::zip(operandLattices, callResultLattices))
503 void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors(
504 RegionBranchOpInterface branch,
509 branch.getEntrySuccessorRegions(operands, successors);
516 OperandRange operands = branch.getEntrySuccessorOperands(successor);
518 ValueRange inputs = successor.getSuccessorInputs();
519 for (
auto [operand, input] : llvm::zip(opoperands, inputs)) {
521 unaccounted.reset(operand.getOperandNumber());
526 for (
int index : unaccounted.set_bits()) {
531 void AbstractSparseBackwardDataFlowAnalysis::
532 visitRegionSuccessorsFromTerminator(
533 RegionBranchTerminatorOpInterface terminator,
534 RegionBranchOpInterface branch) {
535 assert(isa<RegionBranchTerminatorOpInterface>(terminator) &&
536 "expected a `RegionBranchTerminatorOpInterface` op");
537 assert(terminator->getParentOp() == branch.getOperation() &&
538 "expected `branch` to be the parent op of `terminator`");
543 terminator.getSuccessorRegions(operandAttributes, successors);
546 BitVector unaccounted(terminator->getNumOperands(),
true);
549 ValueRange inputs = successor.getSuccessorInputs();
550 OperandRange operands = terminator.getSuccessorOperands(successor);
552 for (
auto [opOperand, input] : llvm::zip(opOperands, inputs)) {
554 *getLatticeElementFor(terminator, input));
555 unaccounted.reset(
const_cast<OpOperand &
>(opOperand).getOperandNumber());
560 for (
int index : unaccounted.set_bits()) {
566 AbstractSparseBackwardDataFlowAnalysis::getLatticeElementFor(
ProgramPoint point,
static MutableArrayRef< OpOperand > operandsToOpOperands(OperandRange &operands)
virtual void onUpdate(DataFlowSolver *solver) const
This function is called by the solver when the analysis state is updated to enqueue more work items.
ProgramPoint point
The program point to which the state belongs.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
unsigned getNumArguments()
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
pred_iterator pred_begin()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
reverse_iterator rbegin()
Base class for all data-flow analyses.
void propagateIfChanged(AnalysisState *state, ChangeResult changed)
Propagate an update to a state if it changed.
void addDependency(AnalysisState *state, ProgramPoint point)
Create a dependency between the given analysis state and program point on this analysis.
The general data-flow analysis solver.
void enqueue(WorkItem item)
Push a work item onto the worklist.
This class represents an operand of an operation.
This class implements the operand iterators for the Operation class.
unsigned getBeginOperandIndex() const
Return the operand index of the first element of this range.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
OpOperand & getOpOperand(unsigned idx)
unsigned getNumOperands()
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Block * getBlock()
Returns the operation block that contains this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
MutableArrayRef< OpOperand > getOpOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
SuccessorRange getSuccessors()
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
Implement a predecessor iterator for blocks.
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
static constexpr RegionBranchPoint parent()
Returns an instance of RegionBranchPoint representing the parent operation.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgListType getArguments()
This class models how operands are forwarded to block arguments in control flow.
OperandRange getForwardedOperands() const
Get the range of operands that are simply forwarded to the successor.
This class represents a collection of SymbolTables.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
user_range getUsers() const
virtual void setToExitState(AbstractSparseLattice *lattice)=0
Set the given lattice element(s) at control flow exit point(s).
SmallVector< AbstractSparseLattice * > getLatticeElements(ValueRange values)
Get the lattice elements for a range of values.
AbstractSparseBackwardDataFlowAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
virtual void visitBranchOperand(OpOperand &operand)=0
virtual void visitOperationImpl(Operation *op, ArrayRef< AbstractSparseLattice * > operandLattices, ArrayRef< const AbstractSparseLattice * > resultLattices)=0
The operation transfer function.
virtual void visitCallOperand(OpOperand &operand)=0
LogicalResult visit(ProgramPoint point) override
Visit a program point.
void meet(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs)
Join the lattice element and propagate and update if it changed.
virtual AbstractSparseLattice * getLatticeElement(Value value)=0
Get the lattice element for a value.
LogicalResult initialize(Operation *top) override
Initialize the analysis by visiting the operation and everything nested under it.
void setAllToExitStates(ArrayRef< AbstractSparseLattice * > lattices)
Set the given lattice element(s) at control flow exit point(s).
LogicalResult visit(ProgramPoint point) override
Visit a program point.
LogicalResult initialize(Operation *top) override
Initialize the analysis by visiting every owner of an SSA value: all operations and blocks.
virtual void visitOperationImpl(Operation *op, ArrayRef< const AbstractSparseLattice * > operandLattices, ArrayRef< AbstractSparseLattice * > resultLattices)=0
The operation transfer function.
virtual AbstractSparseLattice * getLatticeElement(Value value)=0
Get the lattice element of a value.
AbstractSparseForwardDataFlowAnalysis(DataFlowSolver &solver)
void setAllToEntryStates(ArrayRef< AbstractSparseLattice * > lattices)
virtual void setToEntryState(AbstractSparseLattice *lattice)=0
Set the given lattice element(s) at control flow entry point(s).
virtual void visitNonControlFlowArgumentsImpl(Operation *op, const RegionSuccessor &successor, ArrayRef< AbstractSparseLattice * > argLattices, unsigned firstIndex)=0
Given an operation with region control-flow, the lattices of the operands, and a region successor,...
const AbstractSparseLattice * getLatticeElementFor(ProgramPoint point, Value value)
Get a read-only lattice element for a value and add it as a dependency to a program point.
void join(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs)
Join the lattice element and propagate and update if it changed.
This class represents an abstract lattice.
void onUpdate(DataFlowSolver *solver) const override
When the lattice gets updated, propagate an update to users of the value using its use-def chain to s...
virtual ChangeResult join(const AbstractSparseLattice &rhs)
Join the information contained in 'rhs' into this lattice.
virtual ChangeResult meet(const AbstractSparseLattice &rhs)
Meet (intersect) the information in this lattice with 'rhs'.
void useDefSubscribe(DataFlowAnalysis *analysis)
Subscribe an analysis to updates of the lattice.
This analysis state represents a set of live control-flow "predecessors" of a program point (either a...
bool allPredecessorsKnown() const
Returns true if all predecessors are known.
ArrayRef< Operation * > getKnownPredecessors() const
Get the known predecessors.
std::optional< BlockArgument > getBranchSuccessorArgument(const SuccessorOperands &operands, unsigned operandIndex, Block *successor)
Return the BlockArgument corresponding to operand operandIndex in some successor if operandIndex is w...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
This trait indicates that a terminator operation is "return-like".
Fundamental IR components are supported as first-class program points.