13#include "llvm/ADT/STLExtras.h"
14#include "llvm/ADT/StringExtras.h"
15#include "llvm/Support/FileSystem.h"
16#include "llvm/Support/FormatVariadic.h"
17#include "llvm/Support/Path.h"
18#include "llvm/Support/ToolOutputFile.h"
30 IRPrinterInstrumentation(std::unique_ptr<PassManager::IRPrinterConfig> config)
31 : config(std::move(config)) {}
35 void runBeforePass(Pass *pass, Operation *op)
override;
36 void runAfterPass(Pass *pass, Operation *op)
override;
37 void runAfterPassFailed(Pass *pass, Operation *op)
override;
40 std::unique_ptr<PassManager::IRPrinterConfig> config;
52 if (!printModuleScope)
53 return op->
print(out <<
" //----- //\n",
57 out <<
" ('" << op->
getName() <<
"' operation";
60 out <<
": @" << symbolName.getValue();
61 out <<
") //----- //\n";
64 auto *topLevelOp = op;
65 while (
auto *parentOp = topLevelOp->getParentOp())
66 topLevelOp = parentOp;
67 topLevelOp->
print(out, flags);
71void IRPrinterInstrumentation::runBeforePass(
Pass *pass,
Operation *op) {
72 if (isa<OpToOpPassAdaptor>(pass))
75 if (
config->shouldPrintAfterOnlyOnChange())
76 beforePassFingerPrints.try_emplace(pass, op);
78 config->printBeforeIfEnabled(pass, op, [&](raw_ostream &out) {
79 out <<
"// -----// IR Dump Before " << pass->
getName() <<
" ("
82 config->getOpPrintingFlags());
87void IRPrinterInstrumentation::runAfterPass(Pass *pass, Operation *op) {
88 if (isa<OpToOpPassAdaptor>(pass))
92 if (
config->shouldPrintAfterOnlyOnFailure())
97 if (
config->shouldPrintAfterOnlyOnChange()) {
98 auto fingerPrintIt = beforePassFingerPrints.find(pass);
99 assert(fingerPrintIt != beforePassFingerPrints.end() &&
100 "expected valid fingerprint");
102 if (fingerPrintIt->second == OperationFingerPrint(op)) {
103 beforePassFingerPrints.erase(fingerPrintIt);
106 beforePassFingerPrints.erase(fingerPrintIt);
109 config->printAfterIfEnabled(pass, op, [&](raw_ostream &out) {
110 out <<
"// -----// IR Dump After " << pass->
getName() <<
" ("
113 config->getOpPrintingFlags());
118void IRPrinterInstrumentation::runAfterPassFailed(Pass *pass, Operation *op) {
119 if (isa<OpToOpPassAdaptor>(pass))
121 if (
config->shouldPrintAfterOnlyOnChange())
122 beforePassFingerPrints.erase(pass);
124 config->printAfterIfEnabled(pass, op, [&](raw_ostream &out) {
125 out << formatv(
"// -----// IR Dump After {0} Failed ({1})", pass->
getName(),
128 config->getOpPrintingFlags());
139 bool printAfterOnlyOnChange,
140 bool printAfterOnlyOnFailure,
142 : printModuleScope(printModuleScope),
143 printAfterOnlyOnChange(printAfterOnlyOnChange),
144 printAfterOnlyOnFailure(printAfterOnlyOnFailure),
145 opPrintingFlags(opPrintingFlags) {}
173 BasicIRPrinterConfig(
174 std::function<
bool(
Pass *,
Operation *)> shouldPrintBeforePass,
175 std::function<
bool(
Pass *,
Operation *)> shouldPrintAfterPass,
176 bool printModuleScope,
bool printAfterOnlyOnChange,
180 printAfterOnlyOnFailure, opPrintingFlags),
181 shouldPrintBeforePass(std::move(shouldPrintBeforePass)),
182 shouldPrintAfterPass(std::move(shouldPrintAfterPass)), out(out) {
183 assert((this->shouldPrintBeforePass || this->shouldPrintAfterPass) &&
184 "expected at least one valid filter function");
188 PrintCallbackFn printCallback)
final {
189 if (shouldPrintBeforePass && shouldPrintBeforePass(pass, operation))
194 PrintCallbackFn printCallback)
final {
195 if (shouldPrintAfterPass && shouldPrintAfterPass(pass, operation))
200 std::function<bool(Pass *, Operation *)> shouldPrintBeforePass;
201 std::function<bool(Pass *, Operation *)> shouldPrintAfterPass;
214static std::pair<SmallVector<std::pair<std::string, std::string>>, std::string>
221 ++counters.try_emplace(op, -1).first->second;
223 countPrefix.push_back(counters[iter]);
224 StringAttr symbolNameAttr =
226 std::string symbolName =
227 symbolNameAttr ? symbolNameAttr.str() :
"no-symbol-name";
228 llvm::replace(symbolName,
'/',
'_');
229 llvm::replace(symbolName,
'\\',
'_');
232 llvm::join(llvm::split(iter->getName().getStringRef().str(),
'.'),
"_");
233 pathElements.emplace_back(std::move(opName), std::move(symbolName));
234 iter = iter->getParentOp();
237 std::reverse(countPrefix.begin(), countPrefix.end());
238 std::reverse(pathElements.begin(), pathElements.end());
240 std::string passFileName = llvm::formatv(
242 llvm::make_range(countPrefix.begin(), countPrefix.end()), passName);
244 return {pathElements, passFileName};
248 if (std::error_code ec =
249 llvm::sys::fs::create_directory(dirPath,
true)) {
250 llvm::errs() <<
"Error while creating directory " << dirPath <<
": "
251 << ec.message() <<
"\n";
259static std::unique_ptr<llvm::ToolOutputFile>
261 llvm::StringRef rootDir,
267 auto [opAndSymbolNames, fileName] =
276 for (
const auto &[opName, symbolName] : opAndSymbolNames) {
277 llvm::sys::path::append(path, opName +
"_" + symbolName);
283 llvm::sys::path::append(path, fileName);
285 std::unique_ptr<llvm::ToolOutputFile> file =
openOutputFile(path, &error);
287 llvm::errs() <<
"Error opening output file " << path <<
": " << error
298struct FileTreeIRPrinterConfig :
public PassManager::IRPrinterConfig {
299 FileTreeIRPrinterConfig(
300 std::function<
bool(Pass *, Operation *)> shouldPrintBeforePass,
301 std::function<
bool(Pass *, Operation *)> shouldPrintAfterPass,
302 bool printModuleScope,
bool printAfterOnlyOnChange,
303 bool printAfterOnlyOnFailure, OpPrintingFlags opPrintingFlags,
304 llvm::StringRef treeDir)
305 : IRPrinterConfig(printModuleScope, printAfterOnlyOnChange,
306 printAfterOnlyOnFailure, opPrintingFlags),
307 shouldPrintBeforePass(std::move(shouldPrintBeforePass)),
308 shouldPrintAfterPass(std::move(shouldPrintAfterPass)),
310 assert((this->shouldPrintBeforePass || this->shouldPrintAfterPass) &&
311 "expected at least one valid filter function");
314 void printBeforeIfEnabled(Pass *pass, Operation *operation,
315 PrintCallbackFn printCallback)
final {
316 if (!shouldPrintBeforePass || !shouldPrintBeforePass(pass, operation))
319 operation, pass->
getArgument(), treeDir, counters);
322 printCallback(file->os());
326 void printAfterIfEnabled(Pass *pass, Operation *operation,
327 PrintCallbackFn printCallback)
final {
328 if (!shouldPrintAfterPass || !shouldPrintAfterPass(pass, operation))
331 operation, pass->
getArgument(), treeDir, counters);
334 printCallback(file->os());
339 std::function<bool(Pass *, Operation *)> shouldPrintBeforePass;
340 std::function<bool(Pass *, Operation *)> shouldPrintAfterPass;
347 llvm::DenseMap<Operation *, unsigned> counters;
355 if (
config->shouldPrintAtModuleScope() &&
357 llvm::report_fatal_error(
"IR printing can't be setup on a pass-manager "
358 "without disabling multi-threading first.");
360 std::make_unique<IRPrinterInstrumentation>(std::move(
config)));
365 std::function<
bool(
Pass *,
Operation *)> shouldPrintBeforePass,
366 std::function<
bool(
Pass *,
Operation *)> shouldPrintAfterPass,
367 bool printModuleScope,
bool printAfterOnlyOnChange,
371 std::move(shouldPrintBeforePass), std::move(shouldPrintAfterPass),
372 printModuleScope, printAfterOnlyOnChange, printAfterOnlyOnFailure,
373 opPrintingFlags, out));
378 std::function<
bool(
Pass *,
Operation *)> shouldPrintBeforePass,
379 std::function<
bool(
Pass *,
Operation *)> shouldPrintAfterPass,
380 bool printModuleScope,
bool printAfterOnlyOnChange,
381 bool printAfterOnlyOnFailure, StringRef printTreeDir,
384 std::move(shouldPrintBeforePass), std::move(shouldPrintAfterPass),
385 printModuleScope, printAfterOnlyOnChange, printAfterOnlyOnFailure,
386 opPrintingFlags, printTreeDir));
static std::pair< SmallVector< std::pair< std::string, std::string > >, std::string > getOpAndSymbolNames(Operation *op, StringRef passName, llvm::DenseMap< Operation *, unsigned > &counters)
Return pairs of (sanitized op name, symbol name) for op and all parent operations.
static void printIR(Operation *op, bool printModuleScope, raw_ostream &out, OpPrintingFlags flags)
static std::unique_ptr< llvm::ToolOutputFile > createTreePrinterOutputPath(Operation *op, llvm::StringRef passArgument, llvm::StringRef rootDir, llvm::DenseMap< Operation *, unsigned > &counters)
Creates directories (if required) and opens an output file for the FileTreeIRPrinterConfig.
static LogicalResult createDirectoryOrPrintErr(llvm::StringRef dirPath)
Set of flags used to control the behavior of the various IR print methods (e.g.
OpPrintingFlags & useLocalScope(bool enable=true)
Use local scope when printing the operation.
Operation is the basic unit of execution within MLIR.
AttrClass getAttrOfType(StringAttr name)
Block * getBlock()
Returns the operation block that contains this operation.
OperationName getName()
The name of an operation is the key identifier for it.
void print(raw_ostream &os, const OpPrintingFlags &flags={})
PassInstrumentation provides several entry points into the pass manager infrastructure.
A configuration struct provided to the IR printer instrumentation.
virtual ~IRPrinterConfig()
IRPrinterConfig(bool printModuleScope=false, bool printAfterOnlyOnChange=false, bool printAfterOnlyOnFailure=false, OpPrintingFlags opPrintingFlags=OpPrintingFlags())
Initialize the configuration.
function_ref< void(raw_ostream &)> PrintCallbackFn
virtual void printBeforeIfEnabled(Pass *pass, Operation *operation, PrintCallbackFn printCallback)
A hook that may be overridden by a derived config that checks if the IR of 'operation' should be dump...
virtual void printAfterIfEnabled(Pass *pass, Operation *operation, PrintCallbackFn printCallback)
A hook that may be overridden by a derived config that checks if the IR of 'operation' should be dump...
void enableIRPrinting(std::unique_ptr< IRPrinterConfig > config)
Add an instrumentation to print the IR before and after pass execution, using the provided configurat...
MLIRContext * getContext() const
Return an instance of the context.
void enableIRPrintingToFileTree(std::function< bool(Pass *, Operation *)> shouldPrintBeforePass=[](Pass *, Operation *) { return true;}, std::function< bool(Pass *, Operation *)> shouldPrintAfterPass=[](Pass *, Operation *) { return true;}, bool printModuleScope=true, bool printAfterOnlyOnChange=true, bool printAfterOnlyOnFailure=false, llvm::StringRef printTreeDir=".pass_manager_output", OpPrintingFlags opPrintingFlags=OpPrintingFlags())
Similar to enableIRPrinting above, except that instead of printing the IR to a single output stream,...
void addInstrumentation(std::unique_ptr< PassInstrumentation > pi)
Add the provided instrumentation to the pass manager.
The abstract base pass class.
virtual StringRef getName() const =0
Returns the derived pass name.
virtual StringRef getArgument() const
Return the command line argument used when registering this pass.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Include the generated interface declarations.
std::unique_ptr< llvm::ToolOutputFile > openOutputFile(llvm::StringRef outputFilename, std::string *errorMessage=nullptr)
Open the file specified by its name for writing.
const FrozenRewritePatternSet GreedyRewriteConfig config
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap