MLIR 22.0.0git
TransformInterpreterUtils.cpp
Go to the documentation of this file.
1//===- TransformInterpreterUtils.cpp --------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Lightweight transform dialect interpreter utilities.
10//
11//===----------------------------------------------------------------------===//
12
18#include "mlir/IR/BuiltinOps.h"
19#include "mlir/IR/Verifier.h"
20#include "mlir/IR/Visitors.h"
21#include "mlir/Parser/Parser.h"
23#include "llvm/ADT/StringRef.h"
24#include "llvm/Support/Debug.h"
25#include "llvm/Support/FileSystem.h"
26#include "llvm/Support/SourceMgr.h"
27#include "llvm/Support/raw_ostream.h"
28
29using namespace mlir;
30
31#define DEBUG_TYPE "transform-dialect-interpreter-utils"
32#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
33
34/// Expands the given list of `paths` to a list of `.mlir` files.
35///
36/// Each entry in `paths` may either be a regular file, in which case it ends up
37/// in the result list, or a directory, in which case all (regular) `.mlir`
38/// files in that directory are added. Any other file types lead to a failure.
40 ArrayRef<std::string> paths, MLIRContext *context,
42 for (const std::string &path : paths) {
43 auto loc = FileLineColLoc::get(context, path, 0, 0);
44
45 if (llvm::sys::fs::is_regular_file(path)) {
46 LLVM_DEBUG(DBGS() << "Adding '" << path << "' to list of files\n");
47 fileNames.push_back(path);
48 continue;
49 }
50
51 if (!llvm::sys::fs::is_directory(path)) {
52 return emitError(loc)
53 << "'" << path << "' is neither a file nor a directory";
54 }
55
56 LLVM_DEBUG(DBGS() << "Looking for files in '" << path << "':\n");
57
58 std::error_code ec;
59 for (llvm::sys::fs::directory_iterator it(path, ec), itEnd;
60 it != itEnd && !ec; it.increment(ec)) {
61 const std::string &fileName = it->path();
62
63 if (it->type() != llvm::sys::fs::file_type::regular_file &&
64 it->type() != llvm::sys::fs::file_type::symlink_file) {
65 LLVM_DEBUG(DBGS() << " Skipping non-regular file '" << fileName
66 << "'\n");
67 continue;
68 }
69
70 if (!StringRef(fileName).ends_with(".mlir")) {
71 LLVM_DEBUG(DBGS() << " Skipping '" << fileName
72 << "' because it does not end with '.mlir'\n");
73 continue;
74 }
75
76 LLVM_DEBUG(DBGS() << " Adding '" << fileName << "' to list of files\n");
77 fileNames.push_back(fileName);
78 }
79
80 if (ec)
81 return emitError(loc) << "error while opening files in '" << path
82 << "': " << ec.message();
83 }
84
85 return success();
86}
87
89 MLIRContext *context, llvm::StringRef transformFileName,
90 OwningOpRef<ModuleOp> &transformModule) {
91 if (transformFileName.empty()) {
92 LLVM_DEBUG(
93 DBGS() << "no transform file name specified, assuming the transform "
94 "module is embedded in the IR next to the top-level\n");
95 return success();
96 }
97 // Parse transformFileName content into a ModuleOp.
98 std::string errorMessage;
99 auto memoryBuffer = mlir::openInputFile(transformFileName, &errorMessage);
100 if (!memoryBuffer) {
102 StringAttr::get(context, transformFileName), 0, 0))
103 << "failed to open transform file: " << errorMessage;
104 }
105 // Tell sourceMgr about this buffer, the parser will pick it up.
106 llvm::SourceMgr sourceMgr;
107 sourceMgr.AddNewSourceBuffer(std::move(memoryBuffer), llvm::SMLoc());
108 transformModule =
110 if (!transformModule) {
111 // Failed to parse the transform module.
112 // Don't need to emit an error here as the parsing should have already done
113 // that.
114 return failure();
115 }
116 return mlir::verify(*transformModule);
117}
118
120 return context->getOrLoadDialect<transform::TransformDialect>()
121 ->getLibraryModule();
122}
123
124static transform::TransformOpInterface
126 for (Region &region : op->getRegions()) {
127 for (Block &block : region.getBlocks()) {
128 for (auto namedSequenceOp : block.getOps<transform::NamedSequenceOp>()) {
129 if (namedSequenceOp.getSymName() == entryPoint) {
130 return cast<transform::TransformOpInterface>(
131 namedSequenceOp.getOperation());
132 }
133 }
134 }
135 }
136 return nullptr;
137}
138
139static transform::TransformOpInterface
140findTransformEntryPointRecursive(Operation *op, StringRef entryPoint) {
141 transform::TransformOpInterface transform = nullptr;
143 [&](transform::NamedSequenceOp namedSequenceOp) {
144 if (namedSequenceOp.getSymName() == entryPoint) {
145 transform = cast<transform::TransformOpInterface>(
146 namedSequenceOp.getOperation());
147 return WalkResult::interrupt();
148 }
149 return WalkResult::advance();
150 });
151 return transform;
152}
153
154// Will look for the transform's entry point favouring NamedSequenceOps
155// ops that exist within the operation without the need for nesting.
156// If no operation exists in the blocks owned by op, then it will recursively
157// walk the op in preorder and find the first NamedSequenceOp that matches
158// the entry point's name.
159//
160// This allows for the following two use cases:
161// 1. op is a module annotated with the transform.with_named_sequence attribute
162// that has an entry point in its block. E.g.,
163//
164// ```mlir
165// module {transform.with_named_sequence} {
166// transform.named_sequence @__transform_main(%arg0 : !transform.any_op) ->
167// () {
168// transform.yield
169// }
170// }
171// ```
172//
173// 2. op is a program which contains a nested module annotated with the
174// transform.with_named_sequence attribute. E.g.,
175//
176// ```mlir
177// module {
178// func.func @foo () {
179// }
180//
181// module {transform.with_named_sequence} {
182// transform.named_sequence @__transform_main(%arg0 : !transform.any_op)
183// -> () {
184// transform.yield
185// }
186// }
187// }
188// ```
189static transform::TransformOpInterface
190findTransformEntryPointInOp(Operation *op, StringRef entryPoint) {
191 transform::TransformOpInterface transform =
193 if (!transform)
195 return transform;
196}
197
198transform::TransformOpInterface
200 StringRef entryPoint) {
202 if (module)
203 l.push_back(module);
204 for (Operation *op : l) {
205 TransformOpInterface transform =
206 findTransformEntryPointInOp(op, entryPoint);
207 if (transform)
208 return transform;
209 }
210 auto diag = root->emitError()
211 << "could not find a nested named sequence with name: "
212 << entryPoint;
213 return nullptr;
214}
215
217 MLIRContext *context, ArrayRef<std::string> transformLibraryPaths,
218 OwningOpRef<ModuleOp> &transformModule) {
219 // Assemble list of library files.
220 SmallVector<std::string> libraryFileNames;
221 if (failed(detail::expandPathsToMLIRFiles(transformLibraryPaths, context,
222 libraryFileNames)))
223 return failure();
224
225 // Parse modules from library files.
226 SmallVector<OwningOpRef<ModuleOp>> parsedLibraries;
227 for (const std::string &libraryFileName : libraryFileNames) {
228 OwningOpRef<ModuleOp> parsedLibrary;
230 context, libraryFileName, parsedLibrary)))
231 return failure();
232 parsedLibraries.push_back(std::move(parsedLibrary));
233 }
234
235 // Merge parsed libraries into one module.
236 auto loc = FileLineColLoc::get(context, "<shared-library-module>", 0, 0);
237 OwningOpRef<ModuleOp> mergedParsedLibraries =
238 ModuleOp::create(loc, "__transform");
239 {
240 mergedParsedLibraries.get()->setAttr("transform.with_named_sequence",
241 UnitAttr::get(context));
242 // TODO: extend `mergeSymbolsInto` to support multiple `other` modules.
243 for (OwningOpRef<ModuleOp> &parsedLibrary : parsedLibraries) {
245 mergedParsedLibraries.get(), std::move(parsedLibrary))))
246 return parsedLibrary->emitError()
247 << "failed to merge symbols into shared library module";
248 }
249 }
250
251 transformModule = std::move(mergedParsedLibraries);
252 return success();
253}
254
256 Operation *payload, Operation *transformRoot, ModuleOp transformModule,
257 const TransformOptions &options) {
259 bindings.push_back(ArrayRef<Operation *>{payload});
260 return applyTransformNamedSequence(bindings,
261 cast<TransformOpInterface>(transformRoot),
262 transformModule, options);
263}
264
266 RaggedArray<MappedValue> bindings, TransformOpInterface transformRoot,
267 ModuleOp transformModule, const TransformOptions &options) {
268 if (bindings.empty()) {
269 return transformRoot.emitError()
270 << "expected at least one binding for the root";
271 }
272 if (bindings.at(0).size() != 1) {
273 return transformRoot.emitError()
274 << "expected one payload to be bound to the first argument, got "
275 << bindings.at(0).size();
276 }
277 auto *payloadRoot = dyn_cast<Operation *>(bindings.at(0).front());
278 if (!payloadRoot) {
279 return transformRoot->emitError() << "expected the object bound to the "
280 "first argument to be an operation";
281 }
282
283 bindings.removeFront();
284
285 // `transformModule` may not be modified.
286 if (transformModule && !transformModule->isAncestor(transformRoot)) {
287 OwningOpRef<Operation *> clonedTransformModule(transformModule->clone());
288 if (failed(detail::mergeSymbolsInto(
290 std::move(clonedTransformModule)))) {
291 return payloadRoot->emitError() << "failed to merge symbols";
292 }
293 }
294
295 LLVM_DEBUG(DBGS() << "Apply\n" << *transformRoot << "\n");
296 LLVM_DEBUG(DBGS() << "To\n" << *payloadRoot << "\n");
297
298 return applyTransforms(payloadRoot, transformRoot, bindings, options,
299 /*enforceToplevelTransformOp=*/false);
300}
return success()
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static transform::TransformOpInterface findTransformEntryPointRecursive(Operation *op, StringRef entryPoint)
static transform::TransformOpInterface findTransformEntryPointNonRecursive(Operation *op, StringRef entryPoint)
#define DBGS()
static transform::TransformOpInterface findTransformEntryPointInOp(Operation *op, StringRef entryPoint)
Block represents an ordered list of Operations.
Definition Block.h:33
static FileLineColLoc get(StringAttr filename, unsigned line, unsigned column)
Definition Location.cpp:157
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
T * getOrLoadDialect()
Get (or create) a dialect for the given derived dialect type.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:677
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:797
This class acts as an owning reference to an op, and will automatically destroy the held op on destru...
Definition OwningOpRef.h:29
OpTy get() const
Allow accessing the internal op.
Definition OwningOpRef.h:51
A 2D array where each row may have different length.
Definition RaggedArray.h:18
void removeFront()
Removes the first subarray in-place. Invalidates iterators to all rows.
bool empty() const
Returns true if the are no rows in the 2D array.
Definition RaggedArray.h:25
ArrayRef< T > at(size_t pos) const
Definition RaggedArray.h:29
void push_back(Range &&elements)
Appends the given range of elements as a new row to the 2D array.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
Options controlling the application of transform operations by the TransformState.
LogicalResult assembleTransformLibraryFromPaths(MLIRContext *context, ArrayRef< std::string > transformLibraryPaths, OwningOpRef< ModuleOp > &transformModule)
Utility to parse, verify, aggregate and link the content of all mlir files nested under transformLibr...
LogicalResult parseTransformModuleFromFile(MLIRContext *context, llvm::StringRef transformFileName, OwningOpRef< ModuleOp > &transformModule)
Utility to parse and verify the content of a transformFileName MLIR file containing a transform diale...
ModuleOp getPreloadedTransformModule(MLIRContext *context)
Utility to load a transform interpreter module from a module that has already been preloaded in the c...
InFlightDiagnostic mergeSymbolsInto(Operation *target, OwningOpRef< Operation * > other)
Merge all symbols from other into target.
Definition Utils.cpp:80
LogicalResult expandPathsToMLIRFiles(ArrayRef< std::string > paths, MLIRContext *context, SmallVectorImpl< std::string > &fileNames)
Expands the given list of paths to a list of .mlir files.
TransformOpInterface findTransformEntryPoint(Operation *root, ModuleOp module, StringRef entryPoint=TransformDialect::kTransformEntryPointSymbolName)
Finds the first TransformOpInterface named kTransformEntryPointSymbolName that is either:
LogicalResult applyTransformNamedSequence(Operation *payload, Operation *transformRoot, ModuleOp transformModule, const TransformOptions &options)
Standalone util to apply the named sequence transformRoot to payload IR.
LogicalResult applyTransforms(Operation *payloadRoot, TransformOpInterface transform, const RaggedArray< MappedValue > &extraMapping={}, const TransformOptions &options=TransformOptions(), bool enforceToplevelTransformOp=true, function_ref< void(TransformState &)> stateInitializer=nullptr, function_ref< LogicalResult(TransformState &)> stateExporter=nullptr)
Entry point to the Transform dialect infrastructure.
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::unique_ptr< llvm::MemoryBuffer > openInputFile(llvm::StringRef inputFilename, std::string *errorMessage=nullptr)
Open the file specified by its name for reading.
LogicalResult parseSourceFile(const llvm::SourceMgr &sourceMgr, Block *block, const ParserConfig &config, LocationAttr *sourceFileLoc=nullptr)
This parses the file specified by the indicated SourceMgr and appends parsed operations to the given ...
Definition Parser.cpp:38
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition Verifier.cpp:423