22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/Support/Casting.h"
40 solver->
enqueue({user, analysis});
50 registerPointKind<CFGEdge>();
60 for (
Value argument : region.front().getArguments())
64 return initializeRecursively(top);
68 AbstractSparseForwardDataFlowAnalysis::initializeRecursively(
Operation *op) {
73 for (
Block &block : region) {
74 getOrCreate<Executable>(&block)->blockContentSubscribe(
this);
77 if (
failed(initializeRecursively(&op)))
86 if (
Operation *op = llvm::dyn_cast_if_present<Operation *>(point))
88 else if (
Block *block = llvm::dyn_cast_if_present<Block *>(point))
95 void AbstractSparseForwardDataFlowAnalysis::visitOperation(
Operation *op) {
101 if (!getOrCreate<Executable>(op->
getBlock())->isLive())
109 resultLattices.push_back(resultLattice);
113 if (
auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
114 return visitRegionSuccessors({branch}, branch,
125 operandLattices.push_back(operandLattice);
128 if (
auto call = dyn_cast<CallOpInterface>(op)) {
132 dyn_cast_if_present<CallableOpInterface>(call.resolveCallable());
134 (callable && !callable.getCallableRegion())) {
140 const auto *predecessors = getOrCreateFor<PredecessorState>(op, call);
143 if (!predecessors->allPredecessorsKnown())
145 for (
Operation *predecessor : predecessors->getKnownPredecessors())
146 for (
auto it : llvm::zip(predecessor->getOperands(), resultLattices))
155 void AbstractSparseForwardDataFlowAnalysis::visitBlock(
Block *block) {
161 if (!getOrCreate<Executable>(block)->isLive())
169 argLattices.push_back(argLattice);
176 auto callable = dyn_cast<CallableOpInterface>(block->
getParentOp());
177 if (callable && callable.getCallableRegion() == block->
getParent()) {
178 const auto *callsites = getOrCreateFor<PredecessorState>(block, callable);
181 if (!callsites->allPredecessorsKnown() ||
185 for (
Operation *callsite : callsites->getKnownPredecessors()) {
186 auto call = cast<CallOpInterface>(callsite);
187 for (
auto it : llvm::zip(call.getArgOperands(), argLattices))
194 if (
auto branch = dyn_cast<RegionBranchOpInterface>(block->
getParentOp())) {
195 return visitRegionSuccessors(block, branch, block->
getParent(),
208 Block *predecessor = *it;
212 auto *edgeExecutable =
213 getOrCreate<Executable>(getProgramPoint<CFGEdge>(predecessor, block));
214 edgeExecutable->blockContentSubscribe(
this);
215 if (!edgeExecutable->isLive())
222 branch.getSuccessorOperands(it.getSuccessorIndex());
224 if (
Value operand = operands[idx]) {
238 void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors(
241 const auto *predecessors = getOrCreateFor<PredecessorState>(point, point);
242 assert(predecessors->allPredecessorsKnown() &&
243 "unexpected unresolved region successors");
245 for (
Operation *op : predecessors->getKnownPredecessors()) {
247 std::optional<OperandRange> operands;
251 operands = branch.getEntrySuccessorOperands(successor);
253 }
else if (
auto regionTerminator =
254 dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
255 operands = regionTerminator.getSuccessorOperands(successor);
263 ValueRange inputs = predecessors->getSuccessorInputs(op);
264 assert(inputs.size() == operands->size() &&
265 "expected the same number of successor inputs as operands");
267 unsigned firstIndex = 0;
268 if (inputs.size() != lattices.size()) {
269 if (llvm::dyn_cast_if_present<Operation *>(point)) {
271 firstIndex = cast<OpResult>(inputs.front()).getResultNumber();
275 branch->getResults().slice(firstIndex, inputs.size())),
276 lattices, firstIndex);
279 firstIndex = cast<BlockArgument>(inputs.front()).getArgNumber();
280 Region *region = point.get<
Block *>()->getParent();
284 firstIndex, inputs.size())),
285 lattices, firstIndex);
289 for (
auto it : llvm::zip(*operands, lattices.drop_front(firstIndex)))
320 registerPointKind<CFGEdge>();
325 return initializeRecursively(top);
329 AbstractSparseBackwardDataFlowAnalysis::initializeRecursively(
Operation *op) {
332 for (
Block &block : region) {
333 getOrCreate<Executable>(&block)->blockContentSubscribe(
this);
337 for (
auto it = block.
rbegin(); it != block.
rend(); it++)
338 if (
failed(initializeRecursively(&*it)))
347 if (
Operation *op = llvm::dyn_cast_if_present<Operation *>(point))
349 else if (llvm::dyn_cast_if_present<Block *>(point))
363 resultLattices.reserve(values.size());
364 for (
Value result : values) {
366 resultLattices.push_back(resultLattice);
368 return resultLattices;
372 AbstractSparseBackwardDataFlowAnalysis::getLatticeElementsFor(
375 resultLattices.reserve(values.size());
376 for (
Value result : values) {
378 getLatticeElementFor(point, result);
379 resultLattices.push_back(resultLattice);
381 return resultLattices;
388 void AbstractSparseBackwardDataFlowAnalysis::visitOperation(
Operation *op) {
390 if (!getOrCreate<Executable>(op->
getBlock())->isLive())
400 if (
auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
401 visitRegionSuccessors(branch, operandLattices);
405 if (
auto branch = dyn_cast<BranchOpInterface>(op)) {
416 if (!forwarded.empty()) {
420 unaccounted.reset(operand.getOperandNumber());
421 if (std::optional<BlockArgument> blockArg =
423 successorOperands, operand.getOperandNumber(), block)) {
425 *getLatticeElementFor(op, *blockArg));
432 for (
int index : unaccounted.set_bits()) {
441 if (
auto call = dyn_cast<CallOpInterface>(op)) {
442 Operation *callableOp = call.resolveCallable(&symbolTable);
443 if (
auto callable = dyn_cast_or_null<CallableOpInterface>(callableOp)) {
454 Region *region = callable.getCallableRegion();
461 for (
auto [blockArg, argOpOperand] :
464 *getLatticeElementFor(op, blockArg));
465 unaccounted.reset(argOpOperand.getOperandNumber());
470 for (
int index : unaccounted.set_bits()) {
488 if (
auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
489 if (
auto branch = dyn_cast<RegionBranchOpInterface>(op->
getParentOp())) {
490 visitRegionSuccessorsFromTerminator(terminator, branch);
498 if (
auto callable = dyn_cast<CallableOpInterface>(op->
getParentOp())) {
500 getOrCreateFor<PredecessorState>(op, callable);
504 getLatticeElementsFor(op, call->getResults());
505 for (
auto [op, result] :
506 llvm::zip(operandLattices, callResultLattices))
522 void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors(
523 RegionBranchOpInterface branch,
528 branch.getEntrySuccessorRegions(operands, successors);
535 OperandRange operands = branch.getEntrySuccessorOperands(successor);
537 ValueRange inputs = successor.getSuccessorInputs();
538 for (
auto [operand, input] : llvm::zip(opoperands, inputs)) {
540 unaccounted.reset(operand.getOperandNumber());
545 for (
int index : unaccounted.set_bits()) {
550 void AbstractSparseBackwardDataFlowAnalysis::
551 visitRegionSuccessorsFromTerminator(
552 RegionBranchTerminatorOpInterface terminator,
553 RegionBranchOpInterface branch) {
554 assert(isa<RegionBranchTerminatorOpInterface>(terminator) &&
555 "expected a `RegionBranchTerminatorOpInterface` op");
556 assert(terminator->getParentOp() == branch.getOperation() &&
557 "expected `branch` to be the parent op of `terminator`");
562 terminator.getSuccessorRegions(operandAttributes, successors);
565 BitVector unaccounted(terminator->getNumOperands(),
true);
568 ValueRange inputs = successor.getSuccessorInputs();
569 OperandRange operands = terminator.getSuccessorOperands(successor);
571 for (
auto [opOperand, input] : llvm::zip(opOperands, inputs)) {
573 *getLatticeElementFor(terminator, input));
574 unaccounted.reset(
const_cast<OpOperand &
>(opOperand).getOperandNumber());
579 for (
int index : unaccounted.set_bits()) {
585 AbstractSparseBackwardDataFlowAnalysis::getLatticeElementFor(
ProgramPoint point,
static MutableArrayRef< OpOperand > operandsToOpOperands(OperandRange &operands)
virtual void onUpdate(DataFlowSolver *solver) const
This function is called by the solver when the analysis state is updated to enqueue more work items.
ProgramPoint point
The program point to which the state belongs.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
unsigned getNumArguments()
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
pred_iterator pred_begin()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
reverse_iterator rbegin()
Base class for all data-flow analyses.
void propagateIfChanged(AnalysisState *state, ChangeResult changed)
Propagate an update to a state if it changed.
const DataFlowConfig & getSolverConfig() const
Return the configuration of the solver used for this analysis.
void addDependency(AnalysisState *state, ProgramPoint point)
Create a dependency between the given analysis state and program point on this analysis.
The general data-flow analysis solver.
void enqueue(WorkItem item)
Push a work item onto the worklist.
This class represents an operand of an operation.
This class implements the operand iterators for the Operation class.
unsigned getBeginOperandIndex() const
Return the operand index of the first element of this range.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
OpOperand & getOpOperand(unsigned idx)
unsigned getNumOperands()
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Block * getBlock()
Returns the operation block that contains this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
MutableArrayRef< OpOperand > getOpOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
SuccessorRange getSuccessors()
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
Implement a predecessor iterator for blocks.
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
static constexpr RegionBranchPoint parent()
Returns an instance of RegionBranchPoint representing the parent operation.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgListType getArguments()
This class models how operands are forwarded to block arguments in control flow.
OperandRange getForwardedOperands() const
Get the range of operands that are simply forwarded to the successor.
This class represents a collection of SymbolTables.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
user_range getUsers() const
virtual void setToExitState(AbstractSparseLattice *lattice)=0
Set the given lattice element(s) at control flow exit point(s).
SmallVector< AbstractSparseLattice * > getLatticeElements(ValueRange values)
Get the lattice elements for a range of values.
AbstractSparseBackwardDataFlowAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
virtual void visitBranchOperand(OpOperand &operand)=0
virtual void visitOperationImpl(Operation *op, ArrayRef< AbstractSparseLattice * > operandLattices, ArrayRef< const AbstractSparseLattice * > resultLattices)=0
The operation transfer function.
virtual void visitCallOperand(OpOperand &operand)=0
LogicalResult visit(ProgramPoint point) override
Visit a program point.
void meet(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs)
Join the lattice element and propagate and update if it changed.
virtual void visitExternalCallImpl(CallOpInterface call, ArrayRef< AbstractSparseLattice * > operandLattices, ArrayRef< const AbstractSparseLattice * > resultLattices)=0
The transfer function for calls to external functions.
virtual AbstractSparseLattice * getLatticeElement(Value value)=0
Get the lattice element for a value.
LogicalResult initialize(Operation *top) override
Initialize the analysis by visiting the operation and everything nested under it.
void setAllToExitStates(ArrayRef< AbstractSparseLattice * > lattices)
Set the given lattice element(s) at control flow exit point(s).
LogicalResult visit(ProgramPoint point) override
Visit a program point.
LogicalResult initialize(Operation *top) override
Initialize the analysis by visiting every owner of an SSA value: all operations and blocks.
virtual void visitOperationImpl(Operation *op, ArrayRef< const AbstractSparseLattice * > operandLattices, ArrayRef< AbstractSparseLattice * > resultLattices)=0
The operation transfer function.
virtual AbstractSparseLattice * getLatticeElement(Value value)=0
Get the lattice element of a value.
virtual void visitExternalCallImpl(CallOpInterface call, ArrayRef< const AbstractSparseLattice * > argumentLattices, ArrayRef< AbstractSparseLattice * > resultLattices)=0
The transfer function for calls to external functions.
AbstractSparseForwardDataFlowAnalysis(DataFlowSolver &solver)
void setAllToEntryStates(ArrayRef< AbstractSparseLattice * > lattices)
virtual void setToEntryState(AbstractSparseLattice *lattice)=0
Set the given lattice element(s) at control flow entry point(s).
virtual void visitNonControlFlowArgumentsImpl(Operation *op, const RegionSuccessor &successor, ArrayRef< AbstractSparseLattice * > argLattices, unsigned firstIndex)=0
Given an operation with region control-flow, the lattices of the operands, and a region successor,...
const AbstractSparseLattice * getLatticeElementFor(ProgramPoint point, Value value)
Get a read-only lattice element for a value and add it as a dependency to a program point.
void join(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs)
Join the lattice element and propagate and update if it changed.
This class represents an abstract lattice.
void onUpdate(DataFlowSolver *solver) const override
When the lattice gets updated, propagate an update to users of the value using its use-def chain to s...
virtual ChangeResult join(const AbstractSparseLattice &rhs)
Join the information contained in 'rhs' into this lattice.
virtual ChangeResult meet(const AbstractSparseLattice &rhs)
Meet (intersect) the information in this lattice with 'rhs'.
void useDefSubscribe(DataFlowAnalysis *analysis)
Subscribe an analysis to updates of the lattice.
This analysis state represents a set of live control-flow "predecessors" of a program point (either a...
bool allPredecessorsKnown() const
Returns true if all predecessors are known.
ArrayRef< Operation * > getKnownPredecessors() const
Get the known predecessors.
std::optional< BlockArgument > getBranchSuccessorArgument(const SuccessorOperands &operands, unsigned operandIndex, Block *successor)
Return the BlockArgument corresponding to operand operandIndex in some successor if operandIndex is w...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
This trait indicates that a terminator operation is "return-like".
Fundamental IR components are supported as first-class program points.