16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/ADT/SetVector.h"
18 #include "llvm/Support/CrashRecoveryContext.h"
19 #include "llvm/Support/ManagedStatic.h"
20 #include "llvm/Support/Mutex.h"
21 #include "llvm/Support/Signals.h"
22 #include "llvm/Support/Threading.h"
23 #include "llvm/Support/ToolOutputFile.h"
44 void generate(std::string &description);
55 static void crashHandler(
void *);
58 static void registerSignalHandler();
61 std::string pipelineElements;
78 static llvm::ManagedStatic<llvm::sys::SmartMutex<true>> reproducerMutex;
79 static llvm::ManagedStatic<
80 llvm::SmallSetVector<RecoveryReproducerContext *, 1>>
86 llvm::ManagedStatic<llvm::sys::SmartMutex<true>>
87 RecoveryReproducerContext::reproducerMutex;
88 llvm::ManagedStatic<llvm::SmallSetVector<RecoveryReproducerContext *, 1>>
89 RecoveryReproducerContext::reproducerSet;
92 std::string passPipelineStr,
Operation *op,
94 : pipelineElements(std::move(passPipelineStr)),
95 preCrashOperation(op->
clone()), streamFactory(streamFactory),
96 disableThreads(!op->
getContext()->isMultithreadingEnabled()),
97 verifyPasses(verifyPasses) {
103 preCrashOperation->
erase();
109 const std::string &pipelineElements,
110 bool disableThreads,
bool verifyPasses) {
111 llvm::raw_string_ostream descOS(description);
115 std::unique_ptr<ReproducerStream> stream = factory(error);
117 descOS <<
"failed to create output stream: " << error;
120 descOS <<
"reproducer generated at `" << stream->description() <<
"`";
122 std::string pipeline =
125 state.attachResourcePrinter(
128 builder.
buildBool(
"disable_threading", disableThreads);
129 builder.
buildBool(
"verify_each", verifyPasses);
133 op->
print(stream->os(), state);
138 pipelineElements, disableThreads, verifyPasses);
142 llvm::sys::SmartScopedLock<true> lock(*reproducerMutex);
143 reproducerSet->remove(
this);
144 if (reproducerSet->empty())
145 llvm::CrashRecoveryContext::Disable();
149 llvm::sys::SmartScopedLock<true> lock(*reproducerMutex);
150 if (reproducerSet->empty())
151 llvm::CrashRecoveryContext::Enable();
152 registerSignalHandler();
153 reproducerSet->insert(
this);
156 void RecoveryReproducerContext::crashHandler(
void *) {
161 std::string description;
162 context->generate(description);
165 emitError(context->preCrashOperation->getLoc())
166 <<
"A signal was caught while processing the MLIR module:"
167 << description <<
"; marking pass as failed";
171 void RecoveryReproducerContext::registerSignalHandler() {
173 static bool registered =
174 (llvm::sys::AddSignalHandler(crashHandler,
nullptr),
false);
206 :
impl(std::make_unique<
Impl>(streamFactory, localReproducer)) {}
211 bool pmFlagVerifyPasses) {
212 assert((!
impl->localReproducer ||
214 "expected multi-threading to be disabled when generating a local "
217 llvm::CrashRecoveryContext::Enable();
218 impl->pmFlagVerifyPasses = pmFlagVerifyPasses;
222 if (!
impl->localReproducer)
228 std::pair<Pass *, Operation *> passOpPair) {
229 os <<
"`" << passOpPair.first->getName() <<
"` on "
230 <<
"'" << passOpPair.second->getName() <<
"' operation";
231 if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(passOpPair.second))
232 os <<
": @" << symbol.getName();
236 LogicalResult executionResult) {
238 if (
impl->activeContexts.empty())
242 if (succeeded(executionResult))
243 return impl->activeContexts.clear();
246 <<
"Failures have been detected while "
247 "processing an MLIR pass pipeline";
251 if (!
impl->localReproducer) {
252 assert(
impl->activeContexts.size() == 1 &&
"expected one active context");
255 std::string description;
256 impl->activeContexts.front()->generate(description);
259 Diagnostic ¬e =
diag.attachNote() <<
"Pipeline failed while executing [";
260 llvm::interleaveComma(
impl->runningPasses, note,
261 [&](
const std::pair<Pass *, Operation *> &value) {
262 formatPassOpReproducerMessage(note, value);
264 note <<
"]: " << description;
265 impl->runningPasses.clear();
266 impl->activeContexts.clear();
273 assert(
impl->activeContexts.size() ==
impl->runningPasses.size() &&
274 "expected running passes to match active contexts");
278 std::string description;
279 reproducerContext.
generate(description);
282 Diagnostic ¬e =
diag.attachNote() <<
"Pipeline failed while executing ";
284 note <<
": " << description;
286 impl->activeContexts.clear();
287 impl->runningPasses.clear();
294 impl->runningPasses.insert(std::make_pair(pass, op));
295 if (!
impl->localReproducer)
300 if (!
impl->activeContexts.empty())
301 impl->activeContexts.back()->disable();
306 scopes.push_back(op->
getName());
313 llvm::raw_string_ostream passOS(passStr);
315 passOS << scope <<
"(";
317 for (
unsigned i = 0, e = scopes.size(); i < e; ++i)
320 impl->activeContexts.push_back(std::make_unique<RecoveryReproducerContext>(
321 passStr, op,
impl->streamFactory,
impl->pmFlagVerifyPasses));
326 llvm::raw_string_ostream passOS(passStr);
327 llvm::interleaveComma(
330 impl->activeContexts.push_back(std::make_unique<RecoveryReproducerContext>(
331 passStr, op,
impl->streamFactory,
impl->pmFlagVerifyPasses));
337 impl->runningPasses.remove(std::make_pair(pass, op));
338 if (
impl->localReproducer) {
339 impl->activeContexts.pop_back();
343 if (!
impl->activeContexts.empty())
344 impl->activeContexts.back()->enable();
356 ~CrashReproducerInstrumentation()
override =
default;
359 if (!isa<OpToOpPassAdaptor>(pass))
360 generator.prepareReproducerFor(pass, op);
364 if (!isa<OpToOpPassAdaptor>(pass))
365 generator.removeLastReproducerFor(pass, op);
368 void runAfterPassFailed(
Pass *pass,
Operation *op)
override {
373 alreadyFailed =
true;
380 bool alreadyFailed =
false;
392 FileReproducerStream(std::unique_ptr<llvm::ToolOutputFile> outputFile)
393 : outputFile(std::move(outputFile)) {}
394 ~FileReproducerStream()
override { outputFile->keep(); }
397 StringRef description()
override {
return outputFile->getFilename(); }
400 raw_ostream &os()
override {
return outputFile->os(); }
404 std::unique_ptr<llvm::ToolOutputFile> outputFile =
nullptr;
412 LogicalResult PassManager::runWithCrashRecovery(
Operation *op,
415 crashReproGenerator->initialize(getPasses(), op, verifyPasses);
418 LogicalResult passManagerResult = failure();
419 llvm::CrashRecoveryContext recoveryContext;
420 const auto runPassesFn = [&] { passManagerResult = runPasses(op, am); };
421 if (threadingEnabled)
422 recoveryContext.RunSafelyOnThread(runPassesFn);
424 recoveryContext.RunSafely(runPassesFn);
425 crashReproGenerator->finalize(op, passManagerResult);
427 return passManagerResult;
434 std::string filename = outputFile.str();
435 return [filename](std::string &error) -> std::unique_ptr<ReproducerStream> {
436 std::unique_ptr<llvm::ToolOutputFile> outputFile =
439 error =
"Failed to create reproducer stream: " + error;
442 return std::make_unique<FileReproducerStream>(std::move(outputFile));
447 raw_ostream &os, StringRef anchorName,
449 bool pretty =
false);
452 StringRef anchorName,
454 Operation *op, StringRef outputFile,
bool disableThreads,
457 std::string description;
458 std::string pipelineStr;
459 llvm::raw_string_ostream passOS(pipelineStr);
462 pipelineStr, disableThreads, verifyPasses);
467 bool genLocalReproducer) {
474 assert(!crashReproGenerator &&
475 "crash reproducer has already been initialized");
476 if (genLocalReproducer &&
getContext()->isMultithreadingEnabled())
477 llvm::report_fatal_error(
478 "Local crash reproduction can't be setup on a "
479 "pass-manager without disabling multi-threading first.");
481 crashReproGenerator = std::make_unique<PassCrashReproducerGenerator>(
482 factory, genLocalReproducer);
484 std::make_unique<CrashReproducerInstrumentation>(*crashReproGenerator));
493 if (entry.getKey() ==
"pipeline") {
494 FailureOr<std::string> value = entry.parseAsString();
495 if (succeeded(value))
496 this->pipeline = std::move(*value);
499 if (entry.getKey() ==
"disable_threading") {
500 FailureOr<bool> value = entry.parseAsBool();
501 if (succeeded(value))
502 this->disableThreading = *value;
505 if (entry.getKey() ==
"verify_each") {
506 FailureOr<bool> value = entry.parseAsBool();
507 if (succeeded(value))
508 this->verifyEach = *value;
511 return entry.emitError() <<
"unknown 'mlir_reproducer' resource key '"
512 << entry.getKey() <<
"'";
514 config.attachResourceParser(
"mlir_reproducer", parseFn);
518 if (pipeline.has_value()) {
525 if (disableThreading.has_value())
528 if (verifyEach.has_value())
static MLIRContext * getContext(OpFoldResult val)
static const mlir::GenInfo * generator
static std::string diag(const llvm::Value &value)
static void appendReproducer(std::string &description, Operation *op, const ReproducerStreamFactory &factory, const std::string &pipelineElements, bool disableThreads, bool verifyPasses)
static void formatPassOpReproducerMessage(Diagnostic &os, std::pair< Pass *, Operation * > passOpPair)
void printAsTextualPipeline(raw_ostream &os, StringRef anchorName, const llvm::iterator_range< OpPassManager::pass_iterator > &passes, bool pretty=false)
static ReproducerStreamFactory makeReproducerStreamFactory(StringRef outputFile)
This class represents an analysis manager for a particular operation instance.
This class represents a single parsed resource entry.
This class is used to build resource entries for use by the printer.
virtual void buildString(StringRef key, StringRef data)=0
Build a resource entry represented by the given human-readable string value.
virtual void buildBool(StringRef key, bool data)=0
Build a resource entry represented by the given bool.
This class provides management for the lifetime of the state used when printing the IR.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
This class represents a diagnostic that is inflight and set to be reported.
void disableMultithreading(bool disable=true)
Set the flag specifying if multi-threading is disabled by the context.
bool isMultithreadingEnabled()
Return true if multi-threading is enabled by the context.
This class represents a pass manager that runs passes on either a specific operation type,...
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
OperationName getName()
The name of an operation is the key identifier for it.
void print(raw_ostream &os, const OpPrintingFlags &flags={})
void erase()
Remove this operation from its parent block and delete it.
This class represents a configuration for the MLIR assembly parser.
PassInstrumentation provides several entry points into the pass manager infrastructure.
The main pass manager and pipeline builder.
MLIRContext * getContext() const
Return an instance of the context.
void enableCrashReproducerGeneration(StringRef outputFile, bool genLocalReproducer=false)
Enable support for the pass manager to generate a reproducer on the event of a crash or a pass failur...
void enableVerifier(bool enabled=true)
Runs the verifier after each individual pass.
The abstract base pass class.
void printAsTextualPipeline(raw_ostream &os, bool pretty=false)
Prints out the pass in the textual representation of pipelines.
void initialize(iterator_range< PassManager::pass_iterator > passes, Operation *op, bool pmFlagVerifyPasses)
Initialize the generator in preparation for reproducer generation.
void removeLastReproducerFor(Pass *pass, Operation *op)
Remove the last recorded reproducer anchored at the given pass and operation.
void finalize(Operation *rootOp, LogicalResult executionResult)
Finalize the current run of the generator, generating any necessary reproducers if the provided execu...
void prepareReproducerFor(Pass *pass, Operation *op)
Prepare a new reproducer for the given pass, operating on op.
~PassCrashReproducerGenerator()
PassCrashReproducerGenerator(ReproducerStreamFactory &streamFactory, bool localReproducer)
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
std::unique_ptr< llvm::ToolOutputFile > openOutputFile(llvm::StringRef outputFilename, std::string *errorMessage=nullptr)
Open the file specified by its name for writing.
std::string makeReproducer(StringRef anchorName, const llvm::iterator_range< OpPassManager::pass_iterator > &passes, Operation *op, StringRef outputFile, bool disableThreads=false, bool verifyPasses=false)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::function< std::unique_ptr< ReproducerStream >(std::string &error)> ReproducerStreamFactory
Method type for constructing ReproducerStream.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
LogicalResult parsePassPipeline(StringRef pipeline, OpPassManager &pm, raw_ostream &errorStream=llvm::errs())
Parse the textual representation of a pass pipeline, adding the result to 'pm' on success.
bool pmFlagVerifyPasses
Various pass manager flags that get emitted when generating a reproducer.
ReproducerStreamFactory streamFactory
The factory to use when generating a crash reproducer.
SetVector< std::pair< Pass *, Operation * > > runningPasses
The set of all currently running passes.
bool localReproducer
Flag indicating if reproducer generation should be localized to the failing pass.
Impl(ReproducerStreamFactory &streamFactory, bool localReproducer)
SmallVector< std::unique_ptr< RecoveryReproducerContext > > activeContexts
A record of all of the currently active reproducer contexts.
void attachResourceParser(ParserConfig &config)
Attach an assembly resource parser to 'config' that collects the MLIR reproducer configuration into t...
LogicalResult apply(PassManager &pm) const
Apply the reproducer options to 'pm' and its context.
Streams on which to output crash reproducer.
This class contains all of the context for generating a recovery reproducer.
void disable()
Disable this reproducer context.
~RecoveryReproducerContext()
RecoveryReproducerContext(std::string passPipelineStr, Operation *op, ReproducerStreamFactory &streamFactory, bool verifyPasses)
void generate(std::string &description)
Generate a reproducer with the current context.
void enable()
Enable a previously disabled reproducer context.