13#include "llvm/Support/Format.h"
14#include "llvm/Support/raw_ostream.h"
17#define GEN_PASS_DEF_PRINTOPSTATSPASS
18#include "mlir/Transforms/Passes.h.inc"
25 using impl::PrintOpStatsPassBase<PrintOpStatsPass>::PrintOpStatsPassBase;
27 explicit PrintOpStatsPass(raw_ostream &os) : os(&os) {}
29 explicit PrintOpStatsPass(raw_ostream &os,
bool printAsJSON) : os(&os) {
30 this->printAsJSON = printAsJSON;
34 void runOnOperation()
override;
40 void printSummaryInJSON();
43 llvm::StringMap<int64_t> opCount;
44 raw_ostream *os = &llvm::errs();
48void PrintOpStatsPass::runOnOperation() {
58 markAllAnalysesPreserved();
61void PrintOpStatsPass::printSummary() {
62 *os <<
"Operations encountered:\n";
63 *os <<
"-----------------------\n";
64 SmallVector<StringRef, 64> sorted(opCount.keys());
68 auto splitOperationName = [](StringRef opName) {
69 auto splitName = opName.split(
'.');
70 return splitName.second.empty() ? std::make_pair(
"", splitName.first)
75 size_t maxLenOpName = 0, maxLenDialect = 0;
76 for (
const auto &key : sorted) {
77 auto [dialectName, opName] = splitOperationName(key);
78 maxLenDialect = std::max(maxLenDialect, dialectName.size());
79 maxLenOpName = std::max(maxLenOpName, opName.size());
82 for (
const auto &key : sorted) {
83 auto [dialectName, opName] = splitOperationName(key);
88 if (dialectName.empty())
89 os->indent(maxLenDialect + 3);
91 *os << llvm::right_justify(dialectName, maxLenDialect + 2) <<
'.';
94 *os << llvm::left_justify(opName, maxLenOpName) <<
" , " << opCount[key]
99void PrintOpStatsPass::printSummaryInJSON() {
100 SmallVector<StringRef, 64> sorted(opCount.keys());
105 for (
unsigned i = 0, e = sorted.size(); i != e; ++i) {
106 const auto &key = sorted[i];
107 *os <<
" \"" << key <<
"\" : " << opCount[key];
117 return std::make_unique<PrintOpStatsPass>(os);
122 return std::make_unique<PrintOpStatsPass>(os, printAsJSON);
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
OperationName getName()
The name of an operation is the key identifier for it.
Include the generated interface declarations.
std::unique_ptr<::mlir::Pass > createPrintOpStatsPass()