14 #include "llvm/ADT/DenseMap.h"
15 #include "llvm/Support/Format.h"
16 #include "llvm/Support/raw_ostream.h"
19 #define GEN_PASS_DEF_PRINTOPSTATS
20 #include "mlir/Transforms/Passes.h.inc"
26 struct PrintOpStatsPass :
public impl::PrintOpStatsBase<PrintOpStatsPass> {
27 explicit PrintOpStatsPass(raw_ostream &os) : os(os) {}
29 explicit PrintOpStatsPass(raw_ostream &os,
bool printAsJSON) : os(os) {
30 this->printAsJSON = printAsJSON;
34 void runOnOperation()
override;
40 void printSummaryInJSON();
43 llvm::StringMap<int64_t> opCount;
48 void PrintOpStatsPass::runOnOperation() {
58 markAllAnalysesPreserved();
61 void PrintOpStatsPass::printSummary() {
62 os <<
"Operations encountered:\n";
63 os <<
"-----------------------\n";
68 auto splitOperationName = [](StringRef opName) {
69 auto splitName = opName.split(
'.');
70 return splitName.second.empty() ? std::make_pair(
"", splitName.first)
75 size_t maxLenOpName = 0, maxLenDialect = 0;
76 for (
const auto &key : sorted) {
77 auto [dialectName, opName] = splitOperationName(key);
78 maxLenDialect =
std::max(maxLenDialect, dialectName.size());
79 maxLenOpName =
std::max(maxLenOpName, opName.size());
82 for (
const auto &key : sorted) {
83 auto [dialectName, opName] = splitOperationName(key);
88 if (dialectName.empty())
89 os.indent(maxLenDialect + 3);
91 os << llvm::right_justify(dialectName, maxLenDialect + 2) <<
'.';
94 os << llvm::left_justify(opName, maxLenOpName) <<
" , " << opCount[key]
99 void PrintOpStatsPass::printSummaryInJSON() {
105 for (
unsigned i = 0, e = sorted.size(); i != e; ++i) {
106 const auto &key = sorted[i];
107 os <<
" \"" << key <<
"\" : " << opCount[key];
117 return std::make_unique<PrintOpStatsPass>(os);
122 return std::make_unique<PrintOpStatsPass>(os, printAsJSON);
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Operation is the basic unit of execution within MLIR.
OperationName getName()
The name of an operation is the key identifier for it.
Include the generated interface declarations.
std::unique_ptr< Pass > createPrintOpStatsPass(raw_ostream &os=llvm::errs())
Creates a pass which prints the list of ops and the number of occurrences in the module.