MLIR 22.0.0git
TranslateFromWasm.cpp
Go to the documentation of this file.
1//===- TranslateFromWasm.cpp - Translating to WasmSSA dialect -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the WebAssembly importer.
10//
11//===----------------------------------------------------------------------===//
12
14#include "mlir/IR/Attributes.h"
15#include "mlir/IR/Builders.h"
19#include "mlir/IR/Location.h"
20#include "mlir/Support/LLVM.h"
23#include "llvm/Support/Debug.h"
24#include "llvm/Support/DebugLog.h"
25#include "llvm/Support/Endian.h"
26#include "llvm/Support/FormatVariadic.h"
27#include "llvm/Support/LEB128.h"
28#include "llvm/Support/LogicalResult.h"
29
30#include <cstddef>
31#include <cstdint>
32#include <variant>
33
34#define DEBUG_TYPE "wasm-translate"
35
36static_assert(CHAR_BIT == 8,
37 "This code expects std::byte to be exactly 8 bits");
38
39using namespace mlir;
40using namespace mlir::wasm;
41using namespace mlir::wasmssa;
42
43namespace {
44using section_id_t = uint8_t;
45enum struct WasmSectionType : section_id_t {
46 CUSTOM = 0,
47 TYPE = 1,
48 IMPORT = 2,
49 FUNCTION = 3,
50 TABLE = 4,
51 MEMORY = 5,
52 GLOBAL = 6,
53 EXPORT = 7,
54 START = 8,
55 ELEMENT = 9,
56 CODE = 10,
57 DATA = 11,
58 DATACOUNT = 12
59};
60
61constexpr section_id_t highestWasmSectionID{
62 static_cast<section_id_t>(WasmSectionType::DATACOUNT)};
63
64#define APPLY_WASM_SEC_TRANSFORM \
65 WASM_SEC_TRANSFORM(CUSTOM) \
66 WASM_SEC_TRANSFORM(TYPE) \
67 WASM_SEC_TRANSFORM(IMPORT) \
68 WASM_SEC_TRANSFORM(FUNCTION) \
69 WASM_SEC_TRANSFORM(TABLE) \
70 WASM_SEC_TRANSFORM(MEMORY) \
71 WASM_SEC_TRANSFORM(GLOBAL) \
72 WASM_SEC_TRANSFORM(EXPORT) \
73 WASM_SEC_TRANSFORM(START) \
74 WASM_SEC_TRANSFORM(ELEMENT) \
75 WASM_SEC_TRANSFORM(CODE) \
76 WASM_SEC_TRANSFORM(DATA) \
77 WASM_SEC_TRANSFORM(DATACOUNT)
78
79template <WasmSectionType>
80constexpr const char *wasmSectionName = "";
81
82#define WASM_SEC_TRANSFORM(section) \
83 template <> \
84 [[maybe_unused]] constexpr const char \
85 *wasmSectionName<WasmSectionType::section> = #section;
87#undef WASM_SEC_TRANSFORM
88
89constexpr bool sectionShouldBeUnique(WasmSectionType secType) {
90 return secType != WasmSectionType::CUSTOM;
91}
92
93template <std::byte... Bytes>
94struct ByteSequence {};
95
96/// Template class for representing a byte sequence of only one byte
97template <std::byte Byte>
98struct UniqueByte : ByteSequence<Byte> {};
99
100[[maybe_unused]] constexpr ByteSequence<
103 WasmBinaryEncoding::Type::v128> valueTypesEncodings{};
104
105template <std::byte... allowedFlags>
106constexpr bool isValueOneOf(std::byte value,
107 ByteSequence<allowedFlags...> = {}) {
108 return ((value == allowedFlags) | ... | false);
109}
110
111template <std::byte... flags>
112constexpr bool isNotIn(std::byte value, ByteSequence<flags...> = {}) {
113 return !isValueOneOf<flags...>(value);
114}
115
116struct GlobalTypeRecord {
117 Type type;
118 bool isMutable;
119};
120
121struct TypeIdxRecord {
122 size_t id;
123};
124
125struct SymbolRefContainer {
126 FlatSymbolRefAttr symbol;
127};
128
129struct GlobalSymbolRefContainer : SymbolRefContainer {
130 Type globalType;
131};
132
133struct FunctionSymbolRefContainer : SymbolRefContainer {
134 FunctionType functionType;
135};
136
137using ImportDesc =
138 std::variant<TypeIdxRecord, TableType, LimitType, GlobalTypeRecord>;
139
140using parsed_inst_t = FailureOr<SmallVector<Value>>;
141
142struct EmptyBlockMarker {};
143using BlockTypeParseResult =
144 std::variant<EmptyBlockMarker, TypeIdxRecord, Type>;
145
146struct WasmModuleSymbolTables {
147 SmallVector<FunctionSymbolRefContainer> funcSymbols;
148 SmallVector<GlobalSymbolRefContainer> globalSymbols;
149 SmallVector<SymbolRefContainer> memSymbols;
150 SmallVector<SymbolRefContainer> tableSymbols;
151 SmallVector<FunctionType> moduleFuncTypes;
152
153 std::string getNewSymbolName(StringRef prefix, size_t id) const {
154 return (prefix + Twine{id}).str();
155 }
156
157 std::string getNewFuncSymbolName() const {
158 size_t id = funcSymbols.size();
159 return getNewSymbolName("func_", id);
160 }
161
162 std::string getNewGlobalSymbolName() const {
163 size_t id = globalSymbols.size();
164 return getNewSymbolName("global_", id);
165 }
166
167 std::string getNewMemorySymbolName() const {
168 size_t id = memSymbols.size();
169 return getNewSymbolName("mem_", id);
170 }
171
172 std::string getNewTableSymbolName() const {
173 size_t id = tableSymbols.size();
174 return getNewSymbolName("table_", id);
175 }
176};
177
178class ParserHead;
179
180/// Wrapper around SmallVector to only allow access as push and pop on the
181/// stack. Makes sure that there are no "free accesses" on the stack to preserve
182/// its state.
183/// This class also keep tracks of the Wasm labels defined by different ops,
184/// which can be targeted by control flow ops. This can be modeled as part of
185/// the Value Stack as Wasm control flow ops can only target enclosing labels.
186class ValueStack {
187private:
188 struct LabelLevel {
189 size_t stackIdx;
190 LabelLevelOpInterface levelOp;
191 };
192
193public:
194 bool empty() const { return values.empty(); }
195
196 size_t size() const { return values.size(); }
197
198 /// Pops values from the stack because they are being used in an operation.
199 /// @param operandTypes The list of expected types of the operation, used
200 /// to know how many values to pop and check if the types match the
201 /// expectation.
202 /// @param opLoc Location of the caller, used to report accurately the
203 /// location
204 /// if an error occurs.
205 /// @return Failure or the vector of popped values.
206 FailureOr<SmallVector<Value>> popOperands(TypeRange operandTypes,
207 Location *opLoc);
208
209 /// Push the results of an operation to the stack so they can be used in a
210 /// following operation.
211 /// @param results The list of results of the operation
212 /// @param opLoc Location of the caller, used to report accurately the
213 /// location
214 /// if an error occurs.
215 LogicalResult pushResults(ValueRange results, Location *opLoc);
216
217 void addLabelLevel(LabelLevelOpInterface levelOp) {
218 labelLevel.push_back({values.size(), levelOp});
219 LDBG() << "Adding a new frame context to ValueStack";
220 }
221
222 void dropLabelLevel() {
223 assert(!labelLevel.empty() && "Trying to drop a frame from empty context");
224 auto newSize = labelLevel.pop_back_val().stackIdx;
225 values.truncate(newSize);
226 }
227#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
228 /// A simple dump function for debugging.
229 /// Writes output to llvm::dbgs().
230 LLVM_DUMP_METHOD void dump() const;
231#endif
232
233private:
234 SmallVector<Value> values;
235 SmallVector<LabelLevel> labelLevel;
236};
237
238using local_val_t = TypedValue<wasmssa::LocalRefType>;
239
240class ExpressionParser {
241public:
242 using locals_t = SmallVector<local_val_t>;
243 ExpressionParser(ParserHead &parser, WasmModuleSymbolTables const &symbols,
244 ArrayRef<local_val_t> initLocal)
245 : parser{parser}, symbols{symbols}, locals{initLocal} {}
246
247private:
248 template <std::byte opCode>
249 inline parsed_inst_t parseSpecificInstruction(OpBuilder &builder);
250
251 template <typename valueT>
252 parsed_inst_t
253 parseConstInst(OpBuilder &builder,
254 std::enable_if_t<std::is_arithmetic_v<valueT>> * = nullptr);
255
256 /// Construct an operation with \p numOperands operands and a single result.
257 /// Each operand must have the same type. Suitable for e.g. binops, unary
258 /// ops, etc.
259 ///
260 /// \p opcode - The WASM opcode to build.
261 /// \p valueType - The operand and result type for the built instruction.
262 /// \p numOperands - The number of operands for the built operation.
263 ///
264 /// \returns The parsed instruction result, or failure.
265 template <typename opcode, typename valueType, unsigned int numOperands>
266 inline parsed_inst_t
267 buildNumericOp(OpBuilder &builder,
268 std::enable_if_t<std::is_arithmetic_v<valueType>> * = nullptr);
269
270 /// Construct a conversion operation of type \p opType that takes a value from
271 /// type \p inputType on the stack and will produce a value of type
272 /// \p outputType.
273 ///
274 /// \p opType - The WASM dialect operation to build.
275 /// \p inputType - The operand type for the built instruction.
276 /// \p outputType - The result type for the built instruction.
277 ///
278 /// \returns The parsed instruction result, or failure.
279 template <typename opType, typename inputType, typename outputType,
280 typename... extraArgsT>
281 inline parsed_inst_t buildConvertOp(OpBuilder &builder, extraArgsT...);
282
283 /// This function generates a dispatch tree to associate an opcode with a
284 /// parser. Parsers are registered by specialising the
285 /// `parseSpecificInstruction` function for the op code to handle.
286 ///
287 /// The dispatcher is generated by recursively creating all possible patterns
288 /// for an opcode and calling the relevant parser on the leaf.
289 ///
290 /// @tparam patternBitSize is the first bit for which the pattern is not fixed
291 ///
292 /// @tparam highBitPattern is the fixed pattern that this instance handles for
293 /// the 8-patternBitSize bits
294 template <size_t patternBitSize = 0, std::byte highBitPattern = std::byte{0}>
295 inline parsed_inst_t dispatchToInstParser(std::byte opCode,
296 OpBuilder &builder) {
297 static_assert(patternBitSize <= 8,
298 "PatternBitSize is outside of range of opcode space! "
299 "(expected at most 8 bits)");
300 if constexpr (patternBitSize < 8) {
301 constexpr std::byte bitSelect{1 << (7 - patternBitSize)};
302 constexpr std::byte nextHighBitPatternStem = highBitPattern << 1;
303 constexpr size_t nextPatternBitSize = patternBitSize + 1;
304 if ((opCode & bitSelect) != std::byte{0})
305 return dispatchToInstParser<nextPatternBitSize,
306 nextHighBitPatternStem | std::byte{1}>(
307 opCode, builder);
308 return dispatchToInstParser<nextPatternBitSize, nextHighBitPatternStem>(
309 opCode, builder);
310 } else {
311 return parseSpecificInstruction<highBitPattern>(builder);
312 }
313 }
314
315 ///
316 /// RAII guard class for creating a nesting level
317 ///
318 struct NestingContextGuard {
319 NestingContextGuard(ExpressionParser &parser, LabelLevelOpInterface levelOp)
320 : parser{parser} {
321 parser.addNestingContextLevel(levelOp);
322 }
323 NestingContextGuard(NestingContextGuard &&other) : parser{other.parser} {
324 other.shouldDropOnDestruct = false;
325 }
326 NestingContextGuard(NestingContextGuard const &) = delete;
327 ~NestingContextGuard() {
328 if (shouldDropOnDestruct)
329 parser.dropNestingContextLevel();
330 }
331 ExpressionParser &parser;
332 bool shouldDropOnDestruct = true;
333 };
334
335 void addNestingContextLevel(LabelLevelOpInterface levelOp) {
336 valueStack.addLabelLevel(levelOp);
337 }
338
339 void dropNestingContextLevel() {
340 // Should always succeed as we are droping the frame that was previously
341 // created.
342 valueStack.dropLabelLevel();
343 }
344
345 llvm::FailureOr<FunctionType> getFuncTypeFor(OpBuilder &builder,
346 EmptyBlockMarker) {
347 return builder.getFunctionType({}, {});
348 }
349
350 llvm::FailureOr<FunctionType> getFuncTypeFor(OpBuilder &builder,
351 TypeIdxRecord type) {
352 if (type.id >= symbols.moduleFuncTypes.size())
353 return emitError(*currentOpLoc,
354 "type index references nonexistent type (")
355 << type.id << "). Only " << symbols.moduleFuncTypes.size()
356 << " types are registered";
357 return symbols.moduleFuncTypes[type.id];
358 }
359
360 llvm::FailureOr<FunctionType> getFuncTypeFor(OpBuilder &builder,
361 Type valType) {
362 return builder.getFunctionType({}, {valType});
363 }
364
365 llvm::FailureOr<FunctionType>
366 getFuncTypeFor(OpBuilder &builder, BlockTypeParseResult parseResult) {
367 return std::visit(
368 [this, &builder](auto value) { return getFuncTypeFor(builder, value); },
369 parseResult);
370 }
371
372 llvm::FailureOr<FunctionType>
373 getFuncTypeFor(OpBuilder &builder,
374 llvm::FailureOr<BlockTypeParseResult> parseResult) {
375 if (llvm::failed(parseResult))
376 return failure();
377 return getFuncTypeFor(builder, *parseResult);
378 }
379
380 llvm::FailureOr<FunctionType> parseBlockFuncType(OpBuilder &builder);
381
382 struct ParseResultWithInfo {
383 SmallVector<Value> opResults;
384 std::byte endingByte;
385 };
386
387 template <typename FilterT = ByteSequence<WasmBinaryEncoding::endByte>>
388 /// @param blockToFill: the block which content will be populated
389 /// @param resType: the type that this block is supposed to return
390 llvm::FailureOr<std::byte>
391 parseBlockContent(OpBuilder &builder, Block *blockToFill, TypeRange resTypes,
392 Location opLoc, LabelLevelOpInterface levelOp,
393 FilterT parseEndBytes = {}) {
394 OpBuilder::InsertionGuard guard{builder};
395 builder.setInsertionPointToStart(blockToFill);
396 LDBG() << "parsing a block of type "
397 << builder.getFunctionType(blockToFill->getArgumentTypes(),
398 resTypes);
399 auto nC = addNesting(levelOp);
400
401 if (failed(pushResults(blockToFill->getArguments())))
402 return failure();
403 auto bodyParsingRes = parse(builder, parseEndBytes);
404 if (failed(bodyParsingRes))
405 return failure();
406 auto returnOperands = popOperands(resTypes);
407 if (failed(returnOperands))
408 return failure();
409 BlockReturnOp::create(builder, opLoc, *returnOperands);
410 LDBG() << "end of parsing of a block";
411 return bodyParsingRes->endingByte;
412 }
413
414public:
415 template <std::byte ParseEndByte = WasmBinaryEncoding::endByte>
416 parsed_inst_t parse(OpBuilder &builder, UniqueByte<ParseEndByte> = {});
417
418 template <std::byte... ExpressionParseEnd>
419 FailureOr<ParseResultWithInfo>
420 parse(OpBuilder &builder,
421 ByteSequence<ExpressionParseEnd...> parsingEndFilters);
422
423 NestingContextGuard addNesting(LabelLevelOpInterface levelOp) {
424 return NestingContextGuard{*this, levelOp};
425 }
426
427 FailureOr<llvm::SmallVector<Value>> popOperands(TypeRange operandTypes) {
428 return valueStack.popOperands(operandTypes, &currentOpLoc.value());
429 }
430
431 LogicalResult pushResults(ValueRange results) {
432 return valueStack.pushResults(results, &currentOpLoc.value());
433 }
434
435 /// The local.set and local.tee operations behave similarly and only differ
436 /// on their return value. This function factorizes the behavior of the two
437 /// operations in one place.
438 template <typename OpToCreate>
439 parsed_inst_t parseSetOrTee(OpBuilder &);
440
441 /// Blocks and Loops have a similar format and differ only in how their exit
442 /// is handled which doesn´t matter at parsing time. Factorizes in one
443 /// function.
444 template <typename OpToCreate>
445 parsed_inst_t parseBlockLikeOp(OpBuilder &);
446
447private:
448 std::optional<Location> currentOpLoc;
449 ParserHead &parser;
450 WasmModuleSymbolTables const &symbols;
451 locals_t locals;
452 ValueStack valueStack;
453};
454
455class ParserHead {
456public:
457 ParserHead(StringRef src, StringAttr name) : head{src}, locName{name} {}
458 ParserHead(ParserHead &&) = default;
459
460private:
461 ParserHead(ParserHead const &other) = default;
462
463public:
464 auto getLocation() const {
465 return FileLineColLoc::get(locName, 0, anchorOffset + offset);
466 }
467
468 FailureOr<StringRef> consumeNBytes(size_t nBytes) {
469 LDBG() << "Consume " << nBytes << " bytes";
470 LDBG() << " Bytes remaining: " << size();
471 LDBG() << " Current offset: " << offset;
472 if (nBytes > size())
473 return emitError(getLocation(), "trying to extract ")
474 << nBytes << "bytes when only " << size() << "are available";
475
476 StringRef res = head.slice(offset, offset + nBytes);
477 offset += nBytes;
478 LDBG() << " Updated offset (+" << nBytes << "): " << offset;
479 return res;
480 }
481
482 FailureOr<std::byte> consumeByte() {
483 FailureOr<StringRef> res = consumeNBytes(1);
484 if (failed(res))
485 return failure();
486 return std::byte{*res->bytes_begin()};
487 }
488
489 template <typename T>
490 FailureOr<T> parseLiteral();
491
492 FailureOr<uint32_t> parseVectorSize();
493
494private:
495 // TODO: This is equivalent to parseLiteral<uint32_t> and could be removed
496 // if parseLiteral specialization were moved here, but default GCC on Ubuntu
497 // 22.04 has bug with template specialization in class declaration
498 inline FailureOr<uint32_t> parseUI32();
499 inline FailureOr<int64_t> parseI64();
500
501public:
502 FailureOr<StringRef> parseName() {
503 FailureOr<uint32_t> size = parseVectorSize();
504 if (failed(size))
505 return failure();
506
507 return consumeNBytes(*size);
508 }
509
510 FailureOr<WasmSectionType> parseWasmSectionType() {
511 FailureOr<std::byte> id = consumeByte();
512 if (failed(id))
513 return failure();
514 if (std::to_integer<unsigned>(*id) > highestWasmSectionID)
515 return emitError(getLocation(), "invalid section ID: ")
516 << static_cast<int>(*id);
517 return static_cast<WasmSectionType>(*id);
518 }
519
520 FailureOr<LimitType> parseLimit(MLIRContext *ctx) {
521 using WasmLimits = WasmBinaryEncoding::LimitHeader;
522 FileLineColLoc limitLocation = getLocation();
523 FailureOr<std::byte> limitHeader = consumeByte();
524 if (failed(limitHeader))
525 return failure();
526
527 if (isNotIn<WasmLimits::bothLimits, WasmLimits::lowLimitOnly>(*limitHeader))
528 return emitError(limitLocation, "invalid limit header: ")
529 << static_cast<int>(*limitHeader);
530 FailureOr<uint32_t> minParse = parseUI32();
531 if (failed(minParse))
532 return failure();
533 std::optional<uint32_t> max{std::nullopt};
534 if (*limitHeader == WasmLimits::bothLimits) {
535 FailureOr<uint32_t> maxParse = parseUI32();
536 if (failed(maxParse))
537 return failure();
538 max = *maxParse;
539 }
540 return LimitType::get(ctx, *minParse, max);
541 }
542
543 FailureOr<Type> parseValueType(MLIRContext *ctx) {
544 FileLineColLoc typeLoc = getLocation();
545 FailureOr<std::byte> typeEncoding = consumeByte();
546 if (failed(typeEncoding))
547 return failure();
548 switch (*typeEncoding) {
550 return IntegerType::get(ctx, 32);
552 return IntegerType::get(ctx, 64);
554 return Float32Type::get(ctx);
556 return Float64Type::get(ctx);
558 return IntegerType::get(ctx, 128);
560 return wasmssa::FuncRefType::get(ctx);
562 return wasmssa::ExternRefType::get(ctx);
563 default:
564 return emitError(typeLoc, "invalid value type encoding: ")
565 << static_cast<int>(*typeEncoding);
566 }
567 }
568
569 FailureOr<GlobalTypeRecord> parseGlobalType(MLIRContext *ctx) {
570 using WasmGlobalMut = WasmBinaryEncoding::GlobalMutability;
571 FailureOr<Type> typeParsed = parseValueType(ctx);
572 if (failed(typeParsed))
573 return failure();
574 FileLineColLoc mutLoc = getLocation();
575 FailureOr<std::byte> mutSpec = consumeByte();
576 if (failed(mutSpec))
577 return failure();
578 if (isNotIn<WasmGlobalMut::isConst, WasmGlobalMut::isMutable>(*mutSpec))
579 return emitError(mutLoc, "invalid global mutability specifier: ")
580 << static_cast<int>(*mutSpec);
581 return GlobalTypeRecord{*typeParsed, *mutSpec == WasmGlobalMut::isMutable};
582 }
583
584 FailureOr<TupleType> parseResultType(MLIRContext *ctx) {
585 FailureOr<uint32_t> nParamsParsed = parseVectorSize();
586 if (failed(nParamsParsed))
587 return failure();
588 uint32_t nParams = *nParamsParsed;
589 SmallVector<Type> res{};
590 res.reserve(nParams);
591 for (size_t i = 0; i < nParams; ++i) {
592 FailureOr<Type> parsedType = parseValueType(ctx);
593 if (failed(parsedType))
594 return failure();
595 res.push_back(*parsedType);
596 }
597 return TupleType::get(ctx, res);
598 }
599
600 FailureOr<FunctionType> parseFunctionType(MLIRContext *ctx) {
601 FileLineColLoc typeLoc = getLocation();
602 FailureOr<std::byte> funcTypeHeader = consumeByte();
603 if (failed(funcTypeHeader))
604 return failure();
605 if (*funcTypeHeader != WasmBinaryEncoding::Type::funcType)
606 return emitError(typeLoc, "invalid function type header byte. Expecting ")
607 << std::to_integer<unsigned>(WasmBinaryEncoding::Type::funcType)
608 << " got " << std::to_integer<unsigned>(*funcTypeHeader);
609 FailureOr<TupleType> inputTypes = parseResultType(ctx);
610 if (failed(inputTypes))
611 return failure();
612
613 FailureOr<TupleType> resTypes = parseResultType(ctx);
614 if (failed(resTypes))
615 return failure();
616
617 return FunctionType::get(ctx, inputTypes->getTypes(), resTypes->getTypes());
618 }
619
620 FailureOr<TypeIdxRecord> parseTypeIndex() {
621 FailureOr<uint32_t> res = parseUI32();
622 if (failed(res))
623 return failure();
624 return TypeIdxRecord{*res};
625 }
626
627 FailureOr<TableType> parseTableType(MLIRContext *ctx) {
628 FailureOr<Type> elmTypeParse = parseValueType(ctx);
629 if (failed(elmTypeParse))
630 return failure();
631 if (!isWasmRefType(*elmTypeParse))
632 return emitError(getLocation(), "invalid element type for table");
633 FailureOr<LimitType> limitParse = parseLimit(ctx);
634 if (failed(limitParse))
635 return failure();
636 return TableType::get(ctx, *elmTypeParse, *limitParse);
637 }
638
639 FailureOr<ImportDesc> parseImportDesc(MLIRContext *ctx) {
640 FileLineColLoc importLoc = getLocation();
641 FailureOr<std::byte> importType = consumeByte();
642 auto packager = [](auto parseResult) -> FailureOr<ImportDesc> {
643 if (failed(parseResult))
644 return failure();
645 return {*parseResult};
646 };
647 if (failed(importType))
648 return failure();
649 switch (*importType) {
651 return packager(parseTypeIndex());
653 return packager(parseTableType(ctx));
655 return packager(parseLimit(ctx));
657 return packager(parseGlobalType(ctx));
658 default:
659 return emitError(importLoc, "invalid import type descriptor: ")
660 << static_cast<int>(*importType);
661 }
662 }
663
664 parsed_inst_t parseExpression(OpBuilder &builder,
665 WasmModuleSymbolTables const &symbols,
666 ArrayRef<local_val_t> locals = {}) {
667 auto eParser = ExpressionParser{*this, symbols, locals};
668 return eParser.parse(builder);
669 }
670
671 LogicalResult parseCodeFor(FuncOp func,
672 WasmModuleSymbolTables const &symbols) {
673 SmallVector<local_val_t> locals{};
674 // Populating locals with function argument
675 Block &block = func.getBody().front();
676 // Delete temporary return argument which was only created for IR validity
677 assert(func.getBody().getBlocks().size() == 1 &&
678 "Function should only have its default created block at this point");
679 assert(block.getOperations().size() == 1 &&
680 "Only the placeholder return op should be present at this point");
681 auto returnOp = cast<ReturnOp>(&block.back());
682 assert(returnOp);
683
684 FailureOr<uint32_t> codeSizeInBytes = parseUI32();
685 if (failed(codeSizeInBytes))
686 return failure();
687 FailureOr<StringRef> codeContent = consumeNBytes(*codeSizeInBytes);
688 if (failed(codeContent))
689 return failure();
690 auto name = StringAttr::get(func->getContext(),
691 locName.str() + "::" + func.getSymName());
692 auto cParser = ParserHead{*codeContent, name};
693 FailureOr<uint32_t> localVecSize = cParser.parseVectorSize();
694 if (failed(localVecSize))
695 return failure();
696 OpBuilder builder{&func.getBody().front().back()};
697 for (auto arg : block.getArguments())
698 locals.push_back(cast<TypedValue<LocalRefType>>(arg));
699 // Declare the local ops
700 uint32_t nVarVec = *localVecSize;
701 for (size_t i = 0; i < nVarVec; ++i) {
702 FileLineColLoc varLoc = cParser.getLocation();
703 FailureOr<uint32_t> nSubVar = cParser.parseUI32();
704 if (failed(nSubVar))
705 return failure();
706 FailureOr<Type> varT = cParser.parseValueType(func->getContext());
707 if (failed(varT))
708 return failure();
709 for (size_t j = 0; j < *nSubVar; ++j) {
710 auto local = LocalOp::create(builder, varLoc, *varT);
711 locals.push_back(local.getResult());
712 }
713 }
714 parsed_inst_t res = cParser.parseExpression(builder, symbols, locals);
715 if (failed(res))
716 return failure();
717 if (!cParser.end())
718 return emitError(cParser.getLocation(),
719 "unparsed garbage remaining at end of code block");
720 ReturnOp::create(builder, func->getLoc(), *res);
721 returnOp->erase();
722 return success();
723 }
724
725 llvm::FailureOr<BlockTypeParseResult> parseBlockType(MLIRContext *ctx) {
726 auto loc = getLocation();
727 auto blockIndicator = peek();
728 if (failed(blockIndicator))
729 return failure();
730 if (*blockIndicator == WasmBinaryEncoding::Type::emptyBlockType) {
731 offset += 1;
732 return {EmptyBlockMarker{}};
733 }
734 if (isValueOneOf(*blockIndicator, valueTypesEncodings))
735 return parseValueType(ctx);
736 /// Block type idx is a 32 bit positive integer encoded as a 33 bit signed
737 /// value
738 auto typeIdx = parseI64();
739 if (failed(typeIdx))
740 return failure();
741 if (*typeIdx < 0 || *typeIdx > std::numeric_limits<uint32_t>::max())
742 return emitError(loc, "type ID should be representable with an unsigned "
743 "32 bits integer. Got ")
744 << *typeIdx;
745 return {TypeIdxRecord{static_cast<uint32_t>(*typeIdx)}};
746 }
747
748 bool end() const { return curHead().empty(); }
749
750 ParserHead copy() const { return *this; }
751
752private:
753 StringRef curHead() const { return head.drop_front(offset); }
754
755 FailureOr<std::byte> peek() const {
756 if (end())
757 return emitError(
758 getLocation(),
759 "trying to peek at next byte, but input stream is empty");
760 return static_cast<std::byte>(curHead().front());
761 }
762
763 size_t size() const { return head.size() - offset; }
764
765 StringRef head;
766 StringAttr locName;
767 unsigned anchorOffset{0};
768 unsigned offset{0};
769};
770
771template <>
772FailureOr<float> ParserHead::parseLiteral<float>() {
773 FailureOr<StringRef> bytes = consumeNBytes(4);
774 if (failed(bytes))
775 return failure();
776 return llvm::support::endian::read<float>(bytes->bytes_begin(),
777 llvm::endianness::little);
778}
779
780template <>
781FailureOr<double> ParserHead::parseLiteral<double>() {
782 FailureOr<StringRef> bytes = consumeNBytes(8);
783 if (failed(bytes))
784 return failure();
785 return llvm::support::endian::read<double>(bytes->bytes_begin(),
786 llvm::endianness::little);
787}
788
789template <>
790FailureOr<uint32_t> ParserHead::parseLiteral<uint32_t>() {
791 char const *error = nullptr;
792 uint32_t res{0};
793 unsigned encodingSize{0};
794 StringRef src = curHead();
795 uint64_t decoded = llvm::decodeULEB128(src.bytes_begin(), &encodingSize,
796 src.bytes_end(), &error);
797 if (error)
798 return emitError(getLocation(), error);
799
800 if (std::isgreater(decoded, std::numeric_limits<uint32_t>::max()))
801 return emitError(getLocation()) << "literal does not fit on 32 bits";
802
803 res = static_cast<uint32_t>(decoded);
804 offset += encodingSize;
805 return res;
806}
807
808template <>
809FailureOr<int32_t> ParserHead::parseLiteral<int32_t>() {
810 char const *error = nullptr;
811 int32_t res{0};
812 unsigned encodingSize{0};
813 StringRef src = curHead();
814 int64_t decoded = llvm::decodeSLEB128(src.bytes_begin(), &encodingSize,
815 src.bytes_end(), &error);
816 if (error)
817 return emitError(getLocation(), error);
818 if (std::isgreater(decoded, std::numeric_limits<int32_t>::max()) ||
819 std::isgreater(std::numeric_limits<int32_t>::min(), decoded))
820 return emitError(getLocation()) << "literal does not fit on 32 bits";
821
822 res = static_cast<int32_t>(decoded);
823 offset += encodingSize;
824 return res;
825}
826
827template <>
828FailureOr<int64_t> ParserHead::parseLiteral<int64_t>() {
829 char const *error = nullptr;
830 unsigned encodingSize{0};
831 StringRef src = curHead();
832 int64_t res = llvm::decodeSLEB128(src.bytes_begin(), &encodingSize,
833 src.bytes_end(), &error);
834 if (error)
835 return emitError(getLocation(), error);
836
837 offset += encodingSize;
838 return res;
839}
840
841FailureOr<uint32_t> ParserHead::parseVectorSize() {
842 return parseLiteral<uint32_t>();
843}
844
845inline FailureOr<uint32_t> ParserHead::parseUI32() {
846 return parseLiteral<uint32_t>();
847}
848
849inline FailureOr<int64_t> ParserHead::parseI64() {
850 return parseLiteral<int64_t>();
851}
852
853template <std::byte opCode>
854inline parsed_inst_t ExpressionParser::parseSpecificInstruction(OpBuilder &) {
855 return emitError(*currentOpLoc, "unknown instruction opcode: ")
856 << static_cast<int>(opCode);
857}
858
859#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
860void ValueStack::dump() const {
861 llvm::dbgs() << "================= Wasm ValueStack =======================\n";
862 llvm::dbgs() << "size: " << size() << "\n";
863 llvm::dbgs() << "nbFrames: " << labelLevel.size() << '\n';
864 llvm::dbgs() << "<Top>"
865 << "\n";
866 // Stack is pushed to via push_back. Therefore the top of the stack is the
867 // end of the vector. Iterate in reverse so that the first thing we print
868 // is the top of the stack.
869 auto indexGetter = [this]() {
870 size_t idx = labelLevel.size();
871 return [this, idx]() mutable -> std::optional<std::pair<size_t, size_t>> {
872 llvm::dbgs() << "IDX: " << idx << '\n';
873 if (idx == 0)
874 return std::nullopt;
875 auto frameId = idx - 1;
876 auto frameLimit = labelLevel[frameId].stackIdx;
877 idx -= 1;
878 return {{frameId, frameLimit}};
879 };
880 };
881 auto getNextFrameIndex = indexGetter();
882 auto nextFrameIdx = getNextFrameIndex();
883 size_t stackSize = size();
884 for (size_t idx = 0; idx < stackSize; ++idx) {
885 size_t actualIdx = stackSize - 1 - idx;
886 while (nextFrameIdx && (nextFrameIdx->second > actualIdx)) {
887 llvm::dbgs() << " --------------- Frame (" << nextFrameIdx->first
888 << ")\n";
889 nextFrameIdx = getNextFrameIndex();
890 }
891 llvm::dbgs() << " ";
892 values[actualIdx].dump();
893 }
894 while (nextFrameIdx) {
895 llvm::dbgs() << " --------------- Frame (" << nextFrameIdx->first << ")\n";
896 nextFrameIdx = getNextFrameIndex();
897 }
898 llvm::dbgs() << "<Bottom>"
899 << "\n";
900 llvm::dbgs() << "=========================================================\n";
901}
902#endif
903
904parsed_inst_t ValueStack::popOperands(TypeRange operandTypes, Location *opLoc) {
905 LDBG() << "Popping from ValueStack\n"
906 << " Elements(s) to pop: " << operandTypes.size() << "\n"
907 << " Current stack size: " << values.size();
908 if (operandTypes.size() > values.size())
909 return emitError(*opLoc,
910 "stack doesn't contain enough values. trying to get ")
911 << operandTypes.size() << " operands on a stack containing only "
912 << values.size() << " values";
913 size_t stackIdxOffset = values.size() - operandTypes.size();
914 SmallVector<Value> res{};
915 res.reserve(operandTypes.size());
916 for (size_t i{0}; i < operandTypes.size(); ++i) {
917 Value operand = values[i + stackIdxOffset];
918 Type stackType = operand.getType();
919 if (stackType != operandTypes[i])
920 return emitError(*opLoc, "invalid operand type on stack. expecting ")
921 << operandTypes[i] << ", value on stack is of type " << stackType;
922 LDBG() << " POP: " << operand;
923 res.push_back(operand);
924 }
925 values.resize(values.size() - operandTypes.size());
926 LDBG() << " Updated stack size: " << values.size();
927 return res;
928}
929
930LogicalResult ValueStack::pushResults(ValueRange results, Location *opLoc) {
931 LDBG() << "Pushing to ValueStack\n"
932 << " Elements(s) to push: " << results.size() << "\n"
933 << " Current stack size: " << values.size();
934 for (Value val : results) {
935 if (!isWasmValueType(val.getType()))
936 return emitError(*opLoc, "invalid value type on stack: ")
937 << val.getType();
938 LDBG() << " PUSH: " << val;
939 values.push_back(val);
940 }
941
942 LDBG() << " Updated stack size: " << values.size();
943 return success();
944}
945
946template <std::byte EndParseByte>
947parsed_inst_t ExpressionParser::parse(OpBuilder &builder,
948 UniqueByte<EndParseByte> endByte) {
949 auto res = parse(builder, ByteSequence<EndParseByte>{});
950 if (failed(res))
951 return failure();
952 return res->opResults;
953}
954
955template <std::byte... ExpressionParseEnd>
956FailureOr<ExpressionParser::ParseResultWithInfo>
957ExpressionParser::parse(OpBuilder &builder,
958 ByteSequence<ExpressionParseEnd...> parsingEndFilters) {
959 SmallVector<Value> res;
960 for (;;) {
961 currentOpLoc = parser.getLocation();
962 FailureOr<std::byte> opCode = parser.consumeByte();
963 if (failed(opCode))
964 return failure();
965 if (isValueOneOf(*opCode, parsingEndFilters))
966 return {{res, *opCode}};
967 parsed_inst_t resParsed;
968 resParsed = dispatchToInstParser(*opCode, builder);
969 if (failed(resParsed))
970 return failure();
971 std::swap(res, *resParsed);
972 if (failed(pushResults(res)))
973 return failure();
974 }
975}
976
977llvm::FailureOr<FunctionType>
978ExpressionParser::parseBlockFuncType(OpBuilder &builder) {
979 return getFuncTypeFor(builder, parser.parseBlockType(builder.getContext()));
980}
981
982template <typename OpToCreate>
983parsed_inst_t ExpressionParser::parseBlockLikeOp(OpBuilder &builder) {
984 auto opLoc = currentOpLoc;
985 auto funcType = parseBlockFuncType(builder);
986 if (failed(funcType))
987 return failure();
988
989 auto inputTypes = funcType->getInputs();
990 auto inputOps = popOperands(inputTypes);
991 if (failed(inputOps))
992 return failure();
993
994 Block *curBlock = builder.getBlock();
995 Region *curRegion = curBlock->getParent();
996 auto resTypes = funcType->getResults();
997 llvm::SmallVector<Location> locations{};
998 locations.resize(resTypes.size(), *currentOpLoc);
999 auto *successor =
1000 builder.createBlock(curRegion, curRegion->end(), resTypes, locations);
1001 builder.setInsertionPointToEnd(curBlock);
1002 auto blockOp =
1003 OpToCreate::create(builder, *currentOpLoc, *inputOps, successor);
1004 auto *blockBody = blockOp.createBlock();
1005 if (failed(parseBlockContent(builder, blockBody, resTypes, *opLoc, blockOp)))
1006 return failure();
1007 builder.setInsertionPointToStart(successor);
1008 return {ValueRange{successor->getArguments()}};
1009}
1010
1011template <>
1012inline parsed_inst_t
1013ExpressionParser::parseSpecificInstruction<WasmBinaryEncoding::OpCode::block>(
1014 OpBuilder &builder) {
1015 return parseBlockLikeOp<BlockOp>(builder);
1016}
1017
1018template <>
1019inline parsed_inst_t
1020ExpressionParser::parseSpecificInstruction<WasmBinaryEncoding::OpCode::loop>(
1021 OpBuilder &builder) {
1022 return parseBlockLikeOp<LoopOp>(builder);
1023}
1024
1025template <>
1026inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
1027 WasmBinaryEncoding::OpCode::ifOpCode>(OpBuilder &builder) {
1028 auto opLoc = currentOpLoc;
1029 auto funcType = parseBlockFuncType(builder);
1030 if (failed(funcType))
1031 return failure();
1032
1033 LDBG() << "Parsing an if instruction of type " << *funcType;
1034 auto inputTypes = funcType->getInputs();
1035 auto conditionValue = popOperands(builder.getI32Type());
1036 if (failed(conditionValue))
1037 return failure();
1038 auto inputOps = popOperands(inputTypes);
1039 if (failed(inputOps))
1040 return failure();
1041
1042 Block *curBlock = builder.getBlock();
1043 Region *curRegion = curBlock->getParent();
1044 auto resTypes = funcType->getResults();
1045 llvm::SmallVector<Location> locations{};
1046 locations.resize(resTypes.size(), *currentOpLoc);
1047 auto *successor =
1048 builder.createBlock(curRegion, curRegion->end(), resTypes, locations);
1049 builder.setInsertionPointToEnd(curBlock);
1050 auto ifOp = IfOp::create(builder, *currentOpLoc, conditionValue->front(),
1051 *inputOps, successor);
1052 auto *ifEntryBlock = ifOp.createIfBlock();
1053 constexpr auto ifElseFilter =
1054 ByteSequence<WasmBinaryEncoding::endByte,
1056 auto parseIfRes = parseBlockContent(builder, ifEntryBlock, resTypes, *opLoc,
1057 ifOp, ifElseFilter);
1058 if (failed(parseIfRes))
1059 return failure();
1060 if (*parseIfRes == WasmBinaryEncoding::OpCode::elseOpCode) {
1061 LDBG() << " else block is present.";
1062 Block *elseEntryBlock = ifOp.createElseBlock();
1063 auto parseElseRes =
1064 parseBlockContent(builder, elseEntryBlock, resTypes, *opLoc, ifOp);
1065 if (failed(parseElseRes))
1066 return failure();
1067 }
1068 builder.setInsertionPointToStart(successor);
1069 return {ValueRange{successor->getArguments()}};
1070}
1071
1072template <>
1073inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
1074 WasmBinaryEncoding::OpCode::branchIf>(OpBuilder &builder) {
1075 auto level = parser.parseLiteral<uint32_t>();
1076 if (failed(level))
1077 return failure();
1078 Block *curBlock = builder.getBlock();
1079 Region *curRegion = curBlock->getParent();
1080 auto sip = builder.saveInsertionPoint();
1081 Block *elseBlock = builder.createBlock(curRegion, curRegion->end());
1082 auto condition = popOperands(builder.getI32Type());
1083 if (failed(condition))
1084 return failure();
1085 builder.restoreInsertionPoint(sip);
1086 auto targetOp =
1087 LabelBranchingOpInterface::getTargetOpFromBlock(curBlock, *level);
1088 if (failed(targetOp))
1089 return failure();
1090 auto inputTypes = targetOp->getLabelTarget()->getArgumentTypes();
1091 auto branchArgs = popOperands(inputTypes);
1092 if (failed(branchArgs))
1093 return failure();
1094 BranchIfOp::create(builder, *currentOpLoc, condition->front(),
1095 builder.getUI32IntegerAttr(*level), *branchArgs,
1096 elseBlock);
1097 builder.setInsertionPointToStart(elseBlock);
1098 return {*branchArgs};
1099}
1100
1101template <>
1102inline parsed_inst_t
1103ExpressionParser::parseSpecificInstruction<WasmBinaryEncoding::OpCode::call>(
1104 OpBuilder &builder) {
1105 auto loc = *currentOpLoc;
1106 auto funcIdx = parser.parseLiteral<uint32_t>();
1107 if (failed(funcIdx))
1108 return failure();
1109 if (*funcIdx >= symbols.funcSymbols.size())
1110 return emitError(loc, "Invalid function index: ") << *funcIdx;
1111 auto callee = symbols.funcSymbols[*funcIdx];
1112 llvm::ArrayRef<Type> inTypes = callee.functionType.getInputs();
1113 llvm::ArrayRef<Type> resTypes = callee.functionType.getResults();
1114 parsed_inst_t inOperands = popOperands(inTypes);
1115 if (failed(inOperands))
1116 return failure();
1117 auto callOp =
1118 FuncCallOp::create(builder, loc, resTypes, callee.symbol, *inOperands);
1119 return {callOp.getResults()};
1120}
1121
1122template <>
1123inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
1124 WasmBinaryEncoding::OpCode::localGet>(OpBuilder &builder) {
1125 FailureOr<uint32_t> id = parser.parseLiteral<uint32_t>();
1126 Location instLoc = *currentOpLoc;
1127 if (failed(id))
1128 return failure();
1129 if (*id >= locals.size())
1130 return emitError(instLoc, "invalid local index. function has ")
1131 << locals.size() << " accessible locals, received index " << *id;
1132 return {{LocalGetOp::create(builder, instLoc, locals[*id]).getResult()}};
1133}
1134
1135template <>
1136inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
1137 WasmBinaryEncoding::OpCode::globalGet>(OpBuilder &builder) {
1138 FailureOr<uint32_t> id = parser.parseLiteral<uint32_t>();
1139 Location instLoc = *currentOpLoc;
1140 if (failed(id))
1141 return failure();
1142 if (*id >= symbols.globalSymbols.size())
1143 return emitError(instLoc, "invalid global index. function has ")
1144 << symbols.globalSymbols.size()
1145 << " accessible globals, received index " << *id;
1146 GlobalSymbolRefContainer globalVar = symbols.globalSymbols[*id];
1147 auto globalOp = GlobalGetOp::create(builder, instLoc, globalVar.globalType,
1148 globalVar.symbol);
1149
1150 return {{globalOp.getResult()}};
1151}
1152
1153template <typename OpToCreate>
1154parsed_inst_t ExpressionParser::parseSetOrTee(OpBuilder &builder) {
1155 FailureOr<uint32_t> id = parser.parseLiteral<uint32_t>();
1156 if (failed(id))
1157 return failure();
1158 if (*id >= locals.size())
1159 return emitError(*currentOpLoc, "invalid local index. function has ")
1160 << locals.size() << " accessible locals, received index " << *id;
1161 if (valueStack.empty())
1162 return emitError(
1163 *currentOpLoc,
1164 "invalid stack access, trying to access a value on an empty stack");
1165
1166 parsed_inst_t poppedOp = popOperands(locals[*id].getType().getElementType());
1167 if (failed(poppedOp))
1168 return failure();
1169 return {
1170 OpToCreate::create(builder, *currentOpLoc, locals[*id], poppedOp->front())
1171 ->getResults()};
1172}
1173
1174template <>
1175inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
1176 WasmBinaryEncoding::OpCode::localSet>(OpBuilder &builder) {
1177 return parseSetOrTee<LocalSetOp>(builder);
1178}
1179
1180template <>
1181inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
1182 WasmBinaryEncoding::OpCode::localTee>(OpBuilder &builder) {
1183 return parseSetOrTee<LocalTeeOp>(builder);
1184}
1185
1186template <typename T>
1187inline Type buildLiteralType(OpBuilder &);
1188
1189template <>
1190inline Type buildLiteralType<int32_t>(OpBuilder &builder) {
1191 return builder.getI32Type();
1192}
1193
1194template <>
1195inline Type buildLiteralType<int64_t>(OpBuilder &builder) {
1196 return builder.getI64Type();
1197}
1198
1199template <>
1200[[maybe_unused]] inline Type buildLiteralType<uint32_t>(OpBuilder &builder) {
1201 return builder.getI32Type();
1202}
1203
1204template <>
1205[[maybe_unused]] inline Type buildLiteralType<uint64_t>(OpBuilder &builder) {
1206 return builder.getI64Type();
1207}
1208
1209template <>
1210inline Type buildLiteralType<float>(OpBuilder &builder) {
1211 return builder.getF32Type();
1212}
1213
1214template <>
1215inline Type buildLiteralType<double>(OpBuilder &builder) {
1216 return builder.getF64Type();
1217}
1218
1219template <typename ValT,
1220 typename E = std::enable_if_t<std::is_arithmetic_v<ValT>>>
1221struct AttrHolder;
1222
1223template <typename ValT>
1224struct AttrHolder<ValT, std::enable_if_t<std::is_integral_v<ValT>>> {
1225 using type = IntegerAttr;
1226};
1227
1228template <typename ValT>
1229struct AttrHolder<ValT, std::enable_if_t<std::is_floating_point_v<ValT>>> {
1230 using type = FloatAttr;
1231};
1232
1233template <typename ValT>
1234using attr_holder_t = typename AttrHolder<ValT>::type;
1235
1236template <typename ValT,
1237 typename EnableT = std::enable_if_t<std::is_arithmetic_v<ValT>>>
1238attr_holder_t<ValT> buildLiteralAttr(OpBuilder &builder, ValT val) {
1239 return attr_holder_t<ValT>::get(buildLiteralType<ValT>(builder), val);
1240}
1241
1242template <typename valueT>
1243parsed_inst_t ExpressionParser::parseConstInst(
1244 OpBuilder &builder, std::enable_if_t<std::is_arithmetic_v<valueT>> *) {
1245 auto parsedConstant = parser.parseLiteral<valueT>();
1246 if (failed(parsedConstant))
1247 return failure();
1248 auto constOp =
1249 ConstOp::create(builder, *currentOpLoc,
1250 buildLiteralAttr<valueT>(builder, *parsedConstant));
1251 return {{constOp.getResult()}};
1252}
1253
1254template <>
1255inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
1256 WasmBinaryEncoding::OpCode::constI32>(OpBuilder &builder) {
1257 return parseConstInst<int32_t>(builder);
1258}
1259
1260template <>
1261inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
1262 WasmBinaryEncoding::OpCode::constI64>(OpBuilder &builder) {
1263 return parseConstInst<int64_t>(builder);
1264}
1265
1266template <>
1267inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
1268 WasmBinaryEncoding::OpCode::constFP32>(OpBuilder &builder) {
1269 return parseConstInst<float>(builder);
1270}
1271
1272template <>
1273inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
1274 WasmBinaryEncoding::OpCode::constFP64>(OpBuilder &builder) {
1275 return parseConstInst<double>(builder);
1276}
1277
1278template <typename opcode, typename valueType, unsigned int numOperands>
1279inline parsed_inst_t ExpressionParser::buildNumericOp(
1280 OpBuilder &builder, std::enable_if_t<std::is_arithmetic_v<valueType>> *) {
1281 auto ty = buildLiteralType<valueType>(builder);
1282 LDBG() << "*** buildNumericOp: numOperands = " << numOperands
1283 << ", type = " << ty << " ***";
1284 auto tysToPop = SmallVector<Type, numOperands>();
1285 tysToPop.resize(numOperands);
1286 llvm::fill(tysToPop, ty);
1287 auto operands = popOperands(tysToPop);
1288 if (failed(operands))
1289 return failure();
1290 auto op = opcode::create(builder, *currentOpLoc, *operands).getResult();
1291 LDBG() << "Built operation: " << op;
1292 return {{op}};
1293}
1294
1295// Convenience macro for generating numerical operations.
1296#define BUILD_NUMERIC_OP(OP_NAME, N_ARGS, PREFIX, SUFFIX, TYPE) \
1297 template <> \
1298 inline parsed_inst_t ExpressionParser::parseSpecificInstruction< \
1299 WasmBinaryEncoding::OpCode::PREFIX##SUFFIX>(OpBuilder & builder) { \
1300 return buildNumericOp<OP_NAME, TYPE, N_ARGS>(builder); \
1301 }
1302
1303// Macro to define binops that only support integer types.
1304#define BUILD_NUMERIC_BINOP_INT(OP_NAME, PREFIX) \
1305 BUILD_NUMERIC_OP(OP_NAME, 2, PREFIX, I32, int32_t) \
1306 BUILD_NUMERIC_OP(OP_NAME, 2, PREFIX, I64, int64_t)
1307
1308// Macro to define binops that only support floating point types.
1309#define BUILD_NUMERIC_BINOP_FP(OP_NAME, PREFIX) \
1310 BUILD_NUMERIC_OP(OP_NAME, 2, PREFIX, F32, float) \
1311 BUILD_NUMERIC_OP(OP_NAME, 2, PREFIX, F64, double)
1312
1313// Macro to define binops that support both floating point and integer types.
1314#define BUILD_NUMERIC_BINOP_INTFP(OP_NAME, PREFIX) \
1315 BUILD_NUMERIC_BINOP_INT(OP_NAME, PREFIX) \
1316 BUILD_NUMERIC_BINOP_FP(OP_NAME, PREFIX)
1317
1318// Macro to implement unary ops that only support integers.
1319#define BUILD_NUMERIC_UNARY_OP_INT(OP_NAME, PREFIX) \
1320 BUILD_NUMERIC_OP(OP_NAME, 1, PREFIX, I32, int32_t) \
1321 BUILD_NUMERIC_OP(OP_NAME, 1, PREFIX, I64, int64_t)
1322
1323// Macro to implement unary ops that support integer and floating point types.
1324#define BUILD_NUMERIC_UNARY_OP_FP(OP_NAME, PREFIX) \
1325 BUILD_NUMERIC_OP(OP_NAME, 1, PREFIX, F32, float) \
1326 BUILD_NUMERIC_OP(OP_NAME, 1, PREFIX, F64, double)
1327
1328BUILD_NUMERIC_BINOP_FP(CopySignOp, copysign)
1330BUILD_NUMERIC_BINOP_FP(GeOp, ge)
1331BUILD_NUMERIC_BINOP_FP(GtOp, gt)
1332BUILD_NUMERIC_BINOP_FP(LeOp, le)
1333BUILD_NUMERIC_BINOP_FP(LtOp, lt)
1336BUILD_NUMERIC_BINOP_INT(AndOp, and)
1337BUILD_NUMERIC_BINOP_INT(DivSIOp, divS)
1338BUILD_NUMERIC_BINOP_INT(DivUIOp, divU)
1339BUILD_NUMERIC_BINOP_INT(GeSIOp, geS)
1340BUILD_NUMERIC_BINOP_INT(GeUIOp, geU)
1341BUILD_NUMERIC_BINOP_INT(GtSIOp, gtS)
1342BUILD_NUMERIC_BINOP_INT(GtUIOp, gtU)
1343BUILD_NUMERIC_BINOP_INT(LeSIOp, leS)
1344BUILD_NUMERIC_BINOP_INT(LeUIOp, leU)
1345BUILD_NUMERIC_BINOP_INT(LtSIOp, ltS)
1346BUILD_NUMERIC_BINOP_INT(LtUIOp, ltU)
1348BUILD_NUMERIC_BINOP_INT(RemSIOp, remS)
1349BUILD_NUMERIC_BINOP_INT(RemUIOp, remU)
1350BUILD_NUMERIC_BINOP_INT(RotlOp, rotl)
1351BUILD_NUMERIC_BINOP_INT(RotrOp, rotr)
1352BUILD_NUMERIC_BINOP_INT(ShLOp, shl)
1353BUILD_NUMERIC_BINOP_INT(ShRSOp, shrS)
1354BUILD_NUMERIC_BINOP_INT(ShRUOp, shrU)
1355BUILD_NUMERIC_BINOP_INT(XOrOp, xor)
1360BUILD_NUMERIC_BINOP_INTFP(SubOp, sub)
1361BUILD_NUMERIC_UNARY_OP_FP(AbsOp, abs)
1362BUILD_NUMERIC_UNARY_OP_FP(CeilOp, ceil)
1363BUILD_NUMERIC_UNARY_OP_FP(FloorOp, floor)
1364BUILD_NUMERIC_UNARY_OP_FP(NegOp, neg)
1365BUILD_NUMERIC_UNARY_OP_FP(SqrtOp, sqrt)
1366BUILD_NUMERIC_UNARY_OP_FP(TruncOp, trunc)
1370BUILD_NUMERIC_UNARY_OP_INT(PopCntOp, popcnt)
1371
1372// Don't need these anymore so let's undef them.
1373#undef BUILD_NUMERIC_BINOP_FP
1374#undef BUILD_NUMERIC_BINOP_INT
1375#undef BUILD_NUMERIC_BINOP_INTFP
1376#undef BUILD_NUMERIC_UNARY_OP_FP
1377#undef BUILD_NUMERIC_UNARY_OP_INT
1378#undef BUILD_NUMERIC_OP
1379#undef BUILD_NUMERIC_CAST_OP
1380
1381template <typename opType, typename inputType, typename outputType,
1382 typename... extraArgsT>
1383inline parsed_inst_t ExpressionParser::buildConvertOp(OpBuilder &builder,
1384 extraArgsT... extraArgs) {
1385 static_assert(std::is_arithmetic_v<inputType>,
1386 "InputType should be an arithmetic type");
1387 static_assert(std::is_arithmetic_v<outputType>,
1388 "OutputType should be an arithmetic type");
1389 auto intype = buildLiteralType<inputType>(builder);
1390 auto outType = buildLiteralType<outputType>(builder);
1391 auto operand = popOperands(intype);
1392 if (failed(operand))
1393 return failure();
1394 auto op = opType::create(builder, *currentOpLoc, outType, operand->front(),
1395 extraArgs...);
1396 LDBG() << "Built operation: " << op;
1397 return {{op.getResult()}};
1398}
1399
1400template <>
1401inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
1402 WasmBinaryEncoding::OpCode::demoteF64ToF32>(OpBuilder &builder) {
1403 return buildConvertOp<DemoteOp, double, float>(builder);
1404}
1405
1406template <>
1407inline parsed_inst_t
1408ExpressionParser::parseSpecificInstruction<WasmBinaryEncoding::OpCode::wrap>(
1409 OpBuilder &builder) {
1410 return buildConvertOp<WrapOp, int64_t, int32_t>(builder);
1411}
1412
1413#define BUILD_CONVERSION_OP(IN_T, OUT_T, SOURCE_OP, TARGET_OP) \
1414 template <> \
1415 inline parsed_inst_t ExpressionParser::parseSpecificInstruction< \
1416 WasmBinaryEncoding::OpCode::SOURCE_OP>(OpBuilder & builder) { \
1417 return buildConvertOp<TARGET_OP, IN_T, OUT_T>(builder); \
1418 }
1419
1420#define BUILD_CONVERT_OP_FOR(DEST_T, WIDTH) \
1421 BUILD_CONVERSION_OP(uint32_t, DEST_T, convertUI32F##WIDTH, ConvertUOp) \
1422 BUILD_CONVERSION_OP(int32_t, DEST_T, convertSI32F##WIDTH, ConvertSOp) \
1423 BUILD_CONVERSION_OP(uint64_t, DEST_T, convertUI64F##WIDTH, ConvertUOp) \
1424 BUILD_CONVERSION_OP(int64_t, DEST_T, convertSI64F##WIDTH, ConvertSOp)
1425
1426BUILD_CONVERT_OP_FOR(float, 32)
1427BUILD_CONVERT_OP_FOR(double, 64)
1428
1429#undef BUILD_CONVERT_OP_FOR
1430
1431BUILD_CONVERSION_OP(int32_t, int64_t, extendS, ExtendSI32Op)
1432BUILD_CONVERSION_OP(int32_t, int64_t, extendU, ExtendUI32Op)
1433
1434#undef BUILD_CONVERSION_OP
1435
1436#define BUILD_SLICE_EXTEND_PARSER(IT_WIDTH, EXTRACT_WIDTH) \
1437 template <> \
1438 parsed_inst_t ExpressionParser::parseSpecificInstruction< \
1439 WasmBinaryEncoding::OpCode::extendI##IT_WIDTH##EXTRACT_WIDTH##S>( \
1440 OpBuilder & builder) { \
1441 using inout_t = int##IT_WIDTH##_t; \
1442 auto attr = builder.getUI32IntegerAttr(EXTRACT_WIDTH); \
1443 return buildConvertOp<ExtendLowBitsSOp, inout_t, inout_t>(builder, attr); \
1444 }
1445
1451
1452#undef BUILD_SLICE_EXTEND_PARSER
1453
1454template <>
1455inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
1457 return buildConvertOp<PromoteOp, float, double>(builder);
1458}
1459
1460#define BUILD_REINTERPRET_PARSER(WIDTH, FP_TYPE) \
1461 template <> \
1462 inline parsed_inst_t ExpressionParser::parseSpecificInstruction< \
1463 WasmBinaryEncoding::OpCode::reinterpretF##WIDTH##AsI##WIDTH>(OpBuilder & \
1464 builder) { \
1465 return buildConvertOp<ReinterpretOp, FP_TYPE, int##WIDTH##_t>(builder); \
1466 } \
1467 \
1468 template <> \
1469 inline parsed_inst_t ExpressionParser::parseSpecificInstruction< \
1470 WasmBinaryEncoding::OpCode::reinterpretI##WIDTH##AsF##WIDTH>(OpBuilder & \
1471 builder) { \
1472 return buildConvertOp<ReinterpretOp, int##WIDTH##_t, FP_TYPE>(builder); \
1473 }
1474
1475BUILD_REINTERPRET_PARSER(32, float)
1476BUILD_REINTERPRET_PARSER(64, double)
1477
1478#undef BUILD_REINTERPRET_PARSER
1479
1480class WasmBinaryParser {
1481private:
1482 struct SectionRegistry {
1483 using section_location_t = StringRef;
1484
1485 std::array<SmallVector<section_location_t>, highestWasmSectionID + 1>
1486 registry;
1487
1488 template <WasmSectionType SecType>
1489 std::conditional_t<sectionShouldBeUnique(SecType),
1490 std::optional<section_location_t>,
1492 getContentForSection() const {
1493 constexpr auto idx = static_cast<size_t>(SecType);
1494 if constexpr (sectionShouldBeUnique(SecType)) {
1495 return registry[idx].empty() ? std::nullopt
1496 : std::make_optional(registry[idx][0]);
1497 } else {
1498 return registry[idx];
1499 }
1500 }
1501
1502 bool hasSection(WasmSectionType secType) const {
1503 return !registry[static_cast<size_t>(secType)].empty();
1504 }
1505
1506 ///
1507 /// @returns success if registration valid, failure in case registration
1508 /// can't be done (if another section of same type already exist and this
1509 /// section type should only be present once)
1510 ///
1511 LogicalResult registerSection(WasmSectionType secType,
1512 section_location_t location, Location loc) {
1513 if (sectionShouldBeUnique(secType) && hasSection(secType))
1514 return emitError(loc,
1515 "trying to add a second instance of unique section");
1516
1517 registry[static_cast<size_t>(secType)].push_back(location);
1518 emitRemark(loc, "Adding section with section ID ")
1519 << static_cast<uint8_t>(secType);
1520 return success();
1521 }
1522
1523 LogicalResult populateFromBody(ParserHead ph) {
1524 while (!ph.end()) {
1525 FileLineColLoc sectionLoc = ph.getLocation();
1526 FailureOr<WasmSectionType> secType = ph.parseWasmSectionType();
1527 if (failed(secType))
1528 return failure();
1529
1530 FailureOr<uint32_t> secSizeParsed = ph.parseLiteral<uint32_t>();
1531 if (failed(secSizeParsed))
1532 return failure();
1533
1534 uint32_t secSize = *secSizeParsed;
1535 FailureOr<StringRef> sectionContent = ph.consumeNBytes(secSize);
1536 if (failed(sectionContent))
1537 return failure();
1538
1539 LogicalResult registration =
1540 registerSection(*secType, *sectionContent, sectionLoc);
1541
1542 if (failed(registration))
1543 return failure();
1544 }
1545 return success();
1546 }
1547 };
1548
1549 auto getLocation(int offset = 0) const {
1550 return FileLineColLoc::get(srcName, 0, offset);
1551 }
1552
1553 template <WasmSectionType>
1554 LogicalResult parseSectionItem(ParserHead &, size_t);
1555
1556 template <WasmSectionType section>
1557 LogicalResult parseSection() {
1558 auto secName = std::string{wasmSectionName<section>};
1559 auto sectionNameAttr =
1560 StringAttr::get(ctx, srcName.strref() + ":" + secName + "-SECTION");
1561 unsigned offset = 0;
1562 auto getLocation = [sectionNameAttr, &offset]() {
1563 return FileLineColLoc::get(sectionNameAttr, 0, offset);
1564 };
1565 auto secContent = registry.getContentForSection<section>();
1566 if (!secContent) {
1567 LDBG() << secName << " section is not present in file.";
1568 return success();
1569 }
1570
1571 auto secSrc = secContent.value();
1572 ParserHead ph{secSrc, sectionNameAttr};
1573 FailureOr<uint32_t> nElemsParsed = ph.parseVectorSize();
1574 if (failed(nElemsParsed))
1575 return failure();
1576 uint32_t nElems = *nElemsParsed;
1577 LDBG() << "starting to parse " << nElems << " items for section "
1578 << secName;
1579 for (size_t i = 0; i < nElems; ++i) {
1580 if (failed(parseSectionItem<section>(ph, i)))
1581 return failure();
1582 }
1583
1584 if (!ph.end())
1585 return emitError(getLocation(), "unparsed garbage at end of section ")
1586 << secName;
1587 return success();
1588 }
1589
1590 /// Handles the registration of a function import
1591 LogicalResult visitImport(Location loc, StringRef moduleName,
1592 StringRef importName, TypeIdxRecord tid) {
1593 using llvm::Twine;
1594 if (tid.id >= symbols.moduleFuncTypes.size())
1595 return emitError(loc, "invalid type id: ")
1596 << tid.id << ". Only " << symbols.moduleFuncTypes.size()
1597 << " type registrations";
1598 FunctionType type = symbols.moduleFuncTypes[tid.id];
1599 std::string symbol = symbols.getNewFuncSymbolName();
1600 auto funcOp = FuncImportOp::create(builder, loc, symbol, moduleName,
1601 importName, type);
1602 symbols.funcSymbols.push_back({{FlatSymbolRefAttr::get(funcOp)}, type});
1603 return funcOp.verify();
1604 }
1605
1606 /// Handles the registration of a memory import
1607 LogicalResult visitImport(Location loc, StringRef moduleName,
1608 StringRef importName, LimitType limitType) {
1609 std::string symbol = symbols.getNewMemorySymbolName();
1610 auto memOp = MemImportOp::create(builder, loc, symbol, moduleName,
1611 importName, limitType);
1612 symbols.memSymbols.push_back({FlatSymbolRefAttr::get(memOp)});
1613 return memOp.verify();
1614 }
1615
1616 /// Handles the registration of a table import
1617 LogicalResult visitImport(Location loc, StringRef moduleName,
1618 StringRef importName, TableType tableType) {
1619 std::string symbol = symbols.getNewTableSymbolName();
1620 auto tableOp = TableImportOp::create(builder, loc, symbol, moduleName,
1621 importName, tableType);
1622 symbols.tableSymbols.push_back({FlatSymbolRefAttr::get(tableOp)});
1623 return tableOp.verify();
1624 }
1625
1626 /// Handles the registration of a global variable import
1627 LogicalResult visitImport(Location loc, StringRef moduleName,
1628 StringRef importName, GlobalTypeRecord globalType) {
1629 std::string symbol = symbols.getNewGlobalSymbolName();
1630 auto giOp =
1631 GlobalImportOp::create(builder, loc, symbol, moduleName, importName,
1632 globalType.type, globalType.isMutable);
1633 symbols.globalSymbols.push_back(
1634 {{FlatSymbolRefAttr::get(giOp)}, giOp.getType()});
1635 return giOp.verify();
1636 }
1637
1638 // Detect occurence of errors
1639 LogicalResult peekDiag(Diagnostic &diag) {
1640 if (diag.getSeverity() == DiagnosticSeverity::Error)
1641 isValid = false;
1642 return failure();
1643 }
1644
1645public:
1646 WasmBinaryParser(llvm::SourceMgr &sourceMgr, MLIRContext *ctx)
1647 : builder{ctx}, ctx{ctx} {
1649 [this](Diagnostic &diag) { return peekDiag(diag); });
1651 if (sourceMgr.getNumBuffers() != 1) {
1652 emitError(UnknownLoc::get(ctx), "one source file should be provided");
1653 return;
1654 }
1655 uint32_t sourceBufId = sourceMgr.getMainFileID();
1656 StringRef source = sourceMgr.getMemoryBuffer(sourceBufId)->getBuffer();
1657 srcName = StringAttr::get(
1658 ctx, sourceMgr.getMemoryBuffer(sourceBufId)->getBufferIdentifier());
1659
1660 auto parser = ParserHead{source, srcName};
1661 auto const wasmHeader = StringRef{"\0asm", 4};
1662 FileLineColLoc magicLoc = parser.getLocation();
1663 FailureOr<StringRef> magic = parser.consumeNBytes(wasmHeader.size());
1664 if (failed(magic) || magic->compare(wasmHeader)) {
1665 emitError(magicLoc, "source file does not contain valid Wasm header");
1666 return;
1667 }
1668 auto const expectedVersionString = StringRef{"\1\0\0\0", 4};
1669 FileLineColLoc versionLoc = parser.getLocation();
1670 FailureOr<StringRef> version =
1671 parser.consumeNBytes(expectedVersionString.size());
1672 if (failed(version))
1673 return;
1674 if (version->compare(expectedVersionString)) {
1675 emitError(versionLoc,
1676 "unsupported Wasm version. only version 1 is supported");
1677 return;
1678 }
1679 LogicalResult fillRegistry = registry.populateFromBody(parser.copy());
1680 if (failed(fillRegistry))
1681 return;
1682
1683 mOp = ModuleOp::create(builder, getLocation());
1684 builder.setInsertionPointToStart(&mOp.getBodyRegion().front());
1685 LogicalResult parsingTypes = parseSection<WasmSectionType::TYPE>();
1686 if (failed(parsingTypes))
1687 return;
1688
1689 LogicalResult parsingImports = parseSection<WasmSectionType::IMPORT>();
1690 if (failed(parsingImports))
1691 return;
1692
1693 firstInternalFuncID = symbols.funcSymbols.size();
1694
1695 LogicalResult parsingFunctions = parseSection<WasmSectionType::FUNCTION>();
1696 if (failed(parsingFunctions))
1697 return;
1698
1699 LogicalResult parsingTables = parseSection<WasmSectionType::TABLE>();
1700 if (failed(parsingTables))
1701 return;
1702
1703 LogicalResult parsingMems = parseSection<WasmSectionType::MEMORY>();
1704 if (failed(parsingMems))
1705 return;
1706
1707 LogicalResult parsingGlobals = parseSection<WasmSectionType::GLOBAL>();
1708 if (failed(parsingGlobals))
1709 return;
1710
1711 LogicalResult parsingCode = parseSection<WasmSectionType::CODE>();
1712 if (failed(parsingCode))
1713 return;
1714
1715 LogicalResult parsingExports = parseSection<WasmSectionType::EXPORT>();
1716 if (failed(parsingExports))
1717 return;
1718
1719 // Copy over sizes of containers into statistics.
1720 LDBG() << "WASM Imports:"
1721 << "\n"
1722 << " - Num functions: " << symbols.funcSymbols.size() << "\n"
1723 << " - Num globals: " << symbols.globalSymbols.size() << "\n"
1724 << " - Num memories: " << symbols.memSymbols.size() << "\n"
1725 << " - Num tables: " << symbols.tableSymbols.size();
1726 }
1727
1728 ModuleOp getModule() {
1729 if (isValid)
1730 return mOp;
1731 if (mOp)
1732 mOp.erase();
1733 return ModuleOp{};
1734 }
1735
1736private:
1737 mlir::StringAttr srcName;
1738 OpBuilder builder;
1739 WasmModuleSymbolTables symbols;
1740 MLIRContext *ctx;
1741 ModuleOp mOp;
1742 SectionRegistry registry;
1743 size_t firstInternalFuncID{0};
1744 bool isValid{true};
1745};
1746
1747template <>
1748LogicalResult
1749WasmBinaryParser::parseSectionItem<WasmSectionType::IMPORT>(ParserHead &ph,
1750 size_t) {
1751 FileLineColLoc importLoc = ph.getLocation();
1752 auto moduleName = ph.parseName();
1753 if (failed(moduleName))
1754 return failure();
1755
1756 auto importName = ph.parseName();
1757 if (failed(importName))
1758 return failure();
1759
1760 FailureOr<ImportDesc> import = ph.parseImportDesc(ctx);
1761 if (failed(import))
1762 return failure();
1763
1764 return std::visit(
1765 [this, importLoc, &moduleName, &importName](auto import) {
1766 return visitImport(importLoc, *moduleName, *importName, import);
1767 },
1768 *import);
1769}
1770
1771template <>
1772LogicalResult
1773WasmBinaryParser::parseSectionItem<WasmSectionType::EXPORT>(ParserHead &ph,
1774 size_t) {
1775 FileLineColLoc exportLoc = ph.getLocation();
1776
1777 auto exportName = ph.parseName();
1778 if (failed(exportName))
1779 return failure();
1780
1781 FailureOr<std::byte> opcode = ph.consumeByte();
1782 if (failed(opcode))
1783 return failure();
1784
1785 FailureOr<uint32_t> idx = ph.parseLiteral<uint32_t>();
1786 if (failed(idx))
1787 return failure();
1788
1789 using SymbolRefDesc = std::variant<SmallVector<SymbolRefContainer>,
1790 SmallVector<GlobalSymbolRefContainer>,
1791 SmallVector<FunctionSymbolRefContainer>>;
1792
1793 SymbolRefDesc currentSymbolList;
1794 std::string symbolType = "";
1795 switch (*opcode) {
1797 symbolType = "function";
1798 currentSymbolList = symbols.funcSymbols;
1799 break;
1801 symbolType = "table";
1802 currentSymbolList = symbols.tableSymbols;
1803 break;
1805 symbolType = "memory";
1806 currentSymbolList = symbols.memSymbols;
1807 break;
1809 symbolType = "global";
1810 currentSymbolList = symbols.globalSymbols;
1811 break;
1812 default:
1813 return emitError(exportLoc, "invalid value for export type: ")
1814 << std::to_integer<unsigned>(*opcode);
1815 }
1816
1817 auto currentSymbol = std::visit(
1818 [&](const auto &list) -> FailureOr<FlatSymbolRefAttr> {
1819 if (*idx > list.size()) {
1820 emitError(
1821 exportLoc,
1822 llvm::formatv(
1823 "trying to export {0} {1} which is undefined in this scope",
1824 symbolType, *idx));
1825 return failure();
1826 }
1827 return list[*idx].symbol;
1828 },
1829 currentSymbolList);
1830
1831 if (failed(currentSymbol))
1832 return failure();
1833
1834 Operation *op = SymbolTable::lookupSymbolIn(mOp, *currentSymbol);
1835 op->setAttr("exported", UnitAttr::get(op->getContext()));
1836 StringAttr symName = SymbolTable::getSymbolName(op);
1837 return SymbolTable{mOp}.rename(symName, *exportName);
1838}
1839
1840template <>
1841LogicalResult
1842WasmBinaryParser::parseSectionItem<WasmSectionType::TABLE>(ParserHead &ph,
1843 size_t) {
1844 FileLineColLoc opLocation = ph.getLocation();
1845 FailureOr<TableType> tableType = ph.parseTableType(ctx);
1846 if (failed(tableType))
1847 return failure();
1848 LDBG() << " Parsed table description: " << *tableType;
1849 StringAttr symbol = builder.getStringAttr(symbols.getNewTableSymbolName());
1850 auto tableOp =
1851 TableOp::create(builder, opLocation, symbol.strref(), *tableType);
1852 symbols.tableSymbols.push_back({SymbolRefAttr::get(tableOp)});
1853 return success();
1854}
1855
1856template <>
1857LogicalResult
1858WasmBinaryParser::parseSectionItem<WasmSectionType::FUNCTION>(ParserHead &ph,
1859 size_t) {
1860 FileLineColLoc opLoc = ph.getLocation();
1861 auto typeIdxParsed = ph.parseLiteral<uint32_t>();
1862 if (failed(typeIdxParsed))
1863 return failure();
1864 uint32_t typeIdx = *typeIdxParsed;
1865 if (typeIdx >= symbols.moduleFuncTypes.size())
1866 return emitError(getLocation(), "invalid type index: ") << typeIdx;
1867 std::string symbol = symbols.getNewFuncSymbolName();
1868 auto funcOp =
1869 FuncOp::create(builder, opLoc, symbol, symbols.moduleFuncTypes[typeIdx]);
1870 Block *block = funcOp.addEntryBlock();
1871 OpBuilder::InsertionGuard guard{builder};
1872 builder.setInsertionPointToEnd(block);
1873 ReturnOp::create(builder, opLoc);
1874 symbols.funcSymbols.push_back(
1875 {{FlatSymbolRefAttr::get(funcOp.getSymNameAttr())},
1876 symbols.moduleFuncTypes[typeIdx]});
1877 return funcOp.verify();
1878}
1879
1880template <>
1881LogicalResult
1882WasmBinaryParser::parseSectionItem<WasmSectionType::TYPE>(ParserHead &ph,
1883 size_t) {
1884 FailureOr<FunctionType> funcType = ph.parseFunctionType(ctx);
1885 if (failed(funcType))
1886 return failure();
1887 LDBG() << "Parsed function type " << *funcType;
1888 symbols.moduleFuncTypes.push_back(*funcType);
1889 return success();
1890}
1891
1892template <>
1893LogicalResult
1894WasmBinaryParser::parseSectionItem<WasmSectionType::MEMORY>(ParserHead &ph,
1895 size_t) {
1896 FileLineColLoc opLocation = ph.getLocation();
1897 FailureOr<LimitType> memory = ph.parseLimit(ctx);
1898 if (failed(memory))
1899 return failure();
1900
1901 LDBG() << " Registering memory " << *memory;
1902 std::string symbol = symbols.getNewMemorySymbolName();
1903 auto memOp = MemOp::create(builder, opLocation, symbol, *memory);
1904 symbols.memSymbols.push_back({SymbolRefAttr::get(memOp)});
1905 return success();
1906}
1907
1908template <>
1909LogicalResult
1910WasmBinaryParser::parseSectionItem<WasmSectionType::GLOBAL>(ParserHead &ph,
1911 size_t) {
1912 FileLineColLoc globalLocation = ph.getLocation();
1913 auto globalTypeParsed = ph.parseGlobalType(ctx);
1914 if (failed(globalTypeParsed))
1915 return failure();
1916
1917 GlobalTypeRecord globalType = *globalTypeParsed;
1918 auto symbol = builder.getStringAttr(symbols.getNewGlobalSymbolName());
1919 auto globalOp = wasmssa::GlobalOp::create(
1920 builder, globalLocation, symbol, globalType.type, globalType.isMutable);
1921 symbols.globalSymbols.push_back(
1922 {{FlatSymbolRefAttr::get(globalOp)}, globalOp.getType()});
1923 OpBuilder::InsertionGuard guard{builder};
1924 Block *block = builder.createBlock(&globalOp.getInitializer());
1925 builder.setInsertionPointToStart(block);
1926 parsed_inst_t expr = ph.parseExpression(builder, symbols);
1927 if (failed(expr))
1928 return failure();
1929 if (block->empty())
1930 return emitError(globalLocation, "global with empty initializer");
1931 if (expr->size() != 1 && (*expr)[0].getType() != globalType.type)
1932 return emitError(
1933 globalLocation,
1934 "initializer result type does not match global declaration type");
1935 ReturnOp::create(builder, globalLocation, *expr);
1936 return success();
1937}
1938
1939template <>
1940LogicalResult WasmBinaryParser::parseSectionItem<WasmSectionType::CODE>(
1941 ParserHead &ph, size_t innerFunctionId) {
1942 unsigned long funcId = innerFunctionId + firstInternalFuncID;
1943 FunctionSymbolRefContainer symRef = symbols.funcSymbols[funcId];
1944 auto funcOp =
1945 dyn_cast<FuncOp>(SymbolTable::lookupSymbolIn(mOp, symRef.symbol));
1946 assert(funcOp);
1947 if (failed(ph.parseCodeFor(funcOp, symbols)))
1948 return failure();
1949 return success();
1950}
1951} // namespace
1952
1953namespace mlir::wasm {
1955 MLIRContext *context) {
1956 WasmBinaryParser wBN{source, context};
1957 ModuleOp mOp = wBN.getModule();
1958 if (mOp)
1959 return {mOp};
1960
1961 return {nullptr};
1962}
1963} // namespace mlir::wasm
return success()
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Type getElementType(Type type)
Determine the element type of type.
static std::string diag(const llvm::Value &value)
memberIdxs push_back(ArrayAttr::get(parser.getContext(), values))
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
#define EXPORT
#define BUILD_NUMERIC_BINOP_INTFP(OP_NAME, PREFIX)
#define BUILD_CONVERT_OP_FOR(DEST_T, WIDTH)
#define APPLY_WASM_SEC_TRANSFORM
#define BUILD_REINTERPRET_PARSER(WIDTH, FP_TYPE)
#define BUILD_CONVERSION_OP(IN_T, OUT_T, SOURCE_OP, TARGET_OP)
#define BUILD_SLICE_EXTEND_PARSER(IT_WIDTH, EXTRACT_WIDTH)
#define BUILD_NUMERIC_UNARY_OP_INT(OP_NAME, PREFIX)
#define BUILD_NUMERIC_BINOP_INT(OP_NAME, PREFIX)
#define BUILD_NUMERIC_BINOP_FP(OP_NAME, PREFIX)
#define BUILD_NUMERIC_UNARY_OP_FP(OP_NAME, PREFIX)
#define mul(a, b)
#define add(a, b)
#define div(a, b)
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition Block.cpp:149
bool empty()
Definition Block.h:148
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition Block.cpp:27
OpListType & getOperations()
Definition Block.h:137
Operation & back()
Definition Block.h:152
BlockArgListType getArguments()
Definition Block.h:87
FloatType getF32Type()
Definition Builders.cpp:43
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition Builders.cpp:76
IntegerType getI64Type()
Definition Builders.cpp:65
IntegerType getI32Type()
Definition Builders.cpp:63
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:262
MLIRContext * getContext() const
Definition Builders.h:56
IntegerAttr getUI32IntegerAttr(uint32_t value)
Definition Builders.cpp:212
FloatType getF64Type()
Definition Builders.cpp:45
HandlerID registerHandler(HandlerTy handler)
Register a new handler for diagnostics to the engine.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
An instance of this location represents a tuple of file, line number, and column number.
Definition Location.h:174
static FileLineColLoc get(StringAttr filename, unsigned line, unsigned column)
Definition Location.cpp:157
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
DiagnosticEngine & getDiagEngine()
Returns the diagnostic engine for this context.
void loadAllAvailableDialects()
Load all dialects available in the registry in this context.
This class helps build Operations.
Definition Builders.h:207
InsertPoint saveInsertionPoint() const
Return a saved insertion point.
Definition Builders.h:385
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:430
Block * getBlock() const
Returns the current block of the builder.
Definition Builders.h:448
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:436
void restoreInsertionPoint(InsertPoint ip)
Restore the insert point to a previously saved point.
Definition Builders.h:390
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition Operation.h:582
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
This class acts as an owning reference to an op, and will automatically destroy the held op on destru...
Definition OwningOpRef.h:29
iterator end()
Definition Region.h:56
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
static StringAttr getSymbolName(Operation *symbol)
Returns the name of the given symbol operation, aborting if no symbol is present.
Type getType() const
Return the type of this value.
Definition Value.h:105
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition Query.cpp:21
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
OwningOpRef< ModuleOp > importWebAssemblyToModule(llvm::SourceMgr &source, MLIRContext *context)
If source contains a valid Wasm binary file, this function returns a a ModuleOp containing the repres...
bool isWasmValueType(::mlir::Type type)
bool isWasmRefType(::mlir::Type type)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
InFlightDiagnostic emitRemark(Location loc)
Utility method to emit a remark message using this location.
static constexpr std::byte memory
static constexpr std::byte table
static constexpr std::byte global
static constexpr std::byte function
static constexpr std::byte memType
static constexpr std::byte typeID
static constexpr std::byte tableType
static constexpr std::byte globalType
static constexpr std::byte globalGet
static constexpr std::byte elseOpCode
static constexpr std::byte promoteF32ToF64
static constexpr std::byte demoteF64ToF32
static constexpr std::byte constI64
static constexpr std::byte constFP64
static constexpr std::byte localTee
static constexpr std::byte ifOpCode
static constexpr std::byte localGet
static constexpr std::byte branchIf
static constexpr std::byte localSet
static constexpr std::byte constI32
static constexpr std::byte constFP32
static constexpr std::byte externRef
static constexpr std::byte i32
static constexpr std::byte funcType
static constexpr std::byte i64
static constexpr std::byte emptyBlockType
static constexpr std::byte funcRef
static constexpr std::byte v128
static constexpr std::byte f64
static constexpr std::byte f32
static constexpr std::byte endByte