MLIR  22.0.0git
TranslateRegistration.cpp
Go to the documentation of this file.
1 //===- TranslateRegistration.cpp - hooks to mlir-translate ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a translation from SPIR-V binary module to MLIR SPIR-V
10 // ModuleOp.
11 //
12 //===----------------------------------------------------------------------===//
13 
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/Verifier.h"
18 #include "mlir/Parser/Parser.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "llvm/Support/FileSystem.h"
25 #include "llvm/Support/MemoryBuffer.h"
26 #include "llvm/Support/Path.h"
27 #include "llvm/Support/SourceMgr.h"
28 #include "llvm/Support/ToolOutputFile.h"
29 
30 using namespace mlir;
31 
32 //===----------------------------------------------------------------------===//
33 // Deserialization registration
34 //===----------------------------------------------------------------------===//
35 
36 // Deserializes the SPIR-V binary module stored in the file named as
37 // `inputFilename` and returns a module containing the SPIR-V module.
39 deserializeModule(const llvm::MemoryBuffer *input, MLIRContext *context,
41  context->loadDialect<spirv::SPIRVDialect>();
42 
43  // Make sure the input stream can be treated as a stream of SPIR-V words
44  auto *start = input->getBufferStart();
45  auto size = input->getBufferSize();
46  if (size % sizeof(uint32_t) != 0) {
47  emitError(UnknownLoc::get(context))
48  << "SPIR-V binary module must contain integral number of 32-bit words";
49  return {};
50  }
51 
52  auto binary = llvm::ArrayRef(reinterpret_cast<const uint32_t *>(start),
53  size / sizeof(uint32_t));
54  return spirv::deserialize(binary, context, options);
55 }
56 
57 namespace mlir {
59  static llvm::cl::opt<bool> enableControlFlowStructurization(
60  "spirv-structurize-control-flow",
61  llvm::cl::desc(
62  "Enable control flow structurization into `spirv.mlir.selection` and "
63  "`spirv.mlir.loop`. This may need to be disabled to support "
64  "deserialization of early exits (see #138688)"),
65  llvm::cl::init(true));
66 
67  TranslateToMLIRRegistration fromBinary(
68  "deserialize-spirv", "deserializes the SPIR-V module",
69  [](llvm::SourceMgr &sourceMgr, MLIRContext *context) {
70  assert(sourceMgr.getNumBuffers() == 1 && "expected one buffer");
71  return deserializeModule(
72  sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()), context,
73  {enableControlFlowStructurization});
74  });
75 }
76 } // namespace mlir
77 
78 //===----------------------------------------------------------------------===//
79 // Serialization registration
80 //===----------------------------------------------------------------------===//
81 
82 static LogicalResult
83 serializeModule(spirv::ModuleOp moduleOp, raw_ostream &output,
86  if (failed(spirv::serialize(moduleOp, binary)))
87  return failure();
88 
89  size_t sizeInBytes = binary.size() * sizeof(uint32_t);
90 
91  output.write(reinterpret_cast<char *>(binary.data()), sizeInBytes);
92 
93  if (options.saveModuleForValidation) {
94  size_t dirSeparator =
95  options.validationFilePrefix.find(llvm::sys::path::get_separator());
96  // If file prefix includes directory check if that directory exists.
97  if (dirSeparator != std::string::npos) {
98  llvm::StringRef parentDir =
99  llvm::sys::path::parent_path(options.validationFilePrefix);
100  if (!llvm::sys::fs::is_directory(parentDir))
101  return moduleOp.emitError(
102  "validation prefix directory does not exist\n");
103  }
104 
105  SmallString<128> filename;
106  int fd = 0;
107 
108  std::error_code errorCode = llvm::sys::fs::createUniqueFile(
109  options.validationFilePrefix + "%%%%%%.spv", fd, filename);
110  if (errorCode)
111  return moduleOp.emitError("error creating validation output file: ")
112  << errorCode.message() << "\n";
113 
114  llvm::raw_fd_ostream validationOutput(fd, /*shouldClose=*/true);
115  validationOutput.write(reinterpret_cast<char *>(binary.data()),
116  sizeInBytes);
117  validationOutput.flush();
118  }
119 
120  return mlir::success();
121 }
122 
123 namespace mlir {
125  static llvm::cl::opt<std::string> validationFilesPrefix(
126  "spirv-save-validation-files-with-prefix",
127  llvm::cl::desc(
128  "When non-empty string is passed each serialized SPIR-V module is "
129  "saved to an additional file that starts with the given prefix. This "
130  "is used to generate separate binaries for validation, where "
131  "`--split-input-file` normally combines all outputs into one. The "
132  "one combined output (`-o`) is still written. Created files need to "
133  "be removed manually once processed."),
134  llvm::cl::init(""));
135 
137  "serialize-spirv", "serialize SPIR-V dialect",
138  [](spirv::ModuleOp moduleOp, raw_ostream &output) {
139  return serializeModule(moduleOp, output,
140  {true, false, !validationFilesPrefix.empty(),
141  validationFilesPrefix});
142  },
143  [](DialectRegistry &registry) {
144  registry.insert<spirv::SPIRVDialect>();
145  });
146 }
147 } // namespace mlir
148 
149 //===----------------------------------------------------------------------===//
150 // Round-trip registration
151 //===----------------------------------------------------------------------===//
152 
153 static LogicalResult roundTripModule(spirv::ModuleOp module, bool emitDebugInfo,
154  raw_ostream &output) {
156  MLIRContext *context = module->getContext();
157 
159  options.emitDebugInfo = emitDebugInfo;
160  if (failed(spirv::serialize(module, binary, options)))
161  return failure();
162 
163  MLIRContext deserializationContext(context->getDialectRegistry());
164  // TODO: we should only load the required dialects instead of all dialects.
165  deserializationContext.loadAllAvailableDialects();
166  // Then deserialize to get back a SPIR-V module.
167  OwningOpRef<spirv::ModuleOp> spirvModule =
168  spirv::deserialize(binary, &deserializationContext);
169  if (!spirvModule)
170  return failure();
171  spirvModule->print(output);
172 
173  return mlir::success();
174 }
175 
176 namespace mlir {
179  "test-spirv-roundtrip", "test roundtrip in SPIR-V dialect",
180  [](spirv::ModuleOp module, raw_ostream &output) {
181  return roundTripModule(module, /*emitDebugInfo=*/false, output);
182  },
183  [](DialectRegistry &registry) {
184  registry.insert<spirv::SPIRVDialect>();
185  });
186 }
187 
190  "test-spirv-roundtrip-debug", "test roundtrip debug in SPIR-V",
191  [](spirv::ModuleOp module, raw_ostream &output) {
192  return roundTripModule(module, /*emitDebugInfo=*/true, output);
193  },
194  [](DialectRegistry &registry) {
195  registry.insert<spirv::SPIRVDialect>();
196  });
197 }
198 } // namespace mlir
static llvm::ManagedStatic< PassManagerOptions > options
static OwningOpRef< Operation * > deserializeModule(const llvm::MemoryBuffer *input, MLIRContext *context, const spirv::DeserializationOptions &options)
static LogicalResult roundTripModule(spirv::ModuleOp module, bool emitDebugInfo, raw_ostream &output)
static LogicalResult serializeModule(spirv::ModuleOp moduleOp, raw_ostream &output, const spirv::SerializationOptions &options)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
const DialectRegistry & getDialectRegistry()
Return the dialect registry associated with this context.
void loadDialect()
Load a dialect in the context.
Definition: MLIRContext.h:110
void loadAllAvailableDialects()
Load all dialects available in the registry in this context.
This class acts as an owning reference to an op, and will automatically destroy the held op on destru...
Definition: OwningOpRef.h:29
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
LogicalResult serialize(ModuleOp moduleOp, SmallVectorImpl< uint32_t > &binary, const SerializationOptions &options={})
Serializes the given SPIR-V moduleOp and writes to binary.
OwningOpRef< spirv::ModuleOp > deserialize(ArrayRef< uint32_t > binary, MLIRContext *context, const DeserializationOptions &options={})
Deserializes the given SPIR-V binary module and creates a MLIR ModuleOp in the given context.
Include the generated interface declarations.
void registerTestRoundtripSPIRV()
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void registerFromSPIRVTranslation()
void registerTestRoundtripDebugSPIRV()
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void registerToSPIRVTranslation()
Use Translate[ToMLIR|FromMLIR]Registration as an initializer that registers a function and associates...
Definition: Translation.h:110