MLIR  22.0.0git
TranslateFromWasm.cpp
Go to the documentation of this file.
1 //===- TranslateFromWasm.cpp - Translating to WasmSSA dialect -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the WebAssembly importer.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 #include "mlir/IR/Attributes.h"
15 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/Location.h"
19 #include "mlir/Support/LLVM.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/DebugLog.h"
24 #include "llvm/Support/Endian.h"
25 #include "llvm/Support/FormatVariadic.h"
26 #include "llvm/Support/LEB128.h"
27 #include "llvm/Support/LogicalResult.h"
28 
29 #include <cstddef>
30 #include <cstdint>
31 #include <variant>
32 
33 #define DEBUG_TYPE "wasm-translate"
34 
35 static_assert(CHAR_BIT == 8,
36  "This code expects std::byte to be exactly 8 bits");
37 
38 using namespace mlir;
39 using namespace mlir::wasm;
40 using namespace mlir::wasmssa;
41 
42 namespace {
43 using section_id_t = uint8_t;
44 enum struct WasmSectionType : section_id_t {
45  CUSTOM = 0,
46  TYPE = 1,
47  IMPORT = 2,
48  FUNCTION = 3,
49  TABLE = 4,
50  MEMORY = 5,
51  GLOBAL = 6,
52  EXPORT = 7,
53  START = 8,
54  ELEMENT = 9,
55  CODE = 10,
56  DATA = 11,
57  DATACOUNT = 12
58 };
59 
60 constexpr section_id_t highestWasmSectionID{
61  static_cast<section_id_t>(WasmSectionType::DATACOUNT)};
62 
63 #define APPLY_WASM_SEC_TRANSFORM \
64  WASM_SEC_TRANSFORM(CUSTOM) \
65  WASM_SEC_TRANSFORM(TYPE) \
66  WASM_SEC_TRANSFORM(IMPORT) \
67  WASM_SEC_TRANSFORM(FUNCTION) \
68  WASM_SEC_TRANSFORM(TABLE) \
69  WASM_SEC_TRANSFORM(MEMORY) \
70  WASM_SEC_TRANSFORM(GLOBAL) \
71  WASM_SEC_TRANSFORM(EXPORT) \
72  WASM_SEC_TRANSFORM(START) \
73  WASM_SEC_TRANSFORM(ELEMENT) \
74  WASM_SEC_TRANSFORM(CODE) \
75  WASM_SEC_TRANSFORM(DATA) \
76  WASM_SEC_TRANSFORM(DATACOUNT)
77 
78 template <WasmSectionType>
79 constexpr const char *wasmSectionName = "";
80 
81 #define WASM_SEC_TRANSFORM(section) \
82  template <> \
83  [[maybe_unused]] constexpr const char \
84  *wasmSectionName<WasmSectionType::section> = #section;
86 #undef WASM_SEC_TRANSFORM
87 
88 constexpr bool sectionShouldBeUnique(WasmSectionType secType) {
89  return secType != WasmSectionType::CUSTOM;
90 }
91 
92 template <std::byte... Bytes>
93 struct ByteSequence {};
94 
95 /// Template class for representing a byte sequence of only one byte
96 template <std::byte Byte>
97 struct UniqueByte : ByteSequence<Byte> {};
98 
99 [[maybe_unused]] constexpr ByteSequence<
102  WasmBinaryEncoding::Type::v128> valueTypesEncodings{};
103 
104 template <std::byte... allowedFlags>
105 constexpr bool isValueOneOf(std::byte value,
106  ByteSequence<allowedFlags...> = {}) {
107  return ((value == allowedFlags) | ... | false);
108 }
109 
110 template <std::byte... flags>
111 constexpr bool isNotIn(std::byte value, ByteSequence<flags...> = {}) {
112  return !isValueOneOf<flags...>(value);
113 }
114 
115 struct GlobalTypeRecord {
116  Type type;
117  bool isMutable;
118 };
119 
120 struct TypeIdxRecord {
121  size_t id;
122 };
123 
124 struct SymbolRefContainer {
125  FlatSymbolRefAttr symbol;
126 };
127 
128 struct GlobalSymbolRefContainer : SymbolRefContainer {
129  Type globalType;
130 };
131 
132 struct FunctionSymbolRefContainer : SymbolRefContainer {
133  FunctionType functionType;
134 };
135 
136 using ImportDesc =
137  std::variant<TypeIdxRecord, TableType, LimitType, GlobalTypeRecord>;
138 
139 using parsed_inst_t = FailureOr<SmallVector<Value>>;
140 
141 struct WasmModuleSymbolTables {
145  SmallVector<SymbolRefContainer> tableSymbols;
146  SmallVector<FunctionType> moduleFuncTypes;
147 
148  std::string getNewSymbolName(StringRef prefix, size_t id) const {
149  return (prefix + Twine{id}).str();
150  }
151 
152  std::string getNewFuncSymbolName() const {
153  size_t id = funcSymbols.size();
154  return getNewSymbolName("func_", id);
155  }
156 
157  std::string getNewGlobalSymbolName() const {
158  size_t id = globalSymbols.size();
159  return getNewSymbolName("global_", id);
160  }
161 
162  std::string getNewMemorySymbolName() const {
163  size_t id = memSymbols.size();
164  return getNewSymbolName("mem_", id);
165  }
166 
167  std::string getNewTableSymbolName() const {
168  size_t id = tableSymbols.size();
169  return getNewSymbolName("table_", id);
170  }
171 };
172 
173 class ParserHead;
174 
175 /// Wrapper around SmallVector to only allow access as push and pop on the
176 /// stack. Makes sure that there are no "free accesses" on the stack to preserve
177 /// its state.
178 class ValueStack {
179 private:
180  struct LabelLevel {
181  size_t stackIdx;
182  LabelLevelOpInterface levelOp;
183  };
184 
185 public:
186  bool empty() const { return values.empty(); }
187 
188  size_t size() const { return values.size(); }
189 
190  /// Pops values from the stack because they are being used in an operation.
191  /// @param operandTypes The list of expected types of the operation, used
192  /// to know how many values to pop and check if the types match the
193  /// expectation.
194  /// @param opLoc Location of the caller, used to report accurately the
195  /// location
196  /// if an error occurs.
197  /// @return Failure or the vector of popped values.
198  FailureOr<SmallVector<Value>> popOperands(TypeRange operandTypes,
199  Location *opLoc);
200 
201  /// Push the results of an operation to the stack so they can be used in a
202  /// following operation.
203  /// @param results The list of results of the operation
204  /// @param opLoc Location of the caller, used to report accurately the
205  /// location
206  /// if an error occurs.
207  LogicalResult pushResults(ValueRange results, Location *opLoc);
208 
209 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
210  /// A simple dump function for debugging.
211  /// Writes output to llvm::dbgs().
212  LLVM_DUMP_METHOD void dump() const;
213 #endif
214 
215 private:
216  SmallVector<Value> values;
217 };
218 
219 using local_val_t = TypedValue<wasmssa::LocalRefType>;
220 
221 class ExpressionParser {
222 public:
223  using locals_t = SmallVector<local_val_t>;
224  ExpressionParser(ParserHead &parser, WasmModuleSymbolTables const &symbols,
225  ArrayRef<local_val_t> initLocal)
226  : parser{parser}, symbols{symbols}, locals{initLocal} {}
227 
228 private:
229  template <std::byte opCode>
230  inline parsed_inst_t parseSpecificInstruction(OpBuilder &builder);
231 
232  template <typename valueT>
233  parsed_inst_t
234  parseConstInst(OpBuilder &builder,
235  std::enable_if_t<std::is_arithmetic_v<valueT>> * = nullptr);
236 
237  /// Construct an operation with \p numOperands operands and a single result.
238  /// Each operand must have the same type. Suitable for e.g. binops, unary
239  /// ops, etc.
240  ///
241  /// \p opcode - The WASM opcode to build.
242  /// \p valueType - The operand and result type for the built instruction.
243  /// \p numOperands - The number of operands for the built operation.
244  ///
245  /// \returns The parsed instruction result, or failure.
246  template <typename opcode, typename valueType, unsigned int numOperands>
247  inline parsed_inst_t
248  buildNumericOp(OpBuilder &builder,
249  std::enable_if_t<std::is_arithmetic_v<valueType>> * = nullptr);
250 
251  /// This function generates a dispatch tree to associate an opcode with a
252  /// parser. Parsers are registered by specialising the
253  /// `parseSpecificInstruction` function for the op code to handle.
254  ///
255  /// The dispatcher is generated by recursively creating all possible patterns
256  /// for an opcode and calling the relevant parser on the leaf.
257  ///
258  /// @tparam patternBitSize is the first bit for which the pattern is not fixed
259  ///
260  /// @tparam highBitPattern is the fixed pattern that this instance handles for
261  /// the 8-patternBitSize bits
262  template <size_t patternBitSize = 0, std::byte highBitPattern = std::byte{0}>
263  inline parsed_inst_t dispatchToInstParser(std::byte opCode,
264  OpBuilder &builder) {
265  static_assert(patternBitSize <= 8,
266  "PatternBitSize is outside of range of opcode space! "
267  "(expected at most 8 bits)");
268  if constexpr (patternBitSize < 8) {
269  constexpr std::byte bitSelect{1 << (7 - patternBitSize)};
270  constexpr std::byte nextHighBitPatternStem = highBitPattern << 1;
271  constexpr size_t nextPatternBitSize = patternBitSize + 1;
272  if ((opCode & bitSelect) != std::byte{0})
273  return dispatchToInstParser<nextPatternBitSize,
274  nextHighBitPatternStem | std::byte{1}>(
275  opCode, builder);
276  return dispatchToInstParser<nextPatternBitSize, nextHighBitPatternStem>(
277  opCode, builder);
278  } else {
279  return parseSpecificInstruction<highBitPattern>(builder);
280  }
281  }
282 
283  struct ParseResultWithInfo {
284  SmallVector<Value> opResults;
285  std::byte endingByte;
286  };
287 
288 public:
289  template <std::byte ParseEndByte = WasmBinaryEncoding::endByte>
290  parsed_inst_t parse(OpBuilder &builder, UniqueByte<ParseEndByte> = {});
291 
292  template <std::byte... ExpressionParseEnd>
293  FailureOr<ParseResultWithInfo>
294  parse(OpBuilder &builder,
295  ByteSequence<ExpressionParseEnd...> parsingEndFilters);
296 
297  FailureOr<SmallVector<Value>> popOperands(TypeRange operandTypes) {
298  return valueStack.popOperands(operandTypes, &currentOpLoc.value());
299  }
300 
301  LogicalResult pushResults(ValueRange results) {
302  return valueStack.pushResults(results, &currentOpLoc.value());
303  }
304 
305  /// The local.set and local.tee operations behave similarly and only differ
306  /// on their return value. This function factorizes the behavior of the two
307  /// operations in one place.
308  template <typename OpToCreate>
309  parsed_inst_t parseSetOrTee(OpBuilder &);
310 
311 private:
312  std::optional<Location> currentOpLoc;
313  ParserHead &parser;
314  WasmModuleSymbolTables const &symbols;
315  locals_t locals;
316  ValueStack valueStack;
317 };
318 
319 class ParserHead {
320 public:
321  ParserHead(StringRef src, StringAttr name) : head{src}, locName{name} {}
322  ParserHead(ParserHead &&) = default;
323 
324 private:
325  ParserHead(ParserHead const &other) = default;
326 
327 public:
328  auto getLocation() const {
329  return FileLineColLoc::get(locName, 0, anchorOffset + offset);
330  }
331 
332  FailureOr<StringRef> consumeNBytes(size_t nBytes) {
333  LDBG() << "Consume " << nBytes << " bytes";
334  LDBG() << " Bytes remaining: " << size();
335  LDBG() << " Current offset: " << offset;
336  if (nBytes > size())
337  return emitError(getLocation(), "trying to extract ")
338  << nBytes << "bytes when only " << size() << "are available";
339 
340  StringRef res = head.slice(offset, offset + nBytes);
341  offset += nBytes;
342  LDBG() << " Updated offset (+" << nBytes << "): " << offset;
343  return res;
344  }
345 
346  FailureOr<std::byte> consumeByte() {
347  FailureOr<StringRef> res = consumeNBytes(1);
348  if (failed(res))
349  return failure();
350  return std::byte{*res->bytes_begin()};
351  }
352 
353  template <typename T>
354  FailureOr<T> parseLiteral();
355 
356  FailureOr<uint32_t> parseVectorSize();
357 
358 private:
359  // TODO: This is equivalent to parseLiteral<uint32_t> and could be removed
360  // if parseLiteral specialization were moved here, but default GCC on Ubuntu
361  // 22.04 has bug with template specialization in class declaration
362  inline FailureOr<uint32_t> parseUI32();
363  inline FailureOr<int64_t> parseI64();
364 
365 public:
366  FailureOr<StringRef> parseName() {
367  FailureOr<uint32_t> size = parseVectorSize();
368  if (failed(size))
369  return failure();
370 
371  return consumeNBytes(*size);
372  }
373 
374  FailureOr<WasmSectionType> parseWasmSectionType() {
375  FailureOr<std::byte> id = consumeByte();
376  if (failed(id))
377  return failure();
378  if (std::to_integer<unsigned>(*id) > highestWasmSectionID)
379  return emitError(getLocation(), "invalid section ID: ")
380  << static_cast<int>(*id);
381  return static_cast<WasmSectionType>(*id);
382  }
383 
384  FailureOr<LimitType> parseLimit(MLIRContext *ctx) {
385  using WasmLimits = WasmBinaryEncoding::LimitHeader;
386  FileLineColLoc limitLocation = getLocation();
387  FailureOr<std::byte> limitHeader = consumeByte();
388  if (failed(limitHeader))
389  return failure();
390 
391  if (isNotIn<WasmLimits::bothLimits, WasmLimits::lowLimitOnly>(*limitHeader))
392  return emitError(limitLocation, "invalid limit header: ")
393  << static_cast<int>(*limitHeader);
394  FailureOr<uint32_t> minParse = parseUI32();
395  if (failed(minParse))
396  return failure();
397  std::optional<uint32_t> max{std::nullopt};
398  if (*limitHeader == WasmLimits::bothLimits) {
399  FailureOr<uint32_t> maxParse = parseUI32();
400  if (failed(maxParse))
401  return failure();
402  max = *maxParse;
403  }
404  return LimitType::get(ctx, *minParse, max);
405  }
406 
407  FailureOr<Type> parseValueType(MLIRContext *ctx) {
408  FileLineColLoc typeLoc = getLocation();
409  FailureOr<std::byte> typeEncoding = consumeByte();
410  if (failed(typeEncoding))
411  return failure();
412  switch (*typeEncoding) {
414  return IntegerType::get(ctx, 32);
416  return IntegerType::get(ctx, 64);
418  return Float32Type::get(ctx);
420  return Float64Type::get(ctx);
422  return IntegerType::get(ctx, 128);
424  return wasmssa::FuncRefType::get(ctx);
426  return wasmssa::ExternRefType::get(ctx);
427  default:
428  return emitError(typeLoc, "invalid value type encoding: ")
429  << static_cast<int>(*typeEncoding);
430  }
431  }
432 
433  FailureOr<GlobalTypeRecord> parseGlobalType(MLIRContext *ctx) {
434  using WasmGlobalMut = WasmBinaryEncoding::GlobalMutability;
435  FailureOr<Type> typeParsed = parseValueType(ctx);
436  if (failed(typeParsed))
437  return failure();
438  FileLineColLoc mutLoc = getLocation();
439  FailureOr<std::byte> mutSpec = consumeByte();
440  if (failed(mutSpec))
441  return failure();
442  if (isNotIn<WasmGlobalMut::isConst, WasmGlobalMut::isMutable>(*mutSpec))
443  return emitError(mutLoc, "invalid global mutability specifier: ")
444  << static_cast<int>(*mutSpec);
445  return GlobalTypeRecord{*typeParsed, *mutSpec == WasmGlobalMut::isMutable};
446  }
447 
448  FailureOr<TupleType> parseResultType(MLIRContext *ctx) {
449  FailureOr<uint32_t> nParamsParsed = parseVectorSize();
450  if (failed(nParamsParsed))
451  return failure();
452  uint32_t nParams = *nParamsParsed;
453  SmallVector<Type> res{};
454  res.reserve(nParams);
455  for (size_t i = 0; i < nParams; ++i) {
456  FailureOr<Type> parsedType = parseValueType(ctx);
457  if (failed(parsedType))
458  return failure();
459  res.push_back(*parsedType);
460  }
461  return TupleType::get(ctx, res);
462  }
463 
464  FailureOr<FunctionType> parseFunctionType(MLIRContext *ctx) {
465  FileLineColLoc typeLoc = getLocation();
466  FailureOr<std::byte> funcTypeHeader = consumeByte();
467  if (failed(funcTypeHeader))
468  return failure();
469  if (*funcTypeHeader != WasmBinaryEncoding::Type::funcType)
470  return emitError(typeLoc, "invalid function type header byte. Expecting ")
471  << std::to_integer<unsigned>(WasmBinaryEncoding::Type::funcType)
472  << " got " << std::to_integer<unsigned>(*funcTypeHeader);
473  FailureOr<TupleType> inputTypes = parseResultType(ctx);
474  if (failed(inputTypes))
475  return failure();
476 
477  FailureOr<TupleType> resTypes = parseResultType(ctx);
478  if (failed(resTypes))
479  return failure();
480 
481  return FunctionType::get(ctx, inputTypes->getTypes(), resTypes->getTypes());
482  }
483 
484  FailureOr<TypeIdxRecord> parseTypeIndex() {
485  FailureOr<uint32_t> res = parseUI32();
486  if (failed(res))
487  return failure();
488  return TypeIdxRecord{*res};
489  }
490 
491  FailureOr<TableType> parseTableType(MLIRContext *ctx) {
492  FailureOr<Type> elmTypeParse = parseValueType(ctx);
493  if (failed(elmTypeParse))
494  return failure();
495  if (!isWasmRefType(*elmTypeParse))
496  return emitError(getLocation(), "invalid element type for table");
497  FailureOr<LimitType> limitParse = parseLimit(ctx);
498  if (failed(limitParse))
499  return failure();
500  return TableType::get(ctx, *elmTypeParse, *limitParse);
501  }
502 
503  FailureOr<ImportDesc> parseImportDesc(MLIRContext *ctx) {
504  FileLineColLoc importLoc = getLocation();
505  FailureOr<std::byte> importType = consumeByte();
506  auto packager = [](auto parseResult) -> FailureOr<ImportDesc> {
507  if (failed(parseResult))
508  return failure();
509  return {*parseResult};
510  };
511  if (failed(importType))
512  return failure();
513  switch (*importType) {
515  return packager(parseTypeIndex());
517  return packager(parseTableType(ctx));
519  return packager(parseLimit(ctx));
521  return packager(parseGlobalType(ctx));
522  default:
523  return emitError(importLoc, "invalid import type descriptor: ")
524  << static_cast<int>(*importType);
525  }
526  }
527 
528  parsed_inst_t parseExpression(OpBuilder &builder,
529  WasmModuleSymbolTables const &symbols,
530  ArrayRef<local_val_t> locals = {}) {
531  auto eParser = ExpressionParser{*this, symbols, locals};
532  return eParser.parse(builder);
533  }
534 
535  LogicalResult parseCodeFor(FuncOp func,
536  WasmModuleSymbolTables const &symbols) {
537  SmallVector<local_val_t> locals{};
538  // Populating locals with function argument
539  Block &block = func.getBody().front();
540  // Delete temporary return argument which was only created for IR validity
541  assert(func.getBody().getBlocks().size() == 1 &&
542  "Function should only have its default created block at this point");
543  assert(block.getOperations().size() == 1 &&
544  "Only the placeholder return op should be present at this point");
545  auto returnOp = cast<ReturnOp>(&block.back());
546  assert(returnOp);
547 
548  FailureOr<uint32_t> codeSizeInBytes = parseUI32();
549  if (failed(codeSizeInBytes))
550  return failure();
551  FailureOr<StringRef> codeContent = consumeNBytes(*codeSizeInBytes);
552  if (failed(codeContent))
553  return failure();
554  auto name = StringAttr::get(func->getContext(),
555  locName.str() + "::" + func.getSymName());
556  auto cParser = ParserHead{*codeContent, name};
557  FailureOr<uint32_t> localVecSize = cParser.parseVectorSize();
558  if (failed(localVecSize))
559  return failure();
560  OpBuilder builder{&func.getBody().front().back()};
561  for (auto arg : block.getArguments())
562  locals.push_back(cast<TypedValue<LocalRefType>>(arg));
563  // Declare the local ops
564  uint32_t nVarVec = *localVecSize;
565  for (size_t i = 0; i < nVarVec; ++i) {
566  FileLineColLoc varLoc = cParser.getLocation();
567  FailureOr<uint32_t> nSubVar = cParser.parseUI32();
568  if (failed(nSubVar))
569  return failure();
570  FailureOr<Type> varT = cParser.parseValueType(func->getContext());
571  if (failed(varT))
572  return failure();
573  for (size_t j = 0; j < *nSubVar; ++j) {
574  auto local = builder.create<LocalOp>(varLoc, *varT);
575  locals.push_back(local.getResult());
576  }
577  }
578  parsed_inst_t res = cParser.parseExpression(builder, symbols, locals);
579  if (failed(res))
580  return failure();
581  if (!cParser.end())
582  return emitError(cParser.getLocation(),
583  "unparsed garbage remaining at end of code block");
584  builder.create<ReturnOp>(func->getLoc(), *res);
585  returnOp->erase();
586  return success();
587  }
588 
589  bool end() const { return curHead().empty(); }
590 
591  ParserHead copy() const { return *this; }
592 
593 private:
594  StringRef curHead() const { return head.drop_front(offset); }
595 
596  FailureOr<std::byte> peek() const {
597  if (end())
598  return emitError(
599  getLocation(),
600  "trying to peek at next byte, but input stream is empty");
601  return static_cast<std::byte>(curHead().front());
602  }
603 
604  size_t size() const { return head.size() - offset; }
605 
606  StringRef head;
607  StringAttr locName;
608  unsigned anchorOffset{0};
609  unsigned offset{0};
610 };
611 
612 template <>
613 FailureOr<float> ParserHead::parseLiteral<float>() {
614  FailureOr<StringRef> bytes = consumeNBytes(4);
615  if (failed(bytes))
616  return failure();
617  return llvm::support::endian::read<float>(bytes->bytes_begin(),
618  llvm::endianness::little);
619 }
620 
621 template <>
622 FailureOr<double> ParserHead::parseLiteral<double>() {
623  FailureOr<StringRef> bytes = consumeNBytes(8);
624  if (failed(bytes))
625  return failure();
626  return llvm::support::endian::read<double>(bytes->bytes_begin(),
627  llvm::endianness::little);
628 }
629 
630 template <>
631 FailureOr<uint32_t> ParserHead::parseLiteral<uint32_t>() {
632  char const *error = nullptr;
633  uint32_t res{0};
634  unsigned encodingSize{0};
635  StringRef src = curHead();
636  uint64_t decoded = llvm::decodeULEB128(src.bytes_begin(), &encodingSize,
637  src.bytes_end(), &error);
638  if (error)
639  return emitError(getLocation(), error);
640 
641  if (std::isgreater(decoded, std::numeric_limits<uint32_t>::max()))
642  return emitError(getLocation()) << "literal does not fit on 32 bits";
643 
644  res = static_cast<uint32_t>(decoded);
645  offset += encodingSize;
646  return res;
647 }
648 
649 template <>
650 FailureOr<int32_t> ParserHead::parseLiteral<int32_t>() {
651  char const *error = nullptr;
652  int32_t res{0};
653  unsigned encodingSize{0};
654  StringRef src = curHead();
655  int64_t decoded = llvm::decodeSLEB128(src.bytes_begin(), &encodingSize,
656  src.bytes_end(), &error);
657  if (error)
658  return emitError(getLocation(), error);
659  if (std::isgreater(decoded, std::numeric_limits<int32_t>::max()) ||
660  std::isgreater(std::numeric_limits<int32_t>::min(), decoded))
661  return emitError(getLocation()) << "literal does not fit on 32 bits";
662 
663  res = static_cast<int32_t>(decoded);
664  offset += encodingSize;
665  return res;
666 }
667 
668 template <>
669 FailureOr<int64_t> ParserHead::parseLiteral<int64_t>() {
670  char const *error = nullptr;
671  unsigned encodingSize{0};
672  StringRef src = curHead();
673  int64_t res = llvm::decodeSLEB128(src.bytes_begin(), &encodingSize,
674  src.bytes_end(), &error);
675  if (error)
676  return emitError(getLocation(), error);
677 
678  offset += encodingSize;
679  return res;
680 }
681 
682 FailureOr<uint32_t> ParserHead::parseVectorSize() {
683  return parseLiteral<uint32_t>();
684 }
685 
686 inline FailureOr<uint32_t> ParserHead::parseUI32() {
687  return parseLiteral<uint32_t>();
688 }
689 
690 inline FailureOr<int64_t> ParserHead::parseI64() {
691  return parseLiteral<int64_t>();
692 }
693 
694 template <std::byte opCode>
695 inline parsed_inst_t ExpressionParser::parseSpecificInstruction(OpBuilder &) {
696  return emitError(*currentOpLoc, "unknown instruction opcode: ")
697  << static_cast<int>(opCode);
698 }
699 
700 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
701 void ValueStack::dump() const {
702  llvm::dbgs() << "================= Wasm ValueStack =======================\n";
703  llvm::dbgs() << "size: " << size() << "\n";
704  llvm::dbgs() << "<Top>"
705  << "\n";
706  // Stack is pushed to via push_back. Therefore the top of the stack is the
707  // end of the vector. Iterate in reverse so that the first thing we print
708  // is the top of the stack.
709  size_t stackSize = size();
710  for (size_t idx = 0; idx < stackSize; idx++) {
711  size_t actualIdx = stackSize - 1 - idx;
712  llvm::dbgs() << " ";
713  values[actualIdx].dump();
714  }
715  llvm::dbgs() << "<Bottom>"
716  << "\n";
717  llvm::dbgs() << "=========================================================\n";
718 }
719 #endif
720 
721 parsed_inst_t ValueStack::popOperands(TypeRange operandTypes, Location *opLoc) {
722  LDBG() << "Popping from ValueStack\n"
723  << " Elements(s) to pop: " << operandTypes.size() << "\n"
724  << " Current stack size: " << values.size();
725  if (operandTypes.size() > values.size())
726  return emitError(*opLoc,
727  "stack doesn't contain enough values. trying to get ")
728  << operandTypes.size() << " operands on a stack containing only "
729  << values.size() << " values.";
730  size_t stackIdxOffset = values.size() - operandTypes.size();
731  SmallVector<Value> res{};
732  res.reserve(operandTypes.size());
733  for (size_t i{0}; i < operandTypes.size(); ++i) {
734  Value operand = values[i + stackIdxOffset];
735  Type stackType = operand.getType();
736  if (stackType != operandTypes[i])
737  return emitError(*opLoc, "invalid operand type on stack. expecting ")
738  << operandTypes[i] << ", value on stack is of type " << stackType
739  << ".";
740  LDBG() << " POP: " << operand;
741  res.push_back(operand);
742  }
743  values.resize(values.size() - operandTypes.size());
744  LDBG() << " Updated stack size: " << values.size();
745  return res;
746 }
747 
748 LogicalResult ValueStack::pushResults(ValueRange results, Location *opLoc) {
749  LDBG() << "Pushing to ValueStack\n"
750  << " Elements(s) to push: " << results.size() << "\n"
751  << " Current stack size: " << values.size();
752  for (Value val : results) {
753  if (!isWasmValueType(val.getType()))
754  return emitError(*opLoc, "invalid value type on stack: ")
755  << val.getType();
756  LDBG() << " PUSH: " << val;
757  values.push_back(val);
758  }
759 
760  LDBG() << " Updated stack size: " << values.size();
761  return success();
762 }
763 
764 template <std::byte EndParseByte>
765 parsed_inst_t ExpressionParser::parse(OpBuilder &builder,
766  UniqueByte<EndParseByte> endByte) {
767  auto res = parse(builder, ByteSequence<EndParseByte>{});
768  if (failed(res))
769  return failure();
770  return res->opResults;
771 }
772 
773 template <std::byte... ExpressionParseEnd>
774 FailureOr<ExpressionParser::ParseResultWithInfo>
776  ByteSequence<ExpressionParseEnd...> parsingEndFilters) {
777  SmallVector<Value> res;
778  for (;;) {
779  currentOpLoc = parser.getLocation();
780  FailureOr<std::byte> opCode = parser.consumeByte();
781  if (failed(opCode))
782  return failure();
783  if (isValueOneOf(*opCode, parsingEndFilters))
784  return {{res, *opCode}};
785  parsed_inst_t resParsed;
786  resParsed = dispatchToInstParser(*opCode, builder);
787  if (failed(resParsed))
788  return failure();
789  std::swap(res, *resParsed);
790  if (failed(pushResults(res)))
791  return failure();
792  }
793 }
794 
795 template <>
796 inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
798  FailureOr<uint32_t> id = parser.parseLiteral<uint32_t>();
799  Location instLoc = *currentOpLoc;
800  if (failed(id))
801  return failure();
802  if (*id >= locals.size())
803  return emitError(instLoc, "invalid local index. function has ")
804  << locals.size() << " accessible locals, received index " << *id;
805  return {{builder.create<LocalGetOp>(instLoc, locals[*id]).getResult()}};
806 }
807 
808 template <>
809 inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
811  FailureOr<uint32_t> id = parser.parseLiteral<uint32_t>();
812  Location instLoc = *currentOpLoc;
813  if (failed(id))
814  return failure();
815  if (*id >= symbols.globalSymbols.size())
816  return emitError(instLoc, "invalid global index. function has ")
817  << symbols.globalSymbols.size()
818  << " accessible globals, received index " << *id;
819  GlobalSymbolRefContainer globalVar = symbols.globalSymbols[*id];
820  auto globalOp = builder.create<GlobalGetOp>(instLoc, globalVar.globalType,
821  globalVar.symbol);
822 
823  return {{globalOp.getResult()}};
824 }
825 
826 template <typename OpToCreate>
827 parsed_inst_t ExpressionParser::parseSetOrTee(OpBuilder &builder) {
828  FailureOr<uint32_t> id = parser.parseLiteral<uint32_t>();
829  if (failed(id))
830  return failure();
831  if (*id >= locals.size())
832  return emitError(*currentOpLoc, "invalid local index. function has ")
833  << locals.size() << " accessible locals, received index " << *id;
834  if (valueStack.empty())
835  return emitError(
836  *currentOpLoc,
837  "invalid stack access, trying to access a value on an empty stack.");
838 
839  parsed_inst_t poppedOp = popOperands(locals[*id].getType().getElementType());
840  if (failed(poppedOp))
841  return failure();
842  return {
843  builder.create<OpToCreate>(*currentOpLoc, locals[*id], poppedOp->front())
844  ->getResults()};
845 }
846 
847 template <>
848 inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
850  return parseSetOrTee<LocalSetOp>(builder);
851 }
852 
853 template <>
854 inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
856  return parseSetOrTee<LocalTeeOp>(builder);
857 }
858 
859 template <typename T>
860 inline Type buildLiteralType(OpBuilder &);
861 
862 template <>
863 inline Type buildLiteralType<int32_t>(OpBuilder &builder) {
864  return builder.getI32Type();
865 }
866 
867 template <>
868 inline Type buildLiteralType<int64_t>(OpBuilder &builder) {
869  return builder.getI64Type();
870 }
871 
872 template <>
873 [[maybe_unused]] inline Type buildLiteralType<uint32_t>(OpBuilder &builder) {
874  return builder.getI32Type();
875 }
876 
877 template <>
878 [[maybe_unused]] inline Type buildLiteralType<uint64_t>(OpBuilder &builder) {
879  return builder.getI64Type();
880 }
881 
882 template <>
883 inline Type buildLiteralType<float>(OpBuilder &builder) {
884  return builder.getF32Type();
885 }
886 
887 template <>
888 inline Type buildLiteralType<double>(OpBuilder &builder) {
889  return builder.getF64Type();
890 }
891 
892 template <typename ValT,
893  typename E = std::enable_if_t<std::is_arithmetic_v<ValT>>>
894 struct AttrHolder;
895 
896 template <typename ValT>
897 struct AttrHolder<ValT, std::enable_if_t<std::is_integral_v<ValT>>> {
898  using type = IntegerAttr;
899 };
900 
901 template <typename ValT>
902 struct AttrHolder<ValT, std::enable_if_t<std::is_floating_point_v<ValT>>> {
903  using type = FloatAttr;
904 };
905 
906 template <typename ValT>
907 using attr_holder_t = typename AttrHolder<ValT>::type;
908 
909 template <typename ValT,
910  typename EnableT = std::enable_if_t<std::is_arithmetic_v<ValT>>>
911 attr_holder_t<ValT> buildLiteralAttr(OpBuilder &builder, ValT val) {
912  return attr_holder_t<ValT>::get(buildLiteralType<ValT>(builder), val);
913 }
914 
915 template <typename valueT>
916 parsed_inst_t ExpressionParser::parseConstInst(
917  OpBuilder &builder, std::enable_if_t<std::is_arithmetic_v<valueT>> *) {
918  auto parsedConstant = parser.parseLiteral<valueT>();
919  if (failed(parsedConstant))
920  return failure();
921  auto constOp =
922  ConstOp::create(builder, *currentOpLoc,
923  buildLiteralAttr<valueT>(builder, *parsedConstant));
924  return {{constOp.getResult()}};
925 }
926 
927 template <>
928 inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
930  return parseConstInst<int32_t>(builder);
931 }
932 
933 template <>
934 inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
936  return parseConstInst<int64_t>(builder);
937 }
938 
939 template <>
940 inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
942  return parseConstInst<float>(builder);
943 }
944 
945 template <>
946 inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
948  return parseConstInst<double>(builder);
949 }
950 
951 template <typename opcode, typename valueType, unsigned int numOperands>
952 inline parsed_inst_t ExpressionParser::buildNumericOp(
953  OpBuilder &builder, std::enable_if_t<std::is_arithmetic_v<valueType>> *) {
954  auto ty = buildLiteralType<valueType>(builder);
955  LDBG() << "*** buildNumericOp: numOperands = " << numOperands
956  << ", type = " << ty << " ***";
957  auto tysToPop = SmallVector<Type, numOperands>();
958  tysToPop.resize(numOperands);
959  std::fill(tysToPop.begin(), tysToPop.end(), ty);
960  auto operands = popOperands(tysToPop);
961  if (failed(operands))
962  return failure();
963  auto op = builder.create<opcode>(*currentOpLoc, *operands).getResult();
964  LDBG() << "Built operation: " << op;
965  return {{op}};
966 }
967 
968 // Convenience macro for generating numerical operations.
969 #define BUILD_NUMERIC_OP(OP_NAME, N_ARGS, PREFIX, SUFFIX, TYPE) \
970  template <> \
971  inline parsed_inst_t ExpressionParser::parseSpecificInstruction< \
972  WasmBinaryEncoding::OpCode::PREFIX##SUFFIX>(OpBuilder & builder) { \
973  return buildNumericOp<OP_NAME, TYPE, N_ARGS>(builder); \
974  }
975 
976 // Macro to define binops that only support integer types.
977 #define BUILD_NUMERIC_BINOP_INT(OP_NAME, PREFIX) \
978  BUILD_NUMERIC_OP(OP_NAME, 2, PREFIX, I32, int32_t) \
979  BUILD_NUMERIC_OP(OP_NAME, 2, PREFIX, I64, int64_t)
980 
981 // Macro to define binops that only support floating point types.
982 #define BUILD_NUMERIC_BINOP_FP(OP_NAME, PREFIX) \
983  BUILD_NUMERIC_OP(OP_NAME, 2, PREFIX, F32, float) \
984  BUILD_NUMERIC_OP(OP_NAME, 2, PREFIX, F64, double)
985 
986 // Macro to define binops that support both floating point and integer types.
987 #define BUILD_NUMERIC_BINOP_INTFP(OP_NAME, PREFIX) \
988  BUILD_NUMERIC_BINOP_INT(OP_NAME, PREFIX) \
989  BUILD_NUMERIC_BINOP_FP(OP_NAME, PREFIX)
990 
991 // Macro to implement unary ops that only support integers.
992 #define BUILD_NUMERIC_UNARY_OP_INT(OP_NAME, PREFIX) \
993  BUILD_NUMERIC_OP(OP_NAME, 1, PREFIX, I32, int32_t) \
994  BUILD_NUMERIC_OP(OP_NAME, 1, PREFIX, I64, int64_t)
995 
996 // Macro to implement unary ops that support integer and floating point types.
997 #define BUILD_NUMERIC_UNARY_OP_FP(OP_NAME, PREFIX) \
998  BUILD_NUMERIC_OP(OP_NAME, 1, PREFIX, F32, float) \
999  BUILD_NUMERIC_OP(OP_NAME, 1, PREFIX, F64, double)
1000 
1001 BUILD_NUMERIC_BINOP_FP(CopySignOp, copysign)
1002 BUILD_NUMERIC_BINOP_FP(DivOp, div)
1005 BUILD_NUMERIC_BINOP_INT(AndOp, and)
1006 BUILD_NUMERIC_BINOP_INT(DivSIOp, divS)
1007 BUILD_NUMERIC_BINOP_INT(DivUIOp, divU)
1008 BUILD_NUMERIC_BINOP_INT(OrOp, or)
1009 BUILD_NUMERIC_BINOP_INT(RemSIOp, remS)
1010 BUILD_NUMERIC_BINOP_INT(RemUIOp, remU)
1011 BUILD_NUMERIC_BINOP_INT(RotlOp, rotl)
1012 BUILD_NUMERIC_BINOP_INT(RotrOp, rotr)
1013 BUILD_NUMERIC_BINOP_INT(ShLOp, shl)
1014 BUILD_NUMERIC_BINOP_INT(ShRSOp, shrS)
1015 BUILD_NUMERIC_BINOP_INT(ShRUOp, shrU)
1016 BUILD_NUMERIC_BINOP_INT(XOrOp, xor)
1018 BUILD_NUMERIC_BINOP_INTFP(MulOp, mul)
1019 BUILD_NUMERIC_BINOP_INTFP(SubOp, sub)
1020 BUILD_NUMERIC_UNARY_OP_FP(AbsOp, abs)
1023 BUILD_NUMERIC_UNARY_OP_FP(NegOp, neg)
1024 BUILD_NUMERIC_UNARY_OP_FP(SqrtOp, sqrt)
1025 BUILD_NUMERIC_UNARY_OP_FP(TruncOp, trunc)
1026 BUILD_NUMERIC_UNARY_OP_INT(ClzOp, clz)
1027 BUILD_NUMERIC_UNARY_OP_INT(CtzOp, ctz)
1028 BUILD_NUMERIC_UNARY_OP_INT(PopCntOp, popcnt)
1029 
1030 // Don't need these anymore so let's undef them.
1031 #undef BUILD_NUMERIC_BINOP_FP
1032 #undef BUILD_NUMERIC_BINOP_INT
1033 #undef BUILD_NUMERIC_BINOP_INTFP
1034 #undef BUILD_NUMERIC_UNARY_OP_FP
1035 #undef BUILD_NUMERIC_UNARY_OP_INT
1036 #undef BUILD_NUMERIC_OP
1037 #undef BUILD_NUMERIC_CAST_OP
1038 
1039 class WasmBinaryParser {
1040 private:
1041  struct SectionRegistry {
1042  using section_location_t = StringRef;
1043 
1044  std::array<SmallVector<section_location_t>, highestWasmSectionID + 1>
1045  registry;
1046 
1047  template <WasmSectionType SecType>
1048  std::conditional_t<sectionShouldBeUnique(SecType),
1049  std::optional<section_location_t>,
1051  getContentForSection() const {
1052  constexpr auto idx = static_cast<size_t>(SecType);
1053  if constexpr (sectionShouldBeUnique(SecType)) {
1054  return registry[idx].empty() ? std::nullopt
1055  : std::make_optional(registry[idx][0]);
1056  } else {
1057  return registry[idx];
1058  }
1059  }
1060 
1061  bool hasSection(WasmSectionType secType) const {
1062  return !registry[static_cast<size_t>(secType)].empty();
1063  }
1064 
1065  ///
1066  /// @returns success if registration valid, failure in case registration
1067  /// can't be done (if another section of same type already exist and this
1068  /// section type should only be present once)
1069  ///
1070  LogicalResult registerSection(WasmSectionType secType,
1071  section_location_t location, Location loc) {
1072  if (sectionShouldBeUnique(secType) && hasSection(secType))
1073  return emitError(loc,
1074  "trying to add a second instance of unique section");
1075 
1076  registry[static_cast<size_t>(secType)].push_back(location);
1077  emitRemark(loc, "Adding section with section ID ")
1078  << static_cast<uint8_t>(secType);
1079  return success();
1080  }
1081 
1082  LogicalResult populateFromBody(ParserHead ph) {
1083  while (!ph.end()) {
1084  FileLineColLoc sectionLoc = ph.getLocation();
1085  FailureOr<WasmSectionType> secType = ph.parseWasmSectionType();
1086  if (failed(secType))
1087  return failure();
1088 
1089  FailureOr<uint32_t> secSizeParsed = ph.parseLiteral<uint32_t>();
1090  if (failed(secSizeParsed))
1091  return failure();
1092 
1093  uint32_t secSize = *secSizeParsed;
1094  FailureOr<StringRef> sectionContent = ph.consumeNBytes(secSize);
1095  if (failed(sectionContent))
1096  return failure();
1097 
1098  LogicalResult registration =
1099  registerSection(*secType, *sectionContent, sectionLoc);
1100 
1101  if (failed(registration))
1102  return failure();
1103  }
1104  return success();
1105  }
1106  };
1107 
1108  auto getLocation(int offset = 0) const {
1109  return FileLineColLoc::get(srcName, 0, offset);
1110  }
1111 
1112  template <WasmSectionType>
1113  LogicalResult parseSectionItem(ParserHead &, size_t);
1114 
1115  template <WasmSectionType section>
1116  LogicalResult parseSection() {
1117  auto secName = std::string{wasmSectionName<section>};
1118  auto sectionNameAttr =
1119  StringAttr::get(ctx, srcName.strref() + ":" + secName + "-SECTION");
1120  unsigned offset = 0;
1121  auto getLocation = [sectionNameAttr, &offset]() {
1122  return FileLineColLoc::get(sectionNameAttr, 0, offset);
1123  };
1124  auto secContent = registry.getContentForSection<section>();
1125  if (!secContent) {
1126  LDBG() << secName << " section is not present in file.";
1127  return success();
1128  }
1129 
1130  auto secSrc = secContent.value();
1131  ParserHead ph{secSrc, sectionNameAttr};
1132  FailureOr<uint32_t> nElemsParsed = ph.parseVectorSize();
1133  if (failed(nElemsParsed))
1134  return failure();
1135  uint32_t nElems = *nElemsParsed;
1136  LDBG() << "starting to parse " << nElems << " items for section "
1137  << secName;
1138  for (size_t i = 0; i < nElems; ++i) {
1139  if (failed(parseSectionItem<section>(ph, i)))
1140  return failure();
1141  }
1142 
1143  if (!ph.end())
1144  return emitError(getLocation(), "unparsed garbage at end of section ")
1145  << secName;
1146  return success();
1147  }
1148 
1149  /// Handles the registration of a function import
1150  LogicalResult visitImport(Location loc, StringRef moduleName,
1151  StringRef importName, TypeIdxRecord tid) {
1152  using llvm::Twine;
1153  if (tid.id >= symbols.moduleFuncTypes.size())
1154  return emitError(loc, "invalid type id: ")
1155  << tid.id << ". Only " << symbols.moduleFuncTypes.size()
1156  << " type registration.";
1157  FunctionType type = symbols.moduleFuncTypes[tid.id];
1158  std::string symbol = symbols.getNewFuncSymbolName();
1159  auto funcOp = FuncImportOp::create(builder, loc, symbol, moduleName,
1160  importName, type);
1161  symbols.funcSymbols.push_back({{FlatSymbolRefAttr::get(funcOp)}, type});
1162  return funcOp.verify();
1163  }
1164 
1165  /// Handles the registration of a memory import
1166  LogicalResult visitImport(Location loc, StringRef moduleName,
1167  StringRef importName, LimitType limitType) {
1168  std::string symbol = symbols.getNewMemorySymbolName();
1169  auto memOp = MemImportOp::create(builder, loc, symbol, moduleName,
1170  importName, limitType);
1171  symbols.memSymbols.push_back({FlatSymbolRefAttr::get(memOp)});
1172  return memOp.verify();
1173  }
1174 
1175  /// Handles the registration of a table import
1176  LogicalResult visitImport(Location loc, StringRef moduleName,
1177  StringRef importName, TableType tableType) {
1178  std::string symbol = symbols.getNewTableSymbolName();
1179  auto tableOp = TableImportOp::create(builder, loc, symbol, moduleName,
1180  importName, tableType);
1181  symbols.tableSymbols.push_back({FlatSymbolRefAttr::get(tableOp)});
1182  return tableOp.verify();
1183  }
1184 
1185  /// Handles the registration of a global variable import
1186  LogicalResult visitImport(Location loc, StringRef moduleName,
1187  StringRef importName, GlobalTypeRecord globalType) {
1188  std::string symbol = symbols.getNewGlobalSymbolName();
1189  auto giOp =
1190  GlobalImportOp::create(builder, loc, symbol, moduleName, importName,
1191  globalType.type, globalType.isMutable);
1192  symbols.globalSymbols.push_back(
1193  {{FlatSymbolRefAttr::get(giOp)}, giOp.getType()});
1194  return giOp.verify();
1195  }
1196 
1197  // Detect occurence of errors
1198  LogicalResult peekDiag(Diagnostic &diag) {
1199  if (diag.getSeverity() == DiagnosticSeverity::Error)
1200  isValid = false;
1201  return failure();
1202  }
1203 
1204 public:
1205  WasmBinaryParser(llvm::SourceMgr &sourceMgr, MLIRContext *ctx)
1206  : builder{ctx}, ctx{ctx} {
1208  [this](Diagnostic &diag) { return peekDiag(diag); });
1209  ctx->loadAllAvailableDialects();
1210  if (sourceMgr.getNumBuffers() != 1) {
1211  emitError(UnknownLoc::get(ctx), "one source file should be provided");
1212  return;
1213  }
1214  uint32_t sourceBufId = sourceMgr.getMainFileID();
1215  StringRef source = sourceMgr.getMemoryBuffer(sourceBufId)->getBuffer();
1216  srcName = StringAttr::get(
1217  ctx, sourceMgr.getMemoryBuffer(sourceBufId)->getBufferIdentifier());
1218 
1219  auto parser = ParserHead{source, srcName};
1220  auto const wasmHeader = StringRef{"\0asm", 4};
1221  FileLineColLoc magicLoc = parser.getLocation();
1222  FailureOr<StringRef> magic = parser.consumeNBytes(wasmHeader.size());
1223  if (failed(magic) || magic->compare(wasmHeader)) {
1224  emitError(magicLoc, "source file does not contain valid Wasm header.");
1225  return;
1226  }
1227  auto const expectedVersionString = StringRef{"\1\0\0\0", 4};
1228  FileLineColLoc versionLoc = parser.getLocation();
1229  FailureOr<StringRef> version =
1230  parser.consumeNBytes(expectedVersionString.size());
1231  if (failed(version))
1232  return;
1233  if (version->compare(expectedVersionString)) {
1234  emitError(versionLoc,
1235  "unsupported Wasm version. only version 1 is supported");
1236  return;
1237  }
1238  LogicalResult fillRegistry = registry.populateFromBody(parser.copy());
1239  if (failed(fillRegistry))
1240  return;
1241 
1242  mOp = ModuleOp::create(builder, getLocation());
1243  builder.setInsertionPointToStart(&mOp.getBodyRegion().front());
1244  LogicalResult parsingTypes = parseSection<WasmSectionType::TYPE>();
1245  if (failed(parsingTypes))
1246  return;
1247 
1248  LogicalResult parsingImports = parseSection<WasmSectionType::IMPORT>();
1249  if (failed(parsingImports))
1250  return;
1251 
1252  firstInternalFuncID = symbols.funcSymbols.size();
1253 
1254  LogicalResult parsingFunctions = parseSection<WasmSectionType::FUNCTION>();
1255  if (failed(parsingFunctions))
1256  return;
1257 
1258  LogicalResult parsingTables = parseSection<WasmSectionType::TABLE>();
1259  if (failed(parsingTables))
1260  return;
1261 
1262  LogicalResult parsingMems = parseSection<WasmSectionType::MEMORY>();
1263  if (failed(parsingMems))
1264  return;
1265 
1266  LogicalResult parsingGlobals = parseSection<WasmSectionType::GLOBAL>();
1267  if (failed(parsingGlobals))
1268  return;
1269 
1270  LogicalResult parsingCode = parseSection<WasmSectionType::CODE>();
1271  if (failed(parsingCode))
1272  return;
1273 
1274  LogicalResult parsingExports = parseSection<WasmSectionType::EXPORT>();
1275  if (failed(parsingExports))
1276  return;
1277 
1278  // Copy over sizes of containers into statistics.
1279  LDBG() << "WASM Imports:"
1280  << "\n"
1281  << " - Num functions: " << symbols.funcSymbols.size() << "\n"
1282  << " - Num globals: " << symbols.globalSymbols.size() << "\n"
1283  << " - Num memories: " << symbols.memSymbols.size() << "\n"
1284  << " - Num tables: " << symbols.tableSymbols.size();
1285  }
1286 
1287  ModuleOp getModule() {
1288  if (isValid)
1289  return mOp;
1290  if (mOp)
1291  mOp.erase();
1292  return ModuleOp{};
1293  }
1294 
1295 private:
1296  mlir::StringAttr srcName;
1297  OpBuilder builder;
1298  WasmModuleSymbolTables symbols;
1299  MLIRContext *ctx;
1300  ModuleOp mOp;
1301  SectionRegistry registry;
1302  size_t firstInternalFuncID{0};
1303  bool isValid{true};
1304 };
1305 
1306 template <>
1307 LogicalResult
1308 WasmBinaryParser::parseSectionItem<WasmSectionType::IMPORT>(ParserHead &ph,
1309  size_t) {
1310  FileLineColLoc importLoc = ph.getLocation();
1311  auto moduleName = ph.parseName();
1312  if (failed(moduleName))
1313  return failure();
1314 
1315  auto importName = ph.parseName();
1316  if (failed(importName))
1317  return failure();
1318 
1319  FailureOr<ImportDesc> import = ph.parseImportDesc(ctx);
1320  if (failed(import))
1321  return failure();
1322 
1323  return std::visit(
1324  [this, importLoc, &moduleName, &importName](auto import) {
1325  return visitImport(importLoc, *moduleName, *importName, import);
1326  },
1327  *import);
1328 }
1329 
1330 template <>
1331 LogicalResult
1332 WasmBinaryParser::parseSectionItem<WasmSectionType::EXPORT>(ParserHead &ph,
1333  size_t) {
1334  FileLineColLoc exportLoc = ph.getLocation();
1335 
1336  auto exportName = ph.parseName();
1337  if (failed(exportName))
1338  return failure();
1339 
1340  FailureOr<std::byte> opcode = ph.consumeByte();
1341  if (failed(opcode))
1342  return failure();
1343 
1344  FailureOr<uint32_t> idx = ph.parseLiteral<uint32_t>();
1345  if (failed(idx))
1346  return failure();
1347 
1348  using SymbolRefDesc = std::variant<SmallVector<SymbolRefContainer>,
1351 
1352  SymbolRefDesc currentSymbolList;
1353  std::string symbolType = "";
1354  switch (*opcode) {
1356  symbolType = "function";
1357  currentSymbolList = symbols.funcSymbols;
1358  break;
1360  symbolType = "table";
1361  currentSymbolList = symbols.tableSymbols;
1362  break;
1364  symbolType = "memory";
1365  currentSymbolList = symbols.memSymbols;
1366  break;
1368  symbolType = "global";
1369  currentSymbolList = symbols.globalSymbols;
1370  break;
1371  default:
1372  return emitError(exportLoc, "invalid value for export type: ")
1373  << std::to_integer<unsigned>(*opcode);
1374  }
1375 
1376  auto currentSymbol = std::visit(
1377  [&](const auto &list) -> FailureOr<FlatSymbolRefAttr> {
1378  if (*idx > list.size()) {
1379  emitError(
1380  exportLoc,
1381  llvm::formatv(
1382  "trying to export {0} {1} which is undefined in this scope",
1383  symbolType, *idx));
1384  return failure();
1385  }
1386  return list[*idx].symbol;
1387  },
1388  currentSymbolList);
1389 
1390  if (failed(currentSymbol))
1391  return failure();
1392 
1393  Operation *op = SymbolTable::lookupSymbolIn(mOp, *currentSymbol);
1395  StringAttr symName = SymbolTable::getSymbolName(op);
1396  return SymbolTable{mOp}.rename(symName, *exportName);
1397 }
1398 
1399 template <>
1400 LogicalResult
1401 WasmBinaryParser::parseSectionItem<WasmSectionType::TABLE>(ParserHead &ph,
1402  size_t) {
1403  FileLineColLoc opLocation = ph.getLocation();
1404  FailureOr<TableType> tableType = ph.parseTableType(ctx);
1405  if (failed(tableType))
1406  return failure();
1407  LDBG() << " Parsed table description: " << *tableType;
1408  StringAttr symbol = builder.getStringAttr(symbols.getNewTableSymbolName());
1409  auto tableOp =
1410  TableOp::create(builder, opLocation, symbol.strref(), *tableType);
1411  symbols.tableSymbols.push_back({SymbolRefAttr::get(tableOp)});
1412  return success();
1413 }
1414 
1415 template <>
1416 LogicalResult
1417 WasmBinaryParser::parseSectionItem<WasmSectionType::FUNCTION>(ParserHead &ph,
1418  size_t) {
1419  FileLineColLoc opLoc = ph.getLocation();
1420  auto typeIdxParsed = ph.parseLiteral<uint32_t>();
1421  if (failed(typeIdxParsed))
1422  return failure();
1423  uint32_t typeIdx = *typeIdxParsed;
1424  if (typeIdx >= symbols.moduleFuncTypes.size())
1425  return emitError(getLocation(), "invalid type index: ") << typeIdx;
1426  std::string symbol = symbols.getNewFuncSymbolName();
1427  auto funcOp =
1428  FuncOp::create(builder, opLoc, symbol, symbols.moduleFuncTypes[typeIdx]);
1429  Block *block = funcOp.addEntryBlock();
1430  OpBuilder::InsertionGuard guard{builder};
1431  builder.setInsertionPointToEnd(block);
1432  ReturnOp::create(builder, opLoc);
1433  symbols.funcSymbols.push_back(
1434  {{FlatSymbolRefAttr::get(funcOp.getSymNameAttr())},
1435  symbols.moduleFuncTypes[typeIdx]});
1436  return funcOp.verify();
1437 }
1438 
1439 template <>
1440 LogicalResult
1441 WasmBinaryParser::parseSectionItem<WasmSectionType::TYPE>(ParserHead &ph,
1442  size_t) {
1443  FailureOr<FunctionType> funcType = ph.parseFunctionType(ctx);
1444  if (failed(funcType))
1445  return failure();
1446  LDBG() << "Parsed function type " << *funcType;
1447  symbols.moduleFuncTypes.push_back(*funcType);
1448  return success();
1449 }
1450 
1451 template <>
1452 LogicalResult
1453 WasmBinaryParser::parseSectionItem<WasmSectionType::MEMORY>(ParserHead &ph,
1454  size_t) {
1455  FileLineColLoc opLocation = ph.getLocation();
1456  FailureOr<LimitType> memory = ph.parseLimit(ctx);
1457  if (failed(memory))
1458  return failure();
1459 
1460  LDBG() << " Registering memory " << *memory;
1461  std::string symbol = symbols.getNewMemorySymbolName();
1462  auto memOp = MemOp::create(builder, opLocation, symbol, *memory);
1463  symbols.memSymbols.push_back({SymbolRefAttr::get(memOp)});
1464  return success();
1465 }
1466 
1467 template <>
1468 LogicalResult
1469 WasmBinaryParser::parseSectionItem<WasmSectionType::GLOBAL>(ParserHead &ph,
1470  size_t) {
1471  FileLineColLoc globalLocation = ph.getLocation();
1472  auto globalTypeParsed = ph.parseGlobalType(ctx);
1473  if (failed(globalTypeParsed))
1474  return failure();
1475 
1476  GlobalTypeRecord globalType = *globalTypeParsed;
1477  auto symbol = builder.getStringAttr(symbols.getNewGlobalSymbolName());
1478  auto globalOp = builder.create<wasmssa::GlobalOp>(
1479  globalLocation, symbol, globalType.type, globalType.isMutable);
1480  symbols.globalSymbols.push_back(
1481  {{FlatSymbolRefAttr::get(globalOp)}, globalOp.getType()});
1482  OpBuilder::InsertionGuard guard{builder};
1483  Block *block = builder.createBlock(&globalOp.getInitializer());
1484  builder.setInsertionPointToStart(block);
1485  parsed_inst_t expr = ph.parseExpression(builder, symbols);
1486  if (failed(expr))
1487  return failure();
1488  if (block->empty())
1489  return emitError(globalLocation, "global with empty initializer");
1490  if (expr->size() != 1 && (*expr)[0].getType() != globalType.type)
1491  return emitError(
1492  globalLocation,
1493  "initializer result type does not match global declaration type");
1494  builder.create<ReturnOp>(globalLocation, *expr);
1495  return success();
1496 }
1497 
1498 template <>
1499 LogicalResult WasmBinaryParser::parseSectionItem<WasmSectionType::CODE>(
1500  ParserHead &ph, size_t innerFunctionId) {
1501  unsigned long funcId = innerFunctionId + firstInternalFuncID;
1502  FunctionSymbolRefContainer symRef = symbols.funcSymbols[funcId];
1503  auto funcOp =
1504  dyn_cast<FuncOp>(SymbolTable::lookupSymbolIn(mOp, symRef.symbol));
1505  assert(funcOp);
1506  if (failed(ph.parseCodeFor(funcOp, symbols)))
1507  return failure();
1508  return success();
1509 }
1510 } // namespace
1511 
1512 namespace mlir::wasm {
1514  MLIRContext *context) {
1515  WasmBinaryParser wBN{source, context};
1516  ModuleOp mOp = wBN.getModule();
1517  if (mOp)
1518  return {mOp};
1519 
1520  return {nullptr};
1521 }
1522 } // namespace mlir::wasm
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
Definition: PDL.cpp:62
static Type getElementType(Type type)
Determine the element type of type.
static std::string diag(const llvm::Value &value)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
#define EXPORT
#define BUILD_NUMERIC_BINOP_INTFP(OP_NAME, PREFIX)
#define APPLY_WASM_SEC_TRANSFORM
#define BUILD_NUMERIC_UNARY_OP_INT(OP_NAME, PREFIX)
#define BUILD_NUMERIC_BINOP_INT(OP_NAME, PREFIX)
#define BUILD_NUMERIC_BINOP_FP(OP_NAME, PREFIX)
#define BUILD_NUMERIC_UNARY_OP_FP(OP_NAME, PREFIX)
Block represents an ordered list of Operations.
Definition: Block.h:33
bool empty()
Definition: Block.h:148
Operation & back()
Definition: Block.h:152
OpListType & getOperations()
Definition: Block.h:137
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
FloatType getF32Type()
Definition: Builders.cpp:42
IntegerType getI64Type()
Definition: Builders.cpp:64
IntegerType getI32Type()
Definition: Builders.cpp:62
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:261
FloatType getF64Type()
Definition: Builders.cpp:44
HandlerID registerHandler(HandlerTy handler)
Register a new handler for diagnostics to the engine.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
An instance of this location represents a tuple of file, line number, and column number.
Definition: Location.h:174
static FileLineColLoc get(StringAttr filename, unsigned line, unsigned column)
Definition: Location.cpp:157
A symbol reference with a reference path containing a single element.
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
DiagnosticEngine & getDiagEngine()
Returns the diagnostic engine for this context.
void loadAllAvailableDialects()
Load all dialects available in the registry in this context.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
This class helps build Operations.
Definition: Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:429
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:431
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:436
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:456
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:538
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
static void setSymbolVisibility(Operation *symbol, Visibility vis)
Sets the visibility of the given symbol operation.
@ Public
The symbol is public and may be referenced anywhere internal or external to the visible references in...
LogicalResult rename(StringAttr from, StringAttr to)
Renames the given op or the op refered to by the given name to the given new name and updates the sym...
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
static StringAttr getSymbolName(Operation *symbol)
Returns the name of the given symbol operation, aborting if no symbol is present.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
DynamicAPInt floor(const Fraction &f)
Definition: Fraction.h:77
DynamicAPInt ceil(const Fraction &f)
Definition: Fraction.h:79
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
detail::LazyTextBuild add(const char *fmt, Ts &&...ts)
Create a Remark with llvm::formatv formatting.
Definition: Remarks.h:463
OwningOpRef< ModuleOp > importWebAssemblyToModule(llvm::SourceMgr &source, MLIRContext *context)
If source contains a valid Wasm binary file, this function returns a a ModuleOp containing the repres...
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:488
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
InFlightDiagnostic emitRemark(Location loc)
Utility method to emit a remark message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
static constexpr std::byte memory
static constexpr std::byte table
static constexpr std::byte global
static constexpr std::byte function
Byte encodings describing the mutability of globals.
static constexpr std::byte memType
static constexpr std::byte typeID
static constexpr std::byte tableType
static constexpr std::byte globalType
Byte encodings for Wasm limits.
static constexpr std::byte globalGet
static constexpr std::byte constI64
static constexpr std::byte constFP64
static constexpr std::byte localTee
static constexpr std::byte localGet
static constexpr std::byte localSet
static constexpr std::byte constI32
static constexpr std::byte constFP32
static constexpr std::byte externRef
static constexpr std::byte i32
static constexpr std::byte funcType
static constexpr std::byte i64
static constexpr std::byte funcRef
static constexpr std::byte v128
static constexpr std::byte f64
static constexpr std::byte f32
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.