16#include "llvm/ADT/STLExtras.h"
17#include "llvm/ADT/SetVector.h"
18#include "llvm/Support/CrashRecoveryContext.h"
19#include "llvm/Support/ManagedStatic.h"
20#include "llvm/Support/Mutex.h"
21#include "llvm/Support/Signals.h"
22#include "llvm/Support/Threading.h"
23#include "llvm/Support/ToolOutputFile.h"
44 void generate(std::string &description);
55 static void crashHandler(
void *);
58 static void registerSignalHandler();
61 std::string pipelineElements;
78 static llvm::ManagedStatic<llvm::sys::SmartMutex<true>> reproducerMutex;
79 static llvm::ManagedStatic<
80 llvm::SmallSetVector<RecoveryReproducerContext *, 1>>
86llvm::ManagedStatic<llvm::sys::SmartMutex<true>>
87 RecoveryReproducerContext::reproducerMutex;
88llvm::ManagedStatic<llvm::SmallSetVector<RecoveryReproducerContext *, 1>>
89 RecoveryReproducerContext::reproducerSet;
92 std::string passPipelineStr,
Operation *op,
94 : pipelineElements(std::move(passPipelineStr)),
95 preCrashOperation(op->
clone()), streamFactory(streamFactory),
96 disableThreads(!op->
getContext()->isMultithreadingEnabled()),
97 verifyPasses(verifyPasses) {
103 preCrashOperation->erase();
109 const std::string &pipelineElements,
110 bool disableThreads,
bool verifyPasses) {
111 llvm::raw_string_ostream descOS(description);
115 std::unique_ptr<ReproducerStream> stream = factory(error);
117 descOS <<
"failed to create output stream: " << error;
120 descOS <<
"reproducer generated at `" << stream->description() <<
"`";
122 std::string pipeline =
123 (op->
getName().getStringRef() +
"(" + pipelineElements +
")").str();
128 builder.
buildBool(
"disable_threading", disableThreads);
129 builder.
buildBool(
"verify_each", verifyPasses);
133 op->
print(stream->os(), state);
138 pipelineElements, disableThreads, verifyPasses);
142 llvm::sys::SmartScopedLock<true> lock(*reproducerMutex);
143 reproducerSet->remove(
this);
144 if (reproducerSet->empty())
145 llvm::CrashRecoveryContext::Disable();
149 llvm::sys::SmartScopedLock<true> lock(*reproducerMutex);
150 if (reproducerSet->empty())
151 llvm::CrashRecoveryContext::Enable();
152 registerSignalHandler();
153 reproducerSet->insert(
this);
156void RecoveryReproducerContext::crashHandler(
void *) {
161 std::string description;
162 context->generate(description);
165 emitError(context->preCrashOperation->getLoc())
166 <<
"A signal was caught while processing the MLIR module:"
167 << description <<
"; marking pass as failed";
171void RecoveryReproducerContext::registerSignalHandler() {
173 static bool registered =
174 (llvm::sys::AddSignalHandler(crashHandler,
nullptr),
false);
206 :
impl(std::make_unique<
Impl>(streamFactory, localReproducer)) {}
211 bool pmFlagVerifyPasses) {
212 assert((!
impl->localReproducer ||
214 "expected multi-threading to be disabled when generating a local "
217 llvm::CrashRecoveryContext::Enable();
218 impl->pmFlagVerifyPasses = pmFlagVerifyPasses;
222 if (!
impl->localReproducer)
228 std::pair<Pass *, Operation *> passOpPair) {
229 os <<
"`" << passOpPair.first->getName() <<
"` on "
230 <<
"'" << passOpPair.second->getName() <<
"' operation";
231 if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(passOpPair.second))
232 os <<
": @" << symbol.getName();
236 LogicalResult executionResult) {
238 if (
impl->activeContexts.empty())
242 if (succeeded(executionResult))
243 return impl->activeContexts.clear();
246 <<
"Failures have been detected while "
247 "processing an MLIR pass pipeline";
251 if (!
impl->localReproducer) {
252 assert(
impl->activeContexts.size() == 1 &&
"expected one active context");
255 std::string description;
256 impl->activeContexts.front()->generate(description);
259 Diagnostic ¬e =
diag.attachNote() <<
"Pipeline failed while executing [";
260 llvm::interleaveComma(
impl->runningPasses, note,
261 [&](
const std::pair<Pass *, Operation *> &value) {
262 formatPassOpReproducerMessage(note, value);
264 note <<
"]: " << description;
265 impl->runningPasses.clear();
266 impl->activeContexts.clear();
273 assert(
impl->activeContexts.size() ==
impl->runningPasses.size() &&
274 "expected running passes to match active contexts");
278 std::string description;
279 reproducerContext.
generate(description);
282 Diagnostic ¬e =
diag.attachNote() <<
"Pipeline failed while executing ";
284 note <<
": " << description;
286 impl->activeContexts.clear();
287 impl->runningPasses.clear();
294 impl->runningPasses.insert(std::make_pair(pass, op));
295 if (!
impl->localReproducer)
300 if (!
impl->activeContexts.empty())
301 impl->activeContexts.back()->disable();
306 scopes.push_back(op->
getName());
313 llvm::raw_string_ostream passOS(passStr);
315 passOS << scope <<
"(";
317 for (
unsigned i = 0, e = scopes.size(); i < e; ++i)
320 impl->activeContexts.push_back(std::make_unique<RecoveryReproducerContext>(
321 passStr, op,
impl->streamFactory,
impl->pmFlagVerifyPasses));
326 llvm::raw_string_ostream passOS(passStr);
327 llvm::interleaveComma(
330 impl->activeContexts.push_back(std::make_unique<RecoveryReproducerContext>(
331 passStr, op,
impl->streamFactory,
impl->pmFlagVerifyPasses));
337 impl->runningPasses.remove(std::make_pair(pass, op));
338 if (
impl->localReproducer) {
339 impl->activeContexts.pop_back();
343 if (!
impl->activeContexts.empty())
344 impl->activeContexts.back()->enable();
356 ~CrashReproducerInstrumentation()
override =
default;
359 if (!isa<OpToOpPassAdaptor>(pass))
360 generator.prepareReproducerFor(pass, op);
364 if (!isa<OpToOpPassAdaptor>(pass))
365 generator.removeLastReproducerFor(pass, op);
368 void runAfterPassFailed(Pass *pass, Operation *op)
override {
373 alreadyFailed =
true;
380 bool alreadyFailed =
false;
391struct FileReproducerStream :
public mlir::ReproducerStream {
392 FileReproducerStream(std::unique_ptr<llvm::ToolOutputFile> outputFile)
393 : outputFile(std::move(outputFile)) {}
394 ~FileReproducerStream()
override { outputFile->keep(); }
397 StringRef description()
override {
return outputFile->getFilename(); }
400 raw_ostream &os()
override {
return outputFile->os(); }
404 std::unique_ptr<llvm::ToolOutputFile> outputFile =
nullptr;
412LogicalResult PassManager::runWithCrashRecovery(Operation *op,
413 AnalysisManager am) {
414 const bool threadingEnabled =
getContext()->isMultithreadingEnabled();
415 crashReproGenerator->initialize(getPasses(), op, verifyPasses);
418 LogicalResult passManagerResult = failure();
419 llvm::CrashRecoveryContext recoveryContext;
420 const auto runPassesFn = [&] { passManagerResult = runPasses(op, am); };
421 if (threadingEnabled)
422 recoveryContext.RunSafelyOnThread(runPassesFn);
424 recoveryContext.RunSafely(runPassesFn);
425 crashReproGenerator->finalize(op, passManagerResult);
427 return passManagerResult;
434 std::string filename = outputFile.str();
435 return [filename](std::string &error) -> std::unique_ptr<ReproducerStream> {
436 std::unique_ptr<llvm::ToolOutputFile> outputFile =
439 error =
"Failed to create reproducer stream: " + error;
442 return std::make_unique<FileReproducerStream>(std::move(outputFile));
447 raw_ostream &os, StringRef anchorName,
448 const llvm::iterator_range<OpPassManager::pass_iterator> &passes,
449 bool pretty =
false);
452 StringRef anchorName,
454 Operation *op, StringRef outputFile,
bool disableThreads,
457 std::string description;
458 std::string pipelineStr;
459 llvm::raw_string_ostream passOS(pipelineStr);
462 pipelineStr, disableThreads, verifyPasses);
467 bool genLocalReproducer) {
474 assert(!crashReproGenerator &&
475 "crash reproducer has already been initialized");
476 if (genLocalReproducer &&
getContext()->isMultithreadingEnabled())
477 llvm::report_fatal_error(
478 "Local crash reproduction can't be setup on a "
479 "pass-manager without disabling multi-threading first.");
481 crashReproGenerator = std::make_unique<PassCrashReproducerGenerator>(
482 factory, genLocalReproducer);
484 std::make_unique<CrashReproducerInstrumentation>(*crashReproGenerator));
493 if (entry.getKey() ==
"pipeline") {
494 FailureOr<std::string> value = entry.parseAsString();
495 if (succeeded(value))
496 this->pipeline = std::move(*value);
499 if (entry.getKey() ==
"disable_threading") {
500 FailureOr<bool> value = entry.parseAsBool();
501 if (succeeded(value))
502 this->disableThreading = *value;
505 if (entry.getKey() ==
"verify_each") {
506 FailureOr<bool> value = entry.parseAsBool();
507 if (succeeded(value))
508 this->verifyEach = *value;
511 return entry.emitError() <<
"unknown 'mlir_reproducer' resource key '"
512 << entry.getKey() <<
"'";
514 config.attachResourceParser(
"mlir_reproducer", parseFn);
518 if (pipeline.has_value()) {
525 if (disableThreading.has_value())
528 if (verifyEach.has_value())
static const mlir::GenInfo * generator
static std::string diag(const llvm::Value &value)
static void appendReproducer(std::string &description, Operation *op, const ReproducerStreamFactory &factory, const std::string &pipelineElements, bool disableThreads, bool verifyPasses)
static void formatPassOpReproducerMessage(Diagnostic &os, std::pair< Pass *, Operation * > passOpPair)
static ReproducerStreamFactory makeReproducerStreamFactory(StringRef outputFile)
void printAsTextualPipeline(raw_indented_ostream &os, StringRef anchorName, const llvm::iterator_range< OpPassManager::pass_iterator > &passes, bool pretty=false)
Prints out the passes of the pass manager as the textual representation of pipelines.
This class represents a single parsed resource entry.
This class is used to build resource entries for use by the printer.
virtual void buildString(StringRef key, StringRef data)=0
Build a resource entry represented by the given human-readable string value.
virtual void buildBool(StringRef key, bool data)=0
Build a resource entry represented by the given bool.
This class provides management for the lifetime of the state used when printing the IR.
void attachResourcePrinter(std::unique_ptr< AsmResourcePrinter > printer)
Attach the given resource printer to the AsmState.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
This class represents a diagnostic that is inflight and set to be reported.
void disableMultithreading(bool disable=true)
Set the flag specifying if multi-threading is disabled by the context.
bool isMultithreadingEnabled()
Return true if multi-threading is enabled by the context.
This class represents a pass manager that runs passes on either a specific operation type,...
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
OperationName getName()
The name of an operation is the key identifier for it.
void print(raw_ostream &os, const OpPrintingFlags &flags={})
MLIRContext * getContext()
Return the context this operation is associated with.
This class represents a configuration for the MLIR assembly parser.
PassInstrumentation provides several entry points into the pass manager infrastructure.
The main pass manager and pipeline builder.
MLIRContext * getContext() const
Return an instance of the context.
void addInstrumentation(std::unique_ptr< PassInstrumentation > pi)
Add the provided instrumentation to the pass manager.
void enableCrashReproducerGeneration(StringRef outputFile, bool genLocalReproducer=false)
Enable support for the pass manager to generate a reproducer on the event of a crash or a pass failur...
void enableVerifier(bool enabled=true)
Runs the verifier after each individual pass.
The abstract base pass class.
void printAsTextualPipeline(raw_ostream &os, bool pretty=false)
Prints out the pass in the textual representation of pipelines.
void initialize(iterator_range< PassManager::pass_iterator > passes, Operation *op, bool pmFlagVerifyPasses)
Initialize the generator in preparation for reproducer generation.
void removeLastReproducerFor(Pass *pass, Operation *op)
Remove the last recorded reproducer anchored at the given pass and operation.
void finalize(Operation *rootOp, LogicalResult executionResult)
Finalize the current run of the generator, generating any necessary reproducers if the provided execu...
void prepareReproducerFor(Pass *pass, Operation *op)
Prepare a new reproducer for the given pass, operating on op.
~PassCrashReproducerGenerator()
PassCrashReproducerGenerator(ReproducerStreamFactory &streamFactory, bool localReproducer)
Include the generated interface declarations.
std::unique_ptr< llvm::ToolOutputFile > openOutputFile(llvm::StringRef outputFilename, std::string *errorMessage=nullptr)
Open the file specified by its name for writing.
const FrozenRewritePatternSet GreedyRewriteConfig config
std::string makeReproducer(StringRef anchorName, const llvm::iterator_range< OpPassManager::pass_iterator > &passes, Operation *op, StringRef outputFile, bool disableThreads=false, bool verifyPasses=false)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::SetVector< T, Vector, Set, N > SetVector
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
std::function< std::unique_ptr< ReproducerStream >(std::string &error)> ReproducerStreamFactory
Method type for constructing ReproducerStream.
LogicalResult parsePassPipeline(StringRef pipeline, OpPassManager &pm, raw_ostream &errorStream=llvm::errs())
Parse the textual representation of a pass pipeline, adding the result to 'pm' on success.
bool pmFlagVerifyPasses
Various pass manager flags that get emitted when generating a reproducer.
ReproducerStreamFactory streamFactory
The factory to use when generating a crash reproducer.
SetVector< std::pair< Pass *, Operation * > > runningPasses
The set of all currently running passes.
bool localReproducer
Flag indicating if reproducer generation should be localized to the failing pass.
Impl(ReproducerStreamFactory &streamFactory, bool localReproducer)
SmallVector< std::unique_ptr< RecoveryReproducerContext > > activeContexts
A record of all of the currently active reproducer contexts.
void attachResourceParser(ParserConfig &config)
Attach an assembly resource parser to 'config' that collects the MLIR reproducer configuration into t...
LogicalResult apply(PassManager &pm) const
Apply the reproducer options to 'pm' and its context.
This class contains all of the context for generating a recovery reproducer.
void disable()
Disable this reproducer context.
~RecoveryReproducerContext()
RecoveryReproducerContext(std::string passPipelineStr, Operation *op, ReproducerStreamFactory &streamFactory, bool verifyPasses)
void generate(std::string &description)
Generate a reproducer with the current context.
void enable()
Enable a previously disabled reproducer context.