MLIR  16.0.0git
SparseAnalysis.cpp
Go to the documentation of this file.
1 //===- SparseAnalysis.cpp - Sparse data-flow analysis ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
12 
13 using namespace mlir;
14 using namespace mlir::dataflow;
15 
16 //===----------------------------------------------------------------------===//
17 // AbstractSparseLattice
18 //===----------------------------------------------------------------------===//
19 
21  // Push all users of the value to the queue.
22  for (Operation *user : point.get<Value>().getUsers())
23  for (DataFlowAnalysis *analysis : useDefSubscribers)
24  solver->enqueue({user, analysis});
25 }
26 
27 //===----------------------------------------------------------------------===//
28 // AbstractSparseDataFlowAnalysis
29 //===----------------------------------------------------------------------===//
30 
32  DataFlowSolver &solver)
33  : DataFlowAnalysis(solver) {
34  registerPointKind<CFGEdge>();
35 }
36 
38  // Mark the entry block arguments as having reached their pessimistic
39  // fixpoints.
40  for (Region &region : top->getRegions()) {
41  if (region.empty())
42  continue;
43  for (Value argument : region.front().getArguments())
45  }
46 
47  return initializeRecursively(top);
48 }
49 
51 AbstractSparseDataFlowAnalysis::initializeRecursively(Operation *op) {
52  // Initialize the analysis by visiting every owner of an SSA value (all
53  // operations and blocks).
54  visitOperation(op);
55  for (Region &region : op->getRegions()) {
56  for (Block &block : region) {
57  getOrCreate<Executable>(&block)->blockContentSubscribe(this);
58  visitBlock(&block);
59  for (Operation &op : block)
60  if (failed(initializeRecursively(&op)))
61  return failure();
62  }
63  }
64 
65  return success();
66 }
67 
69  if (Operation *op = point.dyn_cast<Operation *>())
70  visitOperation(op);
71  else if (Block *block = point.dyn_cast<Block *>())
72  visitBlock(block);
73  else
74  return failure();
75  return success();
76 }
77 
78 void AbstractSparseDataFlowAnalysis::visitOperation(Operation *op) {
79  // Exit early on operations with no results.
80  if (op->getNumResults() == 0)
81  return;
82 
83  // If the containing block is not executable, bail out.
84  if (!getOrCreate<Executable>(op->getBlock())->isLive())
85  return;
86 
87  // Get the result lattices.
89  resultLattices.reserve(op->getNumResults());
90  for (Value result : op->getResults()) {
91  AbstractSparseLattice *resultLattice = getLatticeElement(result);
92  resultLattices.push_back(resultLattice);
93  }
94 
95  // The results of a region branch operation are determined by control-flow.
96  if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
97  return visitRegionSuccessors({branch}, branch,
98  /*successorIndex=*/std::nullopt,
99  resultLattices);
100  }
101 
102  // The results of a call operation are determined by the callgraph.
103  if (auto call = dyn_cast<CallOpInterface>(op)) {
104  const auto *predecessors = getOrCreateFor<PredecessorState>(op, call);
105  // If not all return sites are known, then conservatively assume we can't
106  // reason about the data-flow.
107  if (!predecessors->allPredecessorsKnown())
108  return setAllToEntryStates(resultLattices);
109  for (Operation *predecessor : predecessors->getKnownPredecessors())
110  for (auto it : llvm::zip(predecessor->getOperands(), resultLattices))
111  join(std::get<1>(it), *getLatticeElementFor(op, std::get<0>(it)));
112  return;
113  }
114 
115  // Grab the lattice elements of the operands.
117  operandLattices.reserve(op->getNumOperands());
118  for (Value operand : op->getOperands()) {
119  AbstractSparseLattice *operandLattice = getLatticeElement(operand);
120  operandLattice->useDefSubscribe(this);
121  operandLattices.push_back(operandLattice);
122  }
123 
124  // Invoke the operation transfer function.
125  visitOperationImpl(op, operandLattices, resultLattices);
126 }
127 
128 void AbstractSparseDataFlowAnalysis::visitBlock(Block *block) {
129  // Exit early on blocks with no arguments.
130  if (block->getNumArguments() == 0)
131  return;
132 
133  // If the block is not executable, bail out.
134  if (!getOrCreate<Executable>(block)->isLive())
135  return;
136 
137  // Get the argument lattices.
139  argLattices.reserve(block->getNumArguments());
140  for (BlockArgument argument : block->getArguments()) {
141  AbstractSparseLattice *argLattice = getLatticeElement(argument);
142  argLattices.push_back(argLattice);
143  }
144 
145  // The argument lattices of entry blocks are set by region control-flow or the
146  // callgraph.
147  if (block->isEntryBlock()) {
148  // Check if this block is the entry block of a callable region.
149  auto callable = dyn_cast<CallableOpInterface>(block->getParentOp());
150  if (callable && callable.getCallableRegion() == block->getParent()) {
151  const auto *callsites = getOrCreateFor<PredecessorState>(block, callable);
152  // If not all callsites are known, conservatively mark all lattices as
153  // having reached their pessimistic fixpoints.
154  if (!callsites->allPredecessorsKnown())
155  return setAllToEntryStates(argLattices);
156  for (Operation *callsite : callsites->getKnownPredecessors()) {
157  auto call = cast<CallOpInterface>(callsite);
158  for (auto it : llvm::zip(call.getArgOperands(), argLattices))
159  join(std::get<1>(it), *getLatticeElementFor(block, std::get<0>(it)));
160  }
161  return;
162  }
163 
164  // Check if the lattices can be determined from region control flow.
165  if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) {
166  return visitRegionSuccessors(
167  block, branch, block->getParent()->getRegionNumber(), argLattices);
168  }
169 
170  // Otherwise, we can't reason about the data-flow.
172  RegionSuccessor(block->getParent()),
173  argLattices, /*firstIndex=*/0);
174  }
175 
176  // Iterate over the predecessors of the non-entry block.
177  for (Block::pred_iterator it = block->pred_begin(), e = block->pred_end();
178  it != e; ++it) {
179  Block *predecessor = *it;
180 
181  // If the edge from the predecessor block to the current block is not live,
182  // bail out.
183  auto *edgeExecutable =
184  getOrCreate<Executable>(getProgramPoint<CFGEdge>(predecessor, block));
185  edgeExecutable->blockContentSubscribe(this);
186  if (!edgeExecutable->isLive())
187  continue;
188 
189  // Check if we can reason about the data-flow from the predecessor.
190  if (auto branch =
191  dyn_cast<BranchOpInterface>(predecessor->getTerminator())) {
192  SuccessorOperands operands =
193  branch.getSuccessorOperands(it.getSuccessorIndex());
194  for (auto &it : llvm::enumerate(argLattices)) {
195  if (Value operand = operands[it.index()]) {
196  join(it.value(), *getLatticeElementFor(block, operand));
197  } else {
198  // Conservatively consider internally produced arguments as entry
199  // points.
200  setAllToEntryStates(it.value());
201  }
202  }
203  } else {
204  return setAllToEntryStates(argLattices);
205  }
206  }
207 }
208 
209 void AbstractSparseDataFlowAnalysis::visitRegionSuccessors(
210  ProgramPoint point, RegionBranchOpInterface branch,
211  Optional<unsigned> successorIndex,
213  const auto *predecessors = getOrCreateFor<PredecessorState>(point, point);
214  assert(predecessors->allPredecessorsKnown() &&
215  "unexpected unresolved region successors");
216 
217  for (Operation *op : predecessors->getKnownPredecessors()) {
218  // Get the incoming successor operands.
219  Optional<OperandRange> operands;
220 
221  // Check if the predecessor is the parent op.
222  if (op == branch) {
223  operands = branch.getSuccessorEntryOperands(successorIndex);
224  // Otherwise, try to deduce the operands from a region return-like op.
225  } else {
226  if (isRegionReturnLike(op))
227  operands = getRegionBranchSuccessorOperands(op, successorIndex);
228  }
229 
230  if (!operands) {
231  // We can't reason about the data-flow.
232  return setAllToEntryStates(lattices);
233  }
234 
235  ValueRange inputs = predecessors->getSuccessorInputs(op);
236  assert(inputs.size() == operands->size() &&
237  "expected the same number of successor inputs as operands");
238 
239  unsigned firstIndex = 0;
240  if (inputs.size() != lattices.size()) {
241  if (point.dyn_cast<Operation *>()) {
242  if (!inputs.empty())
243  firstIndex = inputs.front().cast<OpResult>().getResultNumber();
245  branch,
247  branch->getResults().slice(firstIndex, inputs.size())),
248  lattices, firstIndex);
249  } else {
250  if (!inputs.empty())
251  firstIndex = inputs.front().cast<BlockArgument>().getArgNumber();
252  Region *region = point.get<Block *>()->getParent();
254  branch,
255  RegionSuccessor(region, region->getArguments().slice(
256  firstIndex, inputs.size())),
257  lattices, firstIndex);
258  }
259  }
260 
261  for (auto it : llvm::zip(*operands, lattices.drop_front(firstIndex)))
262  join(std::get<1>(it), *getLatticeElementFor(point, std::get<0>(it)));
263  }
264 }
265 
266 const AbstractSparseLattice *
268  Value value) {
270  addDependency(state, point);
271  return state;
272 }
273 
276  for (AbstractSparseLattice *lattice : lattices)
277  setToEntryState(lattice);
278 }
279 
281  const AbstractSparseLattice &rhs) {
282  propagateIfChanged(lhs, lhs->join(rhs));
283 }
static constexpr const bool value
ProgramPoint point
The program point to which the state belongs.
This class represents an argument of a Block.
Definition: Value.h:296
Block represents an ordered list of Operations.
Definition: Block.h:30
unsigned getNumArguments()
Definition: Block.h:117
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:26
pred_iterator pred_begin()
Definition: Block.h:219
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:232
BlockArgListType getArguments()
Definition: Block.h:76
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Definition: Block.cpp:35
pred_iterator pred_end()
Definition: Block.h:222
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
Base class for all data-flow analyses.
void propagateIfChanged(AnalysisState *state, ChangeResult changed)
Propagate an update to a state if it changed.
void addDependency(AnalysisState *state, ProgramPoint point)
Create a dependency between the given analysis state and program point on this analysis.
The general data-flow analysis solver.
void enqueue(WorkItem item)
Push a work item onto the worklist.
This is a value defined by a result of an operation.
Definition: Value.h:442
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:31
unsigned getNumOperands()
Definition: Operation.h:263
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:144
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:480
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:295
result_range getResults()
Definition: Operation.h:332
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:321
Implement a predecessor iterator for blocks.
Definition: BlockSupport.h:51
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
BlockArgListType getArguments()
Definition: Region.h:81
unsigned getRegionNumber()
Return the number of this region in the parent operation.
Definition: Region.cpp:62
This class models how operands are forwarded to block arguments in control flow.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:349
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
user_range getUsers() const
Definition: Value.h:209
virtual AbstractSparseLattice * getLatticeElement(Value value)=0
Get the lattice element of a value.
AbstractSparseDataFlowAnalysis(DataFlowSolver &solver)
virtual void visitOperationImpl(Operation *op, ArrayRef< const AbstractSparseLattice * > operandLattices, ArrayRef< AbstractSparseLattice * > resultLattices)=0
The operation transfer function.
LogicalResult visit(ProgramPoint point) override
Visit a program point.
void setAllToEntryStates(ArrayRef< AbstractSparseLattice * > lattices)
void join(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs)
Join the lattice element and propagate and update if it changed.
const AbstractSparseLattice * getLatticeElementFor(ProgramPoint point, Value value)
Get a read-only lattice element for a value and add it as a dependency to a program point.
LogicalResult initialize(Operation *top) override
Initialize the analysis by visiting every owner of an SSA value: all operations and blocks.
virtual void setToEntryState(AbstractSparseLattice *lattice)=0
Set the given lattice element(s) at control flow entry point(s).
virtual void visitNonControlFlowArgumentsImpl(Operation *op, const RegionSuccessor &successor, ArrayRef< AbstractSparseLattice * > argLattices, unsigned firstIndex)=0
Given an operation with region control-flow, the lattices of the operands, and a region successor,...
This class represents an abstract lattice.
void onUpdate(DataFlowSolver *solver) const override
When the lattice gets updated, propagate an update to users of the value using its use-def chain to s...
void useDefSubscribe(DataFlowAnalysis *analysis)
Subscribe an analysis to updates of the lattice.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:230
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Optional< OperandRange > getRegionBranchSuccessorOperands(Operation *operation, Optional< unsigned > regionIndex)
Returns the read only operands that are passed to the region with the given regionIndex.
bool isRegionReturnLike(Operation *operation)
Returns true if the given operation is either annotated with the ReturnLike trait or implements the R...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Fundamental IR components are supported as first-class program points.