MLIR  21.0.0git
TranslateRegistration.cpp
Go to the documentation of this file.
1 //===- TranslateRegistration.cpp - hooks to mlir-translate ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a translation from SPIR-V binary module to MLIR SPIR-V
10 // ModuleOp.
11 //
12 //===----------------------------------------------------------------------===//
13 
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/Dialect.h"
19 #include "mlir/IR/Verifier.h"
20 #include "mlir/Parser/Parser.h"
25 #include "llvm/ADT/StringRef.h"
26 #include "llvm/Support/MemoryBuffer.h"
27 #include "llvm/Support/SMLoc.h"
28 #include "llvm/Support/SourceMgr.h"
29 #include "llvm/Support/ToolOutputFile.h"
30 
31 using namespace mlir;
32 
33 //===----------------------------------------------------------------------===//
34 // Deserialization registration
35 //===----------------------------------------------------------------------===//
36 
37 // Deserializes the SPIR-V binary module stored in the file named as
38 // `inputFilename` and returns a module containing the SPIR-V module.
40 deserializeModule(const llvm::MemoryBuffer *input, MLIRContext *context,
42  context->loadDialect<spirv::SPIRVDialect>();
43 
44  // Make sure the input stream can be treated as a stream of SPIR-V words
45  auto *start = input->getBufferStart();
46  auto size = input->getBufferSize();
47  if (size % sizeof(uint32_t) != 0) {
48  emitError(UnknownLoc::get(context))
49  << "SPIR-V binary module must contain integral number of 32-bit words";
50  return {};
51  }
52 
53  auto binary = llvm::ArrayRef(reinterpret_cast<const uint32_t *>(start),
54  size / sizeof(uint32_t));
55  return spirv::deserialize(binary, context, options);
56 }
57 
58 namespace mlir {
60  static llvm::cl::opt<bool> enableControlFlowStructurization(
61  "spirv-structurize-control-flow",
62  llvm::cl::desc(
63  "Enable control flow structurization into `spirv.mlir.selection` and "
64  "`spirv.mlir.loop`. This may need to be disabled to support "
65  "deserialization of early exits (see #138688)"),
66  llvm::cl::init(true));
67 
68  TranslateToMLIRRegistration fromBinary(
69  "deserialize-spirv", "deserializes the SPIR-V module",
70  [](llvm::SourceMgr &sourceMgr, MLIRContext *context) {
71  assert(sourceMgr.getNumBuffers() == 1 && "expected one buffer");
72  return deserializeModule(
73  sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()), context,
74  {enableControlFlowStructurization});
75  });
76 }
77 } // namespace mlir
78 
79 //===----------------------------------------------------------------------===//
80 // Serialization registration
81 //===----------------------------------------------------------------------===//
82 
83 static LogicalResult serializeModule(spirv::ModuleOp module,
84  raw_ostream &output) {
86  if (failed(spirv::serialize(module, binary)))
87  return failure();
88 
89  output.write(reinterpret_cast<char *>(binary.data()),
90  binary.size() * sizeof(uint32_t));
91 
92  return mlir::success();
93 }
94 
95 namespace mlir {
98  "serialize-spirv", "serialize SPIR-V dialect",
99  [](spirv::ModuleOp module, raw_ostream &output) {
100  return serializeModule(module, output);
101  },
102  [](DialectRegistry &registry) {
103  registry.insert<spirv::SPIRVDialect>();
104  });
105 }
106 } // namespace mlir
107 
108 //===----------------------------------------------------------------------===//
109 // Round-trip registration
110 //===----------------------------------------------------------------------===//
111 
112 static LogicalResult roundTripModule(spirv::ModuleOp module, bool emitDebugInfo,
113  raw_ostream &output) {
115  MLIRContext *context = module->getContext();
116 
118  options.emitDebugInfo = emitDebugInfo;
119  if (failed(spirv::serialize(module, binary, options)))
120  return failure();
121 
122  MLIRContext deserializationContext(context->getDialectRegistry());
123  // TODO: we should only load the required dialects instead of all dialects.
124  deserializationContext.loadAllAvailableDialects();
125  // Then deserialize to get back a SPIR-V module.
126  OwningOpRef<spirv::ModuleOp> spirvModule =
127  spirv::deserialize(binary, &deserializationContext);
128  if (!spirvModule)
129  return failure();
130  spirvModule->print(output);
131 
132  return mlir::success();
133 }
134 
135 namespace mlir {
138  "test-spirv-roundtrip", "test roundtrip in SPIR-V dialect",
139  [](spirv::ModuleOp module, raw_ostream &output) {
140  return roundTripModule(module, /*emitDebugInfo=*/false, output);
141  },
142  [](DialectRegistry &registry) {
143  registry.insert<spirv::SPIRVDialect>();
144  });
145 }
146 
149  "test-spirv-roundtrip-debug", "test roundtrip debug in SPIR-V",
150  [](spirv::ModuleOp module, raw_ostream &output) {
151  return roundTripModule(module, /*emitDebugInfo=*/true, output);
152  },
153  [](DialectRegistry &registry) {
154  registry.insert<spirv::SPIRVDialect>();
155  });
156 }
157 } // namespace mlir
static llvm::ManagedStatic< PassManagerOptions > options
static LogicalResult serializeModule(spirv::ModuleOp module, raw_ostream &output)
static OwningOpRef< Operation * > deserializeModule(const llvm::MemoryBuffer *input, MLIRContext *context, const spirv::DeserializationOptions &options)
static LogicalResult roundTripModule(spirv::ModuleOp module, bool emitDebugInfo, raw_ostream &output)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
const DialectRegistry & getDialectRegistry()
Return the dialect registry associated with this context.
void loadDialect()
Load a dialect in the context.
Definition: MLIRContext.h:107
void loadAllAvailableDialects()
Load all dialects available in the registry in this context.
This class acts as an owning reference to an op, and will automatically destroy the held op on destru...
Definition: OwningOpRef.h:29
OwningOpRef< spirv::ModuleOp > deserialize(ArrayRef< uint32_t > binary, MLIRContext *context, const DeserializationOptions &options={})
Deserializes the given SPIR-V binary module and creates a MLIR ModuleOp in the given context.
LogicalResult serialize(ModuleOp module, SmallVectorImpl< uint32_t > &binary, const SerializationOptions &options={})
Serializes the given SPIR-V module and writes to binary.
Include the generated interface declarations.
void registerTestRoundtripSPIRV()
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void registerFromSPIRVTranslation()
void registerTestRoundtripDebugSPIRV()
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void registerToSPIRVTranslation()
Use Translate[ToMLIR|FromMLIR]Registration as an initializer that registers a function and associates...
Definition: Translation.h:110