MLIR 22.0.0git
PassCrashRecovery.cpp
Go to the documentation of this file.
1//===- PassCrashRecovery.cpp - Pass Crash Recovery Implementation ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "PassDetail.h"
10#include "mlir/IR/Diagnostics.h"
11#include "mlir/IR/SymbolTable.h"
12#include "mlir/IR/Verifier.h"
13#include "mlir/Parser/Parser.h"
14#include "mlir/Pass/Pass.h"
16#include "llvm/ADT/STLExtras.h"
17#include "llvm/ADT/SetVector.h"
18#include "llvm/Support/CrashRecoveryContext.h"
19#include "llvm/Support/ManagedStatic.h"
20#include "llvm/Support/Mutex.h"
21#include "llvm/Support/Signals.h"
22#include "llvm/Support/Threading.h"
23#include "llvm/Support/ToolOutputFile.h"
24
25using namespace mlir;
26using namespace mlir::detail;
27
28//===----------------------------------------------------------------------===//
29// RecoveryReproducerContext
30//===----------------------------------------------------------------------===//
31
32namespace mlir {
33namespace detail {
34/// This class contains all of the context for generating a recovery reproducer.
35/// Each recovery context is registered globally to allow for generating
36/// reproducers when a signal is raised, such as a segfault.
38 RecoveryReproducerContext(std::string passPipelineStr, Operation *op,
39 ReproducerStreamFactory &streamFactory,
40 bool verifyPasses);
42
43 /// Generate a reproducer with the current context.
44 void generate(std::string &description);
45
46 /// Disable this reproducer context. This prevents the context from generating
47 /// a reproducer in the result of a crash.
48 void disable();
49
50 /// Enable a previously disabled reproducer context.
51 void enable();
52
53private:
54 /// This function is invoked in the event of a crash.
55 static void crashHandler(void *);
56
57 /// Register a signal handler to run in the event of a crash.
58 static void registerSignalHandler();
59
60 /// The textual description of the currently executing pipeline.
61 std::string pipelineElements;
62
63 /// The MLIR operation representing the IR before the crash.
64 Operation *preCrashOperation;
65
66 /// The factory for the reproducer output stream to use when generating the
67 /// reproducer.
68 ReproducerStreamFactory &streamFactory;
69
70 /// Various pass manager and context flags.
71 bool disableThreads;
72 bool verifyPasses;
73
74 /// The current set of active reproducer contexts. This is used in the event
75 /// of a crash. This is not thread_local as the pass manager may produce any
76 /// number of child threads. This uses a set to allow for multiple MLIR pass
77 /// managers to be running at the same time.
78 static llvm::ManagedStatic<llvm::sys::SmartMutex<true>> reproducerMutex;
79 static llvm::ManagedStatic<
80 llvm::SmallSetVector<RecoveryReproducerContext *, 1>>
81 reproducerSet;
82};
83} // namespace detail
84} // namespace mlir
85
86llvm::ManagedStatic<llvm::sys::SmartMutex<true>>
87 RecoveryReproducerContext::reproducerMutex;
88llvm::ManagedStatic<llvm::SmallSetVector<RecoveryReproducerContext *, 1>>
89 RecoveryReproducerContext::reproducerSet;
90
92 std::string passPipelineStr, Operation *op,
93 ReproducerStreamFactory &streamFactory, bool verifyPasses)
94 : pipelineElements(std::move(passPipelineStr)),
95 preCrashOperation(op->clone()), streamFactory(streamFactory),
96 disableThreads(!op->getContext()->isMultithreadingEnabled()),
97 verifyPasses(verifyPasses) {
98 enable();
99}
100
102 // Erase the cloned preCrash IR that we cached.
103 preCrashOperation->erase();
104 disable();
105}
106
107static void appendReproducer(std::string &description, Operation *op,
108 const ReproducerStreamFactory &factory,
109 const std::string &pipelineElements,
110 bool disableThreads, bool verifyPasses) {
111 llvm::raw_string_ostream descOS(description);
112
113 // Try to create a new output stream for this crash reproducer.
114 std::string error;
115 std::unique_ptr<ReproducerStream> stream = factory(error);
116 if (!stream) {
117 descOS << "failed to create output stream: " << error;
118 return;
119 }
120 descOS << "reproducer generated at `" << stream->description() << "`";
121
122 std::string pipeline =
123 (op->getName().getStringRef() + "(" + pipelineElements + ")").str();
124 AsmState state(op);
126 "mlir_reproducer", [&](Operation *op, AsmResourceBuilder &builder) {
127 builder.buildString("pipeline", pipeline);
128 builder.buildBool("disable_threading", disableThreads);
129 builder.buildBool("verify_each", verifyPasses);
130 });
131
132 // Output the .mlir module.
133 op->print(stream->os(), state);
134}
135
136void RecoveryReproducerContext::generate(std::string &description) {
137 appendReproducer(description, preCrashOperation, streamFactory,
138 pipelineElements, disableThreads, verifyPasses);
139}
140
142 llvm::sys::SmartScopedLock<true> lock(*reproducerMutex);
143 reproducerSet->remove(this);
144 if (reproducerSet->empty())
145 llvm::CrashRecoveryContext::Disable();
146}
147
149 llvm::sys::SmartScopedLock<true> lock(*reproducerMutex);
150 if (reproducerSet->empty())
151 llvm::CrashRecoveryContext::Enable();
152 registerSignalHandler();
153 reproducerSet->insert(this);
154}
155
156void RecoveryReproducerContext::crashHandler(void *) {
157 // Walk the current stack of contexts and generate a reproducer for each one.
158 // We can't know for certain which one was the cause, so we need to generate
159 // a reproducer for all of them.
160 for (RecoveryReproducerContext *context : *reproducerSet) {
161 std::string description;
162 context->generate(description);
163
164 // Emit an error using information only available within the context.
165 emitError(context->preCrashOperation->getLoc())
166 << "A signal was caught while processing the MLIR module:"
167 << description << "; marking pass as failed";
168 }
169}
170
171void RecoveryReproducerContext::registerSignalHandler() {
172 // Ensure that the handler is only registered once.
173 static bool registered =
174 (llvm::sys::AddSignalHandler(crashHandler, nullptr), false);
175 (void)registered;
176}
177
178//===----------------------------------------------------------------------===//
179// PassCrashReproducerGenerator
180//===----------------------------------------------------------------------===//
181
185
186 /// The factory to use when generating a crash reproducer.
188
189 /// Flag indicating if reproducer generation should be localized to the
190 /// failing pass.
191 bool localReproducer = false;
192
193 /// A record of all of the currently active reproducer contexts.
195
196 /// The set of all currently running passes. Note: This is not populated when
197 /// `localReproducer` is true, as each pass will get its own recovery context.
199
200 /// Various pass manager flags that get emitted when generating a reproducer.
201 bool pmFlagVerifyPasses = false;
202};
203
205 ReproducerStreamFactory &streamFactory, bool localReproducer)
206 : impl(std::make_unique<Impl>(streamFactory, localReproducer)) {}
208
211 bool pmFlagVerifyPasses) {
212 assert((!impl->localReproducer ||
214 "expected multi-threading to be disabled when generating a local "
215 "reproducer");
216
217 llvm::CrashRecoveryContext::Enable();
218 impl->pmFlagVerifyPasses = pmFlagVerifyPasses;
219
220 // If we aren't generating a local reproducer, prepare a reproducer for the
221 // given top-level operation.
222 if (!impl->localReproducer)
223 prepareReproducerFor(passes, op);
224}
225
226static void
228 std::pair<Pass *, Operation *> passOpPair) {
229 os << "`" << passOpPair.first->getName() << "` on "
230 << "'" << passOpPair.second->getName() << "' operation";
231 if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(passOpPair.second))
232 os << ": @" << symbol.getName();
233}
234
236 LogicalResult executionResult) {
237 // Don't generate a reproducer if we have no active contexts.
238 if (impl->activeContexts.empty())
239 return;
240
241 // If the pass manager execution succeeded, we don't generate any reproducers.
242 if (succeeded(executionResult))
243 return impl->activeContexts.clear();
244
246 << "Failures have been detected while "
247 "processing an MLIR pass pipeline";
248
249 // If we are generating a global reproducer, we include all of the running
250 // passes in the error message for the only active context.
251 if (!impl->localReproducer) {
252 assert(impl->activeContexts.size() == 1 && "expected one active context");
253
254 // Generate the reproducer.
255 std::string description;
256 impl->activeContexts.front()->generate(description);
257
258 // Emit an error to the user.
259 Diagnostic &note = diag.attachNote() << "Pipeline failed while executing [";
260 llvm::interleaveComma(impl->runningPasses, note,
261 [&](const std::pair<Pass *, Operation *> &value) {
262 formatPassOpReproducerMessage(note, value);
263 });
264 note << "]: " << description;
265 impl->runningPasses.clear();
266 impl->activeContexts.clear();
267 return;
268 }
269
270 // If we were generating a local reproducer, we generate a reproducer for the
271 // most recently executing pass using the matching entry from `runningPasses`
272 // to generate a localized diagnostic message.
273 assert(impl->activeContexts.size() == impl->runningPasses.size() &&
274 "expected running passes to match active contexts");
275
276 // Generate the reproducer.
277 RecoveryReproducerContext &reproducerContext = *impl->activeContexts.back();
278 std::string description;
279 reproducerContext.generate(description);
280
281 // Emit an error to the user.
282 Diagnostic &note = diag.attachNote() << "Pipeline failed while executing ";
283 formatPassOpReproducerMessage(note, impl->runningPasses.back());
284 note << ": " << description;
285
286 impl->activeContexts.clear();
287 impl->runningPasses.clear();
288}
289
291 Operation *op) {
292 // If not tracking local reproducers, we simply remember that this pass is
293 // running.
294 impl->runningPasses.insert(std::make_pair(pass, op));
295 if (!impl->localReproducer)
296 return;
297
298 // Disable the current pass recovery context, if there is one. This may happen
299 // in the case of dynamic pass pipelines.
300 if (!impl->activeContexts.empty())
301 impl->activeContexts.back()->disable();
302
303 // Collect all of the parent scopes of this operation.
305 while (Operation *parentOp = op->getParentOp()) {
306 scopes.push_back(op->getName());
307 op = parentOp;
308 }
309
310 // Emit a pass pipeline string for the current pass running on the current
311 // operation type.
312 std::string passStr;
313 llvm::raw_string_ostream passOS(passStr);
314 for (OperationName scope : llvm::reverse(scopes))
315 passOS << scope << "(";
316 pass->printAsTextualPipeline(passOS);
317 for (unsigned i = 0, e = scopes.size(); i < e; ++i)
318 passOS << ")";
319
320 impl->activeContexts.push_back(std::make_unique<RecoveryReproducerContext>(
321 passStr, op, impl->streamFactory, impl->pmFlagVerifyPasses));
322}
325 std::string passStr;
326 llvm::raw_string_ostream passOS(passStr);
327 llvm::interleaveComma(
328 passes, passOS, [&](Pass &pass) { pass.printAsTextualPipeline(passOS); });
329
330 impl->activeContexts.push_back(std::make_unique<RecoveryReproducerContext>(
331 passStr, op, impl->streamFactory, impl->pmFlagVerifyPasses));
332}
333
335 Operation *op) {
336 // We only pop the active context if we are tracking local reproducers.
337 impl->runningPasses.remove(std::make_pair(pass, op));
338 if (impl->localReproducer) {
339 impl->activeContexts.pop_back();
340
341 // Re-enable the previous pass recovery context, if there was one. This may
342 // happen in the case of dynamic pass pipelines.
343 if (!impl->activeContexts.empty())
344 impl->activeContexts.back()->enable();
345 }
346}
347
348//===----------------------------------------------------------------------===//
349// CrashReproducerInstrumentation
350//===----------------------------------------------------------------------===//
351
352namespace {
353struct CrashReproducerInstrumentation : public PassInstrumentation {
354 CrashReproducerInstrumentation(PassCrashReproducerGenerator &generator)
355 : generator(generator) {}
356 ~CrashReproducerInstrumentation() override = default;
357
358 void runBeforePass(Pass *pass, Operation *op) override {
359 if (!isa<OpToOpPassAdaptor>(pass))
360 generator.prepareReproducerFor(pass, op);
361 }
362
363 void runAfterPass(Pass *pass, Operation *op) override {
364 if (!isa<OpToOpPassAdaptor>(pass))
365 generator.removeLastReproducerFor(pass, op);
366 }
367
368 void runAfterPassFailed(Pass *pass, Operation *op) override {
369 // Only generate one reproducer per crash reproducer instrumentation.
370 if (alreadyFailed)
371 return;
372
373 alreadyFailed = true;
374 generator.finalize(op, /*executionResult=*/failure());
375 }
376
377private:
378 /// The generator used to create crash reproducers.
379 PassCrashReproducerGenerator &generator;
380 bool alreadyFailed = false;
381};
382} // namespace
383
384//===----------------------------------------------------------------------===//
385// FileReproducerStream
386//===----------------------------------------------------------------------===//
387
388namespace {
389/// This class represents a default instance of mlir::ReproducerStream
390/// that is backed by a file.
391struct FileReproducerStream : public mlir::ReproducerStream {
392 FileReproducerStream(std::unique_ptr<llvm::ToolOutputFile> outputFile)
393 : outputFile(std::move(outputFile)) {}
394 ~FileReproducerStream() override { outputFile->keep(); }
395
396 /// Returns a description of the reproducer stream.
397 StringRef description() override { return outputFile->getFilename(); }
398
399 /// Returns the stream on which to output the reproducer.
400 raw_ostream &os() override { return outputFile->os(); }
401
402private:
403 /// ToolOutputFile corresponding to opened `filename`.
404 std::unique_ptr<llvm::ToolOutputFile> outputFile = nullptr;
405};
406} // namespace
407
408//===----------------------------------------------------------------------===//
409// PassManager
410//===----------------------------------------------------------------------===//
411
412LogicalResult PassManager::runWithCrashRecovery(Operation *op,
413 AnalysisManager am) {
414 const bool threadingEnabled = getContext()->isMultithreadingEnabled();
415 crashReproGenerator->initialize(getPasses(), op, verifyPasses);
416
417 // Safely invoke the passes within a recovery context.
418 LogicalResult passManagerResult = failure();
419 llvm::CrashRecoveryContext recoveryContext;
420 const auto runPassesFn = [&] { passManagerResult = runPasses(op, am); };
421 if (threadingEnabled)
422 recoveryContext.RunSafelyOnThread(runPassesFn);
423 else
424 recoveryContext.RunSafely(runPassesFn);
425 crashReproGenerator->finalize(op, passManagerResult);
426
427 return passManagerResult;
428}
429
431makeReproducerStreamFactory(StringRef outputFile) {
432 // Capture the filename by value in case outputFile is out of scope when
433 // invoked.
434 std::string filename = outputFile.str();
435 return [filename](std::string &error) -> std::unique_ptr<ReproducerStream> {
436 std::unique_ptr<llvm::ToolOutputFile> outputFile =
437 mlir::openOutputFile(filename, &error);
438 if (!outputFile) {
439 error = "Failed to create reproducer stream: " + error;
440 return nullptr;
441 }
442 return std::make_unique<FileReproducerStream>(std::move(outputFile));
443 };
444}
445
447 raw_ostream &os, StringRef anchorName,
448 const llvm::iterator_range<OpPassManager::pass_iterator> &passes,
449 bool pretty = false);
450
452 StringRef anchorName,
454 Operation *op, StringRef outputFile, bool disableThreads,
455 bool verifyPasses) {
456
457 std::string description;
458 std::string pipelineStr;
459 llvm::raw_string_ostream passOS(pipelineStr);
460 ::printAsTextualPipeline(passOS, anchorName, passes);
461 appendReproducer(description, op, makeReproducerStreamFactory(outputFile),
462 pipelineStr, disableThreads, verifyPasses);
463 return description;
464}
465
467 bool genLocalReproducer) {
469 genLocalReproducer);
470}
471
473 ReproducerStreamFactory factory, bool genLocalReproducer) {
474 assert(!crashReproGenerator &&
475 "crash reproducer has already been initialized");
476 if (genLocalReproducer && getContext()->isMultithreadingEnabled())
477 llvm::report_fatal_error(
478 "Local crash reproduction can't be setup on a "
479 "pass-manager without disabling multi-threading first.");
480
481 crashReproGenerator = std::make_unique<PassCrashReproducerGenerator>(
482 factory, genLocalReproducer);
484 std::make_unique<CrashReproducerInstrumentation>(*crashReproGenerator));
485}
486
487//===----------------------------------------------------------------------===//
488// Asm Resource
489//===----------------------------------------------------------------------===//
490
492 auto parseFn = [this](AsmParsedResourceEntry &entry) -> LogicalResult {
493 if (entry.getKey() == "pipeline") {
494 FailureOr<std::string> value = entry.parseAsString();
495 if (succeeded(value))
496 this->pipeline = std::move(*value);
497 return value;
498 }
499 if (entry.getKey() == "disable_threading") {
500 FailureOr<bool> value = entry.parseAsBool();
501 if (succeeded(value))
502 this->disableThreading = *value;
503 return value;
504 }
505 if (entry.getKey() == "verify_each") {
506 FailureOr<bool> value = entry.parseAsBool();
507 if (succeeded(value))
508 this->verifyEach = *value;
509 return value;
510 }
511 return entry.emitError() << "unknown 'mlir_reproducer' resource key '"
512 << entry.getKey() << "'";
513 };
514 config.attachResourceParser("mlir_reproducer", parseFn);
515}
516
518 if (pipeline.has_value()) {
519 FailureOr<OpPassManager> reproPm = parsePassPipeline(*pipeline);
520 if (failed(reproPm))
521 return failure();
522 static_cast<OpPassManager &>(pm) = std::move(*reproPm);
523 }
524
525 if (disableThreading.has_value())
526 pm.getContext()->disableMultithreading(*disableThreading);
527
528 if (verifyEach.has_value())
529 pm.enableVerifier(*verifyEach);
530
531 return success();
532}
return success()
b getContext())
static const mlir::GenInfo * generator
static std::string diag(const llvm::Value &value)
static void appendReproducer(std::string &description, Operation *op, const ReproducerStreamFactory &factory, const std::string &pipelineElements, bool disableThreads, bool verifyPasses)
static void formatPassOpReproducerMessage(Diagnostic &os, std::pair< Pass *, Operation * > passOpPair)
static ReproducerStreamFactory makeReproducerStreamFactory(StringRef outputFile)
void printAsTextualPipeline(raw_indented_ostream &os, StringRef anchorName, const llvm::iterator_range< OpPassManager::pass_iterator > &passes, bool pretty=false)
Prints out the passes of the pass manager as the textual representation of pipelines.
Definition Pass.cpp:422
This class represents a single parsed resource entry.
Definition AsmState.h:291
This class is used to build resource entries for use by the printer.
Definition AsmState.h:247
virtual void buildString(StringRef key, StringRef data)=0
Build a resource entry represented by the given human-readable string value.
virtual void buildBool(StringRef key, bool data)=0
Build a resource entry represented by the given bool.
This class provides management for the lifetime of the state used when printing the IR.
Definition AsmState.h:542
void attachResourcePrinter(std::unique_ptr< AsmResourcePrinter > printer)
Attach the given resource printer to the AsmState.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
This class represents a diagnostic that is inflight and set to be reported.
void disableMultithreading(bool disable=true)
Set the flag specifying if multi-threading is disabled by the context.
bool isMultithreadingEnabled()
Return true if multi-threading is enabled by the context.
This class represents a pass manager that runs passes on either a specific operation type,...
Definition PassManager.h:46
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
void print(raw_ostream &os, const OpPrintingFlags &flags={})
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
This class represents a configuration for the MLIR assembly parser.
Definition AsmState.h:469
PassInstrumentation provides several entry points into the pass manager infrastructure.
The main pass manager and pipeline builder.
MLIRContext * getContext() const
Return an instance of the context.
void addInstrumentation(std::unique_ptr< PassInstrumentation > pi)
Add the provided instrumentation to the pass manager.
Definition Pass.cpp:1108
void enableCrashReproducerGeneration(StringRef outputFile, bool genLocalReproducer=false)
Enable support for the pass manager to generate a reproducer on the event of a crash or a pass failur...
void enableVerifier(bool enabled=true)
Runs the verifier after each individual pass.
Definition Pass.cpp:1026
The abstract base pass class.
Definition Pass.h:51
void printAsTextualPipeline(raw_ostream &os, bool pretty=false)
Prints out the pass in the textual representation of pipelines.
Definition Pass.cpp:85
void initialize(iterator_range< PassManager::pass_iterator > passes, Operation *op, bool pmFlagVerifyPasses)
Initialize the generator in preparation for reproducer generation.
void removeLastReproducerFor(Pass *pass, Operation *op)
Remove the last recorded reproducer anchored at the given pass and operation.
void finalize(Operation *rootOp, LogicalResult executionResult)
Finalize the current run of the generator, generating any necessary reproducers if the provided execu...
void prepareReproducerFor(Pass *pass, Operation *op)
Prepare a new reproducer for the given pass, operating on op.
PassCrashReproducerGenerator(ReproducerStreamFactory &streamFactory, bool localReproducer)
AttrTypeReplacer.
Include the generated interface declarations.
std::unique_ptr< llvm::ToolOutputFile > openOutputFile(llvm::StringRef outputFilename, std::string *errorMessage=nullptr)
Open the file specified by its name for writing.
const FrozenRewritePatternSet GreedyRewriteConfig config
std::string makeReproducer(StringRef anchorName, const llvm::iterator_range< OpPassManager::pass_iterator > &passes, Operation *op, StringRef outputFile, bool disableThreads=false, bool verifyPasses=false)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
std::function< std::unique_ptr< ReproducerStream >(std::string &error)> ReproducerStreamFactory
Method type for constructing ReproducerStream.
LogicalResult parsePassPipeline(StringRef pipeline, OpPassManager &pm, raw_ostream &errorStream=llvm::errs())
Parse the textual representation of a pass pipeline, adding the result to 'pm' on success.
bool pmFlagVerifyPasses
Various pass manager flags that get emitted when generating a reproducer.
ReproducerStreamFactory streamFactory
The factory to use when generating a crash reproducer.
SetVector< std::pair< Pass *, Operation * > > runningPasses
The set of all currently running passes.
bool localReproducer
Flag indicating if reproducer generation should be localized to the failing pass.
Impl(ReproducerStreamFactory &streamFactory, bool localReproducer)
SmallVector< std::unique_ptr< RecoveryReproducerContext > > activeContexts
A record of all of the currently active reproducer contexts.
void attachResourceParser(ParserConfig &config)
Attach an assembly resource parser to 'config' that collects the MLIR reproducer configuration into t...
LogicalResult apply(PassManager &pm) const
Apply the reproducer options to 'pm' and its context.
This class contains all of the context for generating a recovery reproducer.
void disable()
Disable this reproducer context.
RecoveryReproducerContext(std::string passPipelineStr, Operation *op, ReproducerStreamFactory &streamFactory, bool verifyPasses)
void generate(std::string &description)
Generate a reproducer with the current context.
void enable()
Enable a previously disabled reproducer context.