24 #include "llvm/ADT/StringRef.h"
25 #include "llvm/Support/Casting.h"
26 #include "llvm/Support/Debug.h"
27 #include "llvm/Support/FileSystem.h"
28 #include "llvm/Support/SourceMgr.h"
29 #include "llvm/Support/raw_ostream.h"
33 #define DEBUG_TYPE "transform-dialect-interpreter-utils"
34 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
44 for (
const std::string &path : paths) {
47 if (llvm::sys::fs::is_regular_file(path)) {
48 LLVM_DEBUG(
DBGS() <<
"Adding '" << path <<
"' to list of files\n");
49 fileNames.push_back(path);
53 if (!llvm::sys::fs::is_directory(path)) {
55 <<
"'" << path <<
"' is neither a file nor a directory";
58 LLVM_DEBUG(
DBGS() <<
"Looking for files in '" << path <<
"':\n");
61 for (llvm::sys::fs::directory_iterator it(path, ec), itEnd;
62 it != itEnd && !ec; it.increment(ec)) {
63 const std::string &fileName = it->path();
65 if (it->type() != llvm::sys::fs::file_type::regular_file &&
66 it->type() != llvm::sys::fs::file_type::symlink_file) {
67 LLVM_DEBUG(
DBGS() <<
" Skipping non-regular file '" << fileName
72 if (!StringRef(fileName).ends_with(
".mlir")) {
73 LLVM_DEBUG(
DBGS() <<
" Skipping '" << fileName
74 <<
"' because it does not end with '.mlir'\n");
78 LLVM_DEBUG(
DBGS() <<
" Adding '" << fileName <<
"' to list of files\n");
79 fileNames.push_back(fileName);
83 return emitError(loc) <<
"error while opening files in '" << path
84 <<
"': " << ec.message();
91 MLIRContext *context, llvm::StringRef transformFileName,
93 if (transformFileName.empty()) {
95 DBGS() <<
"no transform file name specified, assuming the transform "
96 "module is embedded in the IR next to the top-level\n");
100 std::string errorMessage;
105 <<
"failed to open transform file: " << errorMessage;
108 llvm::SourceMgr sourceMgr;
109 sourceMgr.AddNewSourceBuffer(std::move(memoryBuffer), llvm::SMLoc());
112 if (!transformModule) {
123 ->getLibraryModule();
126 transform::TransformOpInterface
128 StringRef entryPoint) {
133 transform::TransformOpInterface transform =
nullptr;
135 [&](transform::NamedSequenceOp namedSequenceOp) {
136 if (namedSequenceOp.getSymName() == entryPoint) {
137 transform = cast<transform::TransformOpInterface>(
138 namedSequenceOp.getOperation());
147 <<
"could not find a nested named sequence with name: "
163 for (
const std::string &libraryFileName : libraryFileNames) {
166 context, libraryFileName, parsedLibrary)))
168 parsedLibraries.push_back(std::move(parsedLibrary));
174 ModuleOp::create(loc,
"__transform");
176 mergedParsedLibraries.
get()->setAttr(
"transform.with_named_sequence",
181 mergedParsedLibraries.
get(), std::move(parsedLibrary))))
182 return parsedLibrary->emitError()
183 <<
"failed to merge symbols into shared library module";
187 transformModule = std::move(mergedParsedLibraries);
197 cast<TransformOpInterface>(transformRoot),
204 if (bindings.
empty()) {
205 return transformRoot.emitError()
206 <<
"expected at least one binding for the root";
208 if (bindings.
at(0).size() != 1) {
209 return transformRoot.emitError()
210 <<
"expected one payload to be bound to the first argument, got "
211 << bindings.
at(0).size();
213 auto *payloadRoot = bindings.
at(0).front().dyn_cast<
Operation *>();
215 return transformRoot->
emitError() <<
"expected the object bound to the "
216 "first argument to be an operation";
222 if (transformModule && !transformModule->isAncestor(transformRoot)) {
226 std::move(clonedTransformModule)))) {
227 return payloadRoot->emitError() <<
"failed to merge symbols";
231 LLVM_DEBUG(
DBGS() <<
"Apply\n" << *transformRoot <<
"\n");
232 LLVM_DEBUG(
DBGS() <<
"To\n" << *payloadRoot <<
"\n");
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
MLIRContext is the top-level object for a collection of MLIR operations.
T * getOrLoadDialect()
Get (or create) a dialect for the given derived dialect type.
Operation is the basic unit of execution within MLIR.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy get() const
Allow accessing the internal op.
A 2D array where each row may have different length.
ArrayRef< T > at(size_t pos) const
void removeFront()
Removes the first subarray in-place. Invalidates iterators to all rows.
bool empty() const
Returns true if the are no rows in the 2D array.
void push_back(Range &&elements)
Appends the given range of elements as a new row to the 2D array.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
static WalkResult advance()
static WalkResult interrupt()
Include the generated interface declarations.
std::unique_ptr< llvm::MemoryBuffer > openInputFile(llvm::StringRef inputFilename, std::string *errorMessage=nullptr)
Open the file specified by its name for reading.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...