18#define GEN_PASS_DEF_INTERPRETERPASS
19#include "mlir/Dialect/Transform/Transforms/Passes.h.inc"
34 auto tagAttrName = StringAttr::get(
35 passRoot->
getContext(), transform::TransformDialect::kTargetTagAttrName);
38 if (!attr || attr.getValue() != tag)
47 <<
"repeated operation with the target tag '"
49 diag.attachNote(
target->getLoc()) <<
"previously seen operation";
55 <<
"could not find the operation with transform.target_tag=\"" << tag
67 std::optional<RaggedArray<transform::MappedValue>>
68 parseArguments(Operation *payloadRoot) {
69 MLIRContext *context = payloadRoot->
getContext();
71 SmallVector<SmallVector<transform::MappedValue>, 2> trailingBindings;
72 trailingBindings.resize(debugBindTrailingArgs.size());
75 SmallVector<std::optional<OperationName>> debugBindNames;
76 debugBindNames.reserve(debugBindTrailingArgs.size());
77 for (
auto &&[position, nameString] :
78 llvm::enumerate(debugBindTrailingArgs)) {
79 StringRef name = nameString;
82 if (name.starts_with(
"#")) {
83 debugBindNames.push_back(std::nullopt);
85 StringRef
rhs = name.drop_front();
89 if (
lhs.getAsInteger(10, value)) {
91 <<
"couldn't parse integer pass argument " << name;
94 trailingBindings[position].push_back(
95 Builder(context).getI64IntegerAttr(value));
96 }
while (!
rhs.empty());
97 }
else if (name.starts_with(
"^")) {
98 debugBindNames.emplace_back(OperationName(name.drop_front(), context));
100 debugBindNames.emplace_back(OperationName(name, context));
105 payloadRoot->
walk([&](Operation *payload) {
106 for (
auto &&[position, name] : llvm::enumerate(debugBindNames)) {
107 if (!name || payload->
getName() != *name)
110 if (StringRef(*std::next(debugBindTrailingArgs.begin(), position))
112 llvm::append_range(trailingBindings[position], payload->
getResults());
114 trailingBindings[position].push_back(payload);
119 RaggedArray<transform::MappedValue> bindings;
120 bindings.
push_back(ArrayRef<Operation *>{payloadRoot});
121 for (SmallVector<transform::MappedValue> &trailing : trailingBindings)
129 void runOnOperation()
override {
131 ModuleOp transformModule =
133 Operation *payloadRoot =
136 return signalPassFailure();
139 getOperation(), transformModule, entryPoint);
140 if (!transformEntryPoint)
141 return signalPassFailure();
143 std::optional<RaggedArray<transform::MappedValue>> bindings =
144 parseArguments(payloadRoot);
146 return signalPassFailure();
149 cast<transform::TransformOpInterface>(transformEntryPoint),
151 options.enableExpensiveChecks(!disableExpensiveChecks)))) {
152 return signalPassFailure();
158 transform::TransformOptions options;
static Operation * findPayloadRoot(Operation *passRoot, StringRef tag)
Returns the payload operation to be used as payload root:
static std::string diag(const llvm::Value &value)
This class represents a diagnostic that is inflight and set to be reported.
Operation is the basic unit of execution within MLIR.
AttrClass getAttrOfType(StringAttr name)
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
result_range getResults()
MLIRContext * getContext()
Return the context this operation is associated with.
void push_back(Range &&elements)
Appends the given range of elements as a new row to the 2D array.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.