MLIR  22.0.0git
TranslateFromWasm.cpp
Go to the documentation of this file.
1 //===- TranslateFromWasm.cpp - Translating to WasmSSA dialect -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the WebAssembly importer.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 #include "mlir/IR/Attributes.h"
15 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/Location.h"
20 #include "mlir/Support/LLVM.h"
23 #include "llvm/Support/Debug.h"
24 #include "llvm/Support/DebugLog.h"
25 #include "llvm/Support/Endian.h"
26 #include "llvm/Support/FormatVariadic.h"
27 #include "llvm/Support/LEB128.h"
28 #include "llvm/Support/LogicalResult.h"
29 
30 #include <cstddef>
31 #include <cstdint>
32 #include <variant>
33 
34 #define DEBUG_TYPE "wasm-translate"
35 
36 static_assert(CHAR_BIT == 8,
37  "This code expects std::byte to be exactly 8 bits");
38 
39 using namespace mlir;
40 using namespace mlir::wasm;
41 using namespace mlir::wasmssa;
42 
43 namespace {
44 using section_id_t = uint8_t;
45 enum struct WasmSectionType : section_id_t {
46  CUSTOM = 0,
47  TYPE = 1,
48  IMPORT = 2,
49  FUNCTION = 3,
50  TABLE = 4,
51  MEMORY = 5,
52  GLOBAL = 6,
53  EXPORT = 7,
54  START = 8,
55  ELEMENT = 9,
56  CODE = 10,
57  DATA = 11,
58  DATACOUNT = 12
59 };
60 
61 constexpr section_id_t highestWasmSectionID{
62  static_cast<section_id_t>(WasmSectionType::DATACOUNT)};
63 
64 #define APPLY_WASM_SEC_TRANSFORM \
65  WASM_SEC_TRANSFORM(CUSTOM) \
66  WASM_SEC_TRANSFORM(TYPE) \
67  WASM_SEC_TRANSFORM(IMPORT) \
68  WASM_SEC_TRANSFORM(FUNCTION) \
69  WASM_SEC_TRANSFORM(TABLE) \
70  WASM_SEC_TRANSFORM(MEMORY) \
71  WASM_SEC_TRANSFORM(GLOBAL) \
72  WASM_SEC_TRANSFORM(EXPORT) \
73  WASM_SEC_TRANSFORM(START) \
74  WASM_SEC_TRANSFORM(ELEMENT) \
75  WASM_SEC_TRANSFORM(CODE) \
76  WASM_SEC_TRANSFORM(DATA) \
77  WASM_SEC_TRANSFORM(DATACOUNT)
78 
79 template <WasmSectionType>
80 constexpr const char *wasmSectionName = "";
81 
82 #define WASM_SEC_TRANSFORM(section) \
83  template <> \
84  [[maybe_unused]] constexpr const char \
85  *wasmSectionName<WasmSectionType::section> = #section;
87 #undef WASM_SEC_TRANSFORM
88 
89 constexpr bool sectionShouldBeUnique(WasmSectionType secType) {
90  return secType != WasmSectionType::CUSTOM;
91 }
92 
93 template <std::byte... Bytes>
94 struct ByteSequence {};
95 
96 /// Template class for representing a byte sequence of only one byte
97 template <std::byte Byte>
98 struct UniqueByte : ByteSequence<Byte> {};
99 
100 [[maybe_unused]] constexpr ByteSequence<
103  WasmBinaryEncoding::Type::v128> valueTypesEncodings{};
104 
105 template <std::byte... allowedFlags>
106 constexpr bool isValueOneOf(std::byte value,
107  ByteSequence<allowedFlags...> = {}) {
108  return ((value == allowedFlags) | ... | false);
109 }
110 
111 template <std::byte... flags>
112 constexpr bool isNotIn(std::byte value, ByteSequence<flags...> = {}) {
113  return !isValueOneOf<flags...>(value);
114 }
115 
116 struct GlobalTypeRecord {
117  Type type;
118  bool isMutable;
119 };
120 
121 struct TypeIdxRecord {
122  size_t id;
123 };
124 
125 struct SymbolRefContainer {
126  FlatSymbolRefAttr symbol;
127 };
128 
129 struct GlobalSymbolRefContainer : SymbolRefContainer {
130  Type globalType;
131 };
132 
133 struct FunctionSymbolRefContainer : SymbolRefContainer {
134  FunctionType functionType;
135 };
136 
137 using ImportDesc =
138  std::variant<TypeIdxRecord, TableType, LimitType, GlobalTypeRecord>;
139 
140 using parsed_inst_t = FailureOr<SmallVector<Value>>;
141 
142 struct EmptyBlockMarker {};
143 using BlockTypeParseResult =
144  std::variant<EmptyBlockMarker, TypeIdxRecord, Type>;
145 
146 struct WasmModuleSymbolTables {
150  SmallVector<SymbolRefContainer> tableSymbols;
151  SmallVector<FunctionType> moduleFuncTypes;
152 
153  std::string getNewSymbolName(StringRef prefix, size_t id) const {
154  return (prefix + Twine{id}).str();
155  }
156 
157  std::string getNewFuncSymbolName() const {
158  size_t id = funcSymbols.size();
159  return getNewSymbolName("func_", id);
160  }
161 
162  std::string getNewGlobalSymbolName() const {
163  size_t id = globalSymbols.size();
164  return getNewSymbolName("global_", id);
165  }
166 
167  std::string getNewMemorySymbolName() const {
168  size_t id = memSymbols.size();
169  return getNewSymbolName("mem_", id);
170  }
171 
172  std::string getNewTableSymbolName() const {
173  size_t id = tableSymbols.size();
174  return getNewSymbolName("table_", id);
175  }
176 };
177 
178 class ParserHead;
179 
180 /// Wrapper around SmallVector to only allow access as push and pop on the
181 /// stack. Makes sure that there are no "free accesses" on the stack to preserve
182 /// its state.
183 /// This class also keep tracks of the Wasm labels defined by different ops,
184 /// which can be targeted by control flow ops. This can be modeled as part of
185 /// the Value Stack as Wasm control flow ops can only target enclosing labels.
186 class ValueStack {
187 private:
188  struct LabelLevel {
189  size_t stackIdx;
190  LabelLevelOpInterface levelOp;
191  };
192 
193 public:
194  bool empty() const { return values.empty(); }
195 
196  size_t size() const { return values.size(); }
197 
198  /// Pops values from the stack because they are being used in an operation.
199  /// @param operandTypes The list of expected types of the operation, used
200  /// to know how many values to pop and check if the types match the
201  /// expectation.
202  /// @param opLoc Location of the caller, used to report accurately the
203  /// location
204  /// if an error occurs.
205  /// @return Failure or the vector of popped values.
206  FailureOr<SmallVector<Value>> popOperands(TypeRange operandTypes,
207  Location *opLoc);
208 
209  /// Push the results of an operation to the stack so they can be used in a
210  /// following operation.
211  /// @param results The list of results of the operation
212  /// @param opLoc Location of the caller, used to report accurately the
213  /// location
214  /// if an error occurs.
215  LogicalResult pushResults(ValueRange results, Location *opLoc);
216 
217  void addLabelLevel(LabelLevelOpInterface levelOp) {
218  labelLevel.push_back({values.size(), levelOp});
219  LDBG() << "Adding a new frame context to ValueStack";
220  }
221 
222  void dropLabelLevel() {
223  assert(!labelLevel.empty() && "Trying to drop a frame from empty context");
224  auto newSize = labelLevel.pop_back_val().stackIdx;
225  values.truncate(newSize);
226  }
227 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
228  /// A simple dump function for debugging.
229  /// Writes output to llvm::dbgs().
230  LLVM_DUMP_METHOD void dump() const;
231 #endif
232 
233 private:
234  SmallVector<Value> values;
235  SmallVector<LabelLevel> labelLevel;
236 };
237 
238 using local_val_t = TypedValue<wasmssa::LocalRefType>;
239 
240 class ExpressionParser {
241 public:
242  using locals_t = SmallVector<local_val_t>;
243  ExpressionParser(ParserHead &parser, WasmModuleSymbolTables const &symbols,
244  ArrayRef<local_val_t> initLocal)
245  : parser{parser}, symbols{symbols}, locals{initLocal} {}
246 
247 private:
248  template <std::byte opCode>
249  inline parsed_inst_t parseSpecificInstruction(OpBuilder &builder);
250 
251  template <typename valueT>
252  parsed_inst_t
253  parseConstInst(OpBuilder &builder,
254  std::enable_if_t<std::is_arithmetic_v<valueT>> * = nullptr);
255 
256  /// Construct an operation with \p numOperands operands and a single result.
257  /// Each operand must have the same type. Suitable for e.g. binops, unary
258  /// ops, etc.
259  ///
260  /// \p opcode - The WASM opcode to build.
261  /// \p valueType - The operand and result type for the built instruction.
262  /// \p numOperands - The number of operands for the built operation.
263  ///
264  /// \returns The parsed instruction result, or failure.
265  template <typename opcode, typename valueType, unsigned int numOperands>
266  inline parsed_inst_t
267  buildNumericOp(OpBuilder &builder,
268  std::enable_if_t<std::is_arithmetic_v<valueType>> * = nullptr);
269 
270  /// Construct a conversion operation of type \p opType that takes a value from
271  /// type \p inputType on the stack and will produce a value of type
272  /// \p outputType.
273  ///
274  /// \p opType - The WASM dialect operation to build.
275  /// \p inputType - The operand type for the built instruction.
276  /// \p outputType - The result type for the built instruction.
277  ///
278  /// \returns The parsed instruction result, or failure.
279  template <typename opType, typename inputType, typename outputType,
280  typename... extraArgsT>
281  inline parsed_inst_t buildConvertOp(OpBuilder &builder, extraArgsT...);
282 
283  /// This function generates a dispatch tree to associate an opcode with a
284  /// parser. Parsers are registered by specialising the
285  /// `parseSpecificInstruction` function for the op code to handle.
286  ///
287  /// The dispatcher is generated by recursively creating all possible patterns
288  /// for an opcode and calling the relevant parser on the leaf.
289  ///
290  /// @tparam patternBitSize is the first bit for which the pattern is not fixed
291  ///
292  /// @tparam highBitPattern is the fixed pattern that this instance handles for
293  /// the 8-patternBitSize bits
294  template <size_t patternBitSize = 0, std::byte highBitPattern = std::byte{0}>
295  inline parsed_inst_t dispatchToInstParser(std::byte opCode,
296  OpBuilder &builder) {
297  static_assert(patternBitSize <= 8,
298  "PatternBitSize is outside of range of opcode space! "
299  "(expected at most 8 bits)");
300  if constexpr (patternBitSize < 8) {
301  constexpr std::byte bitSelect{1 << (7 - patternBitSize)};
302  constexpr std::byte nextHighBitPatternStem = highBitPattern << 1;
303  constexpr size_t nextPatternBitSize = patternBitSize + 1;
304  if ((opCode & bitSelect) != std::byte{0})
305  return dispatchToInstParser<nextPatternBitSize,
306  nextHighBitPatternStem | std::byte{1}>(
307  opCode, builder);
308  return dispatchToInstParser<nextPatternBitSize, nextHighBitPatternStem>(
309  opCode, builder);
310  } else {
311  return parseSpecificInstruction<highBitPattern>(builder);
312  }
313  }
314 
315  ///
316  /// RAII guard class for creating a nesting level
317  ///
318  struct NestingContextGuard {
319  NestingContextGuard(ExpressionParser &parser, LabelLevelOpInterface levelOp)
320  : parser{parser} {
321  parser.addNestingContextLevel(levelOp);
322  }
323  NestingContextGuard(NestingContextGuard &&other) : parser{other.parser} {
324  other.shouldDropOnDestruct = false;
325  }
326  NestingContextGuard(NestingContextGuard const &) = delete;
327  ~NestingContextGuard() {
328  if (shouldDropOnDestruct)
329  parser.dropNestingContextLevel();
330  }
331  ExpressionParser &parser;
332  bool shouldDropOnDestruct = true;
333  };
334 
335  void addNestingContextLevel(LabelLevelOpInterface levelOp) {
336  valueStack.addLabelLevel(levelOp);
337  }
338 
339  void dropNestingContextLevel() {
340  // Should always succeed as we are droping the frame that was previously
341  // created.
342  valueStack.dropLabelLevel();
343  }
344 
345  llvm::FailureOr<FunctionType> getFuncTypeFor(OpBuilder &builder,
346  EmptyBlockMarker) {
347  return builder.getFunctionType({}, {});
348  }
349 
350  llvm::FailureOr<FunctionType> getFuncTypeFor(OpBuilder &builder,
351  TypeIdxRecord type) {
352  if (type.id >= symbols.moduleFuncTypes.size())
353  return emitError(*currentOpLoc,
354  "type index references nonexistent type (")
355  << type.id << "). Only " << symbols.moduleFuncTypes.size()
356  << " types are registered";
357  return symbols.moduleFuncTypes[type.id];
358  }
359 
360  llvm::FailureOr<FunctionType> getFuncTypeFor(OpBuilder &builder,
361  Type valType) {
362  return builder.getFunctionType({}, {valType});
363  }
364 
365  llvm::FailureOr<FunctionType>
366  getFuncTypeFor(OpBuilder &builder, BlockTypeParseResult parseResult) {
367  return std::visit(
368  [this, &builder](auto value) { return getFuncTypeFor(builder, value); },
369  parseResult);
370  }
371 
372  llvm::FailureOr<FunctionType>
373  getFuncTypeFor(OpBuilder &builder,
374  llvm::FailureOr<BlockTypeParseResult> parseResult) {
375  if (llvm::failed(parseResult))
376  return failure();
377  return getFuncTypeFor(builder, *parseResult);
378  }
379 
380  llvm::FailureOr<FunctionType> parseBlockFuncType(OpBuilder &builder);
381 
382  struct ParseResultWithInfo {
383  SmallVector<Value> opResults;
384  std::byte endingByte;
385  };
386 
387  template <typename FilterT = ByteSequence<WasmBinaryEncoding::endByte>>
388  /// @param blockToFill: the block which content will be populated
389  /// @param resType: the type that this block is supposed to return
390  llvm::FailureOr<std::byte>
391  parseBlockContent(OpBuilder &builder, Block *blockToFill, TypeRange resTypes,
392  Location opLoc, LabelLevelOpInterface levelOp,
393  FilterT parseEndBytes = {}) {
394  OpBuilder::InsertionGuard guard{builder};
395  builder.setInsertionPointToStart(blockToFill);
396  LDBG() << "parsing a block of type "
397  << builder.getFunctionType(blockToFill->getArgumentTypes(),
398  resTypes);
399  auto nC = addNesting(levelOp);
400 
401  if (failed(pushResults(blockToFill->getArguments())))
402  return failure();
403  auto bodyParsingRes = parse(builder, parseEndBytes);
404  if (failed(bodyParsingRes))
405  return failure();
406  auto returnOperands = popOperands(resTypes);
407  if (failed(returnOperands))
408  return failure();
409  BlockReturnOp::create(builder, opLoc, *returnOperands);
410  LDBG() << "end of parsing of a block";
411  return bodyParsingRes->endingByte;
412  }
413 
414 public:
415  template <std::byte ParseEndByte = WasmBinaryEncoding::endByte>
416  parsed_inst_t parse(OpBuilder &builder, UniqueByte<ParseEndByte> = {});
417 
418  template <std::byte... ExpressionParseEnd>
419  FailureOr<ParseResultWithInfo>
420  parse(OpBuilder &builder,
421  ByteSequence<ExpressionParseEnd...> parsingEndFilters);
422 
423  NestingContextGuard addNesting(LabelLevelOpInterface levelOp) {
424  return NestingContextGuard{*this, levelOp};
425  }
426 
427  FailureOr<llvm::SmallVector<Value>> popOperands(TypeRange operandTypes) {
428  return valueStack.popOperands(operandTypes, &currentOpLoc.value());
429  }
430 
431  LogicalResult pushResults(ValueRange results) {
432  return valueStack.pushResults(results, &currentOpLoc.value());
433  }
434 
435  /// The local.set and local.tee operations behave similarly and only differ
436  /// on their return value. This function factorizes the behavior of the two
437  /// operations in one place.
438  template <typename OpToCreate>
439  parsed_inst_t parseSetOrTee(OpBuilder &);
440 
441  /// Blocks and Loops have a similar format and differ only in how their exit
442  /// is handled which doesn´t matter at parsing time. Factorizes in one
443  /// function.
444  template <typename OpToCreate>
445  parsed_inst_t parseBlockLikeOp(OpBuilder &);
446 
447 private:
448  std::optional<Location> currentOpLoc;
449  ParserHead &parser;
450  WasmModuleSymbolTables const &symbols;
451  locals_t locals;
452  ValueStack valueStack;
453 };
454 
455 class ParserHead {
456 public:
457  ParserHead(StringRef src, StringAttr name) : head{src}, locName{name} {}
458  ParserHead(ParserHead &&) = default;
459 
460 private:
461  ParserHead(ParserHead const &other) = default;
462 
463 public:
464  auto getLocation() const {
465  return FileLineColLoc::get(locName, 0, anchorOffset + offset);
466  }
467 
468  FailureOr<StringRef> consumeNBytes(size_t nBytes) {
469  LDBG() << "Consume " << nBytes << " bytes";
470  LDBG() << " Bytes remaining: " << size();
471  LDBG() << " Current offset: " << offset;
472  if (nBytes > size())
473  return emitError(getLocation(), "trying to extract ")
474  << nBytes << "bytes when only " << size() << "are available";
475 
476  StringRef res = head.slice(offset, offset + nBytes);
477  offset += nBytes;
478  LDBG() << " Updated offset (+" << nBytes << "): " << offset;
479  return res;
480  }
481 
482  FailureOr<std::byte> consumeByte() {
483  FailureOr<StringRef> res = consumeNBytes(1);
484  if (failed(res))
485  return failure();
486  return std::byte{*res->bytes_begin()};
487  }
488 
489  template <typename T>
490  FailureOr<T> parseLiteral();
491 
492  FailureOr<uint32_t> parseVectorSize();
493 
494 private:
495  // TODO: This is equivalent to parseLiteral<uint32_t> and could be removed
496  // if parseLiteral specialization were moved here, but default GCC on Ubuntu
497  // 22.04 has bug with template specialization in class declaration
498  inline FailureOr<uint32_t> parseUI32();
499  inline FailureOr<int64_t> parseI64();
500 
501 public:
502  FailureOr<StringRef> parseName() {
503  FailureOr<uint32_t> size = parseVectorSize();
504  if (failed(size))
505  return failure();
506 
507  return consumeNBytes(*size);
508  }
509 
510  FailureOr<WasmSectionType> parseWasmSectionType() {
511  FailureOr<std::byte> id = consumeByte();
512  if (failed(id))
513  return failure();
514  if (std::to_integer<unsigned>(*id) > highestWasmSectionID)
515  return emitError(getLocation(), "invalid section ID: ")
516  << static_cast<int>(*id);
517  return static_cast<WasmSectionType>(*id);
518  }
519 
520  FailureOr<LimitType> parseLimit(MLIRContext *ctx) {
521  using WasmLimits = WasmBinaryEncoding::LimitHeader;
522  FileLineColLoc limitLocation = getLocation();
523  FailureOr<std::byte> limitHeader = consumeByte();
524  if (failed(limitHeader))
525  return failure();
526 
527  if (isNotIn<WasmLimits::bothLimits, WasmLimits::lowLimitOnly>(*limitHeader))
528  return emitError(limitLocation, "invalid limit header: ")
529  << static_cast<int>(*limitHeader);
530  FailureOr<uint32_t> minParse = parseUI32();
531  if (failed(minParse))
532  return failure();
533  std::optional<uint32_t> max{std::nullopt};
534  if (*limitHeader == WasmLimits::bothLimits) {
535  FailureOr<uint32_t> maxParse = parseUI32();
536  if (failed(maxParse))
537  return failure();
538  max = *maxParse;
539  }
540  return LimitType::get(ctx, *minParse, max);
541  }
542 
543  FailureOr<Type> parseValueType(MLIRContext *ctx) {
544  FileLineColLoc typeLoc = getLocation();
545  FailureOr<std::byte> typeEncoding = consumeByte();
546  if (failed(typeEncoding))
547  return failure();
548  switch (*typeEncoding) {
550  return IntegerType::get(ctx, 32);
552  return IntegerType::get(ctx, 64);
554  return Float32Type::get(ctx);
556  return Float64Type::get(ctx);
558  return IntegerType::get(ctx, 128);
560  return wasmssa::FuncRefType::get(ctx);
562  return wasmssa::ExternRefType::get(ctx);
563  default:
564  return emitError(typeLoc, "invalid value type encoding: ")
565  << static_cast<int>(*typeEncoding);
566  }
567  }
568 
569  FailureOr<GlobalTypeRecord> parseGlobalType(MLIRContext *ctx) {
570  using WasmGlobalMut = WasmBinaryEncoding::GlobalMutability;
571  FailureOr<Type> typeParsed = parseValueType(ctx);
572  if (failed(typeParsed))
573  return failure();
574  FileLineColLoc mutLoc = getLocation();
575  FailureOr<std::byte> mutSpec = consumeByte();
576  if (failed(mutSpec))
577  return failure();
578  if (isNotIn<WasmGlobalMut::isConst, WasmGlobalMut::isMutable>(*mutSpec))
579  return emitError(mutLoc, "invalid global mutability specifier: ")
580  << static_cast<int>(*mutSpec);
581  return GlobalTypeRecord{*typeParsed, *mutSpec == WasmGlobalMut::isMutable};
582  }
583 
584  FailureOr<TupleType> parseResultType(MLIRContext *ctx) {
585  FailureOr<uint32_t> nParamsParsed = parseVectorSize();
586  if (failed(nParamsParsed))
587  return failure();
588  uint32_t nParams = *nParamsParsed;
589  SmallVector<Type> res{};
590  res.reserve(nParams);
591  for (size_t i = 0; i < nParams; ++i) {
592  FailureOr<Type> parsedType = parseValueType(ctx);
593  if (failed(parsedType))
594  return failure();
595  res.push_back(*parsedType);
596  }
597  return TupleType::get(ctx, res);
598  }
599 
600  FailureOr<FunctionType> parseFunctionType(MLIRContext *ctx) {
601  FileLineColLoc typeLoc = getLocation();
602  FailureOr<std::byte> funcTypeHeader = consumeByte();
603  if (failed(funcTypeHeader))
604  return failure();
605  if (*funcTypeHeader != WasmBinaryEncoding::Type::funcType)
606  return emitError(typeLoc, "invalid function type header byte. Expecting ")
607  << std::to_integer<unsigned>(WasmBinaryEncoding::Type::funcType)
608  << " got " << std::to_integer<unsigned>(*funcTypeHeader);
609  FailureOr<TupleType> inputTypes = parseResultType(ctx);
610  if (failed(inputTypes))
611  return failure();
612 
613  FailureOr<TupleType> resTypes = parseResultType(ctx);
614  if (failed(resTypes))
615  return failure();
616 
617  return FunctionType::get(ctx, inputTypes->getTypes(), resTypes->getTypes());
618  }
619 
620  FailureOr<TypeIdxRecord> parseTypeIndex() {
621  FailureOr<uint32_t> res = parseUI32();
622  if (failed(res))
623  return failure();
624  return TypeIdxRecord{*res};
625  }
626 
627  FailureOr<TableType> parseTableType(MLIRContext *ctx) {
628  FailureOr<Type> elmTypeParse = parseValueType(ctx);
629  if (failed(elmTypeParse))
630  return failure();
631  if (!isWasmRefType(*elmTypeParse))
632  return emitError(getLocation(), "invalid element type for table");
633  FailureOr<LimitType> limitParse = parseLimit(ctx);
634  if (failed(limitParse))
635  return failure();
636  return TableType::get(ctx, *elmTypeParse, *limitParse);
637  }
638 
639  FailureOr<ImportDesc> parseImportDesc(MLIRContext *ctx) {
640  FileLineColLoc importLoc = getLocation();
641  FailureOr<std::byte> importType = consumeByte();
642  auto packager = [](auto parseResult) -> FailureOr<ImportDesc> {
643  if (failed(parseResult))
644  return failure();
645  return {*parseResult};
646  };
647  if (failed(importType))
648  return failure();
649  switch (*importType) {
651  return packager(parseTypeIndex());
653  return packager(parseTableType(ctx));
655  return packager(parseLimit(ctx));
657  return packager(parseGlobalType(ctx));
658  default:
659  return emitError(importLoc, "invalid import type descriptor: ")
660  << static_cast<int>(*importType);
661  }
662  }
663 
664  parsed_inst_t parseExpression(OpBuilder &builder,
665  WasmModuleSymbolTables const &symbols,
666  ArrayRef<local_val_t> locals = {}) {
667  auto eParser = ExpressionParser{*this, symbols, locals};
668  return eParser.parse(builder);
669  }
670 
671  LogicalResult parseCodeFor(FuncOp func,
672  WasmModuleSymbolTables const &symbols) {
673  SmallVector<local_val_t> locals{};
674  // Populating locals with function argument
675  Block &block = func.getBody().front();
676  // Delete temporary return argument which was only created for IR validity
677  assert(func.getBody().getBlocks().size() == 1 &&
678  "Function should only have its default created block at this point");
679  assert(block.getOperations().size() == 1 &&
680  "Only the placeholder return op should be present at this point");
681  auto returnOp = cast<ReturnOp>(&block.back());
682  assert(returnOp);
683 
684  FailureOr<uint32_t> codeSizeInBytes = parseUI32();
685  if (failed(codeSizeInBytes))
686  return failure();
687  FailureOr<StringRef> codeContent = consumeNBytes(*codeSizeInBytes);
688  if (failed(codeContent))
689  return failure();
690  auto name = StringAttr::get(func->getContext(),
691  locName.str() + "::" + func.getSymName());
692  auto cParser = ParserHead{*codeContent, name};
693  FailureOr<uint32_t> localVecSize = cParser.parseVectorSize();
694  if (failed(localVecSize))
695  return failure();
696  OpBuilder builder{&func.getBody().front().back()};
697  for (auto arg : block.getArguments())
698  locals.push_back(cast<TypedValue<LocalRefType>>(arg));
699  // Declare the local ops
700  uint32_t nVarVec = *localVecSize;
701  for (size_t i = 0; i < nVarVec; ++i) {
702  FileLineColLoc varLoc = cParser.getLocation();
703  FailureOr<uint32_t> nSubVar = cParser.parseUI32();
704  if (failed(nSubVar))
705  return failure();
706  FailureOr<Type> varT = cParser.parseValueType(func->getContext());
707  if (failed(varT))
708  return failure();
709  for (size_t j = 0; j < *nSubVar; ++j) {
710  auto local = LocalOp::create(builder, varLoc, *varT);
711  locals.push_back(local.getResult());
712  }
713  }
714  parsed_inst_t res = cParser.parseExpression(builder, symbols, locals);
715  if (failed(res))
716  return failure();
717  if (!cParser.end())
718  return emitError(cParser.getLocation(),
719  "unparsed garbage remaining at end of code block");
720  ReturnOp::create(builder, func->getLoc(), *res);
721  returnOp->erase();
722  return success();
723  }
724 
725  llvm::FailureOr<BlockTypeParseResult> parseBlockType(MLIRContext *ctx) {
726  auto loc = getLocation();
727  auto blockIndicator = peek();
728  if (failed(blockIndicator))
729  return failure();
730  if (*blockIndicator == WasmBinaryEncoding::Type::emptyBlockType) {
731  offset += 1;
732  return {EmptyBlockMarker{}};
733  }
734  if (isValueOneOf(*blockIndicator, valueTypesEncodings))
735  return parseValueType(ctx);
736  /// Block type idx is a 32 bit positive integer encoded as a 33 bit signed
737  /// value
738  auto typeIdx = parseI64();
739  if (failed(typeIdx))
740  return failure();
741  if (*typeIdx < 0 || *typeIdx > std::numeric_limits<uint32_t>::max())
742  return emitError(loc, "type ID should be representable with an unsigned "
743  "32 bits integer. Got ")
744  << *typeIdx;
745  return {TypeIdxRecord{static_cast<uint32_t>(*typeIdx)}};
746  }
747 
748  bool end() const { return curHead().empty(); }
749 
750  ParserHead copy() const { return *this; }
751 
752 private:
753  StringRef curHead() const { return head.drop_front(offset); }
754 
755  FailureOr<std::byte> peek() const {
756  if (end())
757  return emitError(
758  getLocation(),
759  "trying to peek at next byte, but input stream is empty");
760  return static_cast<std::byte>(curHead().front());
761  }
762 
763  size_t size() const { return head.size() - offset; }
764 
765  StringRef head;
766  StringAttr locName;
767  unsigned anchorOffset{0};
768  unsigned offset{0};
769 };
770 
771 template <>
772 FailureOr<float> ParserHead::parseLiteral<float>() {
773  FailureOr<StringRef> bytes = consumeNBytes(4);
774  if (failed(bytes))
775  return failure();
776  return llvm::support::endian::read<float>(bytes->bytes_begin(),
777  llvm::endianness::little);
778 }
779 
780 template <>
781 FailureOr<double> ParserHead::parseLiteral<double>() {
782  FailureOr<StringRef> bytes = consumeNBytes(8);
783  if (failed(bytes))
784  return failure();
785  return llvm::support::endian::read<double>(bytes->bytes_begin(),
786  llvm::endianness::little);
787 }
788 
789 template <>
790 FailureOr<uint32_t> ParserHead::parseLiteral<uint32_t>() {
791  char const *error = nullptr;
792  uint32_t res{0};
793  unsigned encodingSize{0};
794  StringRef src = curHead();
795  uint64_t decoded = llvm::decodeULEB128(src.bytes_begin(), &encodingSize,
796  src.bytes_end(), &error);
797  if (error)
798  return emitError(getLocation(), error);
799 
800  if (std::isgreater(decoded, std::numeric_limits<uint32_t>::max()))
801  return emitError(getLocation()) << "literal does not fit on 32 bits";
802 
803  res = static_cast<uint32_t>(decoded);
804  offset += encodingSize;
805  return res;
806 }
807 
808 template <>
809 FailureOr<int32_t> ParserHead::parseLiteral<int32_t>() {
810  char const *error = nullptr;
811  int32_t res{0};
812  unsigned encodingSize{0};
813  StringRef src = curHead();
814  int64_t decoded = llvm::decodeSLEB128(src.bytes_begin(), &encodingSize,
815  src.bytes_end(), &error);
816  if (error)
817  return emitError(getLocation(), error);
818  if (std::isgreater(decoded, std::numeric_limits<int32_t>::max()) ||
819  std::isgreater(std::numeric_limits<int32_t>::min(), decoded))
820  return emitError(getLocation()) << "literal does not fit on 32 bits";
821 
822  res = static_cast<int32_t>(decoded);
823  offset += encodingSize;
824  return res;
825 }
826 
827 template <>
828 FailureOr<int64_t> ParserHead::parseLiteral<int64_t>() {
829  char const *error = nullptr;
830  unsigned encodingSize{0};
831  StringRef src = curHead();
832  int64_t res = llvm::decodeSLEB128(src.bytes_begin(), &encodingSize,
833  src.bytes_end(), &error);
834  if (error)
835  return emitError(getLocation(), error);
836 
837  offset += encodingSize;
838  return res;
839 }
840 
841 FailureOr<uint32_t> ParserHead::parseVectorSize() {
842  return parseLiteral<uint32_t>();
843 }
844 
845 inline FailureOr<uint32_t> ParserHead::parseUI32() {
846  return parseLiteral<uint32_t>();
847 }
848 
849 inline FailureOr<int64_t> ParserHead::parseI64() {
850  return parseLiteral<int64_t>();
851 }
852 
853 template <std::byte opCode>
854 inline parsed_inst_t ExpressionParser::parseSpecificInstruction(OpBuilder &) {
855  return emitError(*currentOpLoc, "unknown instruction opcode: ")
856  << static_cast<int>(opCode);
857 }
858 
859 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
860 void ValueStack::dump() const {
861  llvm::dbgs() << "================= Wasm ValueStack =======================\n";
862  llvm::dbgs() << "size: " << size() << "\n";
863  llvm::dbgs() << "nbFrames: " << labelLevel.size() << '\n';
864  llvm::dbgs() << "<Top>"
865  << "\n";
866  // Stack is pushed to via push_back. Therefore the top of the stack is the
867  // end of the vector. Iterate in reverse so that the first thing we print
868  // is the top of the stack.
869  auto indexGetter = [this]() {
870  size_t idx = labelLevel.size();
871  return [this, idx]() mutable -> std::optional<std::pair<size_t, size_t>> {
872  llvm::dbgs() << "IDX: " << idx << '\n';
873  if (idx == 0)
874  return std::nullopt;
875  auto frameId = idx - 1;
876  auto frameLimit = labelLevel[frameId].stackIdx;
877  idx -= 1;
878  return {{frameId, frameLimit}};
879  };
880  };
881  auto getNextFrameIndex = indexGetter();
882  auto nextFrameIdx = getNextFrameIndex();
883  size_t stackSize = size();
884  for (size_t idx = 0; idx < stackSize; ++idx) {
885  size_t actualIdx = stackSize - 1 - idx;
886  while (nextFrameIdx && (nextFrameIdx->second > actualIdx)) {
887  llvm::dbgs() << " --------------- Frame (" << nextFrameIdx->first
888  << ")\n";
889  nextFrameIdx = getNextFrameIndex();
890  }
891  llvm::dbgs() << " ";
892  values[actualIdx].dump();
893  }
894  while (nextFrameIdx) {
895  llvm::dbgs() << " --------------- Frame (" << nextFrameIdx->first << ")\n";
896  nextFrameIdx = getNextFrameIndex();
897  }
898  llvm::dbgs() << "<Bottom>"
899  << "\n";
900  llvm::dbgs() << "=========================================================\n";
901 }
902 #endif
903 
904 parsed_inst_t ValueStack::popOperands(TypeRange operandTypes, Location *opLoc) {
905  LDBG() << "Popping from ValueStack\n"
906  << " Elements(s) to pop: " << operandTypes.size() << "\n"
907  << " Current stack size: " << values.size();
908  if (operandTypes.size() > values.size())
909  return emitError(*opLoc,
910  "stack doesn't contain enough values. trying to get ")
911  << operandTypes.size() << " operands on a stack containing only "
912  << values.size() << " values";
913  size_t stackIdxOffset = values.size() - operandTypes.size();
914  SmallVector<Value> res{};
915  res.reserve(operandTypes.size());
916  for (size_t i{0}; i < operandTypes.size(); ++i) {
917  Value operand = values[i + stackIdxOffset];
918  Type stackType = operand.getType();
919  if (stackType != operandTypes[i])
920  return emitError(*opLoc, "invalid operand type on stack. expecting ")
921  << operandTypes[i] << ", value on stack is of type " << stackType;
922  LDBG() << " POP: " << operand;
923  res.push_back(operand);
924  }
925  values.resize(values.size() - operandTypes.size());
926  LDBG() << " Updated stack size: " << values.size();
927  return res;
928 }
929 
930 LogicalResult ValueStack::pushResults(ValueRange results, Location *opLoc) {
931  LDBG() << "Pushing to ValueStack\n"
932  << " Elements(s) to push: " << results.size() << "\n"
933  << " Current stack size: " << values.size();
934  for (Value val : results) {
935  if (!isWasmValueType(val.getType()))
936  return emitError(*opLoc, "invalid value type on stack: ")
937  << val.getType();
938  LDBG() << " PUSH: " << val;
939  values.push_back(val);
940  }
941 
942  LDBG() << " Updated stack size: " << values.size();
943  return success();
944 }
945 
946 template <std::byte EndParseByte>
947 parsed_inst_t ExpressionParser::parse(OpBuilder &builder,
948  UniqueByte<EndParseByte> endByte) {
949  auto res = parse(builder, ByteSequence<EndParseByte>{});
950  if (failed(res))
951  return failure();
952  return res->opResults;
953 }
954 
955 template <std::byte... ExpressionParseEnd>
956 FailureOr<ExpressionParser::ParseResultWithInfo>
958  ByteSequence<ExpressionParseEnd...> parsingEndFilters) {
959  SmallVector<Value> res;
960  for (;;) {
961  currentOpLoc = parser.getLocation();
962  FailureOr<std::byte> opCode = parser.consumeByte();
963  if (failed(opCode))
964  return failure();
965  if (isValueOneOf(*opCode, parsingEndFilters))
966  return {{res, *opCode}};
967  parsed_inst_t resParsed;
968  resParsed = dispatchToInstParser(*opCode, builder);
969  if (failed(resParsed))
970  return failure();
971  std::swap(res, *resParsed);
972  if (failed(pushResults(res)))
973  return failure();
974  }
975 }
976 
977 llvm::FailureOr<FunctionType>
978 ExpressionParser::parseBlockFuncType(OpBuilder &builder) {
979  return getFuncTypeFor(builder, parser.parseBlockType(builder.getContext()));
980 }
981 
982 template <typename OpToCreate>
983 parsed_inst_t ExpressionParser::parseBlockLikeOp(OpBuilder &builder) {
984  auto opLoc = currentOpLoc;
985  auto funcType = parseBlockFuncType(builder);
986  if (failed(funcType))
987  return failure();
988 
989  auto inputTypes = funcType->getInputs();
990  auto inputOps = popOperands(inputTypes);
991  if (failed(inputOps))
992  return failure();
993 
994  Block *curBlock = builder.getBlock();
995  Region *curRegion = curBlock->getParent();
996  auto resTypes = funcType->getResults();
997  llvm::SmallVector<Location> locations{};
998  locations.resize(resTypes.size(), *currentOpLoc);
999  auto *successor =
1000  builder.createBlock(curRegion, curRegion->end(), resTypes, locations);
1001  builder.setInsertionPointToEnd(curBlock);
1002  auto blockOp =
1003  OpToCreate::create(builder, *currentOpLoc, *inputOps, successor);
1004  auto *blockBody = blockOp.createBlock();
1005  if (failed(parseBlockContent(builder, blockBody, resTypes, *opLoc, blockOp)))
1006  return failure();
1007  builder.setInsertionPointToStart(successor);
1008  return {ValueRange{successor->getArguments()}};
1009 }
1010 
1011 template <>
1012 inline parsed_inst_t
1013 ExpressionParser::parseSpecificInstruction<WasmBinaryEncoding::OpCode::block>(
1014  OpBuilder &builder) {
1015  return parseBlockLikeOp<BlockOp>(builder);
1016 }
1017 
1018 template <>
1019 inline parsed_inst_t
1020 ExpressionParser::parseSpecificInstruction<WasmBinaryEncoding::OpCode::loop>(
1021  OpBuilder &builder) {
1022  return parseBlockLikeOp<LoopOp>(builder);
1023 }
1024 
1025 template <>
1026 inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
1028  auto opLoc = currentOpLoc;
1029  auto funcType = parseBlockFuncType(builder);
1030  if (failed(funcType))
1031  return failure();
1032 
1033  LDBG() << "Parsing an if instruction of type " << *funcType;
1034  auto inputTypes = funcType->getInputs();
1035  auto conditionValue = popOperands(builder.getI32Type());
1036  if (failed(conditionValue))
1037  return failure();
1038  auto inputOps = popOperands(inputTypes);
1039  if (failed(inputOps))
1040  return failure();
1041 
1042  Block *curBlock = builder.getBlock();
1043  Region *curRegion = curBlock->getParent();
1044  auto resTypes = funcType->getResults();
1045  llvm::SmallVector<Location> locations{};
1046  locations.resize(resTypes.size(), *currentOpLoc);
1047  auto *successor =
1048  builder.createBlock(curRegion, curRegion->end(), resTypes, locations);
1049  builder.setInsertionPointToEnd(curBlock);
1050  auto ifOp = IfOp::create(builder, *currentOpLoc, conditionValue->front(),
1051  *inputOps, successor);
1052  auto *ifEntryBlock = ifOp.createIfBlock();
1053  constexpr auto ifElseFilter =
1054  ByteSequence<WasmBinaryEncoding::endByte,
1056  auto parseIfRes = parseBlockContent(builder, ifEntryBlock, resTypes, *opLoc,
1057  ifOp, ifElseFilter);
1058  if (failed(parseIfRes))
1059  return failure();
1060  if (*parseIfRes == WasmBinaryEncoding::OpCode::elseOpCode) {
1061  LDBG() << " else block is present.";
1062  Block *elseEntryBlock = ifOp.createElseBlock();
1063  auto parseElseRes =
1064  parseBlockContent(builder, elseEntryBlock, resTypes, *opLoc, ifOp);
1065  if (failed(parseElseRes))
1066  return failure();
1067  }
1068  builder.setInsertionPointToStart(successor);
1069  return {ValueRange{successor->getArguments()}};
1070 }
1071 
1072 template <>
1073 inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
1075  auto level = parser.parseLiteral<uint32_t>();
1076  if (failed(level))
1077  return failure();
1078  Block *curBlock = builder.getBlock();
1079  Region *curRegion = curBlock->getParent();
1080  auto sip = builder.saveInsertionPoint();
1081  Block *elseBlock = builder.createBlock(curRegion, curRegion->end());
1082  auto condition = popOperands(builder.getI32Type());
1083  if (failed(condition))
1084  return failure();
1085  builder.restoreInsertionPoint(sip);
1086  auto targetOp =
1087  LabelBranchingOpInterface::getTargetOpFromBlock(curBlock, *level);
1088  if (failed(targetOp))
1089  return failure();
1090  auto inputTypes = targetOp->getLabelTarget()->getArgumentTypes();
1091  auto branchArgs = popOperands(inputTypes);
1092  if (failed(branchArgs))
1093  return failure();
1094  BranchIfOp::create(builder, *currentOpLoc, condition->front(),
1095  builder.getUI32IntegerAttr(*level), *branchArgs,
1096  elseBlock);
1097  builder.setInsertionPointToStart(elseBlock);
1098  return {*branchArgs};
1099 }
1100 
1101 template <>
1102 inline parsed_inst_t
1103 ExpressionParser::parseSpecificInstruction<WasmBinaryEncoding::OpCode::call>(
1104  OpBuilder &builder) {
1105  auto loc = *currentOpLoc;
1106  auto funcIdx = parser.parseLiteral<uint32_t>();
1107  if (failed(funcIdx))
1108  return failure();
1109  if (*funcIdx >= symbols.funcSymbols.size())
1110  return emitError(loc, "Invalid function index: ") << *funcIdx;
1111  auto callee = symbols.funcSymbols[*funcIdx];
1112  llvm::ArrayRef<Type> inTypes = callee.functionType.getInputs();
1113  llvm::ArrayRef<Type> resTypes = callee.functionType.getResults();
1114  parsed_inst_t inOperands = popOperands(inTypes);
1115  if (failed(inOperands))
1116  return failure();
1117  auto callOp =
1118  FuncCallOp::create(builder, loc, resTypes, callee.symbol, *inOperands);
1119  return {callOp.getResults()};
1120 }
1121 
1122 template <>
1123 inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
1125  FailureOr<uint32_t> id = parser.parseLiteral<uint32_t>();
1126  Location instLoc = *currentOpLoc;
1127  if (failed(id))
1128  return failure();
1129  if (*id >= locals.size())
1130  return emitError(instLoc, "invalid local index. function has ")
1131  << locals.size() << " accessible locals, received index " << *id;
1132  return {{LocalGetOp::create(builder, instLoc, locals[*id]).getResult()}};
1133 }
1134 
1135 template <>
1136 inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
1138  FailureOr<uint32_t> id = parser.parseLiteral<uint32_t>();
1139  Location instLoc = *currentOpLoc;
1140  if (failed(id))
1141  return failure();
1142  if (*id >= symbols.globalSymbols.size())
1143  return emitError(instLoc, "invalid global index. function has ")
1144  << symbols.globalSymbols.size()
1145  << " accessible globals, received index " << *id;
1146  GlobalSymbolRefContainer globalVar = symbols.globalSymbols[*id];
1147  auto globalOp = GlobalGetOp::create(builder, instLoc, globalVar.globalType,
1148  globalVar.symbol);
1149 
1150  return {{globalOp.getResult()}};
1151 }
1152 
1153 template <typename OpToCreate>
1154 parsed_inst_t ExpressionParser::parseSetOrTee(OpBuilder &builder) {
1155  FailureOr<uint32_t> id = parser.parseLiteral<uint32_t>();
1156  if (failed(id))
1157  return failure();
1158  if (*id >= locals.size())
1159  return emitError(*currentOpLoc, "invalid local index. function has ")
1160  << locals.size() << " accessible locals, received index " << *id;
1161  if (valueStack.empty())
1162  return emitError(
1163  *currentOpLoc,
1164  "invalid stack access, trying to access a value on an empty stack");
1165 
1166  parsed_inst_t poppedOp = popOperands(locals[*id].getType().getElementType());
1167  if (failed(poppedOp))
1168  return failure();
1169  return {
1170  OpToCreate::create(builder, *currentOpLoc, locals[*id], poppedOp->front())
1171  ->getResults()};
1172 }
1173 
1174 template <>
1175 inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
1177  return parseSetOrTee<LocalSetOp>(builder);
1178 }
1179 
1180 template <>
1181 inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
1183  return parseSetOrTee<LocalTeeOp>(builder);
1184 }
1185 
1186 template <typename T>
1187 inline Type buildLiteralType(OpBuilder &);
1188 
1189 template <>
1190 inline Type buildLiteralType<int32_t>(OpBuilder &builder) {
1191  return builder.getI32Type();
1192 }
1193 
1194 template <>
1195 inline Type buildLiteralType<int64_t>(OpBuilder &builder) {
1196  return builder.getI64Type();
1197 }
1198 
1199 template <>
1200 [[maybe_unused]] inline Type buildLiteralType<uint32_t>(OpBuilder &builder) {
1201  return builder.getI32Type();
1202 }
1203 
1204 template <>
1205 [[maybe_unused]] inline Type buildLiteralType<uint64_t>(OpBuilder &builder) {
1206  return builder.getI64Type();
1207 }
1208 
1209 template <>
1210 inline Type buildLiteralType<float>(OpBuilder &builder) {
1211  return builder.getF32Type();
1212 }
1213 
1214 template <>
1215 inline Type buildLiteralType<double>(OpBuilder &builder) {
1216  return builder.getF64Type();
1217 }
1218 
1219 template <typename ValT,
1220  typename E = std::enable_if_t<std::is_arithmetic_v<ValT>>>
1221 struct AttrHolder;
1222 
1223 template <typename ValT>
1224 struct AttrHolder<ValT, std::enable_if_t<std::is_integral_v<ValT>>> {
1225  using type = IntegerAttr;
1226 };
1227 
1228 template <typename ValT>
1229 struct AttrHolder<ValT, std::enable_if_t<std::is_floating_point_v<ValT>>> {
1230  using type = FloatAttr;
1231 };
1232 
1233 template <typename ValT>
1234 using attr_holder_t = typename AttrHolder<ValT>::type;
1235 
1236 template <typename ValT,
1237  typename EnableT = std::enable_if_t<std::is_arithmetic_v<ValT>>>
1238 attr_holder_t<ValT> buildLiteralAttr(OpBuilder &builder, ValT val) {
1239  return attr_holder_t<ValT>::get(buildLiteralType<ValT>(builder), val);
1240 }
1241 
1242 template <typename valueT>
1243 parsed_inst_t ExpressionParser::parseConstInst(
1244  OpBuilder &builder, std::enable_if_t<std::is_arithmetic_v<valueT>> *) {
1245  auto parsedConstant = parser.parseLiteral<valueT>();
1246  if (failed(parsedConstant))
1247  return failure();
1248  auto constOp =
1249  ConstOp::create(builder, *currentOpLoc,
1250  buildLiteralAttr<valueT>(builder, *parsedConstant));
1251  return {{constOp.getResult()}};
1252 }
1253 
1254 template <>
1255 inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
1257  return parseConstInst<int32_t>(builder);
1258 }
1259 
1260 template <>
1261 inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
1263  return parseConstInst<int64_t>(builder);
1264 }
1265 
1266 template <>
1267 inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
1269  return parseConstInst<float>(builder);
1270 }
1271 
1272 template <>
1273 inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
1275  return parseConstInst<double>(builder);
1276 }
1277 
1278 template <typename opcode, typename valueType, unsigned int numOperands>
1279 inline parsed_inst_t ExpressionParser::buildNumericOp(
1280  OpBuilder &builder, std::enable_if_t<std::is_arithmetic_v<valueType>> *) {
1281  auto ty = buildLiteralType<valueType>(builder);
1282  LDBG() << "*** buildNumericOp: numOperands = " << numOperands
1283  << ", type = " << ty << " ***";
1284  auto tysToPop = SmallVector<Type, numOperands>();
1285  tysToPop.resize(numOperands);
1286  llvm::fill(tysToPop, ty);
1287  auto operands = popOperands(tysToPop);
1288  if (failed(operands))
1289  return failure();
1290  auto op = opcode::create(builder, *currentOpLoc, *operands).getResult();
1291  LDBG() << "Built operation: " << op;
1292  return {{op}};
1293 }
1294 
1295 // Convenience macro for generating numerical operations.
1296 #define BUILD_NUMERIC_OP(OP_NAME, N_ARGS, PREFIX, SUFFIX, TYPE) \
1297  template <> \
1298  inline parsed_inst_t ExpressionParser::parseSpecificInstruction< \
1299  WasmBinaryEncoding::OpCode::PREFIX##SUFFIX>(OpBuilder & builder) { \
1300  return buildNumericOp<OP_NAME, TYPE, N_ARGS>(builder); \
1301  }
1302 
1303 // Macro to define binops that only support integer types.
1304 #define BUILD_NUMERIC_BINOP_INT(OP_NAME, PREFIX) \
1305  BUILD_NUMERIC_OP(OP_NAME, 2, PREFIX, I32, int32_t) \
1306  BUILD_NUMERIC_OP(OP_NAME, 2, PREFIX, I64, int64_t)
1307 
1308 // Macro to define binops that only support floating point types.
1309 #define BUILD_NUMERIC_BINOP_FP(OP_NAME, PREFIX) \
1310  BUILD_NUMERIC_OP(OP_NAME, 2, PREFIX, F32, float) \
1311  BUILD_NUMERIC_OP(OP_NAME, 2, PREFIX, F64, double)
1312 
1313 // Macro to define binops that support both floating point and integer types.
1314 #define BUILD_NUMERIC_BINOP_INTFP(OP_NAME, PREFIX) \
1315  BUILD_NUMERIC_BINOP_INT(OP_NAME, PREFIX) \
1316  BUILD_NUMERIC_BINOP_FP(OP_NAME, PREFIX)
1317 
1318 // Macro to implement unary ops that only support integers.
1319 #define BUILD_NUMERIC_UNARY_OP_INT(OP_NAME, PREFIX) \
1320  BUILD_NUMERIC_OP(OP_NAME, 1, PREFIX, I32, int32_t) \
1321  BUILD_NUMERIC_OP(OP_NAME, 1, PREFIX, I64, int64_t)
1322 
1323 // Macro to implement unary ops that support integer and floating point types.
1324 #define BUILD_NUMERIC_UNARY_OP_FP(OP_NAME, PREFIX) \
1325  BUILD_NUMERIC_OP(OP_NAME, 1, PREFIX, F32, float) \
1326  BUILD_NUMERIC_OP(OP_NAME, 1, PREFIX, F64, double)
1327 
1328 BUILD_NUMERIC_BINOP_FP(CopySignOp, copysign)
1330 BUILD_NUMERIC_BINOP_FP(GeOp, ge)
1331 BUILD_NUMERIC_BINOP_FP(GtOp, gt)
1332 BUILD_NUMERIC_BINOP_FP(LeOp, le)
1333 BUILD_NUMERIC_BINOP_FP(LtOp, lt)
1336 BUILD_NUMERIC_BINOP_INT(AndOp, and)
1337 BUILD_NUMERIC_BINOP_INT(DivSIOp, divS)
1338 BUILD_NUMERIC_BINOP_INT(DivUIOp, divU)
1339 BUILD_NUMERIC_BINOP_INT(GeSIOp, geS)
1340 BUILD_NUMERIC_BINOP_INT(GeUIOp, geU)
1341 BUILD_NUMERIC_BINOP_INT(GtSIOp, gtS)
1342 BUILD_NUMERIC_BINOP_INT(GtUIOp, gtU)
1343 BUILD_NUMERIC_BINOP_INT(LeSIOp, leS)
1344 BUILD_NUMERIC_BINOP_INT(LeUIOp, leU)
1345 BUILD_NUMERIC_BINOP_INT(LtSIOp, ltS)
1346 BUILD_NUMERIC_BINOP_INT(LtUIOp, ltU)
1347 BUILD_NUMERIC_BINOP_INT(OrOp, or)
1348 BUILD_NUMERIC_BINOP_INT(RemSIOp, remS)
1349 BUILD_NUMERIC_BINOP_INT(RemUIOp, remU)
1350 BUILD_NUMERIC_BINOP_INT(RotlOp, rotl)
1351 BUILD_NUMERIC_BINOP_INT(RotrOp, rotr)
1352 BUILD_NUMERIC_BINOP_INT(ShLOp, shl)
1353 BUILD_NUMERIC_BINOP_INT(ShRSOp, shrS)
1354 BUILD_NUMERIC_BINOP_INT(ShRUOp, shrU)
1355 BUILD_NUMERIC_BINOP_INT(XOrOp, xor)
1357 BUILD_NUMERIC_BINOP_INTFP(EqOp, eq)
1359 BUILD_NUMERIC_BINOP_INTFP(NeOp, ne)
1360 BUILD_NUMERIC_BINOP_INTFP(SubOp, sub)
1361 BUILD_NUMERIC_UNARY_OP_FP(AbsOp, abs)
1364 BUILD_NUMERIC_UNARY_OP_FP(NegOp, neg)
1365 BUILD_NUMERIC_UNARY_OP_FP(SqrtOp, sqrt)
1366 BUILD_NUMERIC_UNARY_OP_FP(TruncOp, trunc)
1367 BUILD_NUMERIC_UNARY_OP_INT(ClzOp, clz)
1368 BUILD_NUMERIC_UNARY_OP_INT(CtzOp, ctz)
1369 BUILD_NUMERIC_UNARY_OP_INT(EqzOp, eqz)
1370 BUILD_NUMERIC_UNARY_OP_INT(PopCntOp, popcnt)
1371 
1372 // Don't need these anymore so let's undef them.
1373 #undef BUILD_NUMERIC_BINOP_FP
1374 #undef BUILD_NUMERIC_BINOP_INT
1375 #undef BUILD_NUMERIC_BINOP_INTFP
1376 #undef BUILD_NUMERIC_UNARY_OP_FP
1377 #undef BUILD_NUMERIC_UNARY_OP_INT
1378 #undef BUILD_NUMERIC_OP
1379 #undef BUILD_NUMERIC_CAST_OP
1380 
1381 template <typename opType, typename inputType, typename outputType,
1382  typename... extraArgsT>
1383 inline parsed_inst_t ExpressionParser::buildConvertOp(OpBuilder &builder,
1384  extraArgsT... extraArgs) {
1385  static_assert(std::is_arithmetic_v<inputType>,
1386  "InputType should be an arithmetic type");
1387  static_assert(std::is_arithmetic_v<outputType>,
1388  "OutputType should be an arithmetic type");
1389  auto intype = buildLiteralType<inputType>(builder);
1390  auto outType = buildLiteralType<outputType>(builder);
1391  auto operand = popOperands(intype);
1392  if (failed(operand))
1393  return failure();
1394  auto op = opType::create(builder, *currentOpLoc, outType, operand->front(),
1395  extraArgs...);
1396  LDBG() << "Built operation: " << op;
1397  return {{op.getResult()}};
1398 }
1399 
1400 template <>
1401 inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
1403  return buildConvertOp<DemoteOp, double, float>(builder);
1404 }
1405 
1406 template <>
1407 inline parsed_inst_t
1408 ExpressionParser::parseSpecificInstruction<WasmBinaryEncoding::OpCode::wrap>(
1409  OpBuilder &builder) {
1410  return buildConvertOp<WrapOp, int64_t, int32_t>(builder);
1411 }
1412 
1413 #define BUILD_CONVERSION_OP(IN_T, OUT_T, SOURCE_OP, TARGET_OP) \
1414  template <> \
1415  inline parsed_inst_t ExpressionParser::parseSpecificInstruction< \
1416  WasmBinaryEncoding::OpCode::SOURCE_OP>(OpBuilder & builder) { \
1417  return buildConvertOp<TARGET_OP, IN_T, OUT_T>(builder); \
1418  }
1419 
1420 #define BUILD_CONVERT_OP_FOR(DEST_T, WIDTH) \
1421  BUILD_CONVERSION_OP(uint32_t, DEST_T, convertUI32F##WIDTH, ConvertUOp) \
1422  BUILD_CONVERSION_OP(int32_t, DEST_T, convertSI32F##WIDTH, ConvertSOp) \
1423  BUILD_CONVERSION_OP(uint64_t, DEST_T, convertUI64F##WIDTH, ConvertUOp) \
1424  BUILD_CONVERSION_OP(int64_t, DEST_T, convertSI64F##WIDTH, ConvertSOp)
1425 
1426 BUILD_CONVERT_OP_FOR(float, 32)
1427 BUILD_CONVERT_OP_FOR(double, 64)
1428 
1429 #undef BUILD_CONVERT_OP_FOR
1430 
1431 BUILD_CONVERSION_OP(int32_t, int64_t, extendS, ExtendSI32Op)
1432 BUILD_CONVERSION_OP(int32_t, int64_t, extendU, ExtendUI32Op)
1433 
1434 #undef BUILD_CONVERSION_OP
1435 
1436 #define BUILD_SLICE_EXTEND_PARSER(IT_WIDTH, EXTRACT_WIDTH) \
1437  template <> \
1438  parsed_inst_t ExpressionParser::parseSpecificInstruction< \
1439  WasmBinaryEncoding::OpCode::extendI##IT_WIDTH##EXTRACT_WIDTH##S>( \
1440  OpBuilder & builder) { \
1441  using inout_t = int##IT_WIDTH##_t; \
1442  auto attr = builder.getUI32IntegerAttr(EXTRACT_WIDTH); \
1443  return buildConvertOp<ExtendLowBitsSOp, inout_t, inout_t>(builder, attr); \
1444  }
1445 
1451 
1452 #undef BUILD_SLICE_EXTEND_PARSER
1453 
1454 template <>
1455 inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
1457  return buildConvertOp<PromoteOp, float, double>(builder);
1458 }
1459 
1460 #define BUILD_REINTERPRET_PARSER(WIDTH, FP_TYPE) \
1461  template <> \
1462  inline parsed_inst_t ExpressionParser::parseSpecificInstruction< \
1463  WasmBinaryEncoding::OpCode::reinterpretF##WIDTH##AsI##WIDTH>(OpBuilder & \
1464  builder) { \
1465  return buildConvertOp<ReinterpretOp, FP_TYPE, int##WIDTH##_t>(builder); \
1466  } \
1467  \
1468  template <> \
1469  inline parsed_inst_t ExpressionParser::parseSpecificInstruction< \
1470  WasmBinaryEncoding::OpCode::reinterpretI##WIDTH##AsF##WIDTH>(OpBuilder & \
1471  builder) { \
1472  return buildConvertOp<ReinterpretOp, int##WIDTH##_t, FP_TYPE>(builder); \
1473  }
1474 
1475 BUILD_REINTERPRET_PARSER(32, float)
1476 BUILD_REINTERPRET_PARSER(64, double)
1477 
1478 #undef BUILD_REINTERPRET_PARSER
1479 
1480 class WasmBinaryParser {
1481 private:
1482  struct SectionRegistry {
1483  using section_location_t = StringRef;
1484 
1485  std::array<SmallVector<section_location_t>, highestWasmSectionID + 1>
1486  registry;
1487 
1488  template <WasmSectionType SecType>
1489  std::conditional_t<sectionShouldBeUnique(SecType),
1490  std::optional<section_location_t>,
1492  getContentForSection() const {
1493  constexpr auto idx = static_cast<size_t>(SecType);
1494  if constexpr (sectionShouldBeUnique(SecType)) {
1495  return registry[idx].empty() ? std::nullopt
1496  : std::make_optional(registry[idx][0]);
1497  } else {
1498  return registry[idx];
1499  }
1500  }
1501 
1502  bool hasSection(WasmSectionType secType) const {
1503  return !registry[static_cast<size_t>(secType)].empty();
1504  }
1505 
1506  ///
1507  /// @returns success if registration valid, failure in case registration
1508  /// can't be done (if another section of same type already exist and this
1509  /// section type should only be present once)
1510  ///
1511  LogicalResult registerSection(WasmSectionType secType,
1512  section_location_t location, Location loc) {
1513  if (sectionShouldBeUnique(secType) && hasSection(secType))
1514  return emitError(loc,
1515  "trying to add a second instance of unique section");
1516 
1517  registry[static_cast<size_t>(secType)].push_back(location);
1518  emitRemark(loc, "Adding section with section ID ")
1519  << static_cast<uint8_t>(secType);
1520  return success();
1521  }
1522 
1523  LogicalResult populateFromBody(ParserHead ph) {
1524  while (!ph.end()) {
1525  FileLineColLoc sectionLoc = ph.getLocation();
1526  FailureOr<WasmSectionType> secType = ph.parseWasmSectionType();
1527  if (failed(secType))
1528  return failure();
1529 
1530  FailureOr<uint32_t> secSizeParsed = ph.parseLiteral<uint32_t>();
1531  if (failed(secSizeParsed))
1532  return failure();
1533 
1534  uint32_t secSize = *secSizeParsed;
1535  FailureOr<StringRef> sectionContent = ph.consumeNBytes(secSize);
1536  if (failed(sectionContent))
1537  return failure();
1538 
1539  LogicalResult registration =
1540  registerSection(*secType, *sectionContent, sectionLoc);
1541 
1542  if (failed(registration))
1543  return failure();
1544  }
1545  return success();
1546  }
1547  };
1548 
1549  auto getLocation(int offset = 0) const {
1550  return FileLineColLoc::get(srcName, 0, offset);
1551  }
1552 
1553  template <WasmSectionType>
1554  LogicalResult parseSectionItem(ParserHead &, size_t);
1555 
1556  template <WasmSectionType section>
1557  LogicalResult parseSection() {
1558  auto secName = std::string{wasmSectionName<section>};
1559  auto sectionNameAttr =
1560  StringAttr::get(ctx, srcName.strref() + ":" + secName + "-SECTION");
1561  unsigned offset = 0;
1562  auto getLocation = [sectionNameAttr, &offset]() {
1563  return FileLineColLoc::get(sectionNameAttr, 0, offset);
1564  };
1565  auto secContent = registry.getContentForSection<section>();
1566  if (!secContent) {
1567  LDBG() << secName << " section is not present in file.";
1568  return success();
1569  }
1570 
1571  auto secSrc = secContent.value();
1572  ParserHead ph{secSrc, sectionNameAttr};
1573  FailureOr<uint32_t> nElemsParsed = ph.parseVectorSize();
1574  if (failed(nElemsParsed))
1575  return failure();
1576  uint32_t nElems = *nElemsParsed;
1577  LDBG() << "starting to parse " << nElems << " items for section "
1578  << secName;
1579  for (size_t i = 0; i < nElems; ++i) {
1580  if (failed(parseSectionItem<section>(ph, i)))
1581  return failure();
1582  }
1583 
1584  if (!ph.end())
1585  return emitError(getLocation(), "unparsed garbage at end of section ")
1586  << secName;
1587  return success();
1588  }
1589 
1590  /// Handles the registration of a function import
1591  LogicalResult visitImport(Location loc, StringRef moduleName,
1592  StringRef importName, TypeIdxRecord tid) {
1593  using llvm::Twine;
1594  if (tid.id >= symbols.moduleFuncTypes.size())
1595  return emitError(loc, "invalid type id: ")
1596  << tid.id << ". Only " << symbols.moduleFuncTypes.size()
1597  << " type registrations";
1598  FunctionType type = symbols.moduleFuncTypes[tid.id];
1599  std::string symbol = symbols.getNewFuncSymbolName();
1600  auto funcOp = FuncImportOp::create(builder, loc, symbol, moduleName,
1601  importName, type);
1602  symbols.funcSymbols.push_back({{FlatSymbolRefAttr::get(funcOp)}, type});
1603  return funcOp.verify();
1604  }
1605 
1606  /// Handles the registration of a memory import
1607  LogicalResult visitImport(Location loc, StringRef moduleName,
1608  StringRef importName, LimitType limitType) {
1609  std::string symbol = symbols.getNewMemorySymbolName();
1610  auto memOp = MemImportOp::create(builder, loc, symbol, moduleName,
1611  importName, limitType);
1612  symbols.memSymbols.push_back({FlatSymbolRefAttr::get(memOp)});
1613  return memOp.verify();
1614  }
1615 
1616  /// Handles the registration of a table import
1617  LogicalResult visitImport(Location loc, StringRef moduleName,
1618  StringRef importName, TableType tableType) {
1619  std::string symbol = symbols.getNewTableSymbolName();
1620  auto tableOp = TableImportOp::create(builder, loc, symbol, moduleName,
1621  importName, tableType);
1622  symbols.tableSymbols.push_back({FlatSymbolRefAttr::get(tableOp)});
1623  return tableOp.verify();
1624  }
1625 
1626  /// Handles the registration of a global variable import
1627  LogicalResult visitImport(Location loc, StringRef moduleName,
1628  StringRef importName, GlobalTypeRecord globalType) {
1629  std::string symbol = symbols.getNewGlobalSymbolName();
1630  auto giOp =
1631  GlobalImportOp::create(builder, loc, symbol, moduleName, importName,
1632  globalType.type, globalType.isMutable);
1633  symbols.globalSymbols.push_back(
1634  {{FlatSymbolRefAttr::get(giOp)}, giOp.getType()});
1635  return giOp.verify();
1636  }
1637 
1638  // Detect occurence of errors
1639  LogicalResult peekDiag(Diagnostic &diag) {
1640  if (diag.getSeverity() == DiagnosticSeverity::Error)
1641  isValid = false;
1642  return failure();
1643  }
1644 
1645 public:
1646  WasmBinaryParser(llvm::SourceMgr &sourceMgr, MLIRContext *ctx)
1647  : builder{ctx}, ctx{ctx} {
1649  [this](Diagnostic &diag) { return peekDiag(diag); });
1650  ctx->loadAllAvailableDialects();
1651  if (sourceMgr.getNumBuffers() != 1) {
1652  emitError(UnknownLoc::get(ctx), "one source file should be provided");
1653  return;
1654  }
1655  uint32_t sourceBufId = sourceMgr.getMainFileID();
1656  StringRef source = sourceMgr.getMemoryBuffer(sourceBufId)->getBuffer();
1657  srcName = StringAttr::get(
1658  ctx, sourceMgr.getMemoryBuffer(sourceBufId)->getBufferIdentifier());
1659 
1660  auto parser = ParserHead{source, srcName};
1661  auto const wasmHeader = StringRef{"\0asm", 4};
1662  FileLineColLoc magicLoc = parser.getLocation();
1663  FailureOr<StringRef> magic = parser.consumeNBytes(wasmHeader.size());
1664  if (failed(magic) || magic->compare(wasmHeader)) {
1665  emitError(magicLoc, "source file does not contain valid Wasm header");
1666  return;
1667  }
1668  auto const expectedVersionString = StringRef{"\1\0\0\0", 4};
1669  FileLineColLoc versionLoc = parser.getLocation();
1670  FailureOr<StringRef> version =
1671  parser.consumeNBytes(expectedVersionString.size());
1672  if (failed(version))
1673  return;
1674  if (version->compare(expectedVersionString)) {
1675  emitError(versionLoc,
1676  "unsupported Wasm version. only version 1 is supported");
1677  return;
1678  }
1679  LogicalResult fillRegistry = registry.populateFromBody(parser.copy());
1680  if (failed(fillRegistry))
1681  return;
1682 
1683  mOp = ModuleOp::create(builder, getLocation());
1684  builder.setInsertionPointToStart(&mOp.getBodyRegion().front());
1685  LogicalResult parsingTypes = parseSection<WasmSectionType::TYPE>();
1686  if (failed(parsingTypes))
1687  return;
1688 
1689  LogicalResult parsingImports = parseSection<WasmSectionType::IMPORT>();
1690  if (failed(parsingImports))
1691  return;
1692 
1693  firstInternalFuncID = symbols.funcSymbols.size();
1694 
1695  LogicalResult parsingFunctions = parseSection<WasmSectionType::FUNCTION>();
1696  if (failed(parsingFunctions))
1697  return;
1698 
1699  LogicalResult parsingTables = parseSection<WasmSectionType::TABLE>();
1700  if (failed(parsingTables))
1701  return;
1702 
1703  LogicalResult parsingMems = parseSection<WasmSectionType::MEMORY>();
1704  if (failed(parsingMems))
1705  return;
1706 
1707  LogicalResult parsingGlobals = parseSection<WasmSectionType::GLOBAL>();
1708  if (failed(parsingGlobals))
1709  return;
1710 
1711  LogicalResult parsingCode = parseSection<WasmSectionType::CODE>();
1712  if (failed(parsingCode))
1713  return;
1714 
1715  LogicalResult parsingExports = parseSection<WasmSectionType::EXPORT>();
1716  if (failed(parsingExports))
1717  return;
1718 
1719  // Copy over sizes of containers into statistics.
1720  LDBG() << "WASM Imports:"
1721  << "\n"
1722  << " - Num functions: " << symbols.funcSymbols.size() << "\n"
1723  << " - Num globals: " << symbols.globalSymbols.size() << "\n"
1724  << " - Num memories: " << symbols.memSymbols.size() << "\n"
1725  << " - Num tables: " << symbols.tableSymbols.size();
1726  }
1727 
1728  ModuleOp getModule() {
1729  if (isValid)
1730  return mOp;
1731  if (mOp)
1732  mOp.erase();
1733  return ModuleOp{};
1734  }
1735 
1736 private:
1737  mlir::StringAttr srcName;
1738  OpBuilder builder;
1739  WasmModuleSymbolTables symbols;
1740  MLIRContext *ctx;
1741  ModuleOp mOp;
1742  SectionRegistry registry;
1743  size_t firstInternalFuncID{0};
1744  bool isValid{true};
1745 };
1746 
1747 template <>
1748 LogicalResult
1749 WasmBinaryParser::parseSectionItem<WasmSectionType::IMPORT>(ParserHead &ph,
1750  size_t) {
1751  FileLineColLoc importLoc = ph.getLocation();
1752  auto moduleName = ph.parseName();
1753  if (failed(moduleName))
1754  return failure();
1755 
1756  auto importName = ph.parseName();
1757  if (failed(importName))
1758  return failure();
1759 
1760  FailureOr<ImportDesc> import = ph.parseImportDesc(ctx);
1761  if (failed(import))
1762  return failure();
1763 
1764  return std::visit(
1765  [this, importLoc, &moduleName, &importName](auto import) {
1766  return visitImport(importLoc, *moduleName, *importName, import);
1767  },
1768  *import);
1769 }
1770 
1771 template <>
1772 LogicalResult
1773 WasmBinaryParser::parseSectionItem<WasmSectionType::EXPORT>(ParserHead &ph,
1774  size_t) {
1775  FileLineColLoc exportLoc = ph.getLocation();
1776 
1777  auto exportName = ph.parseName();
1778  if (failed(exportName))
1779  return failure();
1780 
1781  FailureOr<std::byte> opcode = ph.consumeByte();
1782  if (failed(opcode))
1783  return failure();
1784 
1785  FailureOr<uint32_t> idx = ph.parseLiteral<uint32_t>();
1786  if (failed(idx))
1787  return failure();
1788 
1789  using SymbolRefDesc = std::variant<SmallVector<SymbolRefContainer>,
1792 
1793  SymbolRefDesc currentSymbolList;
1794  std::string symbolType = "";
1795  switch (*opcode) {
1797  symbolType = "function";
1798  currentSymbolList = symbols.funcSymbols;
1799  break;
1801  symbolType = "table";
1802  currentSymbolList = symbols.tableSymbols;
1803  break;
1805  symbolType = "memory";
1806  currentSymbolList = symbols.memSymbols;
1807  break;
1809  symbolType = "global";
1810  currentSymbolList = symbols.globalSymbols;
1811  break;
1812  default:
1813  return emitError(exportLoc, "invalid value for export type: ")
1814  << std::to_integer<unsigned>(*opcode);
1815  }
1816 
1817  auto currentSymbol = std::visit(
1818  [&](const auto &list) -> FailureOr<FlatSymbolRefAttr> {
1819  if (*idx > list.size()) {
1820  emitError(
1821  exportLoc,
1822  llvm::formatv(
1823  "trying to export {0} {1} which is undefined in this scope",
1824  symbolType, *idx));
1825  return failure();
1826  }
1827  return list[*idx].symbol;
1828  },
1829  currentSymbolList);
1830 
1831  if (failed(currentSymbol))
1832  return failure();
1833 
1834  Operation *op = SymbolTable::lookupSymbolIn(mOp, *currentSymbol);
1835  op->setAttr("exported", UnitAttr::get(op->getContext()));
1836  StringAttr symName = SymbolTable::getSymbolName(op);
1837  return SymbolTable{mOp}.rename(symName, *exportName);
1838 }
1839 
1840 template <>
1841 LogicalResult
1842 WasmBinaryParser::parseSectionItem<WasmSectionType::TABLE>(ParserHead &ph,
1843  size_t) {
1844  FileLineColLoc opLocation = ph.getLocation();
1845  FailureOr<TableType> tableType = ph.parseTableType(ctx);
1846  if (failed(tableType))
1847  return failure();
1848  LDBG() << " Parsed table description: " << *tableType;
1849  StringAttr symbol = builder.getStringAttr(symbols.getNewTableSymbolName());
1850  auto tableOp =
1851  TableOp::create(builder, opLocation, symbol.strref(), *tableType);
1852  symbols.tableSymbols.push_back({SymbolRefAttr::get(tableOp)});
1853  return success();
1854 }
1855 
1856 template <>
1857 LogicalResult
1858 WasmBinaryParser::parseSectionItem<WasmSectionType::FUNCTION>(ParserHead &ph,
1859  size_t) {
1860  FileLineColLoc opLoc = ph.getLocation();
1861  auto typeIdxParsed = ph.parseLiteral<uint32_t>();
1862  if (failed(typeIdxParsed))
1863  return failure();
1864  uint32_t typeIdx = *typeIdxParsed;
1865  if (typeIdx >= symbols.moduleFuncTypes.size())
1866  return emitError(getLocation(), "invalid type index: ") << typeIdx;
1867  std::string symbol = symbols.getNewFuncSymbolName();
1868  auto funcOp =
1869  FuncOp::create(builder, opLoc, symbol, symbols.moduleFuncTypes[typeIdx]);
1870  Block *block = funcOp.addEntryBlock();
1871  OpBuilder::InsertionGuard guard{builder};
1872  builder.setInsertionPointToEnd(block);
1873  ReturnOp::create(builder, opLoc);
1874  symbols.funcSymbols.push_back(
1875  {{FlatSymbolRefAttr::get(funcOp.getSymNameAttr())},
1876  symbols.moduleFuncTypes[typeIdx]});
1877  return funcOp.verify();
1878 }
1879 
1880 template <>
1881 LogicalResult
1882 WasmBinaryParser::parseSectionItem<WasmSectionType::TYPE>(ParserHead &ph,
1883  size_t) {
1884  FailureOr<FunctionType> funcType = ph.parseFunctionType(ctx);
1885  if (failed(funcType))
1886  return failure();
1887  LDBG() << "Parsed function type " << *funcType;
1888  symbols.moduleFuncTypes.push_back(*funcType);
1889  return success();
1890 }
1891 
1892 template <>
1893 LogicalResult
1894 WasmBinaryParser::parseSectionItem<WasmSectionType::MEMORY>(ParserHead &ph,
1895  size_t) {
1896  FileLineColLoc opLocation = ph.getLocation();
1897  FailureOr<LimitType> memory = ph.parseLimit(ctx);
1898  if (failed(memory))
1899  return failure();
1900 
1901  LDBG() << " Registering memory " << *memory;
1902  std::string symbol = symbols.getNewMemorySymbolName();
1903  auto memOp = MemOp::create(builder, opLocation, symbol, *memory);
1904  symbols.memSymbols.push_back({SymbolRefAttr::get(memOp)});
1905  return success();
1906 }
1907 
1908 template <>
1909 LogicalResult
1910 WasmBinaryParser::parseSectionItem<WasmSectionType::GLOBAL>(ParserHead &ph,
1911  size_t) {
1912  FileLineColLoc globalLocation = ph.getLocation();
1913  auto globalTypeParsed = ph.parseGlobalType(ctx);
1914  if (failed(globalTypeParsed))
1915  return failure();
1916 
1917  GlobalTypeRecord globalType = *globalTypeParsed;
1918  auto symbol = builder.getStringAttr(symbols.getNewGlobalSymbolName());
1919  auto globalOp = wasmssa::GlobalOp::create(
1920  builder, globalLocation, symbol, globalType.type, globalType.isMutable);
1921  symbols.globalSymbols.push_back(
1922  {{FlatSymbolRefAttr::get(globalOp)}, globalOp.getType()});
1923  OpBuilder::InsertionGuard guard{builder};
1924  Block *block = builder.createBlock(&globalOp.getInitializer());
1925  builder.setInsertionPointToStart(block);
1926  parsed_inst_t expr = ph.parseExpression(builder, symbols);
1927  if (failed(expr))
1928  return failure();
1929  if (block->empty())
1930  return emitError(globalLocation, "global with empty initializer");
1931  if (expr->size() != 1 && (*expr)[0].getType() != globalType.type)
1932  return emitError(
1933  globalLocation,
1934  "initializer result type does not match global declaration type");
1935  ReturnOp::create(builder, globalLocation, *expr);
1936  return success();
1937 }
1938 
1939 template <>
1940 LogicalResult WasmBinaryParser::parseSectionItem<WasmSectionType::CODE>(
1941  ParserHead &ph, size_t innerFunctionId) {
1942  unsigned long funcId = innerFunctionId + firstInternalFuncID;
1943  FunctionSymbolRefContainer symRef = symbols.funcSymbols[funcId];
1944  auto funcOp =
1945  dyn_cast<FuncOp>(SymbolTable::lookupSymbolIn(mOp, symRef.symbol));
1946  assert(funcOp);
1947  if (failed(ph.parseCodeFor(funcOp, symbols)))
1948  return failure();
1949  return success();
1950 }
1951 } // namespace
1952 
1953 namespace mlir::wasm {
1955  MLIRContext *context) {
1956  WasmBinaryParser wBN{source, context};
1957  ModuleOp mOp = wBN.getModule();
1958  if (mOp)
1959  return {mOp};
1960 
1961  return {nullptr};
1962 }
1963 } // namespace mlir::wasm
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
Definition: PDL.cpp:62
static Type getElementType(Type type)
Determine the element type of type.
static std::string diag(const llvm::Value &value)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
#define EXPORT
#define BUILD_NUMERIC_BINOP_INTFP(OP_NAME, PREFIX)
#define BUILD_CONVERT_OP_FOR(DEST_T, WIDTH)
#define APPLY_WASM_SEC_TRANSFORM
#define BUILD_REINTERPRET_PARSER(WIDTH, FP_TYPE)
#define BUILD_CONVERSION_OP(IN_T, OUT_T, SOURCE_OP, TARGET_OP)
#define BUILD_SLICE_EXTEND_PARSER(IT_WIDTH, EXTRACT_WIDTH)
#define BUILD_NUMERIC_UNARY_OP_INT(OP_NAME, PREFIX)
#define BUILD_NUMERIC_BINOP_INT(OP_NAME, PREFIX)
#define BUILD_NUMERIC_BINOP_FP(OP_NAME, PREFIX)
#define BUILD_NUMERIC_UNARY_OP_FP(OP_NAME, PREFIX)
#define mul(a, b)
#define add(a, b)
#define div(a, b)
Block represents an ordered list of Operations.
Definition: Block.h:33
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:149
bool empty()
Definition: Block.h:148
Operation & back()
Definition: Block.h:152
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:27
OpListType & getOperations()
Definition: Block.h:137
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
FloatType getF32Type()
Definition: Builders.cpp:43
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:76
IntegerType getI64Type()
Definition: Builders.cpp:65
IntegerType getI32Type()
Definition: Builders.cpp:63
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:262
MLIRContext * getContext() const
Definition: Builders.h:56
IntegerAttr getUI32IntegerAttr(uint32_t value)
Definition: Builders.cpp:212
FloatType getF64Type()
Definition: Builders.cpp:45
HandlerID registerHandler(HandlerTy handler)
Register a new handler for diagnostics to the engine.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:156
An instance of this location represents a tuple of file, line number, and column number.
Definition: Location.h:174
static FileLineColLoc get(StringAttr filename, unsigned line, unsigned column)
Definition: Location.cpp:157
A symbol reference with a reference path containing a single element.
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
DiagnosticEngine & getDiagEngine()
Returns the diagnostic engine for this context.
void loadAllAvailableDialects()
Load all dialects available in the registry in this context.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
This class helps build Operations.
Definition: Builders.h:207
InsertPoint saveInsertionPoint() const
Return a saved insertion point.
Definition: Builders.h:385
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:430
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:431
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:436
void restoreInsertionPoint(InsertPoint ip)
Restore the insert point to a previously saved point.
Definition: Builders.h:390
Block * getBlock() const
Returns the current block of the builder.
Definition: Builders.h:448
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:582
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator end()
Definition: Region.h:56
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
LogicalResult rename(StringAttr from, StringAttr to)
Renames the given op or the op refered to by the given name to the given new name and updates the sym...
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
static StringAttr getSymbolName(Operation *symbol)
Returns the name of the given symbol operation, aborting if no symbol is present.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
DynamicAPInt floor(const Fraction &f)
Definition: Fraction.h:77
DynamicAPInt ceil(const Fraction &f)
Definition: Fraction.h:79
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:561
OwningOpRef< ModuleOp > importWebAssemblyToModule(llvm::SourceMgr &source, MLIRContext *context)
If source contains a valid Wasm binary file, this function returns a a ModuleOp containing the repres...
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:498
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
InFlightDiagnostic emitRemark(Location loc)
Utility method to emit a remark message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
static constexpr std::byte memory
static constexpr std::byte table
static constexpr std::byte global
static constexpr std::byte function
Byte encodings describing the mutability of globals.
static constexpr std::byte memType
static constexpr std::byte typeID
static constexpr std::byte tableType
static constexpr std::byte globalType
Byte encodings for Wasm limits.
static constexpr std::byte globalGet
static constexpr std::byte elseOpCode
static constexpr std::byte promoteF32ToF64
static constexpr std::byte demoteF64ToF32
static constexpr std::byte constI64
static constexpr std::byte constFP64
static constexpr std::byte localTee
static constexpr std::byte ifOpCode
static constexpr std::byte localGet
static constexpr std::byte branchIf
static constexpr std::byte localSet
static constexpr std::byte constI32
static constexpr std::byte constFP32
static constexpr std::byte externRef
static constexpr std::byte i32
static constexpr std::byte funcType
static constexpr std::byte i64
static constexpr std::byte emptyBlockType
static constexpr std::byte funcRef
static constexpr std::byte v128
static constexpr std::byte f64
static constexpr std::byte f32
static constexpr std::byte endByte
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.