16#include "llvm/ADT/STLExtras.h"
17#include "llvm/ADT/SetVector.h"
18#include "llvm/Support/CrashRecoveryContext.h"
19#include "llvm/Support/ManagedStatic.h"
20#include "llvm/Support/Mutex.h"
21#include "llvm/Support/Signals.h"
22#include "llvm/Support/Threading.h"
23#include "llvm/Support/ToolOutputFile.h"
44 void generate(std::string &description);
55 static void crashHandler(
void *);
58 static void registerSignalHandler();
61 std::string pipelineElements;
78 static llvm::ManagedStatic<llvm::sys::SmartMutex<true>> reproducerMutex;
79 static llvm::ManagedStatic<
80 llvm::SmallSetVector<RecoveryReproducerContext *, 1>>
86llvm::ManagedStatic<llvm::sys::SmartMutex<true>>
87 RecoveryReproducerContext::reproducerMutex;
88llvm::ManagedStatic<llvm::SmallSetVector<RecoveryReproducerContext *, 1>>
89 RecoveryReproducerContext::reproducerSet;
92 std::string passPipelineStr,
Operation *op,
94 : pipelineElements(std::move(passPipelineStr)),
95 preCrashOperation(op->
clone()), streamFactory(streamFactory),
96 disableThreads(!op->
getContext()->isMultithreadingEnabled()),
97 verifyPasses(verifyPasses) {
103 preCrashOperation->erase();
109 const std::string &pipeline,
bool disableThreads,
111 llvm::raw_string_ostream descOS(description);
115 std::unique_ptr<ReproducerStream> stream = factory(error);
117 descOS <<
"failed to create output stream: " << error;
120 descOS <<
"reproducer generated at `" << stream->description() <<
"`";
126 builder.
buildBool(
"disable_threading", disableThreads);
127 builder.
buildBool(
"verify_each", verifyPasses);
131 op->
print(stream->os(), state);
135 std::string pipeline = (preCrashOperation->getName().getStringRef() +
"(" +
136 pipelineElements +
")")
139 disableThreads, verifyPasses);
143 llvm::sys::SmartScopedLock<true> lock(*reproducerMutex);
144 reproducerSet->remove(
this);
145 if (reproducerSet->empty())
146 llvm::CrashRecoveryContext::Disable();
150 llvm::sys::SmartScopedLock<true> lock(*reproducerMutex);
151 if (reproducerSet->empty())
152 llvm::CrashRecoveryContext::Enable();
153 registerSignalHandler();
154 reproducerSet->insert(
this);
157void RecoveryReproducerContext::crashHandler(
void *) {
162 std::string description;
163 context->generate(description);
166 emitError(context->preCrashOperation->getLoc())
167 <<
"A signal was caught while processing the MLIR module:"
168 << description <<
"; marking pass as failed";
172void RecoveryReproducerContext::registerSignalHandler() {
174 static bool registered =
175 (llvm::sys::AddSignalHandler(crashHandler,
nullptr),
false);
207 :
impl(std::make_unique<
Impl>(streamFactory, localReproducer)) {}
212 bool pmFlagVerifyPasses) {
213 assert((!
impl->localReproducer ||
215 "expected multi-threading to be disabled when generating a local "
218 llvm::CrashRecoveryContext::Enable();
219 impl->pmFlagVerifyPasses = pmFlagVerifyPasses;
223 if (!
impl->localReproducer)
229 std::pair<Pass *, Operation *> passOpPair) {
230 os <<
"`" << passOpPair.first->getName() <<
"` on "
231 <<
"'" << passOpPair.second->getName() <<
"' operation";
232 if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(passOpPair.second))
233 os <<
": @" << symbol.getName();
237 LogicalResult executionResult) {
239 if (
impl->activeContexts.empty())
243 if (succeeded(executionResult))
244 return impl->activeContexts.clear();
247 <<
"Failures have been detected while "
248 "processing an MLIR pass pipeline";
252 if (!
impl->localReproducer) {
253 assert(
impl->activeContexts.size() == 1 &&
"expected one active context");
256 std::string description;
257 impl->activeContexts.front()->generate(description);
260 Diagnostic ¬e =
diag.attachNote() <<
"Pipeline failed while executing [";
261 llvm::interleaveComma(
impl->runningPasses, note,
262 [&](
const std::pair<Pass *, Operation *> &value) {
263 formatPassOpReproducerMessage(note, value);
265 note <<
"]: " << description;
266 impl->runningPasses.clear();
267 impl->activeContexts.clear();
274 assert(
impl->activeContexts.size() ==
impl->runningPasses.size() &&
275 "expected running passes to match active contexts");
279 std::string description;
280 reproducerContext.
generate(description);
283 Diagnostic ¬e =
diag.attachNote() <<
"Pipeline failed while executing ";
285 note <<
": " << description;
287 impl->activeContexts.clear();
288 impl->runningPasses.clear();
295 impl->runningPasses.insert(std::make_pair(pass, op));
296 if (!
impl->localReproducer)
301 if (!
impl->activeContexts.empty())
302 impl->activeContexts.back()->disable();
307 scopes.push_back(op->
getName());
314 llvm::raw_string_ostream passOS(passStr);
316 passOS << scope <<
"(";
318 for (
unsigned i = 0, e = scopes.size(); i < e; ++i)
321 impl->activeContexts.push_back(std::make_unique<RecoveryReproducerContext>(
322 passStr, op,
impl->streamFactory,
impl->pmFlagVerifyPasses));
327 llvm::raw_string_ostream passOS(passStr);
328 llvm::interleaveComma(
331 impl->activeContexts.push_back(std::make_unique<RecoveryReproducerContext>(
332 passStr, op,
impl->streamFactory,
impl->pmFlagVerifyPasses));
338 impl->runningPasses.remove(std::make_pair(pass, op));
339 if (
impl->localReproducer) {
340 impl->activeContexts.pop_back();
344 if (!
impl->activeContexts.empty())
345 impl->activeContexts.back()->enable();
357 ~CrashReproducerInstrumentation()
override =
default;
360 if (!isa<OpToOpPassAdaptor>(pass))
361 generator.prepareReproducerFor(pass, op);
365 if (!isa<OpToOpPassAdaptor>(pass))
366 generator.removeLastReproducerFor(pass, op);
369 void runAfterPassFailed(Pass *pass, Operation *op)
override {
374 alreadyFailed =
true;
381 bool alreadyFailed =
false;
392struct FileReproducerStream :
public mlir::ReproducerStream {
393 FileReproducerStream(std::unique_ptr<llvm::ToolOutputFile> outputFile)
394 : outputFile(std::move(outputFile)) {}
395 ~FileReproducerStream()
override { outputFile->keep(); }
398 StringRef description()
override {
return outputFile->getFilename(); }
401 raw_ostream &os()
override {
return outputFile->os(); }
405 std::unique_ptr<llvm::ToolOutputFile> outputFile =
nullptr;
413LogicalResult PassManager::runWithCrashRecovery(Operation *op,
414 AnalysisManager am) {
415 const bool threadingEnabled =
getContext()->isMultithreadingEnabled();
416 crashReproGenerator->initialize(getPasses(), op, verifyPasses);
419 LogicalResult passManagerResult = failure();
420 llvm::CrashRecoveryContext recoveryContext;
421 const auto runPassesFn = [&] { passManagerResult = runPasses(op, am); };
422 if (threadingEnabled)
423 recoveryContext.RunSafelyOnThread(runPassesFn);
425 recoveryContext.RunSafely(runPassesFn);
426 crashReproGenerator->finalize(op, passManagerResult);
428 return passManagerResult;
435 std::string filename = outputFile.str();
436 return [filename](std::string &error) -> std::unique_ptr<ReproducerStream> {
437 std::unique_ptr<llvm::ToolOutputFile> outputFile =
440 error =
"Failed to create reproducer stream: " + error;
443 return std::make_unique<FileReproducerStream>(std::move(outputFile));
448 raw_ostream &os, StringRef anchorName,
449 const llvm::iterator_range<OpPassManager::pass_iterator> &passes,
450 bool pretty =
false);
453 StringRef anchorName,
455 Operation *op, StringRef outputFile,
bool disableThreads,
458 std::string description;
459 std::string pipelineStr;
460 llvm::raw_string_ostream passOS(pipelineStr);
463 pipelineStr, disableThreads, verifyPasses);
468 bool genLocalReproducer) {
475 assert(!crashReproGenerator &&
476 "crash reproducer has already been initialized");
477 if (genLocalReproducer &&
getContext()->isMultithreadingEnabled())
478 llvm::report_fatal_error(
479 "Local crash reproduction can't be setup on a "
480 "pass-manager without disabling multi-threading first.");
482 crashReproGenerator = std::make_unique<PassCrashReproducerGenerator>(
483 factory, genLocalReproducer);
485 std::make_unique<CrashReproducerInstrumentation>(*crashReproGenerator));
494 if (entry.getKey() ==
"pipeline") {
495 FailureOr<std::string> value = entry.parseAsString();
496 if (succeeded(value))
497 this->pipeline = std::move(*value);
500 if (entry.getKey() ==
"disable_threading") {
501 FailureOr<bool> value = entry.parseAsBool();
502 if (succeeded(value))
503 this->disableThreading = *value;
506 if (entry.getKey() ==
"verify_each") {
507 FailureOr<bool> value = entry.parseAsBool();
508 if (succeeded(value))
509 this->verifyEach = *value;
512 return entry.emitError() <<
"unknown 'mlir_reproducer' resource key '"
513 << entry.getKey() <<
"'";
515 config.attachResourceParser(
"mlir_reproducer", parseFn);
519 if (pipeline.has_value()) {
526 if (disableThreading.has_value())
529 if (verifyEach.has_value())
static const mlir::GenInfo * generator
static std::string diag(const llvm::Value &value)
static void appendReproducer(std::string &description, Operation *op, const ReproducerStreamFactory &factory, const std::string &pipeline, bool disableThreads, bool verifyPasses)
static void formatPassOpReproducerMessage(Diagnostic &os, std::pair< Pass *, Operation * > passOpPair)
static ReproducerStreamFactory makeReproducerStreamFactory(StringRef outputFile)
void printAsTextualPipeline(raw_indented_ostream &os, StringRef anchorName, const llvm::iterator_range< OpPassManager::pass_iterator > &passes, bool pretty=false)
Prints out the passes of the pass manager as the textual representation of pipelines.
This class represents a single parsed resource entry.
This class is used to build resource entries for use by the printer.
virtual void buildString(StringRef key, StringRef data)=0
Build a resource entry represented by the given human-readable string value.
virtual void buildBool(StringRef key, bool data)=0
Build a resource entry represented by the given bool.
This class provides management for the lifetime of the state used when printing the IR.
void attachResourcePrinter(std::unique_ptr< AsmResourcePrinter > printer)
Attach the given resource printer to the AsmState.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
This class represents a diagnostic that is inflight and set to be reported.
void disableMultithreading(bool disable=true)
Set the flag specifying if multi-threading is disabled by the context.
bool isMultithreadingEnabled()
Return true if multi-threading is enabled by the context.
This class represents a pass manager that runs passes on either a specific operation type,...
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
OperationName getName()
The name of an operation is the key identifier for it.
void print(raw_ostream &os, const OpPrintingFlags &flags={})
MLIRContext * getContext()
Return the context this operation is associated with.
This class represents a configuration for the MLIR assembly parser.
PassInstrumentation provides several entry points into the pass manager infrastructure.
The main pass manager and pipeline builder.
MLIRContext * getContext() const
Return an instance of the context.
void addInstrumentation(std::unique_ptr< PassInstrumentation > pi)
Add the provided instrumentation to the pass manager.
void enableCrashReproducerGeneration(StringRef outputFile, bool genLocalReproducer=false)
Enable support for the pass manager to generate a reproducer on the event of a crash or a pass failur...
void enableVerifier(bool enabled=true)
Runs the verifier after each individual pass.
The abstract base pass class.
void printAsTextualPipeline(raw_ostream &os, bool pretty=false)
Prints out the pass in the textual representation of pipelines.
void initialize(iterator_range< PassManager::pass_iterator > passes, Operation *op, bool pmFlagVerifyPasses)
Initialize the generator in preparation for reproducer generation.
void removeLastReproducerFor(Pass *pass, Operation *op)
Remove the last recorded reproducer anchored at the given pass and operation.
void finalize(Operation *rootOp, LogicalResult executionResult)
Finalize the current run of the generator, generating any necessary reproducers if the provided execu...
void prepareReproducerFor(Pass *pass, Operation *op)
Prepare a new reproducer for the given pass, operating on op.
~PassCrashReproducerGenerator()
PassCrashReproducerGenerator(ReproducerStreamFactory &streamFactory, bool localReproducer)
Include the generated interface declarations.
std::unique_ptr< llvm::ToolOutputFile > openOutputFile(llvm::StringRef outputFilename, std::string *errorMessage=nullptr)
Open the file specified by its name for writing.
const FrozenRewritePatternSet GreedyRewriteConfig config
std::string makeReproducer(StringRef anchorName, const llvm::iterator_range< OpPassManager::pass_iterator > &passes, Operation *op, StringRef outputFile, bool disableThreads=false, bool verifyPasses=false)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::SetVector< T, Vector, Set, N > SetVector
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
std::function< std::unique_ptr< ReproducerStream >(std::string &error)> ReproducerStreamFactory
Method type for constructing ReproducerStream.
LogicalResult parsePassPipeline(StringRef pipeline, OpPassManager &pm, raw_ostream &errorStream=llvm::errs())
Parse the textual representation of a pass pipeline, adding the result to 'pm' on success.
bool pmFlagVerifyPasses
Various pass manager flags that get emitted when generating a reproducer.
ReproducerStreamFactory streamFactory
The factory to use when generating a crash reproducer.
SetVector< std::pair< Pass *, Operation * > > runningPasses
The set of all currently running passes.
bool localReproducer
Flag indicating if reproducer generation should be localized to the failing pass.
Impl(ReproducerStreamFactory &streamFactory, bool localReproducer)
SmallVector< std::unique_ptr< RecoveryReproducerContext > > activeContexts
A record of all of the currently active reproducer contexts.
void attachResourceParser(ParserConfig &config)
Attach an assembly resource parser to 'config' that collects the MLIR reproducer configuration into t...
LogicalResult apply(PassManager &pm) const
Apply the reproducer options to 'pm' and its context.
This class contains all of the context for generating a recovery reproducer.
void disable()
Disable this reproducer context.
~RecoveryReproducerContext()
RecoveryReproducerContext(std::string passPipelineStr, Operation *op, ReproducerStreamFactory &streamFactory, bool verifyPasses)
void generate(std::string &description)
Generate a reproducer with the current context.
void enable()
Enable a previously disabled reproducer context.