18 #define GEN_PASS_DEF_INTERPRETERPASS
19 #include "mlir/Dialect/Transform/Transforms/Passes.h.inc"
35 passRoot->
getContext(), transform::TransformDialect::kTargetTagAttrName);
38 if (!attr || attr.getValue() != tag)
47 <<
"repeated operation with the target tag '"
49 diag.attachNote(target->
getLoc()) <<
"previously seen operation";
55 <<
"could not find the operation with transform.target_tag=\"" << tag
65 :
public transform::impl::InterpreterPassBase<InterpreterPass> {
67 std::optional<RaggedArray<transform::MappedValue>>
72 trailingBindings.resize(debugBindTrailingArgs.size());
76 debugBindNames.reserve(debugBindTrailingArgs.size());
77 for (
auto &&[position, nameString] :
79 StringRef name = nameString;
82 if (name.starts_with(
"#")) {
83 debugBindNames.push_back(std::nullopt);
85 StringRef rhs = name.drop_front();
87 std::tie(lhs, rhs) = rhs.split(
';');
89 if (lhs.getAsInteger(10, value)) {
91 <<
"couldn't parse integer pass argument " << name;
94 trailingBindings[position].push_back(
95 Builder(context).getI64IntegerAttr(value));
96 }
while (!rhs.empty());
97 }
else if (name.starts_with(
"^")) {
98 debugBindNames.emplace_back(
OperationName(name.drop_front(), context));
107 if (!name || payload->
getName() != *name)
110 if (StringRef(*std::next(debugBindTrailingArgs.begin(), position))
112 llvm::append_range(trailingBindings[position], payload->getResults());
114 trailingBindings[position].push_back(payload);
129 void runOnOperation()
override {
131 ModuleOp transformModule =
136 return signalPassFailure();
139 getOperation(), transformModule, entryPoint);
140 if (!transformEntryPoint)
141 return signalPassFailure();
143 std::optional<RaggedArray<transform::MappedValue>> bindings =
144 parseArguments(payloadRoot);
146 return signalPassFailure();
149 cast<transform::TransformOpInterface>(transformEntryPoint),
151 options.enableExpensiveChecks(!disableExpensiveChecks)))) {
152 return signalPassFailure();
static MLIRContext * getContext(OpFoldResult val)
static Operation * findPayloadRoot(Operation *passRoot, StringRef tag)
Returns the payload operation to be used as payload root:
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
This class is a general helper class for creating context-global objects like types,...
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
Operation is the basic unit of execution within MLIR.
AttrClass getAttrOfType(StringAttr name)
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
A 2D array where each row may have different length.
void push_back(Range &&elements)
Appends the given range of elements as a new row to the 2D array.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...