MLIR 22.0.0git
InterpreterPass.cpp
Go to the documentation of this file.
1//===- InterpreterPass.cpp - Transform dialect interpreter pass -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
13
14using namespace mlir;
15
16namespace mlir {
17namespace transform {
18#define GEN_PASS_DEF_INTERPRETERPASS
19#include "mlir/Dialect/Transform/Transforms/Passes.h.inc"
20} // namespace transform
21} // namespace mlir
22
23/// Returns the payload operation to be used as payload root:
24/// - the operation nested under `passRoot` that has the given tag attribute,
25/// must be unique;
26/// - the `passRoot` itself if the tag is empty.
27static Operation *findPayloadRoot(Operation *passRoot, StringRef tag) {
28 // Fast return.
29 if (tag.empty())
30 return passRoot;
31
32 // Walk to do a lookup.
33 Operation *target = nullptr;
34 auto tagAttrName = StringAttr::get(
35 passRoot->getContext(), transform::TransformDialect::kTargetTagAttrName);
36 WalkResult walkResult = passRoot->walk([&](Operation *op) {
37 auto attr = op->getAttrOfType<StringAttr>(tagAttrName);
38 if (!attr || attr.getValue() != tag)
39 return WalkResult::advance();
40
41 if (!target) {
42 target = op;
43 return WalkResult::advance();
44 }
45
47 << "repeated operation with the target tag '"
48 << tag << "'";
49 diag.attachNote(target->getLoc()) << "previously seen operation";
50 return WalkResult::interrupt();
51 });
52
53 if (!target) {
54 passRoot->emitError()
55 << "could not find the operation with transform.target_tag=\"" << tag
56 << "\" attribute";
57 return nullptr;
58 }
59
60 return walkResult.wasInterrupted() ? nullptr : target;
61}
62
63namespace {
64class InterpreterPass
65 : public transform::impl::InterpreterPassBase<InterpreterPass> {
66 // Parses the pass arguments to bind trailing arguments of the entry point.
67 std::optional<RaggedArray<transform::MappedValue>>
68 parseArguments(Operation *payloadRoot) {
69 MLIRContext *context = payloadRoot->getContext();
70
71 SmallVector<SmallVector<transform::MappedValue>, 2> trailingBindings;
72 trailingBindings.resize(debugBindTrailingArgs.size());
73
74 // Construct lists of op names to match.
75 SmallVector<std::optional<OperationName>> debugBindNames;
76 debugBindNames.reserve(debugBindTrailingArgs.size());
77 for (auto &&[position, nameString] :
78 llvm::enumerate(debugBindTrailingArgs)) {
79 StringRef name = nameString;
80
81 // Parse the integer literals.
82 if (name.starts_with("#")) {
83 debugBindNames.push_back(std::nullopt);
84 StringRef lhs = "";
85 StringRef rhs = name.drop_front();
86 do {
87 std::tie(lhs, rhs) = rhs.split(';');
88 int64_t value;
89 if (lhs.getAsInteger(10, value)) {
90 emitError(UnknownLoc::get(context))
91 << "couldn't parse integer pass argument " << name;
92 return std::nullopt;
93 }
94 trailingBindings[position].push_back(
95 Builder(context).getI64IntegerAttr(value));
96 } while (!rhs.empty());
97 } else if (name.starts_with("^")) {
98 debugBindNames.emplace_back(OperationName(name.drop_front(), context));
99 } else {
100 debugBindNames.emplace_back(OperationName(name, context));
101 }
102 }
103
104 // Collect operations or results for extra bindings.
105 payloadRoot->walk([&](Operation *payload) {
106 for (auto &&[position, name] : llvm::enumerate(debugBindNames)) {
107 if (!name || payload->getName() != *name)
108 continue;
109
110 if (StringRef(*std::next(debugBindTrailingArgs.begin(), position))
111 .starts_with("^")) {
112 llvm::append_range(trailingBindings[position], payload->getResults());
113 } else {
114 trailingBindings[position].push_back(payload);
115 }
116 }
117 });
118
119 RaggedArray<transform::MappedValue> bindings;
120 bindings.push_back(ArrayRef<Operation *>{payloadRoot});
121 for (SmallVector<transform::MappedValue> &trailing : trailingBindings)
122 bindings.push_back(std::move(trailing));
123 return bindings;
124 }
125
126public:
127 using Base::Base;
128
129 void runOnOperation() override {
130 MLIRContext *context = &getContext();
131 ModuleOp transformModule =
133 Operation *payloadRoot =
134 findPayloadRoot(getOperation(), debugPayloadRootTag);
135 if (!payloadRoot)
136 return signalPassFailure();
137
138 Operation *transformEntryPoint = transform::detail::findTransformEntryPoint(
139 getOperation(), transformModule, entryPoint);
140 if (!transformEntryPoint)
141 return signalPassFailure();
142
143 std::optional<RaggedArray<transform::MappedValue>> bindings =
144 parseArguments(payloadRoot);
145 if (!bindings)
146 return signalPassFailure();
148 *bindings,
149 cast<transform::TransformOpInterface>(transformEntryPoint),
150 transformModule,
151 options.enableExpensiveChecks(!disableExpensiveChecks)))) {
152 return signalPassFailure();
153 }
154 }
155
156private:
157 /// Transform interpreter options.
158 transform::TransformOptions options;
159};
160} // namespace
lhs
static Operation * findPayloadRoot(Operation *passRoot, StringRef tag)
Returns the payload operation to be used as payload root:
b getContext())
static std::string diag(const llvm::Value &value)
This class represents a diagnostic that is inflight and set to be reported.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
AttrClass getAttrOfType(StringAttr name)
Definition Operation.h:550
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:797
result_range getResults()
Definition Operation.h:415
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
void push_back(Range &&elements)
Appends the given range of elements as a new row to the 2D array.
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult advance()
Definition WalkResult.h:47
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition WalkResult.h:51
static WalkResult interrupt()
Definition WalkResult.h:46
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
ModuleOp getPreloadedTransformModule(MLIRContext *context)
Utility to load a transform interpreter module from a module that has already been preloaded in the c...
TransformOpInterface findTransformEntryPoint(Operation *root, ModuleOp module, StringRef entryPoint=TransformDialect::kTransformEntryPointSymbolName)
Finds the first TransformOpInterface named kTransformEntryPointSymbolName that is either:
LogicalResult applyTransformNamedSequence(Operation *payload, Operation *transformRoot, ModuleOp transformModule, const TransformOptions &options)
Standalone util to apply the named sequence transformRoot to payload IR.
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.