MLIR  18.0.0git
PassCrashRecovery.cpp
Go to the documentation of this file.
1 //===- PassCrashRecovery.cpp - Pass Crash Recovery Implementation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 #include "mlir/IR/Diagnostics.h"
11 #include "mlir/IR/Dialect.h"
12 #include "mlir/IR/SymbolTable.h"
13 #include "mlir/IR/Verifier.h"
14 #include "mlir/Parser/Parser.h"
15 #include "mlir/Pass/Pass.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/ScopeExit.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/Support/CommandLine.h"
21 #include "llvm/Support/CrashRecoveryContext.h"
22 #include "llvm/Support/Mutex.h"
23 #include "llvm/Support/Signals.h"
24 #include "llvm/Support/Threading.h"
25 #include "llvm/Support/ToolOutputFile.h"
26 
27 using namespace mlir;
28 using namespace mlir::detail;
29 
30 //===----------------------------------------------------------------------===//
31 // RecoveryReproducerContext
32 //===----------------------------------------------------------------------===//
33 
34 namespace mlir {
35 namespace detail {
36 /// This class contains all of the context for generating a recovery reproducer.
37 /// Each recovery context is registered globally to allow for generating
38 /// reproducers when a signal is raised, such as a segfault.
40  RecoveryReproducerContext(std::string passPipelineStr, Operation *op,
42  bool verifyPasses);
44 
45  /// Generate a reproducer with the current context.
46  void generate(std::string &description);
47 
48  /// Disable this reproducer context. This prevents the context from generating
49  /// a reproducer in the result of a crash.
50  void disable();
51 
52  /// Enable a previously disabled reproducer context.
53  void enable();
54 
55 private:
56  /// This function is invoked in the event of a crash.
57  static void crashHandler(void *);
58 
59  /// Register a signal handler to run in the event of a crash.
60  static void registerSignalHandler();
61 
62  /// The textual description of the currently executing pipeline.
63  std::string pipelineElements;
64 
65  /// The MLIR operation representing the IR before the crash.
66  Operation *preCrashOperation;
67 
68  /// The factory for the reproducer output stream to use when generating the
69  /// reproducer.
71 
72  /// Various pass manager and context flags.
73  bool disableThreads;
74  bool verifyPasses;
75 
76  /// The current set of active reproducer contexts. This is used in the event
77  /// of a crash. This is not thread_local as the pass manager may produce any
78  /// number of child threads. This uses a set to allow for multiple MLIR pass
79  /// managers to be running at the same time.
80  static llvm::ManagedStatic<llvm::sys::SmartMutex<true>> reproducerMutex;
81  static llvm::ManagedStatic<
82  llvm::SmallSetVector<RecoveryReproducerContext *, 1>>
83  reproducerSet;
84 };
85 } // namespace detail
86 } // namespace mlir
87 
88 llvm::ManagedStatic<llvm::sys::SmartMutex<true>>
89  RecoveryReproducerContext::reproducerMutex;
90 llvm::ManagedStatic<llvm::SmallSetVector<RecoveryReproducerContext *, 1>>
91  RecoveryReproducerContext::reproducerSet;
92 
94  std::string passPipelineStr, Operation *op,
95  PassManager::ReproducerStreamFactory &streamFactory, bool verifyPasses)
96  : pipelineElements(std::move(passPipelineStr)),
97  preCrashOperation(op->clone()), streamFactory(streamFactory),
98  disableThreads(!op->getContext()->isMultithreadingEnabled()),
99  verifyPasses(verifyPasses) {
100  enable();
101 }
102 
104  // Erase the cloned preCrash IR that we cached.
105  preCrashOperation->erase();
106  disable();
107 }
108 
109 void RecoveryReproducerContext::generate(std::string &description) {
110  llvm::raw_string_ostream descOS(description);
111 
112  // Try to create a new output stream for this crash reproducer.
113  std::string error;
114  std::unique_ptr<PassManager::ReproducerStream> stream = streamFactory(error);
115  if (!stream) {
116  descOS << "failed to create output stream: " << error;
117  return;
118  }
119  descOS << "reproducer generated at `" << stream->description() << "`";
120 
121  std::string pipeline = (preCrashOperation->getName().getStringRef() + "(" +
122  pipelineElements + ")")
123  .str();
124  AsmState state(preCrashOperation);
125  state.attachResourcePrinter(
126  "mlir_reproducer", [&](Operation *op, AsmResourceBuilder &builder) {
127  builder.buildString("pipeline", pipeline);
128  builder.buildBool("disable_threading", disableThreads);
129  builder.buildBool("verify_each", verifyPasses);
130  });
131 
132  // Output the .mlir module.
133  preCrashOperation->print(stream->os(), state);
134 }
135 
137  llvm::sys::SmartScopedLock<true> lock(*reproducerMutex);
138  reproducerSet->remove(this);
139  if (reproducerSet->empty())
140  llvm::CrashRecoveryContext::Disable();
141 }
142 
144  llvm::sys::SmartScopedLock<true> lock(*reproducerMutex);
145  if (reproducerSet->empty())
146  llvm::CrashRecoveryContext::Enable();
147  registerSignalHandler();
148  reproducerSet->insert(this);
149 }
150 
151 void RecoveryReproducerContext::crashHandler(void *) {
152  // Walk the current stack of contexts and generate a reproducer for each one.
153  // We can't know for certain which one was the cause, so we need to generate
154  // a reproducer for all of them.
155  for (RecoveryReproducerContext *context : *reproducerSet) {
156  std::string description;
157  context->generate(description);
158 
159  // Emit an error using information only available within the context.
160  emitError(context->preCrashOperation->getLoc())
161  << "A signal was caught while processing the MLIR module:"
162  << description << "; marking pass as failed";
163  }
164 }
165 
166 void RecoveryReproducerContext::registerSignalHandler() {
167  // Ensure that the handler is only registered once.
168  static bool registered =
169  (llvm::sys::AddSignalHandler(crashHandler, nullptr), false);
170  (void)registered;
171 }
172 
173 //===----------------------------------------------------------------------===//
174 // PassCrashReproducerGenerator
175 //===----------------------------------------------------------------------===//
176 
179  bool localReproducer)
181 
182  /// The factory to use when generating a crash reproducer.
184 
185  /// Flag indicating if reproducer generation should be localized to the
186  /// failing pass.
187  bool localReproducer = false;
188 
189  /// A record of all of the currently active reproducer contexts.
191 
192  /// The set of all currently running passes. Note: This is not populated when
193  /// `localReproducer` is true, as each pass will get its own recovery context.
195 
196  /// Various pass manager flags that get emitted when generating a reproducer.
197  bool pmFlagVerifyPasses = false;
198 };
199 
201  PassManager::ReproducerStreamFactory &streamFactory, bool localReproducer)
202  : impl(std::make_unique<Impl>(streamFactory, localReproducer)) {}
204 
207  bool pmFlagVerifyPasses) {
208  assert((!impl->localReproducer ||
209  !op->getContext()->isMultithreadingEnabled()) &&
210  "expected multi-threading to be disabled when generating a local "
211  "reproducer");
212 
213  llvm::CrashRecoveryContext::Enable();
214  impl->pmFlagVerifyPasses = pmFlagVerifyPasses;
215 
216  // If we aren't generating a local reproducer, prepare a reproducer for the
217  // given top-level operation.
218  if (!impl->localReproducer)
219  prepareReproducerFor(passes, op);
220 }
221 
222 static void
224  std::pair<Pass *, Operation *> passOpPair) {
225  os << "`" << passOpPair.first->getName() << "` on "
226  << "'" << passOpPair.second->getName() << "' operation";
227  if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(passOpPair.second))
228  os << ": @" << symbol.getName();
229 }
230 
232  LogicalResult executionResult) {
233  // Don't generate a reproducer if we have no active contexts.
234  if (impl->activeContexts.empty())
235  return;
236 
237  // If the pass manager execution succeeded, we don't generate any reproducers.
238  if (succeeded(executionResult))
239  return impl->activeContexts.clear();
240 
242  << "Failures have been detected while "
243  "processing an MLIR pass pipeline";
244 
245  // If we are generating a global reproducer, we include all of the running
246  // passes in the error message for the only active context.
247  if (!impl->localReproducer) {
248  assert(impl->activeContexts.size() == 1 && "expected one active context");
249 
250  // Generate the reproducer.
251  std::string description;
252  impl->activeContexts.front()->generate(description);
253 
254  // Emit an error to the user.
255  Diagnostic &note = diag.attachNote() << "Pipeline failed while executing [";
256  llvm::interleaveComma(impl->runningPasses, note,
257  [&](const std::pair<Pass *, Operation *> &value) {
258  formatPassOpReproducerMessage(note, value);
259  });
260  note << "]: " << description;
261  impl->runningPasses.clear();
262  impl->activeContexts.clear();
263  return;
264  }
265 
266  // If we were generating a local reproducer, we generate a reproducer for the
267  // most recently executing pass using the matching entry from `runningPasses`
268  // to generate a localized diagnostic message.
269  assert(impl->activeContexts.size() == impl->runningPasses.size() &&
270  "expected running passes to match active contexts");
271 
272  // Generate the reproducer.
273  RecoveryReproducerContext &reproducerContext = *impl->activeContexts.back();
274  std::string description;
275  reproducerContext.generate(description);
276 
277  // Emit an error to the user.
278  Diagnostic &note = diag.attachNote() << "Pipeline failed while executing ";
279  formatPassOpReproducerMessage(note, impl->runningPasses.back());
280  note << ": " << description;
281 
282  impl->activeContexts.clear();
283  impl->runningPasses.clear();
284 }
285 
287  Operation *op) {
288  // If not tracking local reproducers, we simply remember that this pass is
289  // running.
290  impl->runningPasses.insert(std::make_pair(pass, op));
291  if (!impl->localReproducer)
292  return;
293 
294  // Disable the current pass recovery context, if there is one. This may happen
295  // in the case of dynamic pass pipelines.
296  if (!impl->activeContexts.empty())
297  impl->activeContexts.back()->disable();
298 
299  // Collect all of the parent scopes of this operation.
301  while (Operation *parentOp = op->getParentOp()) {
302  scopes.push_back(op->getName());
303  op = parentOp;
304  }
305 
306  // Emit a pass pipeline string for the current pass running on the current
307  // operation type.
308  std::string passStr;
309  llvm::raw_string_ostream passOS(passStr);
310  for (OperationName scope : llvm::reverse(scopes))
311  passOS << scope << "(";
312  pass->printAsTextualPipeline(passOS);
313  for (unsigned i = 0, e = scopes.size(); i < e; ++i)
314  passOS << ")";
315 
316  impl->activeContexts.push_back(std::make_unique<RecoveryReproducerContext>(
317  passOS.str(), op, impl->streamFactory, impl->pmFlagVerifyPasses));
318 }
321  std::string passStr;
322  llvm::raw_string_ostream passOS(passStr);
323  llvm::interleaveComma(
324  passes, passOS, [&](Pass &pass) { pass.printAsTextualPipeline(passOS); });
325 
326  impl->activeContexts.push_back(std::make_unique<RecoveryReproducerContext>(
327  passOS.str(), op, impl->streamFactory, impl->pmFlagVerifyPasses));
328 }
329 
331  Operation *op) {
332  // We only pop the active context if we are tracking local reproducers.
333  impl->runningPasses.remove(std::make_pair(pass, op));
334  if (impl->localReproducer) {
335  impl->activeContexts.pop_back();
336 
337  // Re-enable the previous pass recovery context, if there was one. This may
338  // happen in the case of dynamic pass pipelines.
339  if (!impl->activeContexts.empty())
340  impl->activeContexts.back()->enable();
341  }
342 }
343 
344 //===----------------------------------------------------------------------===//
345 // CrashReproducerInstrumentation
346 //===----------------------------------------------------------------------===//
347 
348 namespace {
349 struct CrashReproducerInstrumentation : public PassInstrumentation {
350  CrashReproducerInstrumentation(PassCrashReproducerGenerator &generator)
351  : generator(generator) {}
352  ~CrashReproducerInstrumentation() override = default;
353 
354  void runBeforePass(Pass *pass, Operation *op) override {
355  if (!isa<OpToOpPassAdaptor>(pass))
356  generator.prepareReproducerFor(pass, op);
357  }
358 
359  void runAfterPass(Pass *pass, Operation *op) override {
360  if (!isa<OpToOpPassAdaptor>(pass))
361  generator.removeLastReproducerFor(pass, op);
362  }
363 
364  void runAfterPassFailed(Pass *pass, Operation *op) override {
365  // Only generate one reproducer per crash reproducer instrumentation.
366  if (alreadyFailed)
367  return;
368 
369  alreadyFailed = true;
370  generator.finalize(op, /*executionResult=*/failure());
371  }
372 
373 private:
374  /// The generator used to create crash reproducers.
376  bool alreadyFailed = false;
377 };
378 } // namespace
379 
380 //===----------------------------------------------------------------------===//
381 // FileReproducerStream
382 //===----------------------------------------------------------------------===//
383 
384 namespace {
385 /// This class represents a default instance of PassManager::ReproducerStream
386 /// that is backed by a file.
387 struct FileReproducerStream : public PassManager::ReproducerStream {
388  FileReproducerStream(std::unique_ptr<llvm::ToolOutputFile> outputFile)
389  : outputFile(std::move(outputFile)) {}
390  ~FileReproducerStream() override { outputFile->keep(); }
391 
392  /// Returns a description of the reproducer stream.
393  StringRef description() override { return outputFile->getFilename(); }
394 
395  /// Returns the stream on which to output the reproducer.
396  raw_ostream &os() override { return outputFile->os(); }
397 
398 private:
399  /// ToolOutputFile corresponding to opened `filename`.
400  std::unique_ptr<llvm::ToolOutputFile> outputFile = nullptr;
401 };
402 } // namespace
403 
404 //===----------------------------------------------------------------------===//
405 // PassManager
406 //===----------------------------------------------------------------------===//
407 
408 LogicalResult PassManager::runWithCrashRecovery(Operation *op,
409  AnalysisManager am) {
410  crashReproGenerator->initialize(getPasses(), op, verifyPasses);
411 
412  // Safely invoke the passes within a recovery context.
413  LogicalResult passManagerResult = failure();
414  llvm::CrashRecoveryContext recoveryContext;
415  recoveryContext.RunSafelyOnThread(
416  [&] { passManagerResult = runPasses(op, am); });
417  crashReproGenerator->finalize(op, passManagerResult);
418  return passManagerResult;
419 }
420 
422  bool genLocalReproducer) {
423  // Capture the filename by value in case outputFile is out of scope when
424  // invoked.
425  std::string filename = outputFile.str();
426  enableCrashReproducerGeneration(
427  [filename](std::string &error) -> std::unique_ptr<ReproducerStream> {
428  std::unique_ptr<llvm::ToolOutputFile> outputFile =
429  mlir::openOutputFile(filename, &error);
430  if (!outputFile) {
431  error = "Failed to create reproducer stream: " + error;
432  return nullptr;
433  }
434  return std::make_unique<FileReproducerStream>(std::move(outputFile));
435  },
436  genLocalReproducer);
437 }
438 
440  ReproducerStreamFactory factory, bool genLocalReproducer) {
441  assert(!crashReproGenerator &&
442  "crash reproducer has already been initialized");
443  if (genLocalReproducer && getContext()->isMultithreadingEnabled())
444  llvm::report_fatal_error(
445  "Local crash reproduction can't be setup on a "
446  "pass-manager without disabling multi-threading first.");
447 
448  crashReproGenerator = std::make_unique<PassCrashReproducerGenerator>(
449  factory, genLocalReproducer);
450  addInstrumentation(
451  std::make_unique<CrashReproducerInstrumentation>(*crashReproGenerator));
452 }
453 
454 //===----------------------------------------------------------------------===//
455 // Asm Resource
456 //===----------------------------------------------------------------------===//
457 
459  auto parseFn = [this](AsmParsedResourceEntry &entry) -> LogicalResult {
460  if (entry.getKey() == "pipeline") {
461  FailureOr<std::string> value = entry.parseAsString();
462  if (succeeded(value))
463  this->pipeline = std::move(*value);
464  return value;
465  }
466  if (entry.getKey() == "disable_threading") {
467  FailureOr<bool> value = entry.parseAsBool();
468  if (succeeded(value))
469  this->disableThreading = *value;
470  return value;
471  }
472  if (entry.getKey() == "verify_each") {
473  FailureOr<bool> value = entry.parseAsBool();
474  if (succeeded(value))
475  this->verifyEach = *value;
476  return value;
477  }
478  return entry.emitError() << "unknown 'mlir_reproducer' resource key '"
479  << entry.getKey() << "'";
480  };
481  config.attachResourceParser("mlir_reproducer", parseFn);
482 }
483 
485  if (pipeline.has_value()) {
486  FailureOr<OpPassManager> reproPm = parsePassPipeline(*pipeline);
487  if (failed(reproPm))
488  return failure();
489  static_cast<OpPassManager &>(pm) = std::move(*reproPm);
490  }
491 
492  if (disableThreading.has_value())
493  pm.getContext()->disableMultithreading(*disableThreading);
494 
495  if (verifyEach.has_value())
496  pm.enableVerifier(*verifyEach);
497 
498  return success();
499 }
static MLIRContext * getContext(OpFoldResult val)
static const mlir::GenInfo * generator
static std::string diag(const llvm::Value &value)
static void formatPassOpReproducerMessage(Diagnostic &os, std::pair< Pass *, Operation * > passOpPair)
This class represents an analysis manager for a particular operation instance.
This class represents a single parsed resource entry.
Definition: AsmState.h:283
This class is used to build resource entries for use by the printer.
Definition: AsmState.h:239
virtual void buildString(StringRef key, StringRef data)=0
Build a resource entry represented by the given human-readable string value.
virtual void buildBool(StringRef key, bool data)=0
Build a resource entry represented by the given bool.
This class provides management for the lifetime of the state used when printing the IR.
Definition: AsmState.h:533
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:156
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:308
void disableMultithreading(bool disable=true)
Set the flag specifying if multi-threading is disabled by the context.
bool isMultithreadingEnabled()
Return true if multi-threading is enabled by the context.
This class represents a pass manager that runs passes on either a specific operation type,...
Definition: PassManager.h:48
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void print(raw_ostream &os, const OpPrintingFlags &flags=std::nullopt)
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:538
This class represents a configuration for the MLIR assembly parser.
Definition: AsmState.h:460
void attachResourceParser(std::unique_ptr< AsmResourceParser > parser)
Attach the given resource parser.
Definition: AsmState.h:496
PassInstrumentation provides several entry points into the pass manager infrastructure.
The main pass manager and pipeline builder.
Definition: PassManager.h:211
MLIRContext * getContext() const
Return an instance of the context.
Definition: PassManager.h:236
std::function< std::unique_ptr< ReproducerStream >(std::string &error)> ReproducerStreamFactory
Method type for constructing ReproducerStream.
Definition: PassManager.h:259
void enableCrashReproducerGeneration(StringRef outputFile, bool genLocalReproducer=false)
Enable support for the pass manager to generate a reproducer on the event of a crash or a pass failur...
void enableVerifier(bool enabled=true)
Runs the verifier after each individual pass.
Definition: Pass.cpp:819
The abstract base pass class.
Definition: Pass.h:51
void printAsTextualPipeline(raw_ostream &os)
Prints out the pass in the textual representation of pipelines.
Definition: Pass.cpp:65
void initialize(iterator_range< PassManager::pass_iterator > passes, Operation *op, bool pmFlagVerifyPasses)
Initialize the generator in preparation for reproducer generation.
void removeLastReproducerFor(Pass *pass, Operation *op)
Remove the last recorded reproducer anchored at the given pass and operation.
void finalize(Operation *rootOp, LogicalResult executionResult)
Finalize the current run of the generator, generating any necessary reproducers if the provided execu...
PassCrashReproducerGenerator(PassManager::ReproducerStreamFactory &streamFactory, bool localReproducer)
void prepareReproducerFor(Pass *pass, Operation *op)
Prepare a new reproducer for the given pass, operating on op.
Detect if any of the given parameter types has a sub-element handler.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
std::unique_ptr< llvm::ToolOutputFile > openOutputFile(llvm::StringRef outputFilename, std::string *errorMessage=nullptr)
Open the file specified by its name for writing.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
LogicalResult parsePassPipeline(StringRef pipeline, OpPassManager &pm, raw_ostream &errorStream=llvm::errs())
Parse the textual representation of a pass pipeline, adding the result to 'pm' on success.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
bool pmFlagVerifyPasses
Various pass manager flags that get emitted when generating a reproducer.
Impl(PassManager::ReproducerStreamFactory &streamFactory, bool localReproducer)
SetVector< std::pair< Pass *, Operation * > > runningPasses
The set of all currently running passes.
bool localReproducer
Flag indicating if reproducer generation should be localized to the failing pass.
PassManager::ReproducerStreamFactory streamFactory
The factory to use when generating a crash reproducer.
SmallVector< std::unique_ptr< RecoveryReproducerContext > > activeContexts
A record of all of the currently active reproducer contexts.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Streams on which to output crash reproducer.
Definition: PassManager.h:247
void attachResourceParser(ParserConfig &config)
Attach an assembly resource parser to 'config' that collects the MLIR reproducer configuration into t...
LogicalResult apply(PassManager &pm) const
Apply the reproducer options to 'pm' and its context.
This class contains all of the context for generating a recovery reproducer.
void disable()
Disable this reproducer context.
void generate(std::string &description)
Generate a reproducer with the current context.
RecoveryReproducerContext(std::string passPipelineStr, Operation *op, PassManager::ReproducerStreamFactory &streamFactory, bool verifyPasses)
void enable()
Enable a previously disabled reproducer context.