23 #include "llvm/ADT/StringRef.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/FileSystem.h"
26 #include "llvm/Support/SourceMgr.h"
27 #include "llvm/Support/raw_ostream.h"
31 #define DEBUG_TYPE "transform-dialect-interpreter-utils"
32 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
42 for (
const std::string &path : paths) {
45 if (llvm::sys::fs::is_regular_file(path)) {
46 LLVM_DEBUG(
DBGS() <<
"Adding '" << path <<
"' to list of files\n");
47 fileNames.push_back(path);
51 if (!llvm::sys::fs::is_directory(path)) {
53 <<
"'" << path <<
"' is neither a file nor a directory";
56 LLVM_DEBUG(
DBGS() <<
"Looking for files in '" << path <<
"':\n");
59 for (llvm::sys::fs::directory_iterator it(path, ec), itEnd;
60 it != itEnd && !ec; it.increment(ec)) {
61 const std::string &fileName = it->path();
63 if (it->type() != llvm::sys::fs::file_type::regular_file &&
64 it->type() != llvm::sys::fs::file_type::symlink_file) {
65 LLVM_DEBUG(
DBGS() <<
" Skipping non-regular file '" << fileName
70 if (!StringRef(fileName).ends_with(
".mlir")) {
71 LLVM_DEBUG(
DBGS() <<
" Skipping '" << fileName
72 <<
"' because it does not end with '.mlir'\n");
76 LLVM_DEBUG(
DBGS() <<
" Adding '" << fileName <<
"' to list of files\n");
77 fileNames.push_back(fileName);
81 return emitError(loc) <<
"error while opening files in '" << path
82 <<
"': " << ec.message();
89 MLIRContext *context, llvm::StringRef transformFileName,
91 if (transformFileName.empty()) {
93 DBGS() <<
"no transform file name specified, assuming the transform "
94 "module is embedded in the IR next to the top-level\n");
98 std::string errorMessage;
103 <<
"failed to open transform file: " << errorMessage;
106 llvm::SourceMgr sourceMgr;
107 sourceMgr.AddNewSourceBuffer(std::move(memoryBuffer), llvm::SMLoc());
110 if (!transformModule) {
121 ->getLibraryModule();
124 transform::TransformOpInterface
126 StringRef entryPoint) {
131 transform::TransformOpInterface transform =
nullptr;
133 [&](transform::NamedSequenceOp namedSequenceOp) {
134 if (namedSequenceOp.getSymName() == entryPoint) {
135 transform = cast<transform::TransformOpInterface>(
136 namedSequenceOp.getOperation());
145 <<
"could not find a nested named sequence with name: "
161 for (
const std::string &libraryFileName : libraryFileNames) {
164 context, libraryFileName, parsedLibrary)))
166 parsedLibraries.push_back(std::move(parsedLibrary));
172 ModuleOp::create(loc,
"__transform");
174 mergedParsedLibraries.
get()->setAttr(
"transform.with_named_sequence",
179 mergedParsedLibraries.
get(), std::move(parsedLibrary))))
180 return parsedLibrary->emitError()
181 <<
"failed to merge symbols into shared library module";
185 transformModule = std::move(mergedParsedLibraries);
195 cast<TransformOpInterface>(transformRoot),
202 if (bindings.
empty()) {
203 return transformRoot.emitError()
204 <<
"expected at least one binding for the root";
206 if (bindings.
at(0).size() != 1) {
207 return transformRoot.emitError()
208 <<
"expected one payload to be bound to the first argument, got "
209 << bindings.
at(0).size();
211 auto *payloadRoot = dyn_cast<Operation *>(bindings.
at(0).front());
213 return transformRoot->emitError() <<
"expected the object bound to the "
214 "first argument to be an operation";
220 if (transformModule && !transformModule->isAncestor(transformRoot)) {
224 std::move(clonedTransformModule)))) {
225 return payloadRoot->emitError() <<
"failed to merge symbols";
229 LLVM_DEBUG(
DBGS() <<
"Apply\n" << *transformRoot <<
"\n");
230 LLVM_DEBUG(
DBGS() <<
"To\n" << *payloadRoot <<
"\n");
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static FileLineColLoc get(StringAttr filename, unsigned line, unsigned column)
MLIRContext is the top-level object for a collection of MLIR operations.
T * getOrLoadDialect()
Get (or create) a dialect for the given derived dialect type.
Operation is the basic unit of execution within MLIR.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy get() const
Allow accessing the internal op.
A 2D array where each row may have different length.
ArrayRef< T > at(size_t pos) const
void removeFront()
Removes the first subarray in-place. Invalidates iterators to all rows.
bool empty() const
Returns true if the are no rows in the 2D array.
void push_back(Range &&elements)
Appends the given range of elements as a new row to the 2D array.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
static WalkResult advance()
static WalkResult interrupt()
Include the generated interface declarations.
std::unique_ptr< llvm::MemoryBuffer > openInputFile(llvm::StringRef inputFilename, std::string *errorMessage=nullptr)
Open the file specified by its name for reading.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...