20#include "llvm/ADT/ArrayRef.h"
21#include "llvm/ADT/ScopeExit.h"
22#include "llvm/ADT/StringExtras.h"
23#include "llvm/ADT/StringRef.h"
24#include "llvm/Support/Endian.h"
25#include "llvm/Support/MemoryBufferRef.h"
26#include "llvm/Support/SourceMgr.h"
36#define DEBUG_TYPE "mlir-bytecode-reader"
48 return "AttrType (2)";
50 return "AttrTypeOffset (3)";
54 return "Resource (5)";
56 return "ResourceOffset (6)";
58 return "DialectVersions (7)";
60 return "Properties (8)";
62 return (
"Unknown (" + Twine(
static_cast<unsigned>(sectionID)) +
")").str();
82 llvm_unreachable(
"unknown section ID");
93 explicit EncodingReader(ArrayRef<uint8_t> contents, Location fileLoc)
94 : buffer(contents), dataIt(buffer.begin()), fileLoc(fileLoc) {}
95 explicit EncodingReader(StringRef contents, Location fileLoc)
96 : EncodingReader({
reinterpret_cast<const uint8_t *
>(contents.data()),
101 bool empty()
const {
return dataIt == buffer.end(); }
104 size_t size()
const {
return buffer.end() - dataIt; }
107 LogicalResult alignTo(
unsigned alignment) {
108 if (!llvm::isPowerOf2_32(alignment))
109 return emitError(
"expected alignment to be a power-of-two");
111 auto isUnaligned = [&](
const uint8_t *ptr) {
112 return ((uintptr_t)ptr & (alignment - 1)) != 0;
119 while (isUnaligned(dataIt)) {
121 if (
failed(parseByte(padding)))
124 return emitError(
"expected alignment byte (0xCB), but got: '0x" +
125 llvm::utohexstr(padding) +
"'");
131 if (LLVM_UNLIKELY(isUnaligned(dataIt))) {
132 return emitError(
"expected data iterator aligned to ", alignment,
133 ", but got pointer: '0x" +
134 llvm::utohexstr((uintptr_t)dataIt) +
"'");
141 template <
typename... Args>
142 InFlightDiagnostic
emitError(Args &&...args)
const {
143 return ::emitError(fileLoc).
append(std::forward<Args>(args)...);
145 InFlightDiagnostic
emitError()
const { return ::emitError(fileLoc); }
148 template <
typename T>
149 LogicalResult parseByte(T &value) {
151 return emitError(
"attempting to parse a byte at the end of the bytecode");
152 value =
static_cast<T
>(*dataIt++);
156 LogicalResult parseBytes(
size_t length, ArrayRef<uint8_t> &
result) {
157 if (length > size()) {
158 return emitError(
"attempting to parse ", length,
" bytes when only ",
161 result = {dataIt, length};
167 LogicalResult parseBytes(
size_t length, uint8_t *
result) {
168 if (length > size()) {
169 return emitError(
"attempting to parse ", length,
" bytes when only ",
172 memcpy(
result, dataIt, length);
179 LogicalResult parseBlobAndAlignment(ArrayRef<uint8_t> &data,
180 uint64_t &alignment) {
182 if (
failed(parseVarInt(alignment)) ||
failed(parseVarInt(dataSize)) ||
183 failed(alignTo(alignment)))
185 return parseBytes(dataSize, data);
195 LogicalResult parseVarInt(uint64_t &
result) {
202 if (LLVM_LIKELY(
result & 1)) {
210 if (LLVM_UNLIKELY(
result == 0)) {
211 llvm::support::ulittle64_t resultLE;
212 if (
failed(parseBytes(
sizeof(resultLE),
213 reinterpret_cast<uint8_t *
>(&resultLE))))
218 return parseMultiByteVarInt(
result);
224 LogicalResult parseSignedVarInt(uint64_t &
result) {
234 LogicalResult parseVarIntWithFlag(uint64_t &
result,
bool &flag) {
243 LogicalResult skipBytes(
size_t length) {
244 if (length > size()) {
245 return emitError(
"attempting to skip ", length,
" bytes when only ",
254 LogicalResult parseNullTerminatedString(StringRef &
result) {
255 const char *startIt = (
const char *)dataIt;
256 const char *nulIt = (
const char *)memchr(startIt, 0, size());
259 "malformed null-terminated string, no null character found");
261 result = StringRef(startIt, nulIt - startIt);
262 dataIt = (
const uint8_t *)nulIt + 1;
267 using ValidateAlignmentFn =
function_ref<LogicalResult(
unsigned alignment)>;
272 ValidateAlignmentFn alignmentValidator,
273 ArrayRef<uint8_t> §ionData) {
274 uint8_t sectionIDAndHasAlignment;
276 if (
failed(parseByte(sectionIDAndHasAlignment)) ||
277 failed(parseVarInt(length)))
284 bool hasAlignment = sectionIDAndHasAlignment & 0b10000000;
289 return emitError(
"invalid section ID: ",
unsigned(sectionID));
295 if (
failed(parseVarInt(alignment)))
330 if (
failed(alignmentValidator(alignment)))
331 return emitError(
"failed to align section ID: ",
unsigned(sectionID));
334 if (
failed(alignTo(alignment)))
339 return parseBytes(
static_cast<size_t>(length), sectionData);
342 Location getLoc()
const {
return fileLoc; }
351 LLVM_ATTRIBUTE_NOINLINE LogicalResult parseMultiByteVarInt(uint64_t &
result) {
357 uint32_t numBytes = llvm::countr_zero<uint32_t>(
result);
358 assert(numBytes > 0 && numBytes <= 7 &&
359 "unexpected number of trailing zeros in varint encoding");
362 llvm::support::ulittle64_t resultLE(
result);
364 parseBytes(numBytes,
reinterpret_cast<uint8_t *
>(&resultLE) + 1)))
369 result = resultLE >> (numBytes + 1);
374 ArrayRef<uint8_t> buffer;
377 const uint8_t *dataIt;
388template <
typename RangeT,
typename T>
389static LogicalResult
resolveEntry(EncodingReader &reader, RangeT &entries,
390 uint64_t
index, T &entry,
391 StringRef entryStr) {
392 if (
index >= entries.size())
393 return reader.emitError(
"invalid ", entryStr,
" index: ",
index);
396 if constexpr (std::is_convertible_v<llvm::detail::ValueOfRange<RangeT>, T>)
397 entry = entries[
index];
399 entry = &entries[
index];
404template <
typename RangeT,
typename T>
405static LogicalResult
parseEntry(EncodingReader &reader, RangeT &entries,
406 T &entry, StringRef entryStr) {
408 if (failed(reader.parseVarInt(entryIdx)))
410 return resolveEntry(reader, entries, entryIdx, entry, entryStr);
420class StringSectionReader {
423 LogicalResult
initialize(Location fileLoc, ArrayRef<uint8_t> sectionData);
427 LogicalResult parseString(EncodingReader &reader, StringRef &
result)
const {
434 LogicalResult parseStringWithFlag(EncodingReader &reader, StringRef &
result,
437 if (
failed(reader.parseVarIntWithFlag(entryIdx, flag)))
439 return parseStringAtIndex(reader, entryIdx,
result);
444 LogicalResult parseStringAtIndex(EncodingReader &reader, uint64_t index,
445 StringRef &
result)
const {
451 SmallVector<StringRef> strings;
455LogicalResult StringSectionReader::initialize(
Location fileLoc,
457 EncodingReader stringReader(sectionData, fileLoc);
461 if (
failed(stringReader.parseVarInt(numStrings)))
463 strings.resize(numStrings);
467 size_t stringDataEndOffset = sectionData.size();
468 for (StringRef &
string : llvm::reverse(strings)) {
470 if (
failed(stringReader.parseVarInt(stringSize)))
472 if (stringDataEndOffset < stringSize) {
473 return stringReader.emitError(
474 "string size exceeds the available data size");
478 size_t stringOffset = stringDataEndOffset - stringSize;
480 reinterpret_cast<const char *
>(sectionData.data() + stringOffset),
482 stringDataEndOffset = stringOffset;
487 if ((sectionData.size() - stringReader.size()) != stringDataEndOffset) {
488 return stringReader.emitError(
"unexpected trailing data between the "
489 "offsets for strings and their data");
502struct BytecodeDialect {
507 LogicalResult
load(
const DialectReader &reader, MLIRContext *ctx);
511 Dialect *getLoadedDialect()
const {
513 "expected `load` to be invoked before `getLoadedDialect`");
520 std::optional<Dialect *> dialect;
525 const BytecodeDialectInterface *
interface =
nullptr;
531 ArrayRef<uint8_t> versionBuffer;
534 std::unique_ptr<DialectVersion> loadedVersion;
538struct BytecodeOperationName {
539 BytecodeOperationName(BytecodeDialect *dialect, StringRef name,
540 std::optional<bool> wasRegistered)
541 : dialect(dialect), name(name), wasRegistered(wasRegistered) {}
545 std::optional<OperationName> opName;
548 BytecodeDialect *dialect;
555 std::optional<bool> wasRegistered;
561 EncodingReader &reader,
563 function_ref<LogicalResult(BytecodeDialect *)> entryCallback) {
565 std::unique_ptr<BytecodeDialect> *dialect;
566 if (failed(
parseEntry(reader, dialects, dialect,
"dialect")))
569 if (failed(reader.parseVarInt(numEntries)))
572 for (uint64_t i = 0; i < numEntries; ++i)
573 if (failed(entryCallback(dialect->get())))
584class ResourceSectionReader {
589 MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
590 StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData,
591 ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader,
592 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef);
595 LogicalResult parseResourceHandle(EncodingReader &reader,
596 AsmDialectResourceHandle &
result)
const {
602 SmallVector<AsmDialectResourceHandle> dialectResources;
603 llvm::StringMap<std::string> dialectResourceHandleRenamingMap;
606class ParsedResourceEntry :
public AsmParsedResourceEntry {
609 EncodingReader &reader, StringSectionReader &stringReader,
610 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef)
611 : key(key), kind(kind), reader(reader), stringReader(stringReader),
612 bufferOwnerRef(bufferOwnerRef) {}
613 ~ParsedResourceEntry()
override =
default;
615 StringRef getKey() const final {
return key; }
617 InFlightDiagnostic
emitError() const final {
return reader.emitError(); }
621 FailureOr<bool> parseAsBool() const final {
622 if (kind != AsmResourceEntryKind::Bool)
623 return emitError() <<
"expected a bool resource entry, but found a "
624 <<
toString(kind) <<
" entry instead";
627 if (
failed(reader.parseByte(value)))
631 FailureOr<std::string> parseAsString() const final {
632 if (kind != AsmResourceEntryKind::String)
633 return emitError() <<
"expected a string resource entry, but found a "
634 <<
toString(kind) <<
" entry instead";
637 if (
failed(stringReader.parseString(reader,
string)))
642 FailureOr<AsmResourceBlob>
643 parseAsBlob(BlobAllocatorFn allocator)
const final {
644 if (kind != AsmResourceEntryKind::Blob)
645 return emitError() <<
"expected a blob resource entry, but found a "
646 <<
toString(kind) <<
" entry instead";
648 ArrayRef<uint8_t> data;
650 if (
failed(reader.parseBlobAndAlignment(data, alignment)))
655 if (bufferOwnerRef) {
656 ArrayRef<char> charData(
reinterpret_cast<const char *
>(data.data()),
664 [bufferOwnerRef = bufferOwnerRef](
void *,
size_t,
size_t) {});
669 AsmResourceBlob blob = allocator(data.size(), alignment);
670 assert(llvm::isAddrAligned(llvm::Align(alignment), blob.
getData().data()) &&
672 "blob allocator did not return a properly aligned address");
680 EncodingReader &reader;
681 StringSectionReader &stringReader;
682 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef;
689 EncodingReader &offsetReader, EncodingReader &resourceReader,
690 StringSectionReader &stringReader, T *handler,
691 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef,
693 function_ref<LogicalResult(StringRef)> processKeyFn = {}) {
694 uint64_t numResources;
695 if (
failed(offsetReader.parseVarInt(numResources)))
698 for (uint64_t i = 0; i < numResources; ++i) {
701 uint64_t resourceOffset;
702 ArrayRef<uint8_t> data;
703 if (
failed(stringReader.parseString(offsetReader, key)) ||
704 failed(offsetReader.parseVarInt(resourceOffset)) ||
705 failed(offsetReader.parseByte(kind)) ||
706 failed(resourceReader.parseBytes(resourceOffset, data)))
710 if ((processKeyFn &&
failed(processKeyFn(key))))
715 if (allowEmpty && data.empty())
723 EncodingReader entryReader(data, fileLoc);
725 ParsedResourceEntry entry(key, kind, entryReader, stringReader,
727 if (
failed(handler->parseResource(entry)))
729 if (!entryReader.empty()) {
730 return entryReader.emitError(
731 "unexpected trailing bytes in resource entry '", key,
"'");
737LogicalResult ResourceSectionReader::initialize(
738 Location fileLoc,
const ParserConfig &
config,
739 MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
740 StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData,
741 ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader,
742 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) {
743 EncodingReader resourceReader(sectionData, fileLoc);
744 EncodingReader offsetReader(offsetSectionData, fileLoc);
747 uint64_t numExternalResourceGroups;
748 if (
failed(offsetReader.parseVarInt(numExternalResourceGroups)))
753 auto parseGroup = [&](
auto *handler,
bool allowEmpty =
false,
755 auto resolveKey = [&](StringRef key) -> StringRef {
756 auto it = dialectResourceHandleRenamingMap.find(key);
757 if (it == dialectResourceHandleRenamingMap.end())
763 stringReader, handler, bufferOwnerRef, resolveKey,
768 for (uint64_t i = 0; i < numExternalResourceGroups; ++i) {
770 if (
failed(stringReader.parseString(offsetReader, key)))
775 AsmResourceParser *handler =
config.getResourceParser(key);
777 emitWarning(fileLoc) <<
"ignoring unknown external resources for '" << key
781 if (
failed(parseGroup(handler)))
787 while (!offsetReader.empty()) {
788 std::unique_ptr<BytecodeDialect> *dialect;
790 failed((*dialect)->load(dialectReader, ctx)))
792 Dialect *loadedDialect = (*dialect)->getLoadedDialect();
793 if (!loadedDialect) {
794 return resourceReader.emitError()
795 <<
"dialect '" << (*dialect)->name <<
"' is unknown";
797 const auto *handler = dyn_cast<OpAsmDialectInterface>(loadedDialect);
799 return resourceReader.emitError()
800 <<
"unexpected resources for dialect '" << (*dialect)->name <<
"'";
804 auto processResourceKeyFn = [&](StringRef key) -> LogicalResult {
805 FailureOr<AsmDialectResourceHandle> handle =
806 handler->declareResource(key);
808 return resourceReader.emitError()
809 <<
"unknown 'resource' key '" << key <<
"' for dialect '"
810 << (*dialect)->name <<
"'";
812 dialectResourceHandleRenamingMap[key] = handler->getResourceKey(*handle);
813 dialectResources.push_back(*handle);
819 if (
failed(parseGroup(handler,
true, processResourceKeyFn)))
851class AttrTypeReader {
853 template <
typename T>
858 BytecodeDialect *dialect =
nullptr;
861 bool hasCustomEncoding =
false;
863 ArrayRef<uint8_t> data;
865 using AttrEntry = Entry<Attribute>;
866 using TypeEntry = Entry<Type>;
869 AttrTypeReader(
const StringSectionReader &stringReader,
870 const ResourceSectionReader &resourceReader,
871 const llvm::StringMap<BytecodeDialect *> &dialectsMap,
872 uint64_t &bytecodeVersion, Location fileLoc,
873 const ParserConfig &
config)
874 : stringReader(stringReader), resourceReader(resourceReader),
875 dialectsMap(dialectsMap), fileLoc(fileLoc),
876 bytecodeVersion(bytecodeVersion), parserConfig(
config) {}
880 initialize(MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
881 ArrayRef<uint8_t> sectionData,
882 ArrayRef<uint8_t> offsetSectionData);
884 LogicalResult readAttribute(uint64_t index, Attribute &
result,
885 uint64_t depth = 0) {
886 return readEntry(attributes, index,
result,
"attribute", depth);
889 LogicalResult readType(uint64_t index, Type &
result, uint64_t depth = 0) {
890 return readEntry(types, index,
result,
"type", depth);
895 Attribute resolveAttribute(
size_t index, uint64_t depth = 0) {
896 return resolveEntry(attributes, index,
"Attribute", depth);
898 Type resolveType(
size_t index, uint64_t depth = 0) {
902 Attribute getAttributeOrSentinel(
size_t index) {
903 if (index >= attributes.size())
905 return attributes[index].entry;
907 Type getTypeOrSentinel(
size_t index) {
908 if (index >= types.size())
910 return types[index].entry;
916 if (
failed(reader.parseVarInt(attrIdx)))
918 result = resolveAttribute(attrIdx);
921 LogicalResult parseOptionalAttribute(EncodingReader &reader,
925 if (
failed(reader.parseVarIntWithFlag(attrIdx, flag)))
929 result = resolveAttribute(attrIdx);
935 if (
failed(reader.parseVarInt(typeIdx)))
937 result = resolveType(typeIdx);
941 template <
typename T>
943 Attribute baseResult;
946 if ((
result = dyn_cast<T>(baseResult)))
948 return reader.emitError(
"expected attribute of type: ",
949 llvm::getTypeName<T>(),
", but got: ", baseResult);
953 void addDeferredParsing(uint64_t index) { deferredWorklist.push_back(index); }
956 bool isResolving()
const {
return resolving; }
960 template <
typename T>
961 T
resolveEntry(SmallVectorImpl<Entry<T>> &entries, uint64_t index,
962 StringRef entryType, uint64_t depth = 0);
966 template <
typename T>
967 LogicalResult readEntry(SmallVectorImpl<Entry<T>> &entries, uint64_t index,
968 T &
result, StringRef entryType, uint64_t depth);
972 template <
typename T>
973 LogicalResult parseCustomEntry(Entry<T> &entry, EncodingReader &reader,
974 StringRef entryType, uint64_t index,
979 template <
typename T>
980 LogicalResult parseAsmEntry(T &
result, EncodingReader &reader,
981 StringRef entryType);
985 const StringSectionReader &stringReader;
989 const ResourceSectionReader &resourceReader;
993 const llvm::StringMap<BytecodeDialect *> &dialectsMap;
996 SmallVector<AttrEntry> attributes;
997 SmallVector<TypeEntry> types;
1003 uint64_t &bytecodeVersion;
1006 const ParserConfig &parserConfig;
1010 std::vector<uint64_t> deferredWorklist;
1013 bool resolving =
false;
1016class DialectReader :
public DialectBytecodeReader {
1018 DialectReader(AttrTypeReader &attrTypeReader,
1019 const StringSectionReader &stringReader,
1020 const ResourceSectionReader &resourceReader,
1021 const llvm::StringMap<BytecodeDialect *> &dialectsMap,
1022 EncodingReader &reader, uint64_t &bytecodeVersion,
1024 : attrTypeReader(attrTypeReader), stringReader(stringReader),
1025 resourceReader(resourceReader), dialectsMap(dialectsMap),
1026 reader(reader), bytecodeVersion(bytecodeVersion), depth(depth) {}
1028 InFlightDiagnostic
emitError(
const Twine &msg)
const override {
1029 return reader.emitError(msg);
1032 FailureOr<const DialectVersion *>
1033 getDialectVersion(StringRef dialectName)
const override {
1035 auto dialectEntry = dialectsMap.find(dialectName);
1036 if (dialectEntry == dialectsMap.end())
1041 if (
failed(dialectEntry->getValue()->load(*
this, getLoc().
getContext())) ||
1042 dialectEntry->getValue()->loadedVersion ==
nullptr)
1044 return dialectEntry->getValue()->loadedVersion.get();
1047 MLIRContext *
getContext()
const override {
return getLoc().getContext(); }
1049 uint64_t getBytecodeVersion()
const override {
return bytecodeVersion; }
1051 DialectReader withEncodingReader(EncodingReader &encReader)
const {
1052 return DialectReader(attrTypeReader, stringReader, resourceReader,
1053 dialectsMap, encReader, bytecodeVersion);
1056 Location getLoc()
const {
return reader.getLoc(); }
1064 static constexpr uint64_t maxAttrTypeDepth = 5;
1066 LogicalResult readAttribute(Attribute &
result)
override {
1068 if (
failed(reader.parseVarInt(index)))
1074 if (!attrTypeReader.isResolving()) {
1075 if (Attribute attr = attrTypeReader.resolveAttribute(index)) {
1082 if (depth > maxAttrTypeDepth) {
1083 if (Attribute attr = attrTypeReader.getAttributeOrSentinel(index)) {
1087 attrTypeReader.addDeferredParsing(index);
1090 return attrTypeReader.readAttribute(index,
result, depth + 1);
1092 LogicalResult readOptionalAttribute(Attribute &
result)
override {
1093 return attrTypeReader.parseOptionalAttribute(reader,
result);
1095 LogicalResult readType(Type &
result)
override {
1097 if (
failed(reader.parseVarInt(index)))
1103 if (!attrTypeReader.isResolving()) {
1104 if (Type type = attrTypeReader.resolveType(index)) {
1111 if (depth > maxAttrTypeDepth) {
1112 if (Type type = attrTypeReader.getTypeOrSentinel(index)) {
1116 attrTypeReader.addDeferredParsing(index);
1119 return attrTypeReader.readType(index,
result, depth + 1);
1123 AsmDialectResourceHandle handle;
1124 if (
failed(resourceReader.parseResourceHandle(reader, handle)))
1133 LogicalResult readVarInt(uint64_t &
result)
override {
1134 return reader.parseVarInt(
result);
1137 LogicalResult readSignedVarInt(int64_t &
result)
override {
1138 uint64_t unsignedResult;
1139 if (
failed(reader.parseSignedVarInt(unsignedResult)))
1141 result =
static_cast<int64_t
>(unsignedResult);
1145 FailureOr<APInt> readAPIntWithKnownWidth(
unsigned bitWidth)
override {
1147 if (bitWidth <= 8) {
1149 if (
failed(reader.parseByte(value)))
1151 return APInt(bitWidth, value);
1155 if (bitWidth <= 64) {
1157 if (
failed(reader.parseSignedVarInt(value)))
1159 return APInt(bitWidth, value);
1164 uint64_t numActiveWords;
1165 if (
failed(reader.parseVarInt(numActiveWords)))
1167 SmallVector<uint64_t, 4> words(numActiveWords);
1168 for (uint64_t i = 0; i < numActiveWords; ++i)
1169 if (
failed(reader.parseSignedVarInt(words[i])))
1171 return APInt(bitWidth, words);
1175 readAPFloatWithKnownSemantics(
const llvm::fltSemantics &semantics)
override {
1176 FailureOr<APInt> intVal =
1177 readAPIntWithKnownWidth(APFloat::getSizeInBits(semantics));
1180 return APFloat(semantics, *intVal);
1183 LogicalResult readString(StringRef &
result)
override {
1184 return stringReader.parseString(reader,
result);
1187 LogicalResult readBlob(ArrayRef<char> &
result)
override {
1189 ArrayRef<uint8_t> data;
1190 if (
failed(reader.parseVarInt(dataSize)) ||
1191 failed(reader.parseBytes(dataSize, data)))
1193 result = llvm::ArrayRef(
reinterpret_cast<const char *
>(data.data()),
1198 LogicalResult readBool(
bool &
result)
override {
1199 return reader.parseByte(
result);
1203 AttrTypeReader &attrTypeReader;
1204 const StringSectionReader &stringReader;
1205 const ResourceSectionReader &resourceReader;
1206 const llvm::StringMap<BytecodeDialect *> &dialectsMap;
1207 EncodingReader &reader;
1208 uint64_t &bytecodeVersion;
1213class PropertiesSectionReader {
1216 LogicalResult
initialize(Location fileLoc, ArrayRef<uint8_t> sectionData) {
1217 if (sectionData.empty())
1219 EncodingReader propReader(sectionData, fileLoc);
1221 if (
failed(propReader.parseVarInt(count)))
1224 if (
failed(propReader.parseBytes(propReader.size(), propertiesBuffers)))
1227 EncodingReader offsetsReader(propertiesBuffers, fileLoc);
1228 offsetTable.reserve(count);
1229 for (
auto idx : llvm::seq<int64_t>(0, count)) {
1231 offsetTable.push_back(propertiesBuffers.size() - offsetsReader.size());
1232 ArrayRef<uint8_t> rawProperties;
1234 if (
failed(offsetsReader.parseVarInt(dataSize)) ||
1235 failed(offsetsReader.parseBytes(dataSize, rawProperties)))
1238 if (!offsetsReader.empty())
1239 return offsetsReader.emitError()
1240 <<
"Broken properties section: didn't exhaust the offsets table";
1244 LogicalResult read(Location fileLoc, DialectReader &dialectReader,
1245 OperationName *opName, OperationState &opState)
const {
1246 uint64_t propertiesIdx;
1247 if (
failed(dialectReader.readVarInt(propertiesIdx)))
1249 if (propertiesIdx >= offsetTable.size())
1250 return dialectReader.emitError(
"Properties idx out-of-bound for ")
1252 size_t propertiesOffset = offsetTable[propertiesIdx];
1253 if (propertiesIdx >= propertiesBuffers.size())
1254 return dialectReader.emitError(
"Properties offset out-of-bound for ")
1258 ArrayRef<char> rawProperties;
1262 EncodingReader reader(propertiesBuffers.drop_front(propertiesOffset),
1266 dialectReader.withEncodingReader(reader).readBlob(rawProperties)))
1270 EncodingReader reader(
1271 StringRef(rawProperties.begin(), rawProperties.size()), fileLoc);
1272 DialectReader propReader = dialectReader.withEncodingReader(reader);
1274 auto *iface = opName->
getInterface<BytecodeOpInterface>();
1276 return iface->readProperties(propReader, opState);
1278 return propReader.emitError(
1279 "has properties but missing BytecodeOpInterface for ")
1287 ArrayRef<uint8_t> propertiesBuffers;
1290 SmallVector<int64_t> offsetTable;
1294LogicalResult AttrTypeReader::initialize(
1295 MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
1296 ArrayRef<uint8_t> sectionData, ArrayRef<uint8_t> offsetSectionData) {
1297 EncodingReader offsetReader(offsetSectionData, fileLoc);
1300 uint64_t numAttributes, numTypes;
1301 if (
failed(offsetReader.parseVarInt(numAttributes)) ||
1302 failed(offsetReader.parseVarInt(numTypes)))
1304 attributes.resize(numAttributes);
1305 types.resize(numTypes);
1309 uint64_t currentOffset = 0;
1310 auto parseEntries = [&](
auto &&range) {
1311 size_t currentIndex = 0, endIndex = range.size();
1314 auto parseEntryFn = [&](BytecodeDialect *dialect) -> LogicalResult {
1315 auto &entry = range[currentIndex++];
1318 if (
failed(offsetReader.parseVarIntWithFlag(entrySize,
1319 entry.hasCustomEncoding)))
1323 if (currentOffset + entrySize > sectionData.size()) {
1324 return offsetReader.emitError(
1325 "Attribute or Type entry offset points past the end of section");
1328 entry.data = sectionData.slice(currentOffset, entrySize);
1329 entry.dialect = dialect;
1330 currentOffset += entrySize;
1333 while (currentIndex != endIndex)
1340 if (
failed(parseEntries(attributes)) ||
failed(parseEntries(types)))
1344 if (!offsetReader.empty()) {
1345 return offsetReader.emitError(
1346 "unexpected trailing data in the Attribute/Type offset section");
1352template <
typename T>
1353T AttrTypeReader::resolveEntry(SmallVectorImpl<Entry<T>> &entries,
1354 uint64_t index, StringRef entryType,
1356 bool oldResolving = resolving;
1358 llvm::scope_exit restoreResolving([&]() { resolving = oldResolving; });
1360 if (index >= entries.size()) {
1361 emitError(fileLoc) <<
"invalid " << entryType <<
" index: " << index;
1367 assert(deferredWorklist.empty());
1369 if (succeeded(readEntry(entries, index,
result, entryType, depth))) {
1370 assert(deferredWorklist.empty());
1373 if (deferredWorklist.empty()) {
1383 std::deque<size_t> worklist;
1384 llvm::DenseSet<size_t> inWorklist;
1387 worklist.push_back(index);
1388 inWorklist.insert(index);
1389 for (uint64_t idx : llvm::reverse(deferredWorklist)) {
1390 if (inWorklist.insert(idx).second)
1391 worklist.push_front(idx);
1394 while (!worklist.empty()) {
1395 size_t currentIndex = worklist.front();
1396 worklist.pop_front();
1399 deferredWorklist.clear();
1402 if (succeeded(readEntry(entries, currentIndex,
result, entryType, depth))) {
1403 inWorklist.erase(currentIndex);
1407 if (deferredWorklist.empty()) {
1413 worklist.push_back(currentIndex);
1416 for (uint64_t idx : llvm::reverse(deferredWorklist)) {
1417 if (inWorklist.insert(idx).second)
1418 worklist.push_front(idx);
1420 deferredWorklist.clear();
1422 return entries[index].entry;
1425template <
typename T>
1426LogicalResult AttrTypeReader::readEntry(SmallVectorImpl<Entry<T>> &entries,
1427 uint64_t index, T &
result,
1428 StringRef entryType, uint64_t depth) {
1429 if (index >= entries.size())
1430 return emitError(fileLoc) <<
"invalid " << entryType <<
" index: " << index;
1433 Entry<T> &entry = entries[index];
1440 EncodingReader reader(entry.data, fileLoc);
1441 LogicalResult parseResult =
1442 entry.hasCustomEncoding
1443 ? parseCustomEntry(entry, reader, entryType, index, depth)
1444 : parseAsmEntry(entry.entry, reader, entryType);
1448 if (!reader.empty())
1449 return reader.emitError(
"unexpected trailing bytes after " + entryType +
1456template <
typename T>
1457LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry,
1458 EncodingReader &reader,
1459 StringRef entryType,
1460 uint64_t index, uint64_t depth) {
1461 DialectReader dialectReader(*
this, stringReader, resourceReader, dialectsMap,
1462 reader, bytecodeVersion, depth);
1466 if constexpr (std::is_same_v<T, Type>) {
1468 for (
const auto &callback :
1471 callback->read(dialectReader, entry.dialect->name, entry.entry)))
1479 reader = EncodingReader(entry.data, reader.getLoc());
1483 for (
const auto &callback :
1486 callback->read(dialectReader, entry.dialect->name, entry.entry)))
1494 reader = EncodingReader(entry.data, reader.getLoc());
1499 if (!entry.dialect->interface) {
1500 return reader.emitError(
"dialect '", entry.dialect->name,
1501 "' does not implement the bytecode interface");
1504 if constexpr (std::is_same_v<T, Type>)
1505 entry.entry = entry.dialect->interface->readType(dialectReader);
1507 entry.entry = entry.dialect->interface->readAttribute(dialectReader);
1509 return success(!!entry.entry);
1512template <
typename T>
1513LogicalResult AttrTypeReader::parseAsmEntry(T &
result, EncodingReader &reader,
1514 StringRef entryType) {
1516 if (
failed(reader.parseNullTerminatedString(asmStr)))
1521 MLIRContext *context = fileLoc->
getContext();
1522 if constexpr (std::is_same_v<T, Type>)
1532 if (numRead != asmStr.size()) {
1533 return reader.emitError(
"trailing characters found after ", entryType,
1534 " assembly format: ", asmStr.drop_front(numRead));
1545 struct RegionReadState;
1546 using LazyLoadableOpsInfo =
1547 std::list<std::pair<Operation *, RegionReadState>>;
1548 using LazyLoadableOpsMap =
1553 llvm::MemoryBufferRef buffer,
1554 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef)
1555 : config(config), fileLoc(fileLoc), lazyLoading(lazyLoading),
1556 attrTypeReader(stringReader, resourceReader, dialectsMap, version,
1560 forwardRefOpState(UnknownLoc::
get(config.getContext()),
1561 "builtin.unrealized_conversion_cast",
ValueRange(),
1562 NoneType::
get(config.getContext())),
1563 buffer(buffer), bufferOwnerRef(bufferOwnerRef) {}
1566 LogicalResult read(
Block *block,
1579 this->lazyOpsCallback = lazyOpsCallback;
1580 llvm::scope_exit resetlazyOpsCallback(
1581 [&] { this->lazyOpsCallback =
nullptr; });
1582 auto it = lazyLoadableOpsMap.find(op);
1583 assert(it != lazyLoadableOpsMap.end() &&
1584 "materialize called on non-materializable op");
1590 while (!lazyLoadableOpsMap.empty()) {
1591 if (failed(
materialize(lazyLoadableOpsMap.begin())))
1602 while (!lazyLoadableOps.empty()) {
1603 Operation *op = lazyLoadableOps.begin()->first;
1604 if (shouldMaterialize(op)) {
1605 if (failed(
materialize(lazyLoadableOpsMap.find(op))))
1611 lazyLoadableOps.pop_front();
1612 lazyLoadableOpsMap.erase(op);
1618 LogicalResult
materialize(LazyLoadableOpsMap::iterator it) {
1619 assert(it != lazyLoadableOpsMap.end() &&
1620 "materialize called on non-materializable op");
1621 valueScopes.emplace_back();
1622 std::vector<RegionReadState> regionStack;
1623 regionStack.push_back(std::move(it->getSecond()->second));
1624 lazyLoadableOps.erase(it->getSecond());
1625 lazyLoadableOpsMap.erase(it);
1627 while (!regionStack.empty())
1628 if (failed(
parseRegions(regionStack, regionStack.back())))
1633 LogicalResult checkSectionAlignment(
1644 const bool isGloballyAligned =
1645 ((uintptr_t)buffer.getBufferStart() & (alignment - 1)) == 0;
1647 if (!isGloballyAligned)
1648 return emitError(
"expected section alignment ")
1649 << alignment <<
" but bytecode buffer 0x"
1650 << Twine::utohexstr((uint64_t)buffer.getBufferStart())
1651 <<
" is not aligned";
1660 LogicalResult parseVersion(EncodingReader &reader);
1665 LogicalResult parseDialectSection(ArrayRef<uint8_t> sectionData);
1670 FailureOr<OperationName> parseOpName(EncodingReader &reader,
1671 std::optional<bool> &wasRegistered);
1677 template <
typename T>
1679 return attrTypeReader.parseAttribute(reader,
result);
1682 return attrTypeReader.parseType(reader,
result);
1689 parseResourceSection(EncodingReader &reader,
1690 std::optional<ArrayRef<uint8_t>> resourceData,
1691 std::optional<ArrayRef<uint8_t>> resourceOffsetData);
1698 struct RegionReadState {
1699 RegionReadState(Operation *op, EncodingReader *reader,
1700 bool isIsolatedFromAbove)
1701 : RegionReadState(op->getRegions(), reader, isIsolatedFromAbove) {}
1702 RegionReadState(MutableArrayRef<Region> regions, EncodingReader *reader,
1703 bool isIsolatedFromAbove)
1704 : curRegion(regions.begin()), endRegion(regions.end()), reader(reader),
1705 isIsolatedFromAbove(isIsolatedFromAbove) {}
1708 MutableArrayRef<Region>::iterator curRegion, endRegion;
1713 EncodingReader *reader;
1714 std::unique_ptr<EncodingReader> owningReader;
1717 unsigned numValues = 0;
1720 SmallVector<Block *> curBlocks;
1725 uint64_t numOpsRemaining = 0;
1728 bool isIsolatedFromAbove =
false;
1731 LogicalResult parseIRSection(ArrayRef<uint8_t> sectionData,
Block *block);
1732 LogicalResult
parseRegions(std::vector<RegionReadState> ®ionStack,
1733 RegionReadState &readState);
1734 FailureOr<Operation *> parseOpWithoutRegions(EncodingReader &reader,
1735 RegionReadState &readState,
1736 bool &isIsolatedFromAbove);
1738 LogicalResult parseRegion(RegionReadState &readState);
1739 LogicalResult parseBlockHeader(EncodingReader &reader,
1740 RegionReadState &readState);
1741 LogicalResult parseBlockArguments(EncodingReader &reader,
Block *block);
1748 Value parseOperand(EncodingReader &reader);
1751 LogicalResult defineValues(EncodingReader &reader,
ValueRange values);
1754 Value createForwardRef();
1762 struct UseListOrderStorage {
1763 UseListOrderStorage(
bool isIndexPairEncoding,
1764 SmallVector<unsigned, 4> &&
indices)
1766 isIndexPairEncoding(isIndexPairEncoding) {};
1769 SmallVector<unsigned, 4>
indices;
1773 bool isIndexPairEncoding;
1781 using UseListMapT = DenseMap<unsigned, UseListOrderStorage>;
1782 FailureOr<UseListMapT> parseUseListOrderForRange(EncodingReader &reader,
1783 uint64_t rangeSize);
1786 LogicalResult sortUseListOrder(Value value);
1790 LogicalResult processUseLists(Operation *topLevelOp);
1800 void push(RegionReadState &readState) {
1801 nextValueIDs.push_back(values.size());
1802 values.resize(values.size() + readState.numValues);
1807 void pop(RegionReadState &readState) {
1808 values.resize(values.size() - readState.numValues);
1809 nextValueIDs.pop_back();
1813 std::vector<Value> values;
1817 SmallVector<unsigned, 4> nextValueIDs;
1821 const ParserConfig &
config;
1832 LazyLoadableOpsInfo lazyLoadableOps;
1833 LazyLoadableOpsMap lazyLoadableOpsMap;
1834 llvm::function_ref<bool(Operation *)> lazyOpsCallback;
1837 AttrTypeReader attrTypeReader;
1840 uint64_t version = 0;
1846 SmallVector<std::unique_ptr<BytecodeDialect>> dialects;
1847 llvm::StringMap<BytecodeDialect *> dialectsMap;
1848 SmallVector<BytecodeOperationName> opNames;
1851 ResourceSectionReader resourceReader;
1855 DenseMap<void *, UseListOrderStorage> valueToUseListMap;
1858 StringSectionReader stringReader;
1861 PropertiesSectionReader propertiesReader;
1864 std::vector<ValueScope> valueScopes;
1871 Block forwardRefOps;
1875 Block openForwardRefOps;
1878 OperationState forwardRefOpState;
1881 llvm::MemoryBufferRef buffer;
1885 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef;
1890 EncodingReader reader(buffer.getBuffer(), fileLoc);
1891 this->lazyOpsCallback = lazyOpsCallback;
1892 llvm::scope_exit resetlazyOpsCallback(
1893 [&] { this->lazyOpsCallback =
nullptr; });
1896 if (failed(reader.skipBytes(StringRef(
"ML\xefR").size())))
1899 if (failed(parseVersion(reader)) ||
1900 failed(reader.parseNullTerminatedString(producer)))
1906 diag.attachNote() <<
"in bytecode version " << version
1907 <<
" produced by: " << producer;
1911 const auto checkSectionAlignment = [&](
unsigned alignment) {
1912 return this->checkSectionAlignment(
1913 alignment, [&](
const auto &msg) {
return reader.emitError(msg); });
1917 std::optional<ArrayRef<uint8_t>>
1919 while (!reader.empty()) {
1924 reader.parseSection(sectionID, checkSectionAlignment, sectionData)))
1928 if (sectionDatas[sectionID]) {
1929 return reader.emitError(
"duplicate top-level section: ",
1932 sectionDatas[sectionID] = sectionData;
1938 return reader.emitError(
"missing data for top-level section: ",
1944 if (failed(stringReader.initialize(
1950 failed(propertiesReader.initialize(
1959 if (failed(parseResourceSection(
1965 if (failed(attrTypeReader.initialize(
1974LogicalResult BytecodeReader::Impl::parseVersion(EncodingReader &reader) {
1975 if (failed(reader.parseVarInt(version)))
1981 if (version < minSupportedVersion) {
1982 return reader.emitError(
"bytecode version ", version,
1983 " is older than the current version of ",
1984 currentVersion,
", and upgrade is not supported");
1986 if (version > currentVersion) {
1987 return reader.emitError(
"bytecode version ", version,
1988 " is newer than the current version ",
1993 lazyLoading =
false;
2001LogicalResult BytecodeDialect::load(
const DialectReader &reader,
2007 return reader.emitError(
"dialect '")
2009 <<
"' is unknown. If this is intended, please call "
2010 "allowUnregisteredDialects() on the MLIRContext, or use "
2011 "-allow-unregistered-dialect with the MLIR tool used.";
2013 dialect = loadedDialect;
2018 interface = dyn_cast<BytecodeDialectInterface>(loadedDialect);
2019 if (!versionBuffer.empty()) {
2021 return reader.emitError(
"dialect '")
2023 <<
"' does not implement the bytecode interface, "
2024 "but found a version entry";
2025 EncodingReader encReader(versionBuffer, reader.getLoc());
2026 DialectReader versionReader = reader.withEncodingReader(encReader);
2027 loadedVersion = interface->readVersion(versionReader);
2035BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) {
2036 EncodingReader sectionReader(sectionData, fileLoc);
2039 uint64_t numDialects;
2040 if (
failed(sectionReader.parseVarInt(numDialects)))
2042 dialects.resize(numDialects);
2044 const auto checkSectionAlignment = [&](
unsigned alignment) {
2045 return this->checkSectionAlignment(alignment, [&](
const auto &msg) {
2046 return sectionReader.emitError(msg);
2051 for (uint64_t i = 0; i < numDialects; ++i) {
2052 dialects[i] = std::make_unique<BytecodeDialect>();
2056 if (
failed(stringReader.parseString(sectionReader, dialects[i]->name)))
2062 uint64_t dialectNameIdx;
2063 bool versionAvailable;
2064 if (
failed(sectionReader.parseVarIntWithFlag(dialectNameIdx,
2067 if (
failed(stringReader.parseStringAtIndex(sectionReader, dialectNameIdx,
2068 dialects[i]->name)))
2070 if (versionAvailable) {
2072 if (
failed(sectionReader.parseSection(sectionID, checkSectionAlignment,
2073 dialects[i]->versionBuffer)))
2076 emitError(fileLoc,
"expected dialect version section");
2080 dialectsMap[dialects[i]->name] = dialects[i].get();
2084 auto parseOpName = [&](BytecodeDialect *dialect) {
2086 std::optional<bool> wasRegistered;
2090 if (
failed(stringReader.parseString(sectionReader, opName)))
2093 bool wasRegisteredFlag;
2094 if (
failed(stringReader.parseStringWithFlag(sectionReader, opName,
2095 wasRegisteredFlag)))
2097 wasRegistered = wasRegisteredFlag;
2099 opNames.emplace_back(dialect, opName, wasRegistered);
2106 if (
failed(sectionReader.parseVarInt(numOps)))
2108 opNames.reserve(numOps);
2110 while (!sectionReader.empty())
2116FailureOr<OperationName>
2117BytecodeReader::Impl::parseOpName(EncodingReader &reader,
2118 std::optional<bool> &wasRegistered) {
2119 BytecodeOperationName *opName =
nullptr;
2122 wasRegistered = opName->wasRegistered;
2125 if (!opName->opName) {
2130 if (opName->name.empty()) {
2131 opName->opName.emplace(opName->dialect->name,
getContext());
2134 DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
2135 dialectsMap, reader, version);
2138 opName->opName.emplace((opName->dialect->name +
"." + opName->name).str(),
2142 return *opName->opName;
2149LogicalResult BytecodeReader::Impl::parseResourceSection(
2150 EncodingReader &reader, std::optional<ArrayRef<uint8_t>> resourceData,
2151 std::optional<ArrayRef<uint8_t>> resourceOffsetData) {
2153 if (resourceData.has_value() != resourceOffsetData.has_value()) {
2154 if (resourceOffsetData)
2155 return emitError(fileLoc,
"unexpected resource offset section when "
2156 "resource section is not present");
2159 "expected resource offset section when resource section is present");
2167 DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
2168 dialectsMap, reader, version);
2169 return resourceReader.initialize(fileLoc,
config, dialects, stringReader,
2170 *resourceData, *resourceOffsetData,
2171 dialectReader, bufferOwnerRef);
2178FailureOr<BytecodeReader::Impl::UseListMapT>
2179BytecodeReader::Impl::parseUseListOrderForRange(EncodingReader &reader,
2180 uint64_t numResults) {
2181 BytecodeReader::Impl::UseListMapT map;
2182 uint64_t numValuesToRead = 1;
2183 if (numResults > 1 &&
failed(reader.parseVarInt(numValuesToRead)))
2186 for (
size_t valueIdx = 0; valueIdx < numValuesToRead; valueIdx++) {
2187 uint64_t resultIdx = 0;
2188 if (numResults > 1 &&
failed(reader.parseVarInt(resultIdx)))
2192 bool indexPairEncoding;
2193 if (
failed(reader.parseVarIntWithFlag(numValues, indexPairEncoding)))
2196 SmallVector<unsigned, 4> useListOrders;
2197 for (
size_t idx = 0; idx < numValues; idx++) {
2199 if (
failed(reader.parseVarInt(index)))
2201 useListOrders.push_back(index);
2205 map.try_emplace(resultIdx, UseListOrderStorage(indexPairEncoding,
2206 std::move(useListOrders)));
2217LogicalResult BytecodeReader::Impl::sortUseListOrder(Value value) {
2222 bool hasIncomingOrder =
2227 bool alreadySorted =
true;
2231 llvm::SmallVector<std::pair<unsigned, uint64_t>> currentOrder = {{0, prevID}};
2232 for (
auto item : llvm::drop_begin(llvm::enumerate(value.
getUses()))) {
2234 item.value(), operationIDs.at(item.value().getOwner()));
2235 alreadySorted &= prevID > currentID;
2236 currentOrder.push_back({item.index(), currentID});
2242 if (alreadySorted && !hasIncomingOrder)
2249 currentOrder.begin(), currentOrder.end(),
2250 [](
auto elem1,
auto elem2) { return elem1.second > elem2.second; });
2252 if (!hasIncomingOrder) {
2256 SmallVector<unsigned> shuffle(llvm::make_first_range(currentOrder));
2262 UseListOrderStorage customOrder =
2264 SmallVector<unsigned, 4> shuffle = std::move(customOrder.indices);
2270 if (customOrder.isIndexPairEncoding) {
2272 if (shuffle.size() & 1)
2275 SmallVector<unsigned, 4> newShuffle(numUses);
2277 std::iota(newShuffle.begin(), newShuffle.end(), idx);
2278 for (idx = 0; idx < shuffle.size(); idx += 2)
2279 newShuffle[shuffle[idx]] = shuffle[idx + 1];
2281 shuffle = std::move(newShuffle);
2288 uint64_t accumulator = 0;
2289 for (
const auto &elem : shuffle) {
2290 if (!set.insert(elem).second)
2292 accumulator += elem;
2294 if (numUses != shuffle.size() ||
2295 accumulator != (((numUses - 1) * numUses) >> 1))
2300 shuffle = SmallVector<unsigned, 4>(llvm::map_range(
2301 currentOrder, [&](
auto item) {
return shuffle[item.first]; }));
2306LogicalResult BytecodeReader::Impl::processUseLists(Operation *topLevelOp) {
2310 unsigned operationID = 0;
2312 [&](Operation *op) { operationIDs.try_emplace(op, operationID++); });
2314 auto blockWalk = topLevelOp->
walk([
this](
Block *block) {
2316 if (
failed(sortUseListOrder(arg)))
2321 auto resultWalk = topLevelOp->
walk([
this](Operation *op) {
2328 return failure(blockWalk.wasInterrupted() || resultWalk.wasInterrupted());
2336BytecodeReader::Impl::parseIRSection(ArrayRef<uint8_t> sectionData,
2338 EncodingReader reader(sectionData, fileLoc);
2341 std::vector<RegionReadState> regionStack;
2344 OwningOpRef<ModuleOp> moduleOp = ModuleOp::create(fileLoc);
2345 regionStack.emplace_back(*moduleOp, &reader,
true);
2346 regionStack.back().curBlocks.push_back(moduleOp->getBody());
2347 regionStack.back().curBlock = regionStack.back().curRegion->begin();
2348 if (
failed(parseBlockHeader(reader, regionStack.back())))
2350 valueScopes.emplace_back();
2351 valueScopes.back().push(regionStack.back());
2354 while (!regionStack.empty())
2357 if (!forwardRefOps.empty()) {
2358 return reader.emitError(
2359 "not all forward unresolved forward operand references");
2363 if (
failed(processUseLists(*moduleOp)))
2364 return reader.emitError(
2365 "parsed use-list orders were invalid and could not be applied");
2368 for (
const std::unique_ptr<BytecodeDialect> &byteCodeDialect : dialects) {
2371 if (!byteCodeDialect->loadedVersion)
2373 if (byteCodeDialect->interface &&
2374 failed(byteCodeDialect->interface->upgradeFromVersion(
2375 *moduleOp, *byteCodeDialect->loadedVersion)))
2384 auto &parsedOps = moduleOp->getBody()->getOperations();
2386 destOps.splice(destOps.end(), parsedOps, parsedOps.begin(), parsedOps.end());
2391BytecodeReader::Impl::parseRegions(std::vector<RegionReadState> ®ionStack,
2392 RegionReadState &readState) {
2393 const auto checkSectionAlignment = [&](
unsigned alignment) {
2394 return this->checkSectionAlignment(
2395 alignment, [&](
const auto &msg) {
return emitError(fileLoc, msg); });
2401 for (; readState.curRegion != readState.endRegion; ++readState.curRegion) {
2407 if (
failed(parseRegion(readState)))
2411 if (readState.curRegion->empty())
2416 EncodingReader &reader = *readState.reader;
2418 while (readState.numOpsRemaining--) {
2421 bool isIsolatedFromAbove =
false;
2422 FailureOr<Operation *> op =
2423 parseOpWithoutRegions(reader, readState, isIsolatedFromAbove);
2431 if ((*op)->getNumRegions()) {
2432 RegionReadState childState(*op, &reader, isIsolatedFromAbove);
2437 ArrayRef<uint8_t> sectionData;
2438 if (
failed(reader.parseSection(sectionID, checkSectionAlignment,
2442 return emitError(fileLoc,
"expected IR section for region");
2443 childState.owningReader =
2444 std::make_unique<EncodingReader>(sectionData, fileLoc);
2445 childState.reader = childState.owningReader.get();
2449 if (lazyLoading && (!lazyOpsCallback || !lazyOpsCallback(*op))) {
2450 lazyLoadableOps.emplace_back(*op, std::move(childState));
2451 lazyLoadableOpsMap.try_emplace(*op,
2452 std::prev(lazyLoadableOps.end()));
2456 regionStack.push_back(std::move(childState));
2459 if (isIsolatedFromAbove)
2460 valueScopes.emplace_back();
2466 if (++readState.curBlock == readState.curRegion->end())
2468 if (
failed(parseBlockHeader(reader, readState)))
2473 readState.curBlock = {};
2474 valueScopes.back().pop(readState);
2479 if (readState.isIsolatedFromAbove) {
2480 assert(!valueScopes.empty() &&
"Expect a valueScope after reading region");
2481 valueScopes.pop_back();
2483 assert(!regionStack.empty() &&
"Expect a regionStack after reading region");
2484 regionStack.pop_back();
2488FailureOr<Operation *>
2489BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader,
2490 RegionReadState &readState,
2491 bool &isIsolatedFromAbove) {
2493 std::optional<bool> wasRegistered;
2494 FailureOr<OperationName> opName = parseOpName(reader, wasRegistered);
2501 if (
failed(reader.parseByte(opMask)))
2511 OperationState opState(opLoc, *opName);
2515 DictionaryAttr dictAttr;
2526 "Unexpected missing `wasRegistered` opname flag at "
2527 "bytecode version ")
2528 << version <<
" with properties.";
2532 if (wasRegistered) {
2533 DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
2534 dialectsMap, reader, version);
2536 propertiesReader.read(fileLoc, dialectReader, &*opName, opState)))
2548 uint64_t numResults;
2549 if (
failed(reader.parseVarInt(numResults)))
2551 opState.
types.resize(numResults);
2552 for (
int i = 0, e = numResults; i < e; ++i)
2559 uint64_t numOperands;
2560 if (
failed(reader.parseVarInt(numOperands)))
2562 opState.
operands.resize(numOperands);
2563 for (
int i = 0, e = numOperands; i < e; ++i)
2564 if (!(opState.
operands[i] = parseOperand(reader)))
2571 if (
failed(reader.parseVarInt(numSuccs)))
2574 for (
int i = 0, e = numSuccs; i < e; ++i) {
2583 std::optional<UseListMapT> resultIdxToUseListMap = std::nullopt;
2586 size_t numResults = opState.
types.size();
2587 auto parseResult = parseUseListOrderForRange(reader, numResults);
2590 resultIdxToUseListMap = std::move(*parseResult);
2595 uint64_t numRegions;
2596 if (
failed(reader.parseVarIntWithFlag(numRegions, isIsolatedFromAbove)))
2599 opState.
regions.reserve(numRegions);
2600 for (
int i = 0, e = numRegions; i < e; ++i)
2601 opState.
regions.push_back(std::make_unique<Region>());
2606 readState.curBlock->push_back(op);
2617 if (resultIdxToUseListMap.has_value()) {
2619 if (resultIdxToUseListMap->contains(idx)) {
2621 resultIdxToUseListMap->at(idx));
2628LogicalResult BytecodeReader::Impl::parseRegion(RegionReadState &readState) {
2629 EncodingReader &reader = *readState.reader;
2633 if (
failed(reader.parseVarInt(numBlocks)))
2642 if (
failed(reader.parseVarInt(numValues)))
2644 readState.numValues = numValues;
2648 readState.curBlocks.clear();
2649 readState.curBlocks.reserve(numBlocks);
2650 for (uint64_t i = 0; i < numBlocks; ++i) {
2651 readState.curBlocks.push_back(
new Block());
2652 readState.curRegion->push_back(readState.curBlocks.back());
2656 valueScopes.back().push(readState);
2659 readState.curBlock = readState.curRegion->begin();
2660 return parseBlockHeader(reader, readState);
2664BytecodeReader::Impl::parseBlockHeader(EncodingReader &reader,
2665 RegionReadState &readState) {
2667 if (
failed(reader.parseVarIntWithFlag(readState.numOpsRemaining, hasArgs)))
2671 if (hasArgs &&
failed(parseBlockArguments(reader, &*readState.curBlock)))
2678 uint8_t hasUseListOrders = 0;
2679 if (hasArgs &&
failed(reader.parseByte(hasUseListOrders)))
2682 if (!hasUseListOrders)
2685 Block &blk = *readState.curBlock;
2686 auto argIdxToUseListMap =
2688 if (
failed(argIdxToUseListMap) || argIdxToUseListMap->empty())
2692 if (argIdxToUseListMap->contains(idx))
2694 argIdxToUseListMap->at(idx));
2700LogicalResult BytecodeReader::Impl::parseBlockArguments(EncodingReader &reader,
2704 if (
failed(reader.parseVarInt(numArgs)))
2707 SmallVector<Type> argTypes;
2708 SmallVector<Location> argLocs;
2709 argTypes.reserve(numArgs);
2710 argLocs.reserve(numArgs);
2712 Location unknownLoc = UnknownLoc::get(
config.getContext());
2715 LocationAttr argLoc = unknownLoc;
2720 if (
failed(reader.parseVarIntWithFlag(typeIdx, hasLoc)) ||
2721 !(argType = attrTypeReader.resolveType(typeIdx)))
2731 argTypes.push_back(argType);
2732 argLocs.push_back(argLoc);
2742Value BytecodeReader::Impl::parseOperand(EncodingReader &reader) {
2743 std::vector<Value> &values = valueScopes.back().values;
2744 Value *value =
nullptr;
2750 *value = createForwardRef();
2754LogicalResult BytecodeReader::Impl::defineValues(EncodingReader &reader,
2756 ValueScope &valueScope = valueScopes.back();
2757 std::vector<Value> &values = valueScope.values;
2759 unsigned &valueID = valueScope.nextValueIDs.back();
2760 unsigned valueIDEnd = valueID + newValues.size();
2761 if (valueIDEnd > values.size()) {
2762 return reader.emitError(
2763 "value index range was outside of the expected range for "
2764 "the parent region, got [",
2765 valueID,
", ", valueIDEnd,
"), but the maximum index was ",
2770 for (
unsigned i = 0, e = newValues.size(); i != e; ++i, ++valueID) {
2771 Value newValue = newValues[i];
2774 if (Value oldValue = std::exchange(values[valueID], newValue)) {
2775 Operation *forwardRefOp = oldValue.getDefiningOp();
2780 assert(forwardRefOp && forwardRefOp->
getBlock() == &forwardRefOps &&
2781 "value index was already defined?");
2783 oldValue.replaceAllUsesWith(newValue);
2784 forwardRefOp->
moveBefore(&openForwardRefOps, openForwardRefOps.end());
2790Value BytecodeReader::Impl::createForwardRef() {
2793 if (!openForwardRefOps.empty()) {
2794 Operation *op = &openForwardRefOps.back();
2795 op->
moveBefore(&forwardRefOps, forwardRefOps.end());
2799 return forwardRefOps.back().getResult(0);
2810 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) {
2814 impl = std::make_unique<Impl>(sourceFileLoc,
config, lazyLoading, buffer,
2820 return impl->read(block, lazyOpsCallback);
2824 return impl->getNumOpsToMaterialize();
2828 return impl->isMaterializable(op);
2833 return impl->materialize(op, lazyOpsCallback);
2838 return impl->finalize(shouldMaterialize);
2842 return buffer.getBuffer().starts_with(
"ML\xefR");
2851 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) {
2857 "input buffer is not an MLIR bytecode file");
2861 buffer, bufferOwnerRef);
2862 return reader.
read(block,
nullptr);
2873 *sourceMgr->getMemoryBuffer(sourceMgr->getMainFileID()), block,
config,
static LogicalResult parseDialectGrouping(EncodingReader &reader, MutableArrayRef< std::unique_ptr< BytecodeDialect > > dialects, function_ref< LogicalResult(BytecodeDialect *)> entryCallback)
Parse a single dialect group encoded in the byte stream.
static LogicalResult readBytecodeFileImpl(llvm::MemoryBufferRef buffer, Block *block, const ParserConfig &config, const std::shared_ptr< llvm::SourceMgr > &bufferOwnerRef)
Read the bytecode from the provided memory buffer reference.
static bool isSectionOptional(bytecode::Section::ID sectionID, int version)
Returns true if the given top-level section ID is optional.
static LogicalResult parseResourceGroup(Location fileLoc, bool allowEmpty, EncodingReader &offsetReader, EncodingReader &resourceReader, StringSectionReader &stringReader, T *handler, const std::shared_ptr< llvm::SourceMgr > &bufferOwnerRef, function_ref< StringRef(StringRef)> remapKey={}, function_ref< LogicalResult(StringRef)> processKeyFn={})
static LogicalResult resolveEntry(EncodingReader &reader, RangeT &entries, uint64_t index, T &entry, StringRef entryStr)
Resolve an index into the given entry list.
static LogicalResult parseEntry(EncodingReader &reader, RangeT &entries, T &entry, StringRef entryStr)
Parse and resolve an index into the given entry list.
LogicalResult initialize(unsigned origNumLoops, ArrayRef< ReassociationIndices > foldedIterationDims)
static std::string diag(const llvm::Value &value)
static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, unsigned nRegions=1)
MutableArrayRef< char > getMutableData()
Return a mutable reference to the raw underlying data of this blob.
ArrayRef< char > getData() const
Return the raw underlying data of this blob.
bool isMutable() const
Return if the data of this blob is mutable.
MLIRContext * getContext() const
Return the context this attribute belongs to.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
OpListType & getOperations()
BlockArgListType getArguments()
ArrayRef< std::unique_ptr< AttrTypeBytecodeReader< Type > > > getTypeCallbacks() const
ArrayRef< std::unique_ptr< AttrTypeBytecodeReader< Attribute > > > getAttributeCallbacks() const
Returns the callbacks available to the parser.
This class is used to read a bytecode buffer and translate it into MLIR.
LogicalResult materializeAll()
Materialize all operations.
LogicalResult read(Block *block, llvm::function_ref< bool(Operation *)> lazyOps)
Read the bytecode defined within buffer into the given block.
bool isMaterializable(Operation *op)
Impl(Location fileLoc, const ParserConfig &config, bool lazyLoading, llvm::MemoryBufferRef buffer, const std::shared_ptr< llvm::SourceMgr > &bufferOwnerRef)
LogicalResult finalize(function_ref< bool(Operation *)> shouldMaterialize)
Finalize the lazy-loading by calling back with every op that hasn't been materialized to let the clie...
LogicalResult materialize(Operation *op, llvm::function_ref< bool(Operation *)> lazyOpsCallback)
Materialize the provided operation, invoke the lazyOpsCallback on every newly found lazy operation.
int64_t getNumOpsToMaterialize() const
Return the number of ops that haven't been materialized yet.
LogicalResult materialize(Operation *op, llvm::function_ref< bool(Operation *)> lazyOpsCallback=[](Operation *) { return false;})
Materialize the provide operation.
LogicalResult finalize(function_ref< bool(Operation *)> shouldMaterialize=[](Operation *) { return true;})
Finalize the lazy-loading by calling back with every op that hasn't been materialized to let the clie...
BytecodeReader(llvm::MemoryBufferRef buffer, const ParserConfig &config, bool lazyLoad, const std::shared_ptr< llvm::SourceMgr > &bufferOwnerRef={})
Create a bytecode reader for the given buffer.
int64_t getNumOpsToMaterialize() const
Return the number of ops that haven't been materialized yet.
bool isMaterializable(Operation *op)
Return true if the provided op is materializable.
LogicalResult readTopLevel(Block *block, llvm::function_ref< bool(Operation *)> lazyOps=[](Operation *) { return false;})
Read the operations defined within the given memory buffer, containing MLIR bytecode,...
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
static FileLineColLoc get(StringAttr filename, unsigned line, unsigned column)
This class represents a diagnostic that is inflight and set to be reported.
InFlightDiagnostic & append(Args &&...args) &
Append arguments to the diagnostic.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext * getContext() const
Return the context this location is uniqued in.
MLIRContext is the top-level object for a collection of MLIR operations.
T * getOrLoadDialect()
Get (or create) a dialect for the given derived dialect type.
bool allowsUnregisteredDialects()
Return true if we allow to create operation for unregistered dialects.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
bool isRegistered() const
Return if this operation is registered.
T::Concept * getInterface() const
Returns an instance of the concept object for the given interface if it was registered to this operat...
Operation is the basic unit of execution within MLIR.
void dropAllReferences()
This drops all operand uses from this operation, which is an essential step in breaking cyclic depend...
Block * getBlock()
Returns the operation block that contains this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
result_range getResults()
void erase()
Remove this operation from its parent block and delete it.
unsigned getNumResults()
Return the number of results held by this operation.
This class represents a configuration for the MLIR assembly parser.
BytecodeReaderConfig & getBytecodeReaderConfig() const
Returns the parsing configurations associated to the bytecode read.
BlockListType::iterator iterator
This diagnostic handler is a simple RAII class that registers and erases a diagnostic handler on a gi...
static AsmResourceBlob allocateWithAlign(ArrayRef< char > data, size_t align, AsmResourceBlob::DeleterFn deleter={}, bool dataIsMutable=false)
Create a new unmanaged resource directly referencing the provided data.
This class provides an abstraction over the different types of ranges over Values.
bool use_empty() const
Returns true if this value has no uses.
void shuffleUseList(ArrayRef< unsigned > indices)
Shuffle the use list order according to the provided indices.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
unsigned getNumUses() const
This method computes the number of uses of this Value.
bool hasOneUse() const
Returns true if this value has exactly one use.
use_iterator use_begin() const
static WalkResult advance()
static WalkResult interrupt()
@ kAttrType
This section contains the attributes and types referenced within an IR module.
@ kAttrTypeOffset
This section contains the offsets for the attribute and types within the AttrType section.
@ kIR
This section contains the list of operations serialized into the bytecode, and their nested regions/o...
@ kResource
This section contains the resources of the bytecode.
@ kResourceOffset
This section contains the offsets of resources within the Resource section.
@ kDialect
This section contains the dialects referenced within an IR module.
@ kString
This section contains strings referenced within the bytecode.
@ kDialectVersions
This section contains the versions of each dialect.
@ kProperties
This section contains the properties for the operations.
@ kNumSections
The total number of section types.
static uint64_t getUseID(OperandT &val, unsigned ownerID)
Get the unique ID of a value use.
@ kUseListOrdering
Use-list ordering started to be encoded in version 3.
@ kAlignmentByte
An arbitrary value used to fill alignment padding.
@ kVersion
The current bytecode version.
@ kLazyLoading
Support for lazy-loading of isolated region was added in version 2.
@ kDialectVersioning
Dialects versioning was added in version 1.
@ kElideUnknownBlockArgLocation
Avoid recording unknown locations on block arguments (compression) started in version 4.
@ kNativePropertiesEncoding
Support for encoding properties natively in bytecode instead of merged with the discardable attribute...
@ kMinSupportedVersion
The minimum supported version of the bytecode.
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
StringRef toString(AsmResourceEntryKind kind)
static LogicalResult readResourceHandle(DialectBytecodeReader &reader, FailureOr< T > &value, Ts &&...params)
Helper for resource handle reading that returns LogicalResult.
bool isBytecode(llvm::MemoryBufferRef buffer)
Returns true if the given buffer starts with the magic bytes that signal MLIR bytecode.
const FrozenRewritePatternSet GreedyRewriteConfig config
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
AsmResourceEntryKind
This enum represents the different kinds of resource values.
LogicalResult readBytecodeFile(llvm::MemoryBufferRef buffer, Block *block, const ParserConfig &config)
Read the operations defined within the given memory buffer, containing MLIR bytecode,...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
llvm::function_ref< Fn > function_ref
SmallVector< Block *, 1 > successors
Successors of this operation and their respective operands.
SmallVector< Value, 4 > operands
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Attribute propertiesAttr
This Attribute is used to opaquely construct the properties of the operation.
SmallVector< Type, 4 > types
Types of the results of this operation.