MLIR  19.0.0git
InterpreterPass.cpp
Go to the documentation of this file.
1 //===- InterpreterPass.cpp - Transform dialect interpreter pass -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 
14 using namespace mlir;
15 
16 namespace mlir {
17 namespace transform {
18 #define GEN_PASS_DEF_INTERPRETERPASS
19 #include "mlir/Dialect/Transform/Transforms/Passes.h.inc"
20 } // namespace transform
21 } // namespace mlir
22 
23 /// Returns the payload operation to be used as payload root:
24 /// - the operation nested under `passRoot` that has the given tag attribute,
25 /// must be unique;
26 /// - the `passRoot` itself if the tag is empty.
27 static Operation *findPayloadRoot(Operation *passRoot, StringRef tag) {
28  // Fast return.
29  if (tag.empty())
30  return passRoot;
31 
32  // Walk to do a lookup.
33  Operation *target = nullptr;
34  auto tagAttrName = StringAttr::get(
35  passRoot->getContext(), transform::TransformDialect::kTargetTagAttrName);
36  WalkResult walkResult = passRoot->walk([&](Operation *op) {
37  auto attr = op->getAttrOfType<StringAttr>(tagAttrName);
38  if (!attr || attr.getValue() != tag)
39  return WalkResult::advance();
40 
41  if (!target) {
42  target = op;
43  return WalkResult::advance();
44  }
45 
47  << "repeated operation with the target tag '"
48  << tag << "'";
49  diag.attachNote(target->getLoc()) << "previously seen operation";
50  return WalkResult::interrupt();
51  });
52 
53  if (!target) {
54  passRoot->emitError()
55  << "could not find the operation with transform.target_tag=\"" << tag
56  << "\" attribute";
57  return nullptr;
58  }
59 
60  return walkResult.wasInterrupted() ? nullptr : target;
61 }
62 
63 namespace {
64 class InterpreterPass
65  : public transform::impl::InterpreterPassBase<InterpreterPass> {
66  // Parses the pass arguments to bind trailing arguments of the entry point.
67  std::optional<RaggedArray<transform::MappedValue>>
68  parseArguments(Operation *payloadRoot) {
69  MLIRContext *context = payloadRoot->getContext();
70 
72  trailingBindings.resize(debugBindTrailingArgs.size());
73 
74  // Construct lists of op names to match.
76  debugBindNames.reserve(debugBindTrailingArgs.size());
77  for (auto &&[position, nameString] :
78  llvm::enumerate(debugBindTrailingArgs)) {
79  StringRef name = nameString;
80 
81  // Parse the integer literals.
82  if (name.starts_with("#")) {
83  debugBindNames.push_back(std::nullopt);
84  StringRef lhs = "";
85  StringRef rhs = name.drop_front();
86  do {
87  std::tie(lhs, rhs) = rhs.split(';');
88  int64_t value;
89  if (lhs.getAsInteger(10, value)) {
90  emitError(UnknownLoc::get(context))
91  << "couldn't parse integer pass argument " << name;
92  return std::nullopt;
93  }
94  trailingBindings[position].push_back(
95  Builder(context).getI64IntegerAttr(value));
96  } while (!rhs.empty());
97  } else if (name.starts_with("^")) {
98  debugBindNames.emplace_back(OperationName(name.drop_front(), context));
99  } else {
100  debugBindNames.emplace_back(OperationName(name, context));
101  }
102  }
103 
104  // Collect operations or results for extra bindings.
105  payloadRoot->walk([&](Operation *payload) {
106  for (auto &&[position, name] : llvm::enumerate(debugBindNames)) {
107  if (!name || payload->getName() != *name)
108  continue;
109 
110  if (StringRef(*std::next(debugBindTrailingArgs.begin(), position))
111  .starts_with("^")) {
112  llvm::append_range(trailingBindings[position], payload->getResults());
113  } else {
114  trailingBindings[position].push_back(payload);
115  }
116  }
117  });
118 
120  bindings.push_back(ArrayRef<Operation *>{payloadRoot});
121  for (SmallVector<transform::MappedValue> &trailing : trailingBindings)
122  bindings.push_back(std::move(trailing));
123  return bindings;
124  }
125 
126 public:
127  using Base::Base;
128 
129  void runOnOperation() override {
130  MLIRContext *context = &getContext();
131  ModuleOp transformModule =
133  Operation *payloadRoot =
134  findPayloadRoot(getOperation(), debugPayloadRootTag);
135  if (!payloadRoot)
136  return signalPassFailure();
137 
139  getOperation(), transformModule, entryPoint);
140  if (!transformEntryPoint)
141  return signalPassFailure();
142 
143  std::optional<RaggedArray<transform::MappedValue>> bindings =
144  parseArguments(payloadRoot);
145  if (!bindings)
146  return signalPassFailure();
148  *bindings,
149  cast<transform::TransformOpInterface>(transformEntryPoint),
150  transformModule,
151  options.enableExpensiveChecks(!disableExpensiveChecks)))) {
152  return signalPassFailure();
153  }
154  }
155 
156 private:
157  /// Transform interpreter options.
159 };
160 } // namespace
static MLIRContext * getContext(OpFoldResult val)
static Operation * findPayloadRoot(Operation *passRoot, StringRef tag)
Returns the payload operation to be used as payload root:
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:308
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:545
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
A 2D array where each row may have different length.
Definition: RaggedArray.h:18
void push_back(Range &&elements)
Appends the given range of elements as a new row to the 2D array.
Definition: RaggedArray.h:125
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:34
static WalkResult advance()
Definition: Visitors.h:52
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition: Visitors.h:56
static WalkResult interrupt()
Definition: Visitors.h:51
Options controlling the application of transform operations by the TransformState.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
ModuleOp getPreloadedTransformModule(MLIRContext *context)
Utility to load a transform interpreter module from a module that has already been preloaded in the c...
TransformOpInterface findTransformEntryPoint(Operation *root, ModuleOp module, StringRef entryPoint=TransformDialect::kTransformEntryPointSymbolName)
Finds the first TransformOpInterface named kTransformEntryPointSymbolName that is either:
LogicalResult applyTransformNamedSequence(Operation *payload, Operation *transformRoot, ModuleOp transformModule, const TransformOptions &options)
Standalone util to apply the named sequence transformRoot to payload IR.
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72