MLIR  22.0.0git
GenerateRuntimeVerification.cpp
Go to the documentation of this file.
1 //===- RuntimeOpVerification.cpp - Op Verification ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/AsmState.h"
10 #include "mlir/Transforms/Passes.h"
11 
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/Operation.h"
15 
16 namespace mlir {
17 #define GEN_PASS_DEF_GENERATERUNTIMEVERIFICATION
18 #include "mlir/Transforms/Passes.h.inc"
19 } // namespace mlir
20 
21 using namespace mlir;
22 
23 namespace {
24 struct GenerateRuntimeVerificationPass
25  : public impl::GenerateRuntimeVerificationBase<
26  GenerateRuntimeVerificationPass> {
27  void runOnOperation() override;
28 };
29 
30 /// Default error message generator for runtime verification failures.
31 ///
32 /// This class generates error messages with different levels of verbosity:
33 /// - Level 0: Shows only the error message and operation location
34 /// - Level 1: Shows the full operation string, error message, and location
35 ///
36 /// Clients can call getVerboseLevel() to retrieve the current verbose level
37 /// and use it to customize their own error message generators with similar
38 /// behavior patterns.
39 class DefaultErrMsgGenerator {
40 private:
41  unsigned vLevel;
42  AsmState &state;
43 
44 public:
45  DefaultErrMsgGenerator(unsigned verboseLevel, AsmState &asmState)
46  : vLevel(verboseLevel), state(asmState) {}
47 
48  std::string operator()(Operation *op, StringRef msg) {
49  std::string buffer;
50  llvm::raw_string_ostream stream(buffer);
51  stream << "ERROR: Runtime op verification failed\n";
52  if (vLevel == 1) {
53  op->print(stream, state);
54  stream << "\n^ " << msg;
55  } else {
56  stream << "^ " << msg;
57  }
58  stream << "\nLocation: ";
59  op->getLoc().print(stream);
60  return buffer;
61  }
62 
63  unsigned getVerboseLevel() const { return vLevel; }
64 };
65 } // namespace
66 
67 void GenerateRuntimeVerificationPass::runOnOperation() {
68  // Check verboseLevel is in range [0, 1].
69  if (verboseLevel > 1) {
70  getOperation()->emitError(
71  "generate-runtime-verification pass: set verboseLevel to 0 or 1");
72  signalPassFailure();
73  return;
74  }
75 
76  // The implementation of the RuntimeVerifiableOpInterface may create ops that
77  // can be verified. We don't want to generate verification for IR that
78  // performs verification, so gather all runtime-verifiable ops first.
80  getOperation()->walk([&](RuntimeVerifiableOpInterface verifiableOp) {
81  ops.push_back(verifiableOp);
82  });
83 
84  // We may generate a lot of error messages and so we need to ensure the
85  // printing is fast.
86  OpPrintingFlags flags;
88  flags.skipRegions();
89  flags.useLocalScope();
90  AsmState state(getOperation(), flags);
91 
92  // Client can call getVerboseLevel() to fetch verbose level.
93  DefaultErrMsgGenerator defaultErrMsgGenerator(verboseLevel.getValue(), state);
94 
95  OpBuilder builder(getOperation()->getContext());
96  for (RuntimeVerifiableOpInterface verifiableOp : ops) {
97  builder.setInsertionPoint(verifiableOp);
98  verifiableOp.generateRuntimeVerification(builder, verifiableOp.getLoc(),
99  defaultErrMsgGenerator);
100  };
101 }
102 
104  return std::make_unique<GenerateRuntimeVerificationPass>();
105 }
static MLIRContext * getContext(OpFoldResult val)
This class provides management for the lifetime of the state used when printing the IR.
Definition: AsmState.h:542
void print(raw_ostream &os) const
Print the location.
Definition: Location.h:97
This class helps build Operations.
Definition: Builders.h:207
Set of flags used to control the behavior of the various IR print methods (e.g.
OpPrintingFlags & elideLargeElementsAttrs(int64_t largeElementLimit=16)
Enables the elision of large elements attributes by printing a lexically valid but otherwise meaningl...
Definition: AsmPrinter.cpp:249
OpPrintingFlags & useLocalScope(bool enable=true)
Use local scope when printing the operation.
Definition: AsmPrinter.cpp:296
OpPrintingFlags & skipRegions(bool skip=true)
Skip printing regions.
Definition: AsmPrinter.cpp:282
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
void print(raw_ostream &os, const OpPrintingFlags &flags={})
Include the generated interface declarations.
std::unique_ptr< Pass > createGenerateRuntimeVerificationPass()
Creates a pass that generates IR to verify ops at runtime.