MLIR  16.0.0git
PassCrashRecovery.cpp
Go to the documentation of this file.
1 //===- PassCrashRecovery.cpp - Pass Crash Recovery Implementation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 #include "mlir/IR/Diagnostics.h"
11 #include "mlir/IR/Dialect.h"
12 #include "mlir/IR/SymbolTable.h"
13 #include "mlir/IR/Verifier.h"
14 #include "mlir/Parser/Parser.h"
15 #include "mlir/Pass/Pass.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/ScopeExit.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/Support/CommandLine.h"
21 #include "llvm/Support/CrashRecoveryContext.h"
22 #include "llvm/Support/Mutex.h"
23 #include "llvm/Support/Signals.h"
24 #include "llvm/Support/Threading.h"
25 #include "llvm/Support/ToolOutputFile.h"
26 
27 using namespace mlir;
28 using namespace mlir::detail;
29 
30 //===----------------------------------------------------------------------===//
31 // RecoveryReproducerContext
32 //===----------------------------------------------------------------------===//
33 
34 namespace mlir {
35 namespace detail {
36 /// This class contains all of the context for generating a recovery reproducer.
37 /// Each recovery context is registered globally to allow for generating
38 /// reproducers when a signal is raised, such as a segfault.
40  RecoveryReproducerContext(std::string passPipelineStr, Operation *op,
42  bool verifyPasses);
44 
45  /// Generate a reproducer with the current context.
46  void generate(std::string &description);
47 
48  /// Disable this reproducer context. This prevents the context from generating
49  /// a reproducer in the result of a crash.
50  void disable();
51 
52  /// Enable a previously disabled reproducer context.
53  void enable();
54 
55 private:
56  /// This function is invoked in the event of a crash.
57  static void crashHandler(void *);
58 
59  /// Register a signal handler to run in the event of a crash.
60  static void registerSignalHandler();
61 
62  /// The textual description of the currently executing pipeline.
63  std::string pipelineElements;
64 
65  /// The MLIR operation representing the IR before the crash.
66  Operation *preCrashOperation;
67 
68  /// The factory for the reproducer output stream to use when generating the
69  /// reproducer.
71 
72  /// Various pass manager and context flags.
73  bool disableThreads;
74  bool verifyPasses;
75 
76  /// The current set of active reproducer contexts. This is used in the event
77  /// of a crash. This is not thread_local as the pass manager may produce any
78  /// number of child threads. This uses a set to allow for multiple MLIR pass
79  /// managers to be running at the same time.
80  static llvm::ManagedStatic<llvm::sys::SmartMutex<true>> reproducerMutex;
81  static llvm::ManagedStatic<
82  llvm::SmallSetVector<RecoveryReproducerContext *, 1>>
83  reproducerSet;
84 };
85 } // namespace detail
86 } // namespace mlir
87 
88 llvm::ManagedStatic<llvm::sys::SmartMutex<true>>
89  RecoveryReproducerContext::reproducerMutex;
90 llvm::ManagedStatic<llvm::SmallSetVector<RecoveryReproducerContext *, 1>>
91  RecoveryReproducerContext::reproducerSet;
92 
94  std::string passPipelineStr, Operation *op,
95  PassManager::ReproducerStreamFactory &streamFactory, bool verifyPasses)
96  : pipelineElements(std::move(passPipelineStr)),
97  preCrashOperation(op->clone()), streamFactory(streamFactory),
98  disableThreads(!op->getContext()->isMultithreadingEnabled()),
99  verifyPasses(verifyPasses) {
100  enable();
101 }
102 
104  // Erase the cloned preCrash IR that we cached.
105  preCrashOperation->erase();
106  disable();
107 }
108 
109 void RecoveryReproducerContext::generate(std::string &description) {
110  llvm::raw_string_ostream descOS(description);
111 
112  // Try to create a new output stream for this crash reproducer.
113  std::string error;
114  std::unique_ptr<PassManager::ReproducerStream> stream = streamFactory(error);
115  if (!stream) {
116  descOS << "failed to create output stream: " << error;
117  return;
118  }
119  descOS << "reproducer generated at `" << stream->description() << "`";
120 
121  std::string pipeline = (preCrashOperation->getName().getStringRef() + "(" +
122  pipelineElements + ")")
123  .str();
124  AsmState state(preCrashOperation);
125  state.attachResourcePrinter(
126  "mlir_reproducer", [&](Operation *op, AsmResourceBuilder &builder) {
127  builder.buildString("pipeline", pipeline);
128  builder.buildBool("disable_threading", disableThreads);
129  builder.buildBool("verify_each", verifyPasses);
130  });
131 
132  // Output the .mlir module.
133  preCrashOperation->print(stream->os(), state);
134 }
135 
137  llvm::sys::SmartScopedLock<true> lock(*reproducerMutex);
138  reproducerSet->remove(this);
139  if (reproducerSet->empty())
140  llvm::CrashRecoveryContext::Disable();
141 }
142 
144  llvm::sys::SmartScopedLock<true> lock(*reproducerMutex);
145  if (reproducerSet->empty())
146  llvm::CrashRecoveryContext::Enable();
147  registerSignalHandler();
148  reproducerSet->insert(this);
149 }
150 
151 void RecoveryReproducerContext::crashHandler(void *) {
152  // Walk the current stack of contexts and generate a reproducer for each one.
153  // We can't know for certain which one was the cause, so we need to generate
154  // a reproducer for all of them.
155  for (RecoveryReproducerContext *context : *reproducerSet) {
156  std::string description;
157  context->generate(description);
158 
159  // Emit an error using information only available within the context.
160  emitError(context->preCrashOperation->getLoc())
161  << "A failure has been detected while processing the MLIR module:"
162  << description;
163  }
164 }
165 
166 void RecoveryReproducerContext::registerSignalHandler() {
167  // Ensure that the handler is only registered once.
168  static bool registered =
169  (llvm::sys::AddSignalHandler(crashHandler, nullptr), false);
170  (void)registered;
171 }
172 
173 //===----------------------------------------------------------------------===//
174 // PassCrashReproducerGenerator
175 //===----------------------------------------------------------------------===//
176 
179  bool localReproducer)
181 
182  /// The factory to use when generating a crash reproducer.
184 
185  /// Flag indicating if reproducer generation should be localized to the
186  /// failing pass.
187  bool localReproducer = false;
188 
189  /// A record of all of the currently active reproducer contexts.
191 
192  /// The set of all currently running passes. Note: This is not populated when
193  /// `localReproducer` is true, as each pass will get its own recovery context.
195 
196  /// Various pass manager flags that get emitted when generating a reproducer.
197  bool pmFlagVerifyPasses = false;
198 };
199 
201  PassManager::ReproducerStreamFactory &streamFactory, bool localReproducer)
202  : impl(std::make_unique<Impl>(streamFactory, localReproducer)) {}
204 
207  bool pmFlagVerifyPasses) {
208  assert((!impl->localReproducer ||
209  !op->getContext()->isMultithreadingEnabled()) &&
210  "expected multi-threading to be disabled when generating a local "
211  "reproducer");
212 
213  llvm::CrashRecoveryContext::Enable();
214  impl->pmFlagVerifyPasses = pmFlagVerifyPasses;
215 
216  // If we aren't generating a local reproducer, prepare a reproducer for the
217  // given top-level operation.
218  if (!impl->localReproducer)
219  prepareReproducerFor(passes, op);
220 }
221 
222 static void
224  std::pair<Pass *, Operation *> passOpPair) {
225  os << "`" << passOpPair.first->getName() << "` on "
226  << "'" << passOpPair.second->getName() << "' operation";
227  if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(passOpPair.second))
228  os << ": @" << symbol.getName();
229 }
230 
232  LogicalResult executionResult) {
233  // Don't generate a reproducer if we have no active contexts.
234  if (impl->activeContexts.empty())
235  return;
236 
237  // If the pass manager execution succeeded, we don't generate any reproducers.
238  if (succeeded(executionResult))
239  return impl->activeContexts.clear();
240 
242  << "Failures have been detected while "
243  "processing an MLIR pass pipeline";
244 
245  // If we are generating a global reproducer, we include all of the running
246  // passes in the error message for the only active context.
247  if (!impl->localReproducer) {
248  assert(impl->activeContexts.size() == 1 && "expected one active context");
249 
250  // Generate the reproducer.
251  std::string description;
252  impl->activeContexts.front()->generate(description);
253 
254  // Emit an error to the user.
255  Diagnostic &note = diag.attachNote() << "Pipeline failed while executing [";
256  llvm::interleaveComma(impl->runningPasses, note,
257  [&](const std::pair<Pass *, Operation *> &value) {
258  formatPassOpReproducerMessage(note, value);
259  });
260  note << "]: " << description;
261  return;
262  }
263 
264  // If we were generating a local reproducer, we generate a reproducer for the
265  // most recently executing pass using the matching entry from `runningPasses`
266  // to generate a localized diagnostic message.
267  assert(impl->activeContexts.size() == impl->runningPasses.size() &&
268  "expected running passes to match active contexts");
269 
270  // Generate the reproducer.
271  RecoveryReproducerContext &reproducerContext = *impl->activeContexts.back();
272  std::string description;
273  reproducerContext.generate(description);
274 
275  // Emit an error to the user.
276  Diagnostic &note = diag.attachNote() << "Pipeline failed while executing ";
277  formatPassOpReproducerMessage(note, impl->runningPasses.back());
278  note << ": " << description;
279 
280  impl->activeContexts.clear();
281 }
282 
284  Operation *op) {
285  // If not tracking local reproducers, we simply remember that this pass is
286  // running.
287  impl->runningPasses.insert(std::make_pair(pass, op));
288  if (!impl->localReproducer)
289  return;
290 
291  // Disable the current pass recovery context, if there is one. This may happen
292  // in the case of dynamic pass pipelines.
293  if (!impl->activeContexts.empty())
294  impl->activeContexts.back()->disable();
295 
296  // Collect all of the parent scopes of this operation.
298  while (Operation *parentOp = op->getParentOp()) {
299  scopes.push_back(op->getName());
300  op = parentOp;
301  }
302 
303  // Emit a pass pipeline string for the current pass running on the current
304  // operation type.
305  std::string passStr;
306  llvm::raw_string_ostream passOS(passStr);
307  for (OperationName scope : llvm::reverse(scopes))
308  passOS << scope << "(";
309  pass->printAsTextualPipeline(passOS);
310  for (unsigned i = 0, e = scopes.size(); i < e; ++i)
311  passOS << ")";
312 
313  impl->activeContexts.push_back(std::make_unique<RecoveryReproducerContext>(
314  passOS.str(), op, impl->streamFactory, impl->pmFlagVerifyPasses));
315 }
318  std::string passStr;
319  llvm::raw_string_ostream passOS(passStr);
320  llvm::interleaveComma(
321  passes, passOS, [&](Pass &pass) { pass.printAsTextualPipeline(passOS); });
322 
323  impl->activeContexts.push_back(std::make_unique<RecoveryReproducerContext>(
324  passOS.str(), op, impl->streamFactory, impl->pmFlagVerifyPasses));
325 }
326 
328  Operation *op) {
329  // We only pop the active context if we are tracking local reproducers.
330  impl->runningPasses.remove(std::make_pair(pass, op));
331  if (impl->localReproducer) {
332  impl->activeContexts.pop_back();
333 
334  // Re-enable the previous pass recovery context, if there was one. This may
335  // happen in the case of dynamic pass pipelines.
336  if (!impl->activeContexts.empty())
337  impl->activeContexts.back()->enable();
338  }
339 }
340 
341 //===----------------------------------------------------------------------===//
342 // CrashReproducerInstrumentation
343 //===----------------------------------------------------------------------===//
344 
345 namespace {
346 struct CrashReproducerInstrumentation : public PassInstrumentation {
347  CrashReproducerInstrumentation(PassCrashReproducerGenerator &generator)
348  : generator(generator) {}
349  ~CrashReproducerInstrumentation() override = default;
350 
351  void runBeforePass(Pass *pass, Operation *op) override {
352  if (!isa<OpToOpPassAdaptor>(pass))
353  generator.prepareReproducerFor(pass, op);
354  }
355 
356  void runAfterPass(Pass *pass, Operation *op) override {
357  if (!isa<OpToOpPassAdaptor>(pass))
358  generator.removeLastReproducerFor(pass, op);
359  }
360 
361  void runAfterPassFailed(Pass *pass, Operation *op) override {
362  generator.finalize(op, /*executionResult=*/failure());
363  }
364 
365 private:
366  /// The generator used to create crash reproducers.
368 };
369 } // namespace
370 
371 //===----------------------------------------------------------------------===//
372 // FileReproducerStream
373 //===----------------------------------------------------------------------===//
374 
375 namespace {
376 /// This class represents a default instance of PassManager::ReproducerStream
377 /// that is backed by a file.
378 struct FileReproducerStream : public PassManager::ReproducerStream {
379  FileReproducerStream(std::unique_ptr<llvm::ToolOutputFile> outputFile)
380  : outputFile(std::move(outputFile)) {}
381  ~FileReproducerStream() override { outputFile->keep(); }
382 
383  /// Returns a description of the reproducer stream.
384  StringRef description() override { return outputFile->getFilename(); }
385 
386  /// Returns the stream on which to output the reproducer.
387  raw_ostream &os() override { return outputFile->os(); }
388 
389 private:
390  /// ToolOutputFile corresponding to opened `filename`.
391  std::unique_ptr<llvm::ToolOutputFile> outputFile = nullptr;
392 };
393 } // namespace
394 
395 //===----------------------------------------------------------------------===//
396 // PassManager
397 //===----------------------------------------------------------------------===//
398 
399 LogicalResult PassManager::runWithCrashRecovery(Operation *op,
400  AnalysisManager am) {
401  crashReproGenerator->initialize(getPasses(), op, verifyPasses);
402 
403  // Safely invoke the passes within a recovery context.
404  LogicalResult passManagerResult = failure();
405  llvm::CrashRecoveryContext recoveryContext;
406  recoveryContext.RunSafelyOnThread(
407  [&] { passManagerResult = runPasses(op, am); });
408  crashReproGenerator->finalize(op, passManagerResult);
409  return passManagerResult;
410 }
411 
413  bool genLocalReproducer) {
414  // Capture the filename by value in case outputFile is out of scope when
415  // invoked.
416  std::string filename = outputFile.str();
417  enableCrashReproducerGeneration(
418  [filename](std::string &error) -> std::unique_ptr<ReproducerStream> {
419  std::unique_ptr<llvm::ToolOutputFile> outputFile =
420  mlir::openOutputFile(filename, &error);
421  if (!outputFile) {
422  error = "Failed to create reproducer stream: " + error;
423  return nullptr;
424  }
425  return std::make_unique<FileReproducerStream>(std::move(outputFile));
426  },
427  genLocalReproducer);
428 }
429 
431  ReproducerStreamFactory factory, bool genLocalReproducer) {
432  assert(!crashReproGenerator &&
433  "crash reproducer has already been initialized");
434  if (genLocalReproducer && getContext()->isMultithreadingEnabled())
435  llvm::report_fatal_error(
436  "Local crash reproduction can't be setup on a "
437  "pass-manager without disabling multi-threading first.");
438 
439  crashReproGenerator = std::make_unique<PassCrashReproducerGenerator>(
440  factory, genLocalReproducer);
441  addInstrumentation(
442  std::make_unique<CrashReproducerInstrumentation>(*crashReproGenerator));
443 }
444 
445 //===----------------------------------------------------------------------===//
446 // Asm Resource
447 //===----------------------------------------------------------------------===//
448 
450  auto parseFn = [this](AsmParsedResourceEntry &entry) -> LogicalResult {
451  if (entry.getKey() == "pipeline") {
452  FailureOr<std::string> value = entry.parseAsString();
453  if (succeeded(value))
454  this->pipeline = std::move(*value);
455  return value;
456  }
457  if (entry.getKey() == "disable_threading") {
458  FailureOr<bool> value = entry.parseAsBool();
459  if (succeeded(value))
460  this->disableThreading = *value;
461  return value;
462  }
463  if (entry.getKey() == "verify_each") {
464  FailureOr<bool> value = entry.parseAsBool();
465  if (succeeded(value))
466  this->verifyEach = *value;
467  return value;
468  }
469  return entry.emitError() << "unknown 'mlir_reproducer' resource key '"
470  << entry.getKey() << "'";
471  };
472  config.attachResourceParser("mlir_reproducer", parseFn);
473 }
474 
476  if (pipeline.has_value()) {
477  FailureOr<OpPassManager> reproPm = parsePassPipeline(*pipeline);
478  if (failed(reproPm))
479  return failure();
480  static_cast<OpPassManager &>(pm) = std::move(*reproPm);
481  }
482 
483  if (disableThreading.has_value())
484  pm.getContext()->disableMultithreading(*disableThreading);
485 
486  if (verifyEach.has_value())
487  pm.enableVerifier(*verifyEach);
488 
489  return success();
490 }
static std::string diag(llvm::Value &value)
static constexpr const bool value
static const mlir::GenInfo * generator
static void formatPassOpReproducerMessage(Diagnostic &os, std::pair< Pass *, Operation * > passOpPair)
This class represents an analysis manager for a particular operation instance.
This class represents a single parsed resource entry.
Definition: AsmState.h:280
This class is used to build resource entries for use by the printer.
Definition: AsmState.h:236
virtual void buildString(StringRef key, StringRef data)=0
Build a resource entry represented by the given human-readable string value.
virtual void buildBool(StringRef key, bool data)=0
Build a resource entry represented by the given bool.
This class provides management for the lifetime of the state used when printing the IR.
Definition: AsmState.h:524
void attachResourcePrinter(std::unique_ptr< AsmResourcePrinter > printer)
Attach the given resource printer to the AsmState.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:307
void disableMultithreading(bool disable=true)
Set the flag specifying if multi-threading is disabled by the context.
bool isMultithreadingEnabled()
Return true if multi-threading is enabled by the context.
This class represents a pass manager that runs passes on either a specific operation type,...
Definition: PassManager.h:52
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:31
void print(raw_ostream &os, const OpPrintingFlags &flags=std::nullopt)
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:147
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:154
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:165
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:50
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:418
This class represents a configuration for the MLIR assembly parser.
Definition: AsmState.h:457
void attachResourceParser(std::unique_ptr< AsmResourceParser > parser)
Attach the given resource parser.
Definition: AsmState.h:488
PassInstrumentation provides several entry points into the pass manager infrastructure.
The main pass manager and pipeline builder.
Definition: PassManager.h:211
MLIRContext * getContext() const
Return an instance of the context.
Definition: PassManager.h:230
std::function< std::unique_ptr< ReproducerStream >(std::string &error)> ReproducerStreamFactory
Method type for constructing ReproducerStream.
Definition: PassManager.h:253
void enableCrashReproducerGeneration(StringRef outputFile, bool genLocalReproducer=false)
Enable support for the pass manager to generate a reproducer on the event of a crash or a pass failur...
void enableVerifier(bool enabled=true)
Runs the verifier after each individual pass.
Definition: Pass.cpp:775
The abstract base pass class.
Definition: Pass.h:50
void printAsTextualPipeline(raw_ostream &os)
Prints out the pass in the textual representation of pipelines.
Definition: Pass.cpp:54
void initialize(iterator_range< PassManager::pass_iterator > passes, Operation *op, bool pmFlagVerifyPasses)
Initialize the generator in preparation for reproducer generation.
void removeLastReproducerFor(Pass *pass, Operation *op)
Remove the last recorded reproducer anchored at the given pass and operation.
void finalize(Operation *rootOp, LogicalResult executionResult)
Finalize the current run of the generator, generating any necessary reproducers if the provided execu...
PassCrashReproducerGenerator(PassManager::ReproducerStreamFactory &streamFactory, bool localReproducer)
void prepareReproducerFor(Pass *pass, Operation *op)
Prepare a new reproducer for the given pass, operating on op.
Detect if any of the given parameter types has a sub-element handler.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
std::unique_ptr< llvm::ToolOutputFile > openOutputFile(llvm::StringRef outputFilename, std::string *errorMessage=nullptr)
Open the file specified by its name for writing.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
LogicalResult parsePassPipeline(StringRef pipeline, OpPassManager &pm, raw_ostream &errorStream=llvm::errs())
Parse the textual representation of a pass pipeline, adding the result to 'pm' on success.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
bool pmFlagVerifyPasses
Various pass manager flags that get emitted when generating a reproducer.
Impl(PassManager::ReproducerStreamFactory &streamFactory, bool localReproducer)
SetVector< std::pair< Pass *, Operation * > > runningPasses
The set of all currently running passes.
bool localReproducer
Flag indicating if reproducer generation should be localized to the failing pass.
PassManager::ReproducerStreamFactory streamFactory
The factory to use when generating a crash reproducer.
SmallVector< std::unique_ptr< RecoveryReproducerContext > > activeContexts
A record of all of the currently active reproducer contexts.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Streams on which to output crash reproducer.
Definition: PassManager.h:241
void attachResourceParser(ParserConfig &config)
Attach an assembly resource parser to 'config' that collects the MLIR reproducer configuration into t...
LogicalResult apply(PassManager &pm) const
Apply the reproducer options to 'pm' and its context.
This class contains all of the context for generating a recovery reproducer.
void disable()
Disable this reproducer context.
void generate(std::string &description)
Generate a reproducer with the current context.
RecoveryReproducerContext(std::string passPipelineStr, Operation *op, PassManager::ReproducerStreamFactory &streamFactory, bool verifyPasses)
void enable()
Enable a previously disabled reproducer context.