22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/DebugLog.h"
24 #include "llvm/Support/Endian.h"
25 #include "llvm/Support/FormatVariadic.h"
26 #include "llvm/Support/LEB128.h"
27 #include "llvm/Support/LogicalResult.h"
33 #define DEBUG_TYPE "wasm-translate"
35 static_assert(CHAR_BIT == 8,
36 "This code expects std::byte to be exactly 8 bits");
43 using section_id_t = uint8_t;
44 enum struct WasmSectionType : section_id_t {
60 constexpr section_id_t highestWasmSectionID{
61 static_cast<section_id_t
>(WasmSectionType::DATACOUNT)};
63 #define APPLY_WASM_SEC_TRANSFORM \
64 WASM_SEC_TRANSFORM(CUSTOM) \
65 WASM_SEC_TRANSFORM(TYPE) \
66 WASM_SEC_TRANSFORM(IMPORT) \
67 WASM_SEC_TRANSFORM(FUNCTION) \
68 WASM_SEC_TRANSFORM(TABLE) \
69 WASM_SEC_TRANSFORM(MEMORY) \
70 WASM_SEC_TRANSFORM(GLOBAL) \
71 WASM_SEC_TRANSFORM(EXPORT) \
72 WASM_SEC_TRANSFORM(START) \
73 WASM_SEC_TRANSFORM(ELEMENT) \
74 WASM_SEC_TRANSFORM(CODE) \
75 WASM_SEC_TRANSFORM(DATA) \
76 WASM_SEC_TRANSFORM(DATACOUNT)
78 template <WasmSectionType>
79 constexpr
const char *wasmSectionName =
"";
81 #define WASM_SEC_TRANSFORM(section) \
83 [[maybe_unused]] constexpr const char \
84 *wasmSectionName<WasmSectionType::section> = #section;
86 #undef WASM_SEC_TRANSFORM
88 constexpr
bool sectionShouldBeUnique(WasmSectionType secType) {
89 return secType != WasmSectionType::CUSTOM;
92 template <std::byte... Bytes>
93 struct ByteSequence {};
96 template <std::
byte Byte>
97 struct UniqueByte : ByteSequence<Byte> {};
99 [[maybe_unused]] constexpr ByteSequence<
104 template <std::byte... allowedFlags>
105 constexpr
bool isValueOneOf(std::byte value,
106 ByteSequence<allowedFlags...> = {}) {
107 return ((value == allowedFlags) | ... |
false);
110 template <std::byte... flags>
111 constexpr
bool isNotIn(std::byte value, ByteSequence<flags...> = {}) {
112 return !isValueOneOf<flags...>(value);
115 struct GlobalTypeRecord {
120 struct TypeIdxRecord {
124 struct SymbolRefContainer {
128 struct GlobalSymbolRefContainer : SymbolRefContainer {
132 struct FunctionSymbolRefContainer : SymbolRefContainer {
133 FunctionType functionType;
137 std::variant<TypeIdxRecord, TableType, LimitType, GlobalTypeRecord>;
139 using parsed_inst_t = FailureOr<SmallVector<Value>>;
141 struct WasmModuleSymbolTables {
148 std::string getNewSymbolName(StringRef prefix,
size_t id)
const {
149 return (prefix + Twine{
id}).str();
152 std::string getNewFuncSymbolName()
const {
153 size_t id = funcSymbols.size();
154 return getNewSymbolName(
"func_",
id);
157 std::string getNewGlobalSymbolName()
const {
158 size_t id = globalSymbols.size();
159 return getNewSymbolName(
"global_",
id);
162 std::string getNewMemorySymbolName()
const {
163 size_t id = memSymbols.size();
164 return getNewSymbolName(
"mem_",
id);
167 std::string getNewTableSymbolName()
const {
168 size_t id = tableSymbols.size();
169 return getNewSymbolName(
"table_",
id);
182 LabelLevelOpInterface levelOp;
186 bool empty()
const {
return values.empty(); }
188 size_t size()
const {
return values.size(); }
198 FailureOr<SmallVector<Value>> popOperands(
TypeRange operandTypes,
209 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
212 LLVM_DUMP_METHOD
void dump()
const;
221 class ExpressionParser {
224 ExpressionParser(ParserHead &parser, WasmModuleSymbolTables
const &symbols,
226 : parser{parser}, symbols{symbols}, locals{initLocal} {}
229 template <std::
byte opCode>
230 inline parsed_inst_t parseSpecificInstruction(
OpBuilder &builder);
232 template <
typename valueT>
235 std::enable_if_t<std::is_arithmetic_v<valueT>> * =
nullptr);
246 template <
typename opcode,
typename valueType,
unsigned int numOperands>
249 std::enable_if_t<std::is_arithmetic_v<valueType>> * =
nullptr);
262 template <
size_t patternBitSize = 0, std::byte highBitPattern = std::byte{0}>
263 inline parsed_inst_t dispatchToInstParser(std::byte opCode,
265 static_assert(patternBitSize <= 8,
266 "PatternBitSize is outside of range of opcode space! "
267 "(expected at most 8 bits)");
268 if constexpr (patternBitSize < 8) {
269 constexpr std::byte bitSelect{1 << (7 - patternBitSize)};
270 constexpr std::byte nextHighBitPatternStem = highBitPattern << 1;
271 constexpr
size_t nextPatternBitSize = patternBitSize + 1;
272 if ((opCode & bitSelect) != std::byte{0})
273 return dispatchToInstParser<nextPatternBitSize,
274 nextHighBitPatternStem | std::byte{1}>(
276 return dispatchToInstParser<nextPatternBitSize, nextHighBitPatternStem>(
279 return parseSpecificInstruction<highBitPattern>(builder);
283 struct ParseResultWithInfo {
285 std::byte endingByte;
289 template <std::
byte ParseEndByte = WasmBinaryEncoding::endByte>
290 parsed_inst_t
parse(
OpBuilder &builder, UniqueByte<ParseEndByte> = {});
292 template <std::byte... ExpressionParseEnd>
293 FailureOr<ParseResultWithInfo>
295 ByteSequence<ExpressionParseEnd...> parsingEndFilters);
297 FailureOr<SmallVector<Value>> popOperands(
TypeRange operandTypes) {
298 return valueStack.popOperands(operandTypes, ¤tOpLoc.value());
301 LogicalResult pushResults(
ValueRange results) {
302 return valueStack.pushResults(results, ¤tOpLoc.value());
308 template <
typename OpToCreate>
309 parsed_inst_t parseSetOrTee(
OpBuilder &);
312 std::optional<Location> currentOpLoc;
314 WasmModuleSymbolTables
const &symbols;
316 ValueStack valueStack;
321 ParserHead(StringRef src, StringAttr name) : head{src}, locName{name} {}
322 ParserHead(ParserHead &&) =
default;
325 ParserHead(ParserHead
const &other) =
default;
328 auto getLocation()
const {
332 FailureOr<StringRef> consumeNBytes(
size_t nBytes) {
333 LDBG() <<
"Consume " << nBytes <<
" bytes";
334 LDBG() <<
" Bytes remaining: " << size();
335 LDBG() <<
" Current offset: " << offset;
337 return emitError(getLocation(),
"trying to extract ")
338 << nBytes <<
"bytes when only " << size() <<
"are available";
340 StringRef res = head.slice(offset, offset + nBytes);
342 LDBG() <<
" Updated offset (+" << nBytes <<
"): " << offset;
346 FailureOr<std::byte> consumeByte() {
347 FailureOr<StringRef> res = consumeNBytes(1);
350 return std::byte{*res->bytes_begin()};
353 template <
typename T>
354 FailureOr<T> parseLiteral();
356 FailureOr<uint32_t> parseVectorSize();
362 inline FailureOr<uint32_t> parseUI32();
363 inline FailureOr<int64_t> parseI64();
366 FailureOr<StringRef> parseName() {
367 FailureOr<uint32_t> size = parseVectorSize();
371 return consumeNBytes(*size);
374 FailureOr<WasmSectionType> parseWasmSectionType() {
375 FailureOr<std::byte>
id = consumeByte();
378 if (std::to_integer<unsigned>(*
id) > highestWasmSectionID)
379 return emitError(getLocation(),
"invalid section ID: ")
380 <<
static_cast<int>(*id);
381 return static_cast<WasmSectionType
>(*id);
384 FailureOr<LimitType> parseLimit(
MLIRContext *ctx) {
387 FailureOr<std::byte> limitHeader = consumeByte();
391 if (isNotIn<WasmLimits::bothLimits, WasmLimits::lowLimitOnly>(*limitHeader))
392 return emitError(limitLocation,
"invalid limit header: ")
393 <<
static_cast<int>(*limitHeader);
394 FailureOr<uint32_t> minParse = parseUI32();
397 std::optional<uint32_t>
max{std::nullopt};
398 if (*limitHeader == WasmLimits::bothLimits) {
399 FailureOr<uint32_t> maxParse = parseUI32();
409 FailureOr<std::byte> typeEncoding = consumeByte();
412 switch (*typeEncoding) {
428 return emitError(typeLoc,
"invalid value type encoding: ")
429 <<
static_cast<int>(*typeEncoding);
433 FailureOr<GlobalTypeRecord> parseGlobalType(
MLIRContext *ctx) {
435 FailureOr<Type> typeParsed = parseValueType(ctx);
439 FailureOr<std::byte> mutSpec = consumeByte();
442 if (isNotIn<WasmGlobalMut::isConst, WasmGlobalMut::isMutable>(*mutSpec))
443 return emitError(mutLoc,
"invalid global mutability specifier: ")
444 <<
static_cast<int>(*mutSpec);
445 return GlobalTypeRecord{*typeParsed, *mutSpec == WasmGlobalMut::isMutable};
448 FailureOr<TupleType> parseResultType(
MLIRContext *ctx) {
449 FailureOr<uint32_t> nParamsParsed = parseVectorSize();
450 if (
failed(nParamsParsed))
452 uint32_t nParams = *nParamsParsed;
454 res.reserve(nParams);
455 for (
size_t i = 0; i < nParams; ++i) {
456 FailureOr<Type> parsedType = parseValueType(ctx);
459 res.push_back(*parsedType);
464 FailureOr<FunctionType> parseFunctionType(
MLIRContext *ctx) {
466 FailureOr<std::byte> funcTypeHeader = consumeByte();
467 if (
failed(funcTypeHeader))
470 return emitError(typeLoc,
"invalid function type header byte. Expecting ")
472 <<
" got " << std::to_integer<unsigned>(*funcTypeHeader);
473 FailureOr<TupleType> inputTypes = parseResultType(ctx);
477 FailureOr<TupleType> resTypes = parseResultType(ctx);
484 FailureOr<TypeIdxRecord> parseTypeIndex() {
485 FailureOr<uint32_t> res = parseUI32();
488 return TypeIdxRecord{*res};
491 FailureOr<TableType> parseTableType(
MLIRContext *ctx) {
492 FailureOr<Type> elmTypeParse = parseValueType(ctx);
495 if (!isWasmRefType(*elmTypeParse))
496 return emitError(getLocation(),
"invalid element type for table");
497 FailureOr<LimitType> limitParse = parseLimit(ctx);
503 FailureOr<ImportDesc> parseImportDesc(
MLIRContext *ctx) {
505 FailureOr<std::byte> importType = consumeByte();
506 auto packager = [](
auto parseResult) -> FailureOr<ImportDesc> {
509 return {*parseResult};
513 switch (*importType) {
515 return packager(parseTypeIndex());
517 return packager(parseTableType(ctx));
519 return packager(parseLimit(ctx));
521 return packager(parseGlobalType(ctx));
523 return emitError(importLoc,
"invalid import type descriptor: ")
524 <<
static_cast<int>(*importType);
528 parsed_inst_t parseExpression(
OpBuilder &builder,
529 WasmModuleSymbolTables
const &symbols,
531 auto eParser = ExpressionParser{*
this, symbols, locals};
532 return eParser.parse(builder);
535 LogicalResult parseCodeFor(FuncOp func,
536 WasmModuleSymbolTables
const &symbols) {
541 assert(func.getBody().getBlocks().size() == 1 &&
542 "Function should only have its default created block at this point");
544 "Only the placeholder return op should be present at this point");
545 auto returnOp = cast<ReturnOp>(&block.
back());
548 FailureOr<uint32_t> codeSizeInBytes = parseUI32();
549 if (
failed(codeSizeInBytes))
551 FailureOr<StringRef> codeContent = consumeNBytes(*codeSizeInBytes);
555 locName.str() +
"::" + func.getSymName());
556 auto cParser = ParserHead{*codeContent, name};
557 FailureOr<uint32_t> localVecSize = cParser.parseVectorSize();
560 OpBuilder builder{&func.getBody().front().back()};
564 uint32_t nVarVec = *localVecSize;
565 for (
size_t i = 0; i < nVarVec; ++i) {
567 FailureOr<uint32_t> nSubVar = cParser.parseUI32();
570 FailureOr<Type> varT = cParser.parseValueType(func->getContext());
573 for (
size_t j = 0;
j < *nSubVar; ++
j) {
574 auto local = builder.
create<LocalOp>(varLoc, *varT);
575 locals.push_back(local.getResult());
578 parsed_inst_t res = cParser.parseExpression(builder, symbols, locals);
583 "unparsed garbage remaining at end of code block");
584 builder.
create<ReturnOp>(func->getLoc(), *res);
589 bool end()
const {
return curHead().empty(); }
591 ParserHead
copy()
const {
return *
this; }
594 StringRef curHead()
const {
return head.drop_front(offset); }
596 FailureOr<std::byte> peek()
const {
600 "trying to peek at next byte, but input stream is empty");
601 return static_cast<std::byte
>(curHead().front());
604 size_t size()
const {
return head.size() - offset; }
608 unsigned anchorOffset{0};
613 FailureOr<float> ParserHead::parseLiteral<float>() {
614 FailureOr<StringRef> bytes = consumeNBytes(4);
617 return llvm::support::endian::read<float>(bytes->bytes_begin(),
618 llvm::endianness::little);
622 FailureOr<double> ParserHead::parseLiteral<double>() {
623 FailureOr<StringRef> bytes = consumeNBytes(8);
626 return llvm::support::endian::read<double>(bytes->bytes_begin(),
627 llvm::endianness::little);
631 FailureOr<uint32_t> ParserHead::parseLiteral<uint32_t>() {
632 char const *error =
nullptr;
634 unsigned encodingSize{0};
635 StringRef src = curHead();
636 uint64_t decoded = llvm::decodeULEB128(src.bytes_begin(), &encodingSize,
637 src.bytes_end(), &error);
642 return emitError(getLocation()) <<
"literal does not fit on 32 bits";
644 res =
static_cast<uint32_t
>(decoded);
645 offset += encodingSize;
650 FailureOr<int32_t> ParserHead::parseLiteral<int32_t>() {
651 char const *error =
nullptr;
653 unsigned encodingSize{0};
654 StringRef src = curHead();
655 int64_t decoded = llvm::decodeSLEB128(src.bytes_begin(), &encodingSize,
656 src.bytes_end(), &error);
661 return emitError(getLocation()) <<
"literal does not fit on 32 bits";
663 res =
static_cast<int32_t
>(decoded);
664 offset += encodingSize;
669 FailureOr<int64_t> ParserHead::parseLiteral<int64_t>() {
670 char const *error =
nullptr;
671 unsigned encodingSize{0};
672 StringRef src = curHead();
673 int64_t res = llvm::decodeSLEB128(src.bytes_begin(), &encodingSize,
674 src.bytes_end(), &error);
678 offset += encodingSize;
682 FailureOr<uint32_t> ParserHead::parseVectorSize() {
683 return parseLiteral<uint32_t>();
686 inline FailureOr<uint32_t> ParserHead::parseUI32() {
687 return parseLiteral<uint32_t>();
690 inline FailureOr<int64_t> ParserHead::parseI64() {
691 return parseLiteral<int64_t>();
694 template <std::
byte opCode>
695 inline parsed_inst_t ExpressionParser::parseSpecificInstruction(
OpBuilder &) {
696 return emitError(*currentOpLoc,
"unknown instruction opcode: ")
697 <<
static_cast<int>(opCode);
700 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
701 void ValueStack::dump()
const {
702 llvm::dbgs() <<
"================= Wasm ValueStack =======================\n";
703 llvm::dbgs() <<
"size: " << size() <<
"\n";
704 llvm::dbgs() <<
"<Top>"
709 size_t stackSize = size();
710 for (
size_t idx = 0; idx < stackSize; idx++) {
711 size_t actualIdx = stackSize - 1 - idx;
713 values[actualIdx].dump();
715 llvm::dbgs() <<
"<Bottom>"
717 llvm::dbgs() <<
"=========================================================\n";
722 LDBG() <<
"Popping from ValueStack\n"
723 <<
" Elements(s) to pop: " << operandTypes.size() <<
"\n"
724 <<
" Current stack size: " << values.size();
725 if (operandTypes.size() > values.size())
727 "stack doesn't contain enough values. trying to get ")
728 << operandTypes.size() <<
" operands on a stack containing only "
729 << values.size() <<
" values.";
730 size_t stackIdxOffset = values.size() - operandTypes.size();
732 res.reserve(operandTypes.size());
733 for (
size_t i{0}; i < operandTypes.size(); ++i) {
734 Value operand = values[i + stackIdxOffset];
736 if (stackType != operandTypes[i])
737 return emitError(*opLoc,
"invalid operand type on stack. expecting ")
738 << operandTypes[i] <<
", value on stack is of type " << stackType
740 LDBG() <<
" POP: " << operand;
741 res.push_back(operand);
743 values.resize(values.size() - operandTypes.size());
744 LDBG() <<
" Updated stack size: " << values.size();
749 LDBG() <<
"Pushing to ValueStack\n"
750 <<
" Elements(s) to push: " << results.size() <<
"\n"
751 <<
" Current stack size: " << values.size();
752 for (
Value val : results) {
753 if (!isWasmValueType(val.getType()))
754 return emitError(*opLoc,
"invalid value type on stack: ")
756 LDBG() <<
" PUSH: " << val;
757 values.push_back(val);
760 LDBG() <<
" Updated stack size: " << values.size();
764 template <std::
byte EndParseByte>
766 UniqueByte<EndParseByte> endByte) {
767 auto res =
parse(builder, ByteSequence<EndParseByte>{});
770 return res->opResults;
773 template <std::byte... ExpressionParseEnd>
774 FailureOr<ExpressionParser::ParseResultWithInfo>
776 ByteSequence<ExpressionParseEnd...> parsingEndFilters) {
779 currentOpLoc = parser.getLocation();
780 FailureOr<std::byte> opCode = parser.consumeByte();
783 if (isValueOneOf(*opCode, parsingEndFilters))
784 return {{res, *opCode}};
785 parsed_inst_t resParsed;
786 resParsed = dispatchToInstParser(*opCode, builder);
789 std::swap(res, *resParsed);
790 if (
failed(pushResults(res)))
796 inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
798 FailureOr<uint32_t>
id = parser.parseLiteral<uint32_t>();
802 if (*
id >= locals.size())
803 return emitError(instLoc,
"invalid local index. function has ")
804 << locals.size() <<
" accessible locals, received index " << *id;
805 return {{builder.
create<LocalGetOp>(instLoc, locals[*id]).getResult()}};
809 inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
811 FailureOr<uint32_t>
id = parser.parseLiteral<uint32_t>();
815 if (*
id >= symbols.globalSymbols.size())
816 return emitError(instLoc,
"invalid global index. function has ")
817 << symbols.globalSymbols.size()
818 <<
" accessible globals, received index " << *id;
819 GlobalSymbolRefContainer globalVar = symbols.globalSymbols[*id];
820 auto globalOp = builder.
create<GlobalGetOp>(instLoc, globalVar.globalType,
826 template <
typename OpToCreate>
827 parsed_inst_t ExpressionParser::parseSetOrTee(
OpBuilder &builder) {
828 FailureOr<uint32_t>
id = parser.parseLiteral<uint32_t>();
831 if (*
id >= locals.size())
832 return emitError(*currentOpLoc,
"invalid local index. function has ")
833 << locals.size() <<
" accessible locals, received index " << *id;
834 if (valueStack.empty())
837 "invalid stack access, trying to access a value on an empty stack.");
843 builder.
create<OpToCreate>(*currentOpLoc, locals[*id], poppedOp->front())
848 inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
850 return parseSetOrTee<LocalSetOp>(builder);
854 inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
856 return parseSetOrTee<LocalTeeOp>(builder);
859 template <
typename T>
873 [[maybe_unused]]
inline Type buildLiteralType<uint32_t>(
OpBuilder &builder) {
878 [[maybe_unused]]
inline Type buildLiteralType<uint64_t>(
OpBuilder &builder) {
892 template <
typename ValT,
893 typename E = std::enable_if_t<std::is_arithmetic_v<ValT>>>
896 template <
typename ValT>
897 struct AttrHolder<ValT, std::enable_if_t<std::is_integral_v<ValT>>> {
898 using type = IntegerAttr;
901 template <
typename ValT>
902 struct AttrHolder<ValT, std::enable_if_t<std::is_floating_point_v<ValT>>> {
903 using type = FloatAttr;
906 template <
typename ValT>
907 using attr_holder_t =
typename AttrHolder<ValT>::type;
909 template <
typename ValT,
910 typename EnableT = std::enable_if_t<std::is_arithmetic_v<ValT>>>
911 attr_holder_t<ValT> buildLiteralAttr(
OpBuilder &builder, ValT val) {
915 template <
typename valueT>
916 parsed_inst_t ExpressionParser::parseConstInst(
917 OpBuilder &builder, std::enable_if_t<std::is_arithmetic_v<valueT>> *) {
918 auto parsedConstant = parser.parseLiteral<valueT>();
919 if (
failed(parsedConstant))
922 ConstOp::create(builder, *currentOpLoc,
923 buildLiteralAttr<valueT>(builder, *parsedConstant));
924 return {{constOp.getResult()}};
928 inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
930 return parseConstInst<int32_t>(builder);
934 inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
936 return parseConstInst<int64_t>(builder);
940 inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
942 return parseConstInst<float>(builder);
946 inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
948 return parseConstInst<double>(builder);
951 template <
typename opcode,
typename valueType,
unsigned int numOperands>
952 inline parsed_inst_t ExpressionParser::buildNumericOp(
953 OpBuilder &builder, std::enable_if_t<std::is_arithmetic_v<valueType>> *) {
954 auto ty = buildLiteralType<valueType>(builder);
955 LDBG() <<
"*** buildNumericOp: numOperands = " << numOperands
956 <<
", type = " << ty <<
" ***";
958 tysToPop.resize(numOperands);
959 std::fill(tysToPop.begin(), tysToPop.end(), ty);
960 auto operands = popOperands(tysToPop);
963 auto op = builder.
create<opcode>(*currentOpLoc, *operands).getResult();
964 LDBG() <<
"Built operation: " << op;
969 #define BUILD_NUMERIC_OP(OP_NAME, N_ARGS, PREFIX, SUFFIX, TYPE) \
971 inline parsed_inst_t ExpressionParser::parseSpecificInstruction< \
972 WasmBinaryEncoding::OpCode::PREFIX##SUFFIX>(OpBuilder & builder) { \
973 return buildNumericOp<OP_NAME, TYPE, N_ARGS>(builder); \
977 #define BUILD_NUMERIC_BINOP_INT(OP_NAME, PREFIX) \
978 BUILD_NUMERIC_OP(OP_NAME, 2, PREFIX, I32, int32_t) \
979 BUILD_NUMERIC_OP(OP_NAME, 2, PREFIX, I64, int64_t)
982 #define BUILD_NUMERIC_BINOP_FP(OP_NAME, PREFIX) \
983 BUILD_NUMERIC_OP(OP_NAME, 2, PREFIX, F32, float) \
984 BUILD_NUMERIC_OP(OP_NAME, 2, PREFIX, F64, double)
987 #define BUILD_NUMERIC_BINOP_INTFP(OP_NAME, PREFIX) \
988 BUILD_NUMERIC_BINOP_INT(OP_NAME, PREFIX) \
989 BUILD_NUMERIC_BINOP_FP(OP_NAME, PREFIX)
992 #define BUILD_NUMERIC_UNARY_OP_INT(OP_NAME, PREFIX) \
993 BUILD_NUMERIC_OP(OP_NAME, 1, PREFIX, I32, int32_t) \
994 BUILD_NUMERIC_OP(OP_NAME, 1, PREFIX, I64, int64_t)
997 #define BUILD_NUMERIC_UNARY_OP_FP(OP_NAME, PREFIX) \
998 BUILD_NUMERIC_OP(OP_NAME, 1, PREFIX, F32, float) \
999 BUILD_NUMERIC_OP(OP_NAME, 1, PREFIX, F64, double)
1031 #undef BUILD_NUMERIC_BINOP_FP
1032 #undef BUILD_NUMERIC_BINOP_INT
1033 #undef BUILD_NUMERIC_BINOP_INTFP
1034 #undef BUILD_NUMERIC_UNARY_OP_FP
1035 #undef BUILD_NUMERIC_UNARY_OP_INT
1036 #undef BUILD_NUMERIC_OP
1037 #undef BUILD_NUMERIC_CAST_OP
1039 class WasmBinaryParser {
1041 struct SectionRegistry {
1042 using section_location_t = StringRef;
1044 std::array<SmallVector<section_location_t>, highestWasmSectionID + 1>
1047 template <WasmSectionType SecType>
1048 std::conditional_t<sectionShouldBeUnique(SecType),
1049 std::optional<section_location_t>,
1051 getContentForSection()
const {
1052 constexpr
auto idx =
static_cast<size_t>(SecType);
1053 if constexpr (sectionShouldBeUnique(SecType)) {
1054 return registry[idx].empty() ? std::nullopt
1055 : std::make_optional(registry[idx][0]);
1057 return registry[idx];
1061 bool hasSection(WasmSectionType secType)
const {
1062 return !registry[
static_cast<size_t>(secType)].empty();
1070 LogicalResult registerSection(WasmSectionType secType,
1071 section_location_t location,
Location loc) {
1072 if (sectionShouldBeUnique(secType) && hasSection(secType))
1074 "trying to add a second instance of unique section");
1076 registry[
static_cast<size_t>(secType)].push_back(location);
1077 emitRemark(loc,
"Adding section with section ID ")
1078 <<
static_cast<uint8_t
>(secType);
1082 LogicalResult populateFromBody(ParserHead ph) {
1085 FailureOr<WasmSectionType> secType = ph.parseWasmSectionType();
1089 FailureOr<uint32_t> secSizeParsed = ph.parseLiteral<uint32_t>();
1090 if (
failed(secSizeParsed))
1093 uint32_t secSize = *secSizeParsed;
1094 FailureOr<StringRef> sectionContent = ph.consumeNBytes(secSize);
1095 if (
failed(sectionContent))
1098 LogicalResult registration =
1099 registerSection(*secType, *sectionContent, sectionLoc);
1101 if (
failed(registration))
1108 auto getLocation(
int offset = 0)
const {
1112 template <WasmSectionType>
1113 LogicalResult parseSectionItem(ParserHead &,
size_t);
1115 template <WasmSectionType section>
1116 LogicalResult parseSection() {
1117 auto secName = std::string{wasmSectionName<section>};
1118 auto sectionNameAttr =
1120 unsigned offset = 0;
1121 auto getLocation = [sectionNameAttr, &offset]() {
1124 auto secContent = registry.getContentForSection<section>();
1126 LDBG() << secName <<
" section is not present in file.";
1130 auto secSrc = secContent.value();
1131 ParserHead ph{secSrc, sectionNameAttr};
1132 FailureOr<uint32_t> nElemsParsed = ph.parseVectorSize();
1133 if (
failed(nElemsParsed))
1135 uint32_t nElems = *nElemsParsed;
1136 LDBG() <<
"starting to parse " << nElems <<
" items for section "
1138 for (
size_t i = 0; i < nElems; ++i) {
1139 if (
failed(parseSectionItem<section>(ph, i)))
1144 return emitError(getLocation(),
"unparsed garbage at end of section ")
1150 LogicalResult visitImport(
Location loc, StringRef moduleName,
1151 StringRef importName, TypeIdxRecord tid) {
1153 if (tid.id >= symbols.moduleFuncTypes.size())
1154 return emitError(loc,
"invalid type id: ")
1155 << tid.id <<
". Only " << symbols.moduleFuncTypes.size()
1156 <<
" type registration.";
1157 FunctionType type = symbols.moduleFuncTypes[tid.id];
1158 std::string symbol = symbols.getNewFuncSymbolName();
1159 auto funcOp = FuncImportOp::create(builder, loc, symbol, moduleName,
1162 return funcOp.verify();
1166 LogicalResult visitImport(
Location loc, StringRef moduleName,
1167 StringRef importName, LimitType limitType) {
1168 std::string symbol = symbols.getNewMemorySymbolName();
1169 auto memOp = MemImportOp::create(builder, loc, symbol, moduleName,
1170 importName, limitType);
1172 return memOp.verify();
1176 LogicalResult visitImport(
Location loc, StringRef moduleName,
1177 StringRef importName, TableType tableType) {
1178 std::string symbol = symbols.getNewTableSymbolName();
1179 auto tableOp = TableImportOp::create(builder, loc, symbol, moduleName,
1180 importName, tableType);
1182 return tableOp.verify();
1186 LogicalResult visitImport(
Location loc, StringRef moduleName,
1187 StringRef importName, GlobalTypeRecord globalType) {
1188 std::string symbol = symbols.getNewGlobalSymbolName();
1190 GlobalImportOp::create(builder, loc, symbol, moduleName, importName,
1191 globalType.type, globalType.isMutable);
1192 symbols.globalSymbols.push_back(
1194 return giOp.verify();
1205 WasmBinaryParser(llvm::SourceMgr &sourceMgr,
MLIRContext *ctx)
1206 : builder{ctx}, ctx{ctx} {
1210 if (sourceMgr.getNumBuffers() != 1) {
1214 uint32_t sourceBufId = sourceMgr.getMainFileID();
1215 StringRef source = sourceMgr.getMemoryBuffer(sourceBufId)->getBuffer();
1217 ctx, sourceMgr.getMemoryBuffer(sourceBufId)->getBufferIdentifier());
1219 auto parser = ParserHead{source, srcName};
1220 auto const wasmHeader = StringRef{
"\0asm", 4};
1222 FailureOr<StringRef> magic = parser.consumeNBytes(wasmHeader.size());
1223 if (
failed(magic) || magic->compare(wasmHeader)) {
1224 emitError(magicLoc,
"source file does not contain valid Wasm header.");
1227 auto const expectedVersionString = StringRef{
"\1\0\0\0", 4};
1229 FailureOr<StringRef> version =
1230 parser.consumeNBytes(expectedVersionString.size());
1233 if (version->compare(expectedVersionString)) {
1235 "unsupported Wasm version. only version 1 is supported");
1238 LogicalResult fillRegistry = registry.populateFromBody(parser.copy());
1239 if (
failed(fillRegistry))
1242 mOp = ModuleOp::create(builder, getLocation());
1244 LogicalResult parsingTypes = parseSection<WasmSectionType::TYPE>();
1245 if (
failed(parsingTypes))
1248 LogicalResult parsingImports = parseSection<WasmSectionType::IMPORT>();
1249 if (
failed(parsingImports))
1252 firstInternalFuncID = symbols.funcSymbols.size();
1254 LogicalResult parsingFunctions = parseSection<WasmSectionType::FUNCTION>();
1255 if (
failed(parsingFunctions))
1258 LogicalResult parsingTables = parseSection<WasmSectionType::TABLE>();
1259 if (
failed(parsingTables))
1262 LogicalResult parsingMems = parseSection<WasmSectionType::MEMORY>();
1266 LogicalResult parsingGlobals = parseSection<WasmSectionType::GLOBAL>();
1267 if (
failed(parsingGlobals))
1270 LogicalResult parsingCode = parseSection<WasmSectionType::CODE>();
1274 LogicalResult parsingExports = parseSection<WasmSectionType::EXPORT>();
1275 if (
failed(parsingExports))
1279 LDBG() <<
"WASM Imports:"
1281 <<
" - Num functions: " << symbols.funcSymbols.size() <<
"\n"
1282 <<
" - Num globals: " << symbols.globalSymbols.size() <<
"\n"
1283 <<
" - Num memories: " << symbols.memSymbols.size() <<
"\n"
1284 <<
" - Num tables: " << symbols.tableSymbols.size();
1287 ModuleOp getModule() {
1296 mlir::StringAttr srcName;
1298 WasmModuleSymbolTables symbols;
1301 SectionRegistry registry;
1302 size_t firstInternalFuncID{0};
1308 WasmBinaryParser::parseSectionItem<WasmSectionType::IMPORT>(ParserHead &ph,
1311 auto moduleName = ph.parseName();
1315 auto importName = ph.parseName();
1319 FailureOr<ImportDesc>
import = ph.parseImportDesc(ctx);
1324 [
this, importLoc, &moduleName, &importName](
auto import) {
1325 return visitImport(importLoc, *moduleName, *importName,
import);
1332 WasmBinaryParser::parseSectionItem<WasmSectionType::EXPORT>(ParserHead &ph,
1336 auto exportName = ph.parseName();
1340 FailureOr<std::byte> opcode = ph.consumeByte();
1344 FailureOr<uint32_t> idx = ph.parseLiteral<uint32_t>();
1348 using SymbolRefDesc = std::variant<SmallVector<SymbolRefContainer>,
1352 SymbolRefDesc currentSymbolList;
1353 std::string symbolType =
"";
1356 symbolType =
"function";
1357 currentSymbolList = symbols.funcSymbols;
1360 symbolType =
"table";
1361 currentSymbolList = symbols.tableSymbols;
1364 symbolType =
"memory";
1365 currentSymbolList = symbols.memSymbols;
1368 symbolType =
"global";
1369 currentSymbolList = symbols.globalSymbols;
1372 return emitError(exportLoc,
"invalid value for export type: ")
1373 << std::to_integer<unsigned>(*opcode);
1377 [&](
const auto &list) -> FailureOr<FlatSymbolRefAttr> {
1378 if (*idx > list.size()) {
1382 "trying to export {0} {1} which is undefined in this scope",
1386 return list[*idx].symbol;
1390 if (
failed(currentSymbol))
1401 WasmBinaryParser::parseSectionItem<WasmSectionType::TABLE>(ParserHead &ph,
1404 FailureOr<TableType> tableType = ph.parseTableType(ctx);
1407 LDBG() <<
" Parsed table description: " << *tableType;
1408 StringAttr symbol = builder.
getStringAttr(symbols.getNewTableSymbolName());
1410 TableOp::create(builder, opLocation, symbol.strref(), *tableType);
1417 WasmBinaryParser::parseSectionItem<WasmSectionType::FUNCTION>(ParserHead &ph,
1420 auto typeIdxParsed = ph.parseLiteral<uint32_t>();
1421 if (
failed(typeIdxParsed))
1423 uint32_t typeIdx = *typeIdxParsed;
1424 if (typeIdx >= symbols.moduleFuncTypes.size())
1425 return emitError(getLocation(),
"invalid type index: ") << typeIdx;
1426 std::string symbol = symbols.getNewFuncSymbolName();
1428 FuncOp::create(builder, opLoc, symbol, symbols.moduleFuncTypes[typeIdx]);
1429 Block *block = funcOp.addEntryBlock();
1432 ReturnOp::create(builder, opLoc);
1433 symbols.funcSymbols.push_back(
1435 symbols.moduleFuncTypes[typeIdx]});
1436 return funcOp.verify();
1441 WasmBinaryParser::parseSectionItem<WasmSectionType::TYPE>(ParserHead &ph,
1443 FailureOr<FunctionType> funcType = ph.parseFunctionType(ctx);
1446 LDBG() <<
"Parsed function type " << *funcType;
1447 symbols.moduleFuncTypes.push_back(*funcType);
1453 WasmBinaryParser::parseSectionItem<WasmSectionType::MEMORY>(ParserHead &ph,
1456 FailureOr<LimitType> memory = ph.parseLimit(ctx);
1460 LDBG() <<
" Registering memory " << *memory;
1461 std::string symbol = symbols.getNewMemorySymbolName();
1462 auto memOp = MemOp::create(builder, opLocation, symbol, *memory);
1469 WasmBinaryParser::parseSectionItem<WasmSectionType::GLOBAL>(ParserHead &ph,
1472 auto globalTypeParsed = ph.parseGlobalType(ctx);
1473 if (
failed(globalTypeParsed))
1476 GlobalTypeRecord globalType = *globalTypeParsed;
1477 auto symbol = builder.
getStringAttr(symbols.getNewGlobalSymbolName());
1478 auto globalOp = builder.
create<wasmssa::GlobalOp>(
1479 globalLocation, symbol, globalType.type, globalType.isMutable);
1480 symbols.globalSymbols.push_back(
1485 parsed_inst_t expr = ph.parseExpression(builder, symbols);
1489 return emitError(globalLocation,
"global with empty initializer");
1490 if (expr->size() != 1 && (*expr)[0].getType() != globalType.type)
1493 "initializer result type does not match global declaration type");
1494 builder.
create<ReturnOp>(globalLocation, *expr);
1499 LogicalResult WasmBinaryParser::parseSectionItem<WasmSectionType::CODE>(
1500 ParserHead &ph,
size_t innerFunctionId) {
1501 unsigned long funcId = innerFunctionId + firstInternalFuncID;
1502 FunctionSymbolRefContainer symRef = symbols.funcSymbols[funcId];
1506 if (
failed(ph.parseCodeFor(funcOp, symbols)))
1515 WasmBinaryParser wBN{source, context};
1516 ModuleOp mOp = wBN.getModule();
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
static Type getElementType(Type type)
Determine the element type of type.
static std::string diag(const llvm::Value &value)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
#define BUILD_NUMERIC_BINOP_INTFP(OP_NAME, PREFIX)
#define APPLY_WASM_SEC_TRANSFORM
#define BUILD_NUMERIC_UNARY_OP_INT(OP_NAME, PREFIX)
#define BUILD_NUMERIC_BINOP_INT(OP_NAME, PREFIX)
#define BUILD_NUMERIC_BINOP_FP(OP_NAME, PREFIX)
#define BUILD_NUMERIC_UNARY_OP_FP(OP_NAME, PREFIX)
Block represents an ordered list of Operations.
OpListType & getOperations()
BlockArgListType getArguments()
StringAttr getStringAttr(const Twine &bytes)
HandlerID registerHandler(HandlerTy handler)
Register a new handler for diagnostics to the engine.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
An instance of this location represents a tuple of file, line number, and column number.
static FileLineColLoc get(StringAttr filename, unsigned line, unsigned column)
A symbol reference with a reference path containing a single element.
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
DiagnosticEngine & getDiagEngine()
Returns the diagnostic engine for this context.
void loadAllAvailableDialects()
Load all dialects available in the registry in this context.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
void erase()
Remove this operation from its parent block and delete it.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
static void setSymbolVisibility(Operation *symbol, Visibility vis)
Sets the visibility of the given symbol operation.
@ Public
The symbol is public and may be referenced anywhere internal or external to the visible references in...
LogicalResult rename(StringAttr from, StringAttr to)
Renames the given op or the op refered to by the given name to the given new name and updates the sym...
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
static StringAttr getSymbolName(Operation *symbol)
Returns the name of the given symbol operation, aborting if no symbol is present.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
DynamicAPInt floor(const Fraction &f)
DynamicAPInt ceil(const Fraction &f)
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
OwningOpRef< ModuleOp > importWebAssemblyToModule(llvm::SourceMgr &source, MLIRContext *context)
If source contains a valid Wasm binary file, this function returns a a ModuleOp containing the repres...
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
InFlightDiagnostic emitRemark(Location loc)
Utility method to emit a remark message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
static constexpr std::byte memory
static constexpr std::byte table
static constexpr std::byte global
static constexpr std::byte function
Byte encodings describing the mutability of globals.
static constexpr std::byte memType
static constexpr std::byte typeID
static constexpr std::byte tableType
static constexpr std::byte globalType
static constexpr std::byte globalGet
static constexpr std::byte constI64
static constexpr std::byte constFP64
static constexpr std::byte localTee
static constexpr std::byte localGet
static constexpr std::byte localSet
static constexpr std::byte constI32
static constexpr std::byte constFP32
static constexpr std::byte externRef
static constexpr std::byte i32
static constexpr std::byte funcType
static constexpr std::byte i64
static constexpr std::byte funcRef
static constexpr std::byte v128
static constexpr std::byte f64
static constexpr std::byte f32
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.