MLIR  21.0.0git
TransformInterpreterUtils.cpp
Go to the documentation of this file.
1 //===- TransformInterpreterUtils.cpp --------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Lightweight transform dialect interpreter utilities.
10 //
11 //===----------------------------------------------------------------------===//
12 
18 #include "mlir/IR/BuiltinOps.h"
19 #include "mlir/IR/Verifier.h"
20 #include "mlir/IR/Visitors.h"
21 #include "mlir/Parser/Parser.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/FileSystem.h"
26 #include "llvm/Support/SourceMgr.h"
27 #include "llvm/Support/raw_ostream.h"
28 
29 using namespace mlir;
30 
31 #define DEBUG_TYPE "transform-dialect-interpreter-utils"
32 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
33 
34 /// Expands the given list of `paths` to a list of `.mlir` files.
35 ///
36 /// Each entry in `paths` may either be a regular file, in which case it ends up
37 /// in the result list, or a directory, in which case all (regular) `.mlir`
38 /// files in that directory are added. Any other file types lead to a failure.
40  ArrayRef<std::string> paths, MLIRContext *context,
41  SmallVectorImpl<std::string> &fileNames) {
42  for (const std::string &path : paths) {
43  auto loc = FileLineColLoc::get(context, path, 0, 0);
44 
45  if (llvm::sys::fs::is_regular_file(path)) {
46  LLVM_DEBUG(DBGS() << "Adding '" << path << "' to list of files\n");
47  fileNames.push_back(path);
48  continue;
49  }
50 
51  if (!llvm::sys::fs::is_directory(path)) {
52  return emitError(loc)
53  << "'" << path << "' is neither a file nor a directory";
54  }
55 
56  LLVM_DEBUG(DBGS() << "Looking for files in '" << path << "':\n");
57 
58  std::error_code ec;
59  for (llvm::sys::fs::directory_iterator it(path, ec), itEnd;
60  it != itEnd && !ec; it.increment(ec)) {
61  const std::string &fileName = it->path();
62 
63  if (it->type() != llvm::sys::fs::file_type::regular_file &&
64  it->type() != llvm::sys::fs::file_type::symlink_file) {
65  LLVM_DEBUG(DBGS() << " Skipping non-regular file '" << fileName
66  << "'\n");
67  continue;
68  }
69 
70  if (!StringRef(fileName).ends_with(".mlir")) {
71  LLVM_DEBUG(DBGS() << " Skipping '" << fileName
72  << "' because it does not end with '.mlir'\n");
73  continue;
74  }
75 
76  LLVM_DEBUG(DBGS() << " Adding '" << fileName << "' to list of files\n");
77  fileNames.push_back(fileName);
78  }
79 
80  if (ec)
81  return emitError(loc) << "error while opening files in '" << path
82  << "': " << ec.message();
83  }
84 
85  return success();
86 }
87 
89  MLIRContext *context, llvm::StringRef transformFileName,
90  OwningOpRef<ModuleOp> &transformModule) {
91  if (transformFileName.empty()) {
92  LLVM_DEBUG(
93  DBGS() << "no transform file name specified, assuming the transform "
94  "module is embedded in the IR next to the top-level\n");
95  return success();
96  }
97  // Parse transformFileName content into a ModuleOp.
98  std::string errorMessage;
99  auto memoryBuffer = mlir::openInputFile(transformFileName, &errorMessage);
100  if (!memoryBuffer) {
102  StringAttr::get(context, transformFileName), 0, 0))
103  << "failed to open transform file: " << errorMessage;
104  }
105  // Tell sourceMgr about this buffer, the parser will pick it up.
106  llvm::SourceMgr sourceMgr;
107  sourceMgr.AddNewSourceBuffer(std::move(memoryBuffer), llvm::SMLoc());
108  transformModule =
109  OwningOpRef<ModuleOp>(parseSourceFile<ModuleOp>(sourceMgr, context));
110  if (!transformModule) {
111  // Failed to parse the transform module.
112  // Don't need to emit an error here as the parsing should have already done
113  // that.
114  return failure();
115  }
116  return mlir::verify(*transformModule);
117 }
118 
120  return context->getOrLoadDialect<transform::TransformDialect>()
121  ->getLibraryModule();
122 }
123 
124 transform::TransformOpInterface
126  StringRef entryPoint) {
128  if (module)
129  l.push_back(module);
130  for (Operation *op : l) {
131  transform::TransformOpInterface transform = nullptr;
133  [&](transform::NamedSequenceOp namedSequenceOp) {
134  if (namedSequenceOp.getSymName() == entryPoint) {
135  transform = cast<transform::TransformOpInterface>(
136  namedSequenceOp.getOperation());
137  return WalkResult::interrupt();
138  }
139  return WalkResult::advance();
140  });
141  if (transform)
142  return transform;
143  }
144  auto diag = root->emitError()
145  << "could not find a nested named sequence with name: "
146  << entryPoint;
147  return nullptr;
148 }
149 
151  MLIRContext *context, ArrayRef<std::string> transformLibraryPaths,
152  OwningOpRef<ModuleOp> &transformModule) {
153  // Assemble list of library files.
154  SmallVector<std::string> libraryFileNames;
155  if (failed(detail::expandPathsToMLIRFiles(transformLibraryPaths, context,
156  libraryFileNames)))
157  return failure();
158 
159  // Parse modules from library files.
160  SmallVector<OwningOpRef<ModuleOp>> parsedLibraries;
161  for (const std::string &libraryFileName : libraryFileNames) {
162  OwningOpRef<ModuleOp> parsedLibrary;
164  context, libraryFileName, parsedLibrary)))
165  return failure();
166  parsedLibraries.push_back(std::move(parsedLibrary));
167  }
168 
169  // Merge parsed libraries into one module.
170  auto loc = FileLineColLoc::get(context, "<shared-library-module>", 0, 0);
171  OwningOpRef<ModuleOp> mergedParsedLibraries =
172  ModuleOp::create(loc, "__transform");
173  {
174  mergedParsedLibraries.get()->setAttr("transform.with_named_sequence",
175  UnitAttr::get(context));
176  // TODO: extend `mergeSymbolsInto` to support multiple `other` modules.
177  for (OwningOpRef<ModuleOp> &parsedLibrary : parsedLibraries) {
179  mergedParsedLibraries.get(), std::move(parsedLibrary))))
180  return parsedLibrary->emitError()
181  << "failed to merge symbols into shared library module";
182  }
183  }
184 
185  transformModule = std::move(mergedParsedLibraries);
186  return success();
187 }
188 
190  Operation *payload, Operation *transformRoot, ModuleOp transformModule,
191  const TransformOptions &options) {
192  RaggedArray<MappedValue> bindings;
193  bindings.push_back(ArrayRef<Operation *>{payload});
194  return applyTransformNamedSequence(bindings,
195  cast<TransformOpInterface>(transformRoot),
196  transformModule, options);
197 }
198 
200  RaggedArray<MappedValue> bindings, TransformOpInterface transformRoot,
201  ModuleOp transformModule, const TransformOptions &options) {
202  if (bindings.empty()) {
203  return transformRoot.emitError()
204  << "expected at least one binding for the root";
205  }
206  if (bindings.at(0).size() != 1) {
207  return transformRoot.emitError()
208  << "expected one payload to be bound to the first argument, got "
209  << bindings.at(0).size();
210  }
211  auto *payloadRoot = dyn_cast<Operation *>(bindings.at(0).front());
212  if (!payloadRoot) {
213  return transformRoot->emitError() << "expected the object bound to the "
214  "first argument to be an operation";
215  }
216 
217  bindings.removeFront();
218 
219  // `transformModule` may not be modified.
220  if (transformModule && !transformModule->isAncestor(transformRoot)) {
221  OwningOpRef<Operation *> clonedTransformModule(transformModule->clone());
222  if (failed(detail::mergeSymbolsInto(
223  SymbolTable::getNearestSymbolTable(transformRoot),
224  std::move(clonedTransformModule)))) {
225  return payloadRoot->emitError() << "failed to merge symbols";
226  }
227  }
228 
229  LLVM_DEBUG(DBGS() << "Apply\n" << *transformRoot << "\n");
230  LLVM_DEBUG(DBGS() << "To\n" << *payloadRoot << "\n");
231 
232  return applyTransforms(payloadRoot, transformRoot, bindings, options,
233  /*enforceToplevelTransformOp=*/false);
234 }
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
#define DBGS()
static FileLineColLoc get(StringAttr filename, unsigned line, unsigned column)
Definition: Location.cpp:161
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
T * getOrLoadDialect()
Get (or create) a dialect for the given derived dialect type.
Definition: MLIRContext.h:97
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
OpTy get() const
Allow accessing the internal op.
Definition: OwningOpRef.h:51
A 2D array where each row may have different length.
Definition: RaggedArray.h:18
ArrayRef< T > at(size_t pos) const
Definition: RaggedArray.h:29
void removeFront()
Removes the first subarray in-place. Invalidates iterators to all rows.
Definition: RaggedArray.h:154
bool empty() const
Returns true if the are no rows in the 2D array.
Definition: RaggedArray.h:25
void push_back(Range &&elements)
Appends the given range of elements as a new row to the 2D array.
Definition: RaggedArray.h:125
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
static WalkResult advance()
Definition: WalkResult.h:47
static WalkResult interrupt()
Definition: WalkResult.h:46
Options controlling the application of transform operations by the TransformState.
LogicalResult assembleTransformLibraryFromPaths(MLIRContext *context, ArrayRef< std::string > transformLibraryPaths, OwningOpRef< ModuleOp > &transformModule)
Utility to parse, verify, aggregate and link the content of all mlir files nested under transformLibr...
LogicalResult parseTransformModuleFromFile(MLIRContext *context, llvm::StringRef transformFileName, OwningOpRef< ModuleOp > &transformModule)
Utility to parse and verify the content of a transformFileName MLIR file containing a transform diale...
ModuleOp getPreloadedTransformModule(MLIRContext *context)
Utility to load a transform interpreter module from a module that has already been preloaded in the c...
InFlightDiagnostic mergeSymbolsInto(Operation *target, OwningOpRef< Operation * > other)
Merge all symbols from other into target.
Definition: Utils.cpp:79
LogicalResult expandPathsToMLIRFiles(ArrayRef< std::string > paths, MLIRContext *context, SmallVectorImpl< std::string > &fileNames)
Expands the given list of paths to a list of .mlir files.
TransformOpInterface findTransformEntryPoint(Operation *root, ModuleOp module, StringRef entryPoint=TransformDialect::kTransformEntryPointSymbolName)
Finds the first TransformOpInterface named kTransformEntryPointSymbolName that is either:
LogicalResult applyTransformNamedSequence(Operation *payload, Operation *transformRoot, ModuleOp transformModule, const TransformOptions &options)
Standalone util to apply the named sequence transformRoot to payload IR.
LogicalResult applyTransforms(Operation *payloadRoot, TransformOpInterface transform, const RaggedArray< MappedValue > &extraMapping={}, const TransformOptions &options=TransformOptions(), bool enforceToplevelTransformOp=true, function_ref< void(TransformState &)> stateInitializer=nullptr, function_ref< LogicalResult(TransformState &)> stateExporter=nullptr)
Entry point to the Transform dialect infrastructure.
Include the generated interface declarations.
std::unique_ptr< llvm::MemoryBuffer > openInputFile(llvm::StringRef inputFilename, std::string *errorMessage=nullptr)
Open the file specified by its name for reading.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423