13 #include "llvm/Support/Format.h"
14 #include "llvm/Support/raw_ostream.h"
17 #define GEN_PASS_DEF_PRINTOPSTATS
18 #include "mlir/Transforms/Passes.h.inc"
24 struct PrintOpStatsPass :
public impl::PrintOpStatsBase<PrintOpStatsPass> {
25 explicit PrintOpStatsPass(raw_ostream &os) : os(os) {}
27 explicit PrintOpStatsPass(raw_ostream &os,
bool printAsJSON) : os(os) {
28 this->printAsJSON = printAsJSON;
32 void runOnOperation()
override;
38 void printSummaryInJSON();
41 llvm::StringMap<int64_t> opCount;
46 void PrintOpStatsPass::runOnOperation() {
56 markAllAnalysesPreserved();
59 void PrintOpStatsPass::printSummary() {
60 os <<
"Operations encountered:\n";
61 os <<
"-----------------------\n";
66 auto splitOperationName = [](StringRef opName) {
67 auto splitName = opName.split(
'.');
68 return splitName.second.empty() ? std::make_pair(
"", splitName.first)
73 size_t maxLenOpName = 0, maxLenDialect = 0;
74 for (
const auto &key : sorted) {
75 auto [dialectName, opName] = splitOperationName(key);
76 maxLenDialect =
std::max(maxLenDialect, dialectName.size());
77 maxLenOpName =
std::max(maxLenOpName, opName.size());
80 for (
const auto &key : sorted) {
81 auto [dialectName, opName] = splitOperationName(key);
86 if (dialectName.empty())
87 os.indent(maxLenDialect + 3);
89 os << llvm::right_justify(dialectName, maxLenDialect + 2) <<
'.';
92 os << llvm::left_justify(opName, maxLenOpName) <<
" , " << opCount[key]
97 void PrintOpStatsPass::printSummaryInJSON() {
103 for (
unsigned i = 0, e = sorted.size(); i != e; ++i) {
104 const auto &key = sorted[i];
105 os <<
" \"" << key <<
"\" : " << opCount[key];
115 return std::make_unique<PrintOpStatsPass>(os);
120 return std::make_unique<PrintOpStatsPass>(os, printAsJSON);
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Operation is the basic unit of execution within MLIR.
OperationName getName()
The name of an operation is the key identifier for it.
Include the generated interface declarations.
std::unique_ptr< Pass > createPrintOpStatsPass(raw_ostream &os=llvm::errs())
Creates a pass which prints the list of ops and the number of occurrences in the module.