MLIR 23.0.0git
GenerateRuntimeVerification.cpp
Go to the documentation of this file.
1//===- RuntimeOpVerification.cpp - Op Verification ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/IR/AsmState.h"
11
12#include "mlir/IR/Builders.h"
13#include "mlir/IR/Operation.h"
15
16namespace mlir {
17#define GEN_PASS_DEF_GENERATERUNTIMEVERIFICATIONPASS
18#include "mlir/Transforms/Passes.h.inc"
19} // namespace mlir
20
21using namespace mlir;
22
23namespace {
24struct GenerateRuntimeVerificationPass
25 : public impl::GenerateRuntimeVerificationPassBase<
26 GenerateRuntimeVerificationPass> {
27 using impl::GenerateRuntimeVerificationPassBase<
28 GenerateRuntimeVerificationPass>::GenerateRuntimeVerificationPassBase;
29 void runOnOperation() override;
30};
31
32/// Default error message generator for runtime verification failures.
33///
34/// This class generates error messages with different levels of verbosity:
35/// - Level 0: Shows only the error message and operation location
36/// - Level 1: Shows the full operation string, error message, and location
37///
38/// Clients can call getVerboseLevel() to retrieve the current verbose level
39/// and use it to customize their own error message generators with similar
40/// behavior patterns.
41class DefaultErrMsgGenerator {
42private:
43 unsigned vLevel;
44 AsmState &state;
45
46public:
47 DefaultErrMsgGenerator(unsigned verboseLevel, AsmState &asmState)
48 : vLevel(verboseLevel), state(asmState) {}
49
50 std::string operator()(Operation *op, StringRef msg) {
51 std::string buffer;
52 llvm::raw_string_ostream stream(buffer);
53 stream << "ERROR: Runtime op verification failed\n";
54 if (vLevel == 1) {
55 op->print(stream, state);
56 stream << "\n^ " << msg;
57 } else {
58 stream << "^ " << msg;
59 }
60 stream << "\nLocation: ";
61 op->getLoc().print(stream);
62 return buffer;
63 }
64
65 unsigned getVerboseLevel() const { return vLevel; }
66};
67} // namespace
68
69void GenerateRuntimeVerificationPass::runOnOperation() {
70 // Check verboseLevel is in range [0, 1].
71 if (verboseLevel > 1) {
72 getOperation()->emitError(
73 "generate-runtime-verification pass: set verboseLevel to 0 or 1");
74 signalPassFailure();
75 return;
76 }
77
78 // The implementation of the RuntimeVerifiableOpInterface may create ops that
79 // can be verified. We don't want to generate verification for IR that
80 // performs verification, so gather all runtime-verifiable ops first.
81 SmallVector<RuntimeVerifiableOpInterface> ops;
82 getOperation()->walk([&](RuntimeVerifiableOpInterface verifiableOp) {
83 ops.push_back(verifiableOp);
84 });
85
86 // We may generate a lot of error messages and so we need to ensure the
87 // printing is fast.
88 OpPrintingFlags flags;
90 flags.skipRegions();
91 flags.useLocalScope();
92 AsmState state(getOperation(), flags);
93
94 // Client can call getVerboseLevel() to fetch verbose level.
95 DefaultErrMsgGenerator defaultErrMsgGenerator(verboseLevel.getValue(), state);
96
97 OpBuilder builder(getOperation()->getContext());
98 for (RuntimeVerifiableOpInterface verifiableOp : ops) {
99 builder.setInsertionPoint(verifiableOp);
100 verifiableOp.generateRuntimeVerification(builder, verifiableOp.getLoc(),
101 defaultErrMsgGenerator);
102 };
103}
b getContext())
void print(raw_ostream &os) const
Print the location.
Definition Location.h:97
OpPrintingFlags & elideLargeElementsAttrs(int64_t largeElementLimit=16)
Enables the elision of large elements attributes by printing a lexically valid but otherwise meaningl...
OpPrintingFlags & useLocalScope(bool enable=true)
Use local scope when printing the operation.
OpPrintingFlags & skipRegions(bool skip=true)
Skip printing regions.
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:244
void print(raw_ostream &os, const OpPrintingFlags &flags={})
Include the generated interface declarations.