MLIR  19.0.0git
DataFlowFramework.cpp
Go to the documentation of this file.
1 //===- DataFlowFramework.cpp - A generic framework for data-flow analysis -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "mlir/IR/Location.h"
11 #include "mlir/IR/Operation.h"
12 #include "mlir/IR/Value.h"
14 #include "llvm/ADT/iterator.h"
15 #include "llvm/Config/abi-breaking.h"
16 #include "llvm/Support/Casting.h"
17 #include "llvm/Support/Debug.h"
18 #include "llvm/Support/raw_ostream.h"
19 
20 #define DEBUG_TYPE "dataflow"
21 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
22 #define DATAFLOW_DEBUG(X) LLVM_DEBUG(X)
23 #else
24 #define DATAFLOW_DEBUG(X)
25 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
26 
27 using namespace mlir;
28 
29 //===----------------------------------------------------------------------===//
30 // GenericProgramPoint
31 //===----------------------------------------------------------------------===//
32 
34 
35 //===----------------------------------------------------------------------===//
36 // AnalysisState
37 //===----------------------------------------------------------------------===//
38 
40 
42  DataFlowAnalysis *analysis) {
43  auto inserted = dependents.insert({dependent, analysis});
44  (void)inserted;
46  if (inserted) {
47  llvm::dbgs() << "Creating dependency between " << debugName << " of "
48  << point << "\nand " << debugName << " on " << dependent
49  << "\n";
50  }
51  });
52 }
53 
54 void AnalysisState::dump() const { print(llvm::errs()); }
55 
56 //===----------------------------------------------------------------------===//
57 // ProgramPoint
58 //===----------------------------------------------------------------------===//
59 
60 void ProgramPoint::print(raw_ostream &os) const {
61  if (isNull()) {
62  os << "<NULL POINT>";
63  return;
64  }
65  if (auto *programPoint = llvm::dyn_cast<GenericProgramPoint *>(*this))
66  return programPoint->print(os);
67  if (auto *op = llvm::dyn_cast<Operation *>(*this))
68  return op->print(os, OpPrintingFlags().skipRegions());
69  if (auto value = llvm::dyn_cast<Value>(*this))
70  return value.print(os, OpPrintingFlags().skipRegions());
71  return get<Block *>()->print(os);
72 }
73 
75  if (auto *programPoint = llvm::dyn_cast<GenericProgramPoint *>(*this))
76  return programPoint->getLoc();
77  if (auto *op = llvm::dyn_cast<Operation *>(*this))
78  return op->getLoc();
79  if (auto value = llvm::dyn_cast<Value>(*this))
80  return value.getLoc();
81  return get<Block *>()->getParent()->getLoc();
82 }
83 
84 //===----------------------------------------------------------------------===//
85 // DataFlowSolver
86 //===----------------------------------------------------------------------===//
87 
89  // Initialize the analyses.
90  for (DataFlowAnalysis &analysis : llvm::make_pointee_range(childAnalyses)) {
91  DATAFLOW_DEBUG(llvm::dbgs()
92  << "Priming analysis: " << analysis.debugName << "\n");
93  if (failed(analysis.initialize(top)))
94  return failure();
95  }
96 
97  // Run the analysis until fixpoint.
98  do {
99  // Exhaust the worklist.
100  while (!worklist.empty()) {
101  auto [point, analysis] = worklist.front();
102  worklist.pop();
103 
104  DATAFLOW_DEBUG(llvm::dbgs() << "Invoking '" << analysis->debugName
105  << "' on: " << point << "\n");
106  if (failed(analysis->visit(point)))
107  return failure();
108  }
109 
110  // Iterate until all states are in some initialized state and the worklist
111  // is exhausted.
112  } while (!worklist.empty());
113 
114  return success();
115 }
116 
118  ChangeResult changed) {
119  if (changed == ChangeResult::Change) {
120  DATAFLOW_DEBUG(llvm::dbgs() << "Propagating update to " << state->debugName
121  << " of " << state->point << "\n"
122  << "Value: " << *state << "\n");
123  state->onUpdate(this);
124  }
125 }
126 
127 //===----------------------------------------------------------------------===//
128 // DataFlowAnalysis
129 //===----------------------------------------------------------------------===//
130 
132 
134 
136  state->addDependency(point, this);
137 }
138 
140  ChangeResult changed) {
141  solver.propagateIfChanged(state, changed);
142 }
#define DATAFLOW_DEBUG(X)
Base class for generic analysis states.
LLVM_DUMP_METHOD void dump() const
virtual void print(raw_ostream &os) const =0
Print the contents of the analysis state.
void addDependency(ProgramPoint dependent, DataFlowAnalysis *analysis)
Add a dependency to this analysis state on a program point and an analysis.
virtual ~AnalysisState()
ProgramPoint point
The program point to which the state belongs.
Base class for all data-flow analyses.
void propagateIfChanged(AnalysisState *state, ChangeResult changed)
Propagate an update to a state if it changed.
DataFlowAnalysis(DataFlowSolver &solver)
Create an analysis with a reference to the parent solver.
void addDependency(AnalysisState *state, ProgramPoint point)
Create a dependency between the given analysis state and program point on this analysis.
The general data-flow analysis solver.
void propagateIfChanged(AnalysisState *state, ChangeResult changed)
Propagate an update to an analysis state if it changed by pushing dependent work items to the back of...
LogicalResult initializeAndRun(Operation *top)
Initialize the children analyses starting from the provided top-level operation and run the analysis ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Set of flags used to control the behavior of the various IR print methods (e.g.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void print(raw_ostream &os, const OpPrintingFlags &flags=std::nullopt)
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
ChangeResult
A result type used to indicate if a change happened.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Fundamental IR components are supported as first-class program points.
Location getLoc() const
Get the source location of the program point.
void print(raw_ostream &os) const
Print the program point.