20#include "llvm/ADT/ArrayRef.h"
21#include "llvm/ADT/ScopeExit.h"
22#include "llvm/ADT/StringExtras.h"
23#include "llvm/ADT/StringRef.h"
24#include "llvm/Support/Endian.h"
25#include "llvm/Support/MemoryBufferRef.h"
26#include "llvm/Support/SourceMgr.h"
36#define DEBUG_TYPE "mlir-bytecode-reader"
48 return "AttrType (2)";
50 return "AttrTypeOffset (3)";
54 return "Resource (5)";
56 return "ResourceOffset (6)";
58 return "DialectVersions (7)";
60 return "Properties (8)";
62 return (
"Unknown (" + Twine(
static_cast<unsigned>(sectionID)) +
")").str();
82 llvm_unreachable(
"unknown section ID");
93 explicit EncodingReader(ArrayRef<uint8_t> contents, Location fileLoc)
94 : buffer(contents), dataIt(buffer.begin()), fileLoc(fileLoc) {}
95 explicit EncodingReader(StringRef contents, Location fileLoc)
96 : EncodingReader({
reinterpret_cast<const uint8_t *
>(contents.data()),
101 bool empty()
const {
return dataIt == buffer.end(); }
104 size_t size()
const {
return buffer.end() - dataIt; }
107 LogicalResult alignTo(
unsigned alignment) {
108 if (!llvm::isPowerOf2_32(alignment))
109 return emitError(
"expected alignment to be a power-of-two");
111 auto isUnaligned = [&](
const uint8_t *ptr) {
112 return ((uintptr_t)ptr & (alignment - 1)) != 0;
119 while (isUnaligned(dataIt)) {
121 if (
failed(parseByte(padding)))
124 return emitError(
"expected alignment byte (0xCB), but got: '0x" +
125 llvm::utohexstr(padding) +
"'");
131 if (LLVM_UNLIKELY(isUnaligned(dataIt))) {
132 return emitError(
"expected data iterator aligned to ", alignment,
133 ", but got pointer: '0x" +
134 llvm::utohexstr((uintptr_t)dataIt) +
"'");
141 template <
typename... Args>
142 InFlightDiagnostic
emitError(Args &&...args)
const {
143 return ::emitError(fileLoc).
append(std::forward<Args>(args)...);
145 InFlightDiagnostic
emitError()
const { return ::emitError(fileLoc); }
148 template <
typename... Args>
149 InFlightDiagnostic
emitWarning(Args &&...args)
const {
150 return ::emitWarning(fileLoc).
append(std::forward<Args>(args)...);
152 InFlightDiagnostic
emitWarning()
const { return ::emitWarning(fileLoc); }
155 template <
typename T>
156 LogicalResult parseByte(T &value) {
158 return emitError(
"attempting to parse a byte at the end of the bytecode");
159 value =
static_cast<T
>(*dataIt++);
163 LogicalResult parseBytes(
size_t length, ArrayRef<uint8_t> &
result) {
164 if (length > size()) {
165 return emitError(
"attempting to parse ", length,
" bytes when only ",
168 result = {dataIt, length};
174 LogicalResult parseBytes(
size_t length, uint8_t *
result) {
175 if (length > size()) {
176 return emitError(
"attempting to parse ", length,
" bytes when only ",
179 memcpy(
result, dataIt, length);
186 LogicalResult parseBlobAndAlignment(ArrayRef<uint8_t> &data,
187 uint64_t &alignment) {
189 if (
failed(parseVarInt(alignment)) ||
failed(parseVarInt(dataSize)) ||
190 failed(alignTo(alignment)))
192 return parseBytes(dataSize, data);
202 LogicalResult parseVarInt(uint64_t &
result) {
209 if (LLVM_LIKELY(
result & 1)) {
217 if (LLVM_UNLIKELY(
result == 0)) {
218 llvm::support::ulittle64_t resultLE;
219 if (
failed(parseBytes(
sizeof(resultLE),
220 reinterpret_cast<uint8_t *
>(&resultLE))))
225 return parseMultiByteVarInt(
result);
231 LogicalResult parseSignedVarInt(uint64_t &
result) {
241 LogicalResult parseVarIntWithFlag(uint64_t &
result,
bool &flag) {
250 LogicalResult skipBytes(
size_t length) {
251 if (length > size()) {
252 return emitError(
"attempting to skip ", length,
" bytes when only ",
261 LogicalResult parseNullTerminatedString(StringRef &
result) {
262 const char *startIt = (
const char *)dataIt;
263 const char *nulIt = (
const char *)memchr(startIt, 0, size());
266 "malformed null-terminated string, no null character found");
268 result = StringRef(startIt, nulIt - startIt);
269 dataIt = (
const uint8_t *)nulIt + 1;
274 using ValidateAlignmentFn =
function_ref<LogicalResult(
unsigned alignment)>;
279 ValidateAlignmentFn alignmentValidator,
280 ArrayRef<uint8_t> §ionData) {
281 uint8_t sectionIDAndHasAlignment;
283 if (
failed(parseByte(sectionIDAndHasAlignment)) ||
284 failed(parseVarInt(length)))
291 bool hasAlignment = sectionIDAndHasAlignment & 0b10000000;
296 return emitError(
"invalid section ID: ",
unsigned(sectionID));
302 if (
failed(parseVarInt(alignment)))
337 if (
failed(alignmentValidator(alignment)))
338 return emitError(
"failed to align section ID: ",
unsigned(sectionID));
341 if (
failed(alignTo(alignment)))
346 return parseBytes(
static_cast<size_t>(length), sectionData);
349 Location getLoc()
const {
return fileLoc; }
358 LLVM_ATTRIBUTE_NOINLINE LogicalResult parseMultiByteVarInt(uint64_t &
result) {
364 uint32_t numBytes = llvm::countr_zero<uint32_t>(
result);
365 assert(numBytes > 0 && numBytes <= 7 &&
366 "unexpected number of trailing zeros in varint encoding");
369 llvm::support::ulittle64_t resultLE(
result);
371 parseBytes(numBytes,
reinterpret_cast<uint8_t *
>(&resultLE) + 1)))
376 result = resultLE >> (numBytes + 1);
381 ArrayRef<uint8_t> buffer;
384 const uint8_t *dataIt;
395template <
typename RangeT,
typename T>
396static LogicalResult
resolveEntry(EncodingReader &reader, RangeT &entries,
397 uint64_t
index, T &entry,
398 StringRef entryStr) {
399 if (
index >= entries.size())
400 return reader.emitError(
"invalid ", entryStr,
" index: ",
index);
403 if constexpr (std::is_convertible_v<llvm::detail::ValueOfRange<RangeT>, T>)
404 entry = entries[
index];
406 entry = &entries[
index];
411template <
typename RangeT,
typename T>
412static LogicalResult
parseEntry(EncodingReader &reader, RangeT &entries,
413 T &entry, StringRef entryStr) {
415 if (failed(reader.parseVarInt(entryIdx)))
417 return resolveEntry(reader, entries, entryIdx, entry, entryStr);
427class StringSectionReader {
430 LogicalResult
initialize(Location fileLoc, ArrayRef<uint8_t> sectionData);
434 LogicalResult parseString(EncodingReader &reader, StringRef &
result)
const {
441 LogicalResult parseStringWithFlag(EncodingReader &reader, StringRef &
result,
444 if (
failed(reader.parseVarIntWithFlag(entryIdx, flag)))
446 return parseStringAtIndex(reader, entryIdx,
result);
451 LogicalResult parseStringAtIndex(EncodingReader &reader, uint64_t index,
452 StringRef &
result)
const {
458 SmallVector<StringRef> strings;
462LogicalResult StringSectionReader::initialize(
Location fileLoc,
464 EncodingReader stringReader(sectionData, fileLoc);
468 if (
failed(stringReader.parseVarInt(numStrings)))
470 strings.resize(numStrings);
474 size_t stringDataEndOffset = sectionData.size();
475 for (StringRef &
string : llvm::reverse(strings)) {
477 if (
failed(stringReader.parseVarInt(stringSize)))
479 if (stringDataEndOffset < stringSize) {
480 return stringReader.emitError(
481 "string size exceeds the available data size");
485 size_t stringOffset = stringDataEndOffset - stringSize;
487 reinterpret_cast<const char *
>(sectionData.data() + stringOffset),
489 stringDataEndOffset = stringOffset;
494 if ((sectionData.size() - stringReader.size()) != stringDataEndOffset) {
495 return stringReader.emitError(
"unexpected trailing data between the "
496 "offsets for strings and their data");
509struct BytecodeDialect {
514 LogicalResult
load(
const DialectReader &reader, MLIRContext *ctx);
518 Dialect *getLoadedDialect()
const {
520 "expected `load` to be invoked before `getLoadedDialect`");
527 std::optional<Dialect *> dialect;
532 const BytecodeDialectInterface *
interface =
nullptr;
538 ArrayRef<uint8_t> versionBuffer;
541 std::unique_ptr<DialectVersion> loadedVersion;
545struct BytecodeOperationName {
546 BytecodeOperationName(BytecodeDialect *dialect, StringRef name,
547 std::optional<bool> wasRegistered)
548 : dialect(dialect), name(name), wasRegistered(wasRegistered) {}
552 std::optional<OperationName> opName;
555 BytecodeDialect *dialect;
562 std::optional<bool> wasRegistered;
568 EncodingReader &reader,
570 function_ref<LogicalResult(BytecodeDialect *)> entryCallback) {
572 std::unique_ptr<BytecodeDialect> *dialect;
573 if (failed(
parseEntry(reader, dialects, dialect,
"dialect")))
576 if (failed(reader.parseVarInt(numEntries)))
579 for (uint64_t i = 0; i < numEntries; ++i)
580 if (failed(entryCallback(dialect->get())))
591class ResourceSectionReader {
595 initialize(Location fileLoc,
const ParserConfig &config,
596 MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
597 StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData,
598 ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader,
599 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef);
602 LogicalResult parseResourceHandle(EncodingReader &reader,
603 AsmDialectResourceHandle &
result)
const {
609 SmallVector<AsmDialectResourceHandle> dialectResources;
610 llvm::StringMap<std::string> dialectResourceHandleRenamingMap;
613class ParsedResourceEntry :
public AsmParsedResourceEntry {
616 EncodingReader &reader, StringSectionReader &stringReader,
617 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef)
618 : key(key), kind(kind), reader(reader), stringReader(stringReader),
619 bufferOwnerRef(bufferOwnerRef) {}
620 ~ParsedResourceEntry()
override =
default;
622 StringRef getKey() const final {
return key; }
624 InFlightDiagnostic
emitError() const final {
return reader.emitError(); }
628 FailureOr<bool> parseAsBool() const final {
629 if (kind != AsmResourceEntryKind::Bool)
630 return emitError() <<
"expected a bool resource entry, but found a "
631 <<
toString(kind) <<
" entry instead";
634 if (
failed(reader.parseByte(value)))
638 FailureOr<std::string> parseAsString() const final {
639 if (kind != AsmResourceEntryKind::String)
640 return emitError() <<
"expected a string resource entry, but found a "
641 <<
toString(kind) <<
" entry instead";
644 if (
failed(stringReader.parseString(reader,
string)))
649 FailureOr<AsmResourceBlob>
650 parseAsBlob(BlobAllocatorFn allocator)
const final {
651 if (kind != AsmResourceEntryKind::Blob)
652 return emitError() <<
"expected a blob resource entry, but found a "
653 <<
toString(kind) <<
" entry instead";
655 ArrayRef<uint8_t> data;
657 if (
failed(reader.parseBlobAndAlignment(data, alignment)))
662 if (bufferOwnerRef) {
663 ArrayRef<char> charData(
reinterpret_cast<const char *
>(data.data()),
671 [bufferOwnerRef = bufferOwnerRef](
void *,
size_t,
size_t) {});
676 AsmResourceBlob blob = allocator(data.size(), alignment);
677 assert(llvm::isAddrAligned(llvm::Align(alignment), blob.
getData().data()) &&
679 "blob allocator did not return a properly aligned address");
687 EncodingReader &reader;
688 StringSectionReader &stringReader;
689 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef;
696 EncodingReader &offsetReader, EncodingReader &resourceReader,
697 StringSectionReader &stringReader, T *handler,
698 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef,
700 function_ref<LogicalResult(StringRef)> processKeyFn = {}) {
701 uint64_t numResources;
702 if (
failed(offsetReader.parseVarInt(numResources)))
705 for (uint64_t i = 0; i < numResources; ++i) {
708 uint64_t resourceOffset;
709 ArrayRef<uint8_t> data;
710 if (
failed(stringReader.parseString(offsetReader, key)) ||
711 failed(offsetReader.parseVarInt(resourceOffset)) ||
712 failed(offsetReader.parseByte(kind)) ||
713 failed(resourceReader.parseBytes(resourceOffset, data)))
717 if ((processKeyFn &&
failed(processKeyFn(key))))
722 if (allowEmpty && data.empty())
730 EncodingReader entryReader(data, fileLoc);
732 ParsedResourceEntry entry(key, kind, entryReader, stringReader,
734 if (
failed(handler->parseResource(entry)))
736 if (!entryReader.empty()) {
737 return entryReader.emitError(
738 "unexpected trailing bytes in resource entry '", key,
"'");
744LogicalResult ResourceSectionReader::initialize(
745 Location fileLoc,
const ParserConfig &config,
746 MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
747 StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData,
748 ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader,
749 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) {
750 EncodingReader resourceReader(sectionData, fileLoc);
751 EncodingReader offsetReader(offsetSectionData, fileLoc);
754 uint64_t numExternalResourceGroups;
755 if (
failed(offsetReader.parseVarInt(numExternalResourceGroups)))
760 auto parseGroup = [&](
auto *handler,
bool allowEmpty =
false,
762 auto resolveKey = [&](StringRef key) -> StringRef {
763 auto it = dialectResourceHandleRenamingMap.find(key);
764 if (it == dialectResourceHandleRenamingMap.end())
770 stringReader, handler, bufferOwnerRef, resolveKey,
775 for (uint64_t i = 0; i < numExternalResourceGroups; ++i) {
777 if (
failed(stringReader.parseString(offsetReader, key)))
784 emitWarning(fileLoc) <<
"ignoring unknown external resources for '" << key
788 if (
failed(parseGroup(handler)))
794 while (!offsetReader.empty()) {
795 std::unique_ptr<BytecodeDialect> *dialect;
797 failed((*dialect)->load(dialectReader, ctx)))
799 Dialect *loadedDialect = (*dialect)->getLoadedDialect();
800 if (!loadedDialect) {
801 return resourceReader.emitError()
802 <<
"dialect '" << (*dialect)->name <<
"' is unknown";
804 const auto *handler = dyn_cast<OpAsmDialectInterface>(loadedDialect);
806 return resourceReader.emitError()
807 <<
"unexpected resources for dialect '" << (*dialect)->name <<
"'";
811 auto processResourceKeyFn = [&](StringRef key) -> LogicalResult {
812 FailureOr<AsmDialectResourceHandle> handle =
813 handler->declareResource(key);
815 return resourceReader.emitError()
816 <<
"unknown 'resource' key '" << key <<
"' for dialect '"
817 << (*dialect)->name <<
"'";
819 dialectResourceHandleRenamingMap[key] = handler->getResourceKey(*handle);
820 dialectResources.push_back(*handle);
826 if (
failed(parseGroup(handler,
true, processResourceKeyFn)))
858class AttrTypeReader {
860 template <
typename T>
865 BytecodeDialect *dialect =
nullptr;
868 bool hasCustomEncoding =
false;
870 ArrayRef<uint8_t> data;
872 using AttrEntry = Entry<Attribute>;
873 using TypeEntry = Entry<Type>;
876 AttrTypeReader(
const StringSectionReader &stringReader,
877 const ResourceSectionReader &resourceReader,
878 const llvm::StringMap<BytecodeDialect *> &dialectsMap,
879 uint64_t &bytecodeVersion, Location fileLoc,
880 const ParserConfig &config)
881 : stringReader(stringReader), resourceReader(resourceReader),
882 dialectsMap(dialectsMap), fileLoc(fileLoc),
883 bytecodeVersion(bytecodeVersion), parserConfig(config) {}
887 initialize(MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
888 ArrayRef<uint8_t> sectionData,
889 ArrayRef<uint8_t> offsetSectionData);
891 LogicalResult readAttribute(uint64_t index, Attribute &
result,
892 uint64_t depth = 0) {
893 return readEntry(attributes, index,
result,
"attribute", depth);
896 LogicalResult readType(uint64_t index, Type &
result, uint64_t depth = 0) {
897 return readEntry(types, index,
result,
"type", depth);
902 Attribute resolveAttribute(
size_t index, uint64_t depth = 0) {
903 return resolveEntry(attributes, index,
"Attribute", depth);
905 Type resolveType(
size_t index, uint64_t depth = 0) {
909 Attribute getAttributeOrSentinel(
size_t index) {
910 if (index >= attributes.size())
912 return attributes[index].entry;
914 Type getTypeOrSentinel(
size_t index) {
915 if (index >= types.size())
917 return types[index].entry;
923 if (
failed(reader.parseVarInt(attrIdx)))
925 result = resolveAttribute(attrIdx);
928 LogicalResult parseOptionalAttribute(EncodingReader &reader,
932 if (
failed(reader.parseVarIntWithFlag(attrIdx, flag)))
936 result = resolveAttribute(attrIdx);
942 if (
failed(reader.parseVarInt(typeIdx)))
944 result = resolveType(typeIdx);
948 template <
typename T>
950 Attribute baseResult;
953 if ((
result = dyn_cast<T>(baseResult)))
955 return reader.emitError(
"expected attribute of type: ",
956 llvm::getTypeName<T>(),
", but got: ", baseResult);
960 enum class EntryKind { Attribute, Type };
963 void addDeferredParsing(uint64_t index, EntryKind kind) {
964 deferredWorklist.emplace_back(index, kind);
968 bool isResolving()
const {
return resolving; }
972 template <
typename T>
973 T
resolveEntry(SmallVectorImpl<Entry<T>> &entries, uint64_t index,
974 StringRef entryType, uint64_t depth = 0);
978 template <
typename T>
979 LogicalResult readEntry(SmallVectorImpl<Entry<T>> &entries, uint64_t index,
980 T &
result, StringRef entryType, uint64_t depth);
984 template <
typename T>
985 LogicalResult parseCustomEntry(Entry<T> &entry, EncodingReader &reader,
986 StringRef entryType, uint64_t index,
991 template <
typename T>
992 LogicalResult parseAsmEntry(T &
result, EncodingReader &reader,
993 StringRef entryType);
997 const StringSectionReader &stringReader;
1001 const ResourceSectionReader &resourceReader;
1005 const llvm::StringMap<BytecodeDialect *> &dialectsMap;
1008 SmallVector<AttrEntry> attributes;
1009 SmallVector<TypeEntry> types;
1015 uint64_t &bytecodeVersion;
1018 const ParserConfig &parserConfig;
1024 std::vector<std::pair<uint64_t, EntryKind>> deferredWorklist;
1027 bool resolving =
false;
1030class DialectReader :
public DialectBytecodeReader {
1032 DialectReader(AttrTypeReader &attrTypeReader,
1033 const StringSectionReader &stringReader,
1034 const ResourceSectionReader &resourceReader,
1035 const llvm::StringMap<BytecodeDialect *> &dialectsMap,
1036 EncodingReader &reader, uint64_t &bytecodeVersion,
1038 : attrTypeReader(attrTypeReader), stringReader(stringReader),
1039 resourceReader(resourceReader), dialectsMap(dialectsMap),
1040 reader(reader), bytecodeVersion(bytecodeVersion), depth(depth) {}
1042 InFlightDiagnostic
emitError(
const Twine &msg)
const override {
1043 return reader.emitError(msg);
1046 InFlightDiagnostic
emitWarning(
const Twine &msg)
const override {
1047 return reader.emitWarning(msg);
1050 FailureOr<const DialectVersion *>
1051 getDialectVersion(StringRef dialectName)
const override {
1053 auto dialectEntry = dialectsMap.find(dialectName);
1054 if (dialectEntry == dialectsMap.end())
1059 if (
failed(dialectEntry->getValue()->load(*
this, getLoc().
getContext())) ||
1060 dialectEntry->getValue()->loadedVersion ==
nullptr)
1062 return dialectEntry->getValue()->loadedVersion.get();
1065 MLIRContext *
getContext()
const override {
return getLoc().getContext(); }
1067 uint64_t getBytecodeVersion()
const override {
return bytecodeVersion; }
1069 DialectReader withEncodingReader(EncodingReader &encReader)
const {
1070 return DialectReader(attrTypeReader, stringReader, resourceReader,
1071 dialectsMap, encReader, bytecodeVersion);
1074 Location getLoc()
const {
return reader.getLoc(); }
1082 static constexpr uint64_t maxAttrTypeDepth = 5;
1084 LogicalResult readAttribute(Attribute &
result)
override {
1086 if (
failed(reader.parseVarInt(index)))
1092 if (!attrTypeReader.isResolving()) {
1093 if (Attribute attr = attrTypeReader.resolveAttribute(index)) {
1100 if (depth > maxAttrTypeDepth) {
1101 if (Attribute attr = attrTypeReader.getAttributeOrSentinel(index)) {
1105 attrTypeReader.addDeferredParsing(index,
1106 AttrTypeReader::EntryKind::Attribute);
1109 return attrTypeReader.readAttribute(index,
result, depth + 1);
1111 LogicalResult readOptionalAttribute(Attribute &
result)
override {
1112 return attrTypeReader.parseOptionalAttribute(reader,
result);
1114 LogicalResult readType(Type &
result)
override {
1116 if (
failed(reader.parseVarInt(index)))
1122 if (!attrTypeReader.isResolving()) {
1123 if (Type type = attrTypeReader.resolveType(index)) {
1130 if (depth > maxAttrTypeDepth) {
1131 if (Type type = attrTypeReader.getTypeOrSentinel(index)) {
1135 attrTypeReader.addDeferredParsing(index, AttrTypeReader::EntryKind::Type);
1138 return attrTypeReader.readType(index,
result, depth + 1);
1142 AsmDialectResourceHandle handle;
1143 if (
failed(resourceReader.parseResourceHandle(reader, handle)))
1152 LogicalResult readVarInt(uint64_t &
result)
override {
1153 return reader.parseVarInt(
result);
1156 LogicalResult readSignedVarInt(int64_t &
result)
override {
1157 uint64_t unsignedResult;
1158 if (
failed(reader.parseSignedVarInt(unsignedResult)))
1160 result =
static_cast<int64_t
>(unsignedResult);
1164 FailureOr<APInt> readAPIntWithKnownWidth(
unsigned bitWidth)
override {
1166 if (bitWidth <= 8) {
1168 if (
failed(reader.parseByte(value)))
1170 return APInt(bitWidth, value);
1174 if (bitWidth <= 64) {
1176 if (
failed(reader.parseSignedVarInt(value)))
1178 return APInt(bitWidth, value);
1183 uint64_t numActiveWords;
1184 if (
failed(reader.parseVarInt(numActiveWords)))
1186 SmallVector<uint64_t, 4> words(numActiveWords);
1187 for (uint64_t i = 0; i < numActiveWords; ++i)
1188 if (
failed(reader.parseSignedVarInt(words[i])))
1190 return APInt(bitWidth, words);
1194 readAPFloatWithKnownSemantics(
const llvm::fltSemantics &semantics)
override {
1195 FailureOr<APInt> intVal =
1196 readAPIntWithKnownWidth(APFloat::getSizeInBits(semantics));
1199 return APFloat(semantics, *intVal);
1202 LogicalResult readString(StringRef &
result)
override {
1203 return stringReader.parseString(reader,
result);
1206 LogicalResult readBlob(ArrayRef<char> &
result)
override {
1208 ArrayRef<uint8_t> data;
1209 if (
failed(reader.parseVarInt(dataSize)) ||
1210 failed(reader.parseBytes(dataSize, data)))
1212 result = llvm::ArrayRef(
reinterpret_cast<const char *
>(data.data()),
1217 LogicalResult readBool(
bool &
result)
override {
1218 return reader.parseByte(
result);
1222 AttrTypeReader &attrTypeReader;
1223 const StringSectionReader &stringReader;
1224 const ResourceSectionReader &resourceReader;
1225 const llvm::StringMap<BytecodeDialect *> &dialectsMap;
1226 EncodingReader &reader;
1227 uint64_t &bytecodeVersion;
1232class PropertiesSectionReader {
1235 LogicalResult
initialize(Location fileLoc, ArrayRef<uint8_t> sectionData) {
1236 if (sectionData.empty())
1238 EncodingReader propReader(sectionData, fileLoc);
1240 if (
failed(propReader.parseVarInt(count)))
1243 if (
failed(propReader.parseBytes(propReader.size(), propertiesBuffers)))
1246 EncodingReader offsetsReader(propertiesBuffers, fileLoc);
1247 offsetTable.reserve(count);
1248 for (
auto idx : llvm::seq<int64_t>(0, count)) {
1250 offsetTable.push_back(propertiesBuffers.size() - offsetsReader.size());
1251 ArrayRef<uint8_t> rawProperties;
1253 if (
failed(offsetsReader.parseVarInt(dataSize)) ||
1254 failed(offsetsReader.parseBytes(dataSize, rawProperties)))
1257 if (!offsetsReader.empty())
1258 return offsetsReader.emitError()
1259 <<
"Broken properties section: didn't exhaust the offsets table";
1263 LogicalResult read(Location fileLoc, DialectReader &dialectReader,
1264 OperationName *opName, OperationState &opState)
const {
1265 uint64_t propertiesIdx;
1266 if (
failed(dialectReader.readVarInt(propertiesIdx)))
1268 if (propertiesIdx >= offsetTable.size())
1269 return dialectReader.emitError(
"Properties idx out-of-bound for ")
1271 size_t propertiesOffset = offsetTable[propertiesIdx];
1272 if (propertiesIdx >= propertiesBuffers.size())
1273 return dialectReader.emitError(
"Properties offset out-of-bound for ")
1277 ArrayRef<char> rawProperties;
1281 EncodingReader reader(propertiesBuffers.drop_front(propertiesOffset),
1285 dialectReader.withEncodingReader(reader).readBlob(rawProperties)))
1289 EncodingReader reader(
1290 StringRef(rawProperties.begin(), rawProperties.size()), fileLoc);
1291 DialectReader propReader = dialectReader.withEncodingReader(reader);
1293 auto *iface = opName->
getInterface<BytecodeOpInterface>();
1295 return iface->readProperties(propReader, opState);
1297 return propReader.emitError(
1298 "has properties but missing BytecodeOpInterface for ")
1306 ArrayRef<uint8_t> propertiesBuffers;
1309 SmallVector<int64_t> offsetTable;
1313LogicalResult AttrTypeReader::initialize(
1314 MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
1315 ArrayRef<uint8_t> sectionData, ArrayRef<uint8_t> offsetSectionData) {
1316 EncodingReader offsetReader(offsetSectionData, fileLoc);
1319 uint64_t numAttributes, numTypes;
1320 if (
failed(offsetReader.parseVarInt(numAttributes)) ||
1321 failed(offsetReader.parseVarInt(numTypes)))
1323 attributes.resize(numAttributes);
1324 types.resize(numTypes);
1328 uint64_t currentOffset = 0;
1329 auto parseEntries = [&](
auto &&range) {
1330 size_t currentIndex = 0, endIndex = range.size();
1333 auto parseEntryFn = [&](BytecodeDialect *dialect) -> LogicalResult {
1334 auto &entry = range[currentIndex++];
1337 if (
failed(offsetReader.parseVarIntWithFlag(entrySize,
1338 entry.hasCustomEncoding)))
1342 if (currentOffset + entrySize > sectionData.size()) {
1343 return offsetReader.emitError(
1344 "Attribute or Type entry offset points past the end of section");
1347 entry.data = sectionData.slice(currentOffset, entrySize);
1348 entry.dialect = dialect;
1349 currentOffset += entrySize;
1352 while (currentIndex != endIndex)
1359 if (
failed(parseEntries(attributes)) ||
failed(parseEntries(types)))
1363 if (!offsetReader.empty()) {
1364 return offsetReader.emitError(
1365 "unexpected trailing data in the Attribute/Type offset section");
1371template <
typename T>
1372T AttrTypeReader::resolveEntry(SmallVectorImpl<Entry<T>> &entries,
1373 uint64_t index, StringRef entryType,
1375 bool oldResolving = resolving;
1377 llvm::scope_exit restoreResolving([&]() { resolving = oldResolving; });
1379 if (index >= entries.size()) {
1380 emitError(fileLoc) <<
"invalid " << entryType <<
" index: " << index;
1386 assert(deferredWorklist.empty());
1388 if (succeeded(readEntry(entries, index,
result, entryType, depth))) {
1389 assert(deferredWorklist.empty());
1392 if (deferredWorklist.empty()) {
1402 std::deque<std::pair<uint64_t, EntryKind>> worklist;
1403 llvm::DenseSet<std::pair<uint64_t, EntryKind>> inWorklist;
1405 EntryKind entryKind =
1406 std::is_same_v<T, Type> ? EntryKind::Type : EntryKind::Attribute;
1408 static_assert((std::is_same_v<T, Type> || std::is_same_v<T, Attribute>) &&
1409 "Only support resolving Attributes and Types");
1411 auto addToWorklistFront = [&](std::pair<uint64_t, EntryKind> entry) {
1412 if (inWorklist.insert(entry).second)
1413 worklist.push_front(entry);
1417 worklist.emplace_back(index, entryKind);
1418 inWorklist.insert({index, entryKind});
1419 for (
auto entry : llvm::reverse(deferredWorklist))
1420 addToWorklistFront(entry);
1422 while (!worklist.empty()) {
1423 auto [currentIndex, entryKind] = worklist.front();
1424 worklist.pop_front();
1427 deferredWorklist.clear();
1429 if (entryKind == EntryKind::Type) {
1431 if (succeeded(readType(currentIndex,
result, depth))) {
1432 inWorklist.erase({currentIndex, entryKind});
1436 assert(entryKind == EntryKind::Attribute &&
"Unexpected entry kind");
1438 if (succeeded(readAttribute(currentIndex,
result, depth))) {
1439 inWorklist.erase({currentIndex, entryKind});
1444 if (deferredWorklist.empty()) {
1450 worklist.emplace_back(currentIndex, entryKind);
1453 for (
auto entry : llvm::reverse(deferredWorklist))
1454 addToWorklistFront(entry);
1456 deferredWorklist.clear();
1458 return entries[index].entry;
1461template <
typename T>
1462LogicalResult AttrTypeReader::readEntry(SmallVectorImpl<Entry<T>> &entries,
1463 uint64_t index, T &
result,
1464 StringRef entryType, uint64_t depth) {
1465 if (index >= entries.size())
1466 return emitError(fileLoc) <<
"invalid " << entryType <<
" index: " << index;
1469 Entry<T> &entry = entries[index];
1476 EncodingReader reader(entry.data, fileLoc);
1477 LogicalResult parseResult =
1478 entry.hasCustomEncoding
1479 ? parseCustomEntry(entry, reader, entryType, index, depth)
1480 : parseAsmEntry(entry.entry, reader, entryType);
1484 if (!reader.empty())
1485 return reader.emitError(
"unexpected trailing bytes after " + entryType +
1492template <
typename T>
1493LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry,
1494 EncodingReader &reader,
1495 StringRef entryType,
1496 uint64_t index, uint64_t depth) {
1497 DialectReader dialectReader(*
this, stringReader, resourceReader, dialectsMap,
1498 reader, bytecodeVersion, depth);
1502 if constexpr (std::is_same_v<T, Type>) {
1504 for (
const auto &callback :
1506 size_t savedWorklistSize = deferredWorklist.size();
1508 callback->read(dialectReader, entry.dialect->name, entry.entry)))
1517 deferredWorklist.resize(savedWorklistSize);
1518 reader = EncodingReader(entry.data, reader.getLoc());
1522 for (
const auto &callback :
1524 size_t savedWorklistSize = deferredWorklist.size();
1526 callback->read(dialectReader, entry.dialect->name, entry.entry)))
1535 deferredWorklist.resize(savedWorklistSize);
1536 reader = EncodingReader(entry.data, reader.getLoc());
1541 if (!entry.dialect->interface) {
1542 return reader.emitError(
"dialect '", entry.dialect->name,
1543 "' does not implement the bytecode interface");
1546 if constexpr (std::is_same_v<T, Type>)
1547 entry.entry = entry.dialect->interface->readType(dialectReader);
1549 entry.entry = entry.dialect->interface->readAttribute(dialectReader);
1551 return success(!!entry.entry);
1554template <
typename T>
1555LogicalResult AttrTypeReader::parseAsmEntry(T &
result, EncodingReader &reader,
1556 StringRef entryType) {
1558 if (
failed(reader.parseNullTerminatedString(asmStr)))
1563 MLIRContext *context = fileLoc->
getContext();
1564 if constexpr (std::is_same_v<T, Type>)
1574 if (numRead != asmStr.size()) {
1575 return reader.emitError(
"trailing characters found after ", entryType,
1576 " assembly format: ", asmStr.drop_front(numRead));
1587 struct RegionReadState;
1588 using LazyLoadableOpsInfo =
1589 std::list<std::pair<Operation *, RegionReadState>>;
1590 using LazyLoadableOpsMap =
1595 llvm::MemoryBufferRef buffer,
1596 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef)
1597 : config(config), fileLoc(fileLoc), lazyLoading(lazyLoading),
1598 attrTypeReader(stringReader, resourceReader, dialectsMap, version,
1602 forwardRefOpState(UnknownLoc::
get(config.getContext()),
1603 "builtin.unrealized_conversion_cast",
ValueRange(),
1604 NoneType::
get(config.getContext())),
1605 buffer(buffer), bufferOwnerRef(bufferOwnerRef) {}
1608 LogicalResult read(
Block *block,
1621 this->lazyOpsCallback = lazyOpsCallback;
1622 llvm::scope_exit resetlazyOpsCallback(
1623 [&] { this->lazyOpsCallback =
nullptr; });
1624 auto it = lazyLoadableOpsMap.find(op);
1625 assert(it != lazyLoadableOpsMap.end() &&
1626 "materialize called on non-materializable op");
1632 while (!lazyLoadableOpsMap.empty()) {
1633 if (failed(
materialize(lazyLoadableOpsMap.begin())))
1644 while (!lazyLoadableOps.empty()) {
1645 Operation *op = lazyLoadableOps.begin()->first;
1646 if (shouldMaterialize(op)) {
1647 if (failed(
materialize(lazyLoadableOpsMap.find(op))))
1653 lazyLoadableOps.pop_front();
1654 lazyLoadableOpsMap.erase(op);
1660 LogicalResult
materialize(LazyLoadableOpsMap::iterator it) {
1661 assert(it != lazyLoadableOpsMap.end() &&
1662 "materialize called on non-materializable op");
1663 valueScopes.emplace_back();
1664 std::vector<RegionReadState> regionStack;
1665 regionStack.push_back(std::move(it->getSecond()->second));
1666 lazyLoadableOps.erase(it->getSecond());
1667 lazyLoadableOpsMap.erase(it);
1669 while (!regionStack.empty())
1670 if (failed(
parseRegions(regionStack, regionStack.back())))
1675 LogicalResult checkSectionAlignment(
1686 const bool isGloballyAligned =
1687 ((uintptr_t)buffer.getBufferStart() & (alignment - 1)) == 0;
1689 if (!isGloballyAligned)
1690 return emitError(
"expected section alignment ")
1691 << alignment <<
" but bytecode buffer 0x"
1692 << Twine::utohexstr((uint64_t)buffer.getBufferStart())
1693 <<
" is not aligned";
1702 LogicalResult parseVersion(EncodingReader &reader);
1707 LogicalResult parseDialectSection(ArrayRef<uint8_t> sectionData);
1712 FailureOr<OperationName> parseOpName(EncodingReader &reader,
1713 std::optional<bool> &wasRegistered);
1719 template <
typename T>
1721 return attrTypeReader.parseAttribute(reader,
result);
1724 return attrTypeReader.parseType(reader,
result);
1731 parseResourceSection(EncodingReader &reader,
1732 std::optional<ArrayRef<uint8_t>> resourceData,
1733 std::optional<ArrayRef<uint8_t>> resourceOffsetData);
1740 struct RegionReadState {
1741 RegionReadState(Operation *op, EncodingReader *reader,
1742 bool isIsolatedFromAbove)
1743 : RegionReadState(op->getRegions(), reader, isIsolatedFromAbove) {}
1744 RegionReadState(MutableArrayRef<Region> regions, EncodingReader *reader,
1745 bool isIsolatedFromAbove)
1746 : curRegion(regions.begin()), endRegion(regions.end()), reader(reader),
1747 isIsolatedFromAbove(isIsolatedFromAbove) {}
1750 MutableArrayRef<Region>::iterator curRegion, endRegion;
1755 EncodingReader *reader;
1756 std::unique_ptr<EncodingReader> owningReader;
1759 unsigned numValues = 0;
1762 SmallVector<Block *> curBlocks;
1767 uint64_t numOpsRemaining = 0;
1770 bool isIsolatedFromAbove =
false;
1773 LogicalResult parseIRSection(ArrayRef<uint8_t> sectionData,
Block *block);
1774 LogicalResult
parseRegions(std::vector<RegionReadState> ®ionStack,
1775 RegionReadState &readState);
1776 FailureOr<Operation *> parseOpWithoutRegions(EncodingReader &reader,
1777 RegionReadState &readState,
1778 bool &isIsolatedFromAbove);
1780 LogicalResult parseRegion(RegionReadState &readState);
1781 LogicalResult parseBlockHeader(EncodingReader &reader,
1782 RegionReadState &readState);
1783 LogicalResult parseBlockArguments(EncodingReader &reader,
Block *block);
1790 Value parseOperand(EncodingReader &reader);
1793 LogicalResult defineValues(EncodingReader &reader,
ValueRange values);
1796 Value createForwardRef();
1804 struct UseListOrderStorage {
1805 UseListOrderStorage(
bool isIndexPairEncoding,
1806 SmallVector<unsigned, 4> &&
indices)
1808 isIndexPairEncoding(isIndexPairEncoding) {};
1811 SmallVector<unsigned, 4>
indices;
1815 bool isIndexPairEncoding;
1823 using UseListMapT = DenseMap<unsigned, UseListOrderStorage>;
1824 FailureOr<UseListMapT> parseUseListOrderForRange(EncodingReader &reader,
1825 uint64_t rangeSize);
1828 LogicalResult sortUseListOrder(Value value);
1832 LogicalResult processUseLists(Operation *topLevelOp);
1842 void push(RegionReadState &readState) {
1843 nextValueIDs.push_back(values.size());
1844 values.resize(values.size() + readState.numValues);
1849 void pop(RegionReadState &readState) {
1850 values.resize(values.size() - readState.numValues);
1851 nextValueIDs.pop_back();
1855 std::vector<Value> values;
1859 SmallVector<unsigned, 4> nextValueIDs;
1863 const ParserConfig &config;
1874 LazyLoadableOpsInfo lazyLoadableOps;
1875 LazyLoadableOpsMap lazyLoadableOpsMap;
1876 llvm::function_ref<bool(Operation *)> lazyOpsCallback;
1879 AttrTypeReader attrTypeReader;
1882 uint64_t version = 0;
1888 SmallVector<std::unique_ptr<BytecodeDialect>> dialects;
1889 llvm::StringMap<BytecodeDialect *> dialectsMap;
1890 SmallVector<BytecodeOperationName> opNames;
1893 ResourceSectionReader resourceReader;
1897 DenseMap<void *, UseListOrderStorage> valueToUseListMap;
1900 StringSectionReader stringReader;
1903 PropertiesSectionReader propertiesReader;
1906 std::vector<ValueScope> valueScopes;
1913 Block forwardRefOps;
1917 Block openForwardRefOps;
1920 OperationState forwardRefOpState;
1923 llvm::MemoryBufferRef buffer;
1927 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef;
1932 EncodingReader reader(buffer.getBuffer(), fileLoc);
1933 this->lazyOpsCallback = lazyOpsCallback;
1934 llvm::scope_exit resetlazyOpsCallback(
1935 [&] { this->lazyOpsCallback =
nullptr; });
1938 if (failed(reader.skipBytes(StringRef(
"ML\xefR").size())))
1941 if (failed(parseVersion(reader)) ||
1942 failed(reader.parseNullTerminatedString(producer)))
1948 diag.attachNote() <<
"in bytecode version " << version
1949 <<
" produced by: " << producer;
1953 const auto checkSectionAlignment = [&](
unsigned alignment) {
1954 return this->checkSectionAlignment(
1955 alignment, [&](
const auto &msg) {
return reader.emitError(msg); });
1959 std::optional<ArrayRef<uint8_t>>
1961 while (!reader.empty()) {
1966 reader.parseSection(sectionID, checkSectionAlignment, sectionData)))
1970 if (sectionDatas[sectionID]) {
1971 return reader.emitError(
"duplicate top-level section: ",
1974 sectionDatas[sectionID] = sectionData;
1980 return reader.emitError(
"missing data for top-level section: ",
1986 if (failed(stringReader.initialize(
1992 failed(propertiesReader.initialize(
2001 if (failed(parseResourceSection(
2007 if (failed(attrTypeReader.initialize(
2016LogicalResult BytecodeReader::Impl::parseVersion(EncodingReader &reader) {
2017 if (failed(reader.parseVarInt(version)))
2023 if (version < minSupportedVersion) {
2024 return reader.emitError(
"bytecode version ", version,
2025 " is older than the current version of ",
2026 currentVersion,
", and upgrade is not supported");
2028 if (version > currentVersion) {
2029 return reader.emitError(
"bytecode version ", version,
2030 " is newer than the current version ",
2035 lazyLoading =
false;
2043LogicalResult BytecodeDialect::load(
const DialectReader &reader,
2049 return reader.emitError(
"dialect '")
2051 <<
"' is unknown. If this is intended, please call "
2052 "allowUnregisteredDialects() on the MLIRContext, or use "
2053 "-allow-unregistered-dialect with the MLIR tool used.";
2055 dialect = loadedDialect;
2060 interface = dyn_cast<BytecodeDialectInterface>(loadedDialect);
2061 if (!versionBuffer.empty()) {
2063 return reader.emitError(
"dialect '")
2065 <<
"' does not implement the bytecode interface, "
2066 "but found a version entry";
2067 EncodingReader encReader(versionBuffer, reader.getLoc());
2068 DialectReader versionReader = reader.withEncodingReader(encReader);
2069 loadedVersion = interface->readVersion(versionReader);
2077BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) {
2078 EncodingReader sectionReader(sectionData, fileLoc);
2081 uint64_t numDialects;
2082 if (
failed(sectionReader.parseVarInt(numDialects)))
2084 dialects.resize(numDialects);
2086 const auto checkSectionAlignment = [&](
unsigned alignment) {
2087 return this->checkSectionAlignment(alignment, [&](
const auto &msg) {
2088 return sectionReader.emitError(msg);
2093 for (uint64_t i = 0; i < numDialects; ++i) {
2094 dialects[i] = std::make_unique<BytecodeDialect>();
2098 if (
failed(stringReader.parseString(sectionReader, dialects[i]->name)))
2104 uint64_t dialectNameIdx;
2105 bool versionAvailable;
2106 if (
failed(sectionReader.parseVarIntWithFlag(dialectNameIdx,
2109 if (
failed(stringReader.parseStringAtIndex(sectionReader, dialectNameIdx,
2110 dialects[i]->name)))
2112 if (versionAvailable) {
2114 if (
failed(sectionReader.parseSection(sectionID, checkSectionAlignment,
2115 dialects[i]->versionBuffer)))
2118 emitError(fileLoc,
"expected dialect version section");
2122 dialectsMap[dialects[i]->name] = dialects[i].get();
2126 auto parseOpName = [&](BytecodeDialect *dialect) {
2128 std::optional<bool> wasRegistered;
2132 if (
failed(stringReader.parseString(sectionReader, opName)))
2135 bool wasRegisteredFlag;
2136 if (
failed(stringReader.parseStringWithFlag(sectionReader, opName,
2137 wasRegisteredFlag)))
2139 wasRegistered = wasRegisteredFlag;
2141 opNames.emplace_back(dialect, opName, wasRegistered);
2148 if (
failed(sectionReader.parseVarInt(numOps)))
2150 opNames.reserve(numOps);
2152 while (!sectionReader.empty())
2158FailureOr<OperationName>
2159BytecodeReader::Impl::parseOpName(EncodingReader &reader,
2160 std::optional<bool> &wasRegistered) {
2161 BytecodeOperationName *opName =
nullptr;
2164 wasRegistered = opName->wasRegistered;
2167 if (!opName->opName) {
2172 if (opName->name.empty()) {
2173 opName->opName.emplace(opName->dialect->name,
getContext());
2176 DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
2177 dialectsMap, reader, version);
2180 opName->opName.emplace((opName->dialect->name +
"." + opName->name).str(),
2184 return *opName->opName;
2191LogicalResult BytecodeReader::Impl::parseResourceSection(
2192 EncodingReader &reader, std::optional<ArrayRef<uint8_t>> resourceData,
2193 std::optional<ArrayRef<uint8_t>> resourceOffsetData) {
2195 if (resourceData.has_value() != resourceOffsetData.has_value()) {
2196 if (resourceOffsetData)
2197 return emitError(fileLoc,
"unexpected resource offset section when "
2198 "resource section is not present");
2201 "expected resource offset section when resource section is present");
2209 DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
2210 dialectsMap, reader, version);
2211 return resourceReader.initialize(fileLoc, config, dialects, stringReader,
2212 *resourceData, *resourceOffsetData,
2213 dialectReader, bufferOwnerRef);
2220FailureOr<BytecodeReader::Impl::UseListMapT>
2221BytecodeReader::Impl::parseUseListOrderForRange(EncodingReader &reader,
2222 uint64_t numResults) {
2223 BytecodeReader::Impl::UseListMapT map;
2224 uint64_t numValuesToRead = 1;
2225 if (numResults > 1 &&
failed(reader.parseVarInt(numValuesToRead)))
2228 for (
size_t valueIdx = 0; valueIdx < numValuesToRead; valueIdx++) {
2229 uint64_t resultIdx = 0;
2230 if (numResults > 1 &&
failed(reader.parseVarInt(resultIdx)))
2234 bool indexPairEncoding;
2235 if (
failed(reader.parseVarIntWithFlag(numValues, indexPairEncoding)))
2238 SmallVector<unsigned, 4> useListOrders;
2239 for (
size_t idx = 0; idx < numValues; idx++) {
2241 if (
failed(reader.parseVarInt(index)))
2243 useListOrders.push_back(index);
2247 map.try_emplace(resultIdx, UseListOrderStorage(indexPairEncoding,
2248 std::move(useListOrders)));
2259LogicalResult BytecodeReader::Impl::sortUseListOrder(Value value) {
2264 bool hasIncomingOrder =
2269 bool alreadySorted =
true;
2273 llvm::SmallVector<std::pair<unsigned, uint64_t>> currentOrder = {{0, prevID}};
2274 for (
auto item : llvm::drop_begin(llvm::enumerate(value.
getUses()))) {
2276 item.value(), operationIDs.at(item.value().getOwner()));
2277 alreadySorted &= prevID > currentID;
2278 currentOrder.push_back({item.index(), currentID});
2284 if (alreadySorted && !hasIncomingOrder)
2291 currentOrder.begin(), currentOrder.end(),
2292 [](
auto elem1,
auto elem2) { return elem1.second > elem2.second; });
2294 if (!hasIncomingOrder) {
2298 SmallVector<unsigned> shuffle(llvm::make_first_range(currentOrder));
2304 UseListOrderStorage customOrder =
2306 SmallVector<unsigned, 4> shuffle = std::move(customOrder.indices);
2312 if (customOrder.isIndexPairEncoding) {
2314 if (shuffle.size() & 1)
2317 SmallVector<unsigned, 4> newShuffle(numUses);
2319 std::iota(newShuffle.begin(), newShuffle.end(), idx);
2320 for (idx = 0; idx < shuffle.size(); idx += 2)
2321 newShuffle[shuffle[idx]] = shuffle[idx + 1];
2323 shuffle = std::move(newShuffle);
2330 uint64_t accumulator = 0;
2331 for (
const auto &elem : shuffle) {
2332 if (!set.insert(elem).second)
2334 accumulator += elem;
2336 if (numUses != shuffle.size() ||
2337 accumulator != (((numUses - 1) * numUses) >> 1))
2342 shuffle = SmallVector<unsigned, 4>(llvm::map_range(
2343 currentOrder, [&](
auto item) {
return shuffle[item.first]; }));
2348LogicalResult BytecodeReader::Impl::processUseLists(Operation *topLevelOp) {
2352 unsigned operationID = 0;
2354 [&](Operation *op) { operationIDs.try_emplace(op, operationID++); });
2356 auto blockWalk = topLevelOp->
walk([
this](
Block *block) {
2358 if (
failed(sortUseListOrder(arg)))
2363 auto resultWalk = topLevelOp->
walk([
this](Operation *op) {
2370 return failure(blockWalk.wasInterrupted() || resultWalk.wasInterrupted());
2378BytecodeReader::Impl::parseIRSection(ArrayRef<uint8_t> sectionData,
2380 EncodingReader reader(sectionData, fileLoc);
2383 std::vector<RegionReadState> regionStack;
2386 OwningOpRef<ModuleOp> moduleOp = ModuleOp::create(fileLoc);
2387 regionStack.emplace_back(*moduleOp, &reader,
true);
2388 regionStack.back().curBlocks.push_back(moduleOp->getBody());
2389 regionStack.back().curBlock = regionStack.back().curRegion->begin();
2390 if (
failed(parseBlockHeader(reader, regionStack.back())))
2392 valueScopes.emplace_back();
2393 valueScopes.back().push(regionStack.back());
2396 while (!regionStack.empty())
2399 if (!forwardRefOps.empty()) {
2400 return reader.emitError(
2401 "not all forward unresolved forward operand references");
2405 if (
failed(processUseLists(*moduleOp)))
2406 return reader.emitError(
2407 "parsed use-list orders were invalid and could not be applied");
2410 for (
const std::unique_ptr<BytecodeDialect> &byteCodeDialect : dialects) {
2413 if (!byteCodeDialect->loadedVersion)
2415 if (byteCodeDialect->interface &&
2416 failed(byteCodeDialect->interface->upgradeFromVersion(
2417 *moduleOp, *byteCodeDialect->loadedVersion)))
2426 auto &parsedOps = moduleOp->getBody()->getOperations();
2428 destOps.splice(destOps.end(), parsedOps, parsedOps.begin(), parsedOps.end());
2433BytecodeReader::Impl::parseRegions(std::vector<RegionReadState> ®ionStack,
2434 RegionReadState &readState) {
2435 const auto checkSectionAlignment = [&](
unsigned alignment) {
2436 return this->checkSectionAlignment(
2437 alignment, [&](
const auto &msg) {
return emitError(fileLoc, msg); });
2443 for (; readState.curRegion != readState.endRegion; ++readState.curRegion) {
2449 if (
failed(parseRegion(readState)))
2453 if (readState.curRegion->empty())
2458 EncodingReader &reader = *readState.reader;
2460 while (readState.numOpsRemaining--) {
2463 bool isIsolatedFromAbove =
false;
2464 FailureOr<Operation *> op =
2465 parseOpWithoutRegions(reader, readState, isIsolatedFromAbove);
2473 if ((*op)->getNumRegions()) {
2474 RegionReadState childState(*op, &reader, isIsolatedFromAbove);
2479 ArrayRef<uint8_t> sectionData;
2480 if (
failed(reader.parseSection(sectionID, checkSectionAlignment,
2484 return emitError(fileLoc,
"expected IR section for region");
2485 childState.owningReader =
2486 std::make_unique<EncodingReader>(sectionData, fileLoc);
2487 childState.reader = childState.owningReader.get();
2491 if (lazyLoading && (!lazyOpsCallback || !lazyOpsCallback(*op))) {
2492 lazyLoadableOps.emplace_back(*op, std::move(childState));
2493 lazyLoadableOpsMap.try_emplace(*op,
2494 std::prev(lazyLoadableOps.end()));
2498 regionStack.push_back(std::move(childState));
2501 if (isIsolatedFromAbove)
2502 valueScopes.emplace_back();
2508 if (++readState.curBlock == readState.curRegion->end())
2510 if (
failed(parseBlockHeader(reader, readState)))
2515 readState.curBlock = {};
2516 valueScopes.back().pop(readState);
2521 if (readState.isIsolatedFromAbove) {
2522 assert(!valueScopes.empty() &&
"Expect a valueScope after reading region");
2523 valueScopes.pop_back();
2525 assert(!regionStack.empty() &&
"Expect a regionStack after reading region");
2526 regionStack.pop_back();
2530FailureOr<Operation *>
2531BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader,
2532 RegionReadState &readState,
2533 bool &isIsolatedFromAbove) {
2535 std::optional<bool> wasRegistered;
2536 FailureOr<OperationName> opName = parseOpName(reader, wasRegistered);
2543 if (
failed(reader.parseByte(opMask)))
2553 OperationState opState(opLoc, *opName);
2557 DictionaryAttr dictAttr;
2568 "Unexpected missing `wasRegistered` opname flag at "
2569 "bytecode version ")
2570 << version <<
" with properties.";
2574 if (wasRegistered) {
2575 DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
2576 dialectsMap, reader, version);
2578 propertiesReader.read(fileLoc, dialectReader, &*opName, opState)))
2590 uint64_t numResults;
2591 if (
failed(reader.parseVarInt(numResults)))
2593 opState.
types.resize(numResults);
2594 for (
int i = 0, e = numResults; i < e; ++i)
2601 uint64_t numOperands;
2602 if (
failed(reader.parseVarInt(numOperands)))
2604 opState.
operands.resize(numOperands);
2605 for (
int i = 0, e = numOperands; i < e; ++i)
2606 if (!(opState.
operands[i] = parseOperand(reader)))
2613 if (
failed(reader.parseVarInt(numSuccs)))
2616 for (
int i = 0, e = numSuccs; i < e; ++i) {
2625 std::optional<UseListMapT> resultIdxToUseListMap = std::nullopt;
2628 size_t numResults = opState.
types.size();
2629 auto parseResult = parseUseListOrderForRange(reader, numResults);
2632 resultIdxToUseListMap = std::move(*parseResult);
2637 uint64_t numRegions;
2638 if (
failed(reader.parseVarIntWithFlag(numRegions, isIsolatedFromAbove)))
2641 opState.
regions.reserve(numRegions);
2642 for (
int i = 0, e = numRegions; i < e; ++i)
2643 opState.
regions.push_back(std::make_unique<Region>());
2648 readState.curBlock->push_back(op);
2659 if (resultIdxToUseListMap.has_value()) {
2661 if (resultIdxToUseListMap->contains(idx)) {
2663 resultIdxToUseListMap->at(idx));
2670LogicalResult BytecodeReader::Impl::parseRegion(RegionReadState &readState) {
2671 EncodingReader &reader = *readState.reader;
2675 if (
failed(reader.parseVarInt(numBlocks)))
2684 if (
failed(reader.parseVarInt(numValues)))
2686 readState.numValues = numValues;
2690 readState.curBlocks.clear();
2691 readState.curBlocks.reserve(numBlocks);
2692 for (uint64_t i = 0; i < numBlocks; ++i) {
2693 readState.curBlocks.push_back(
new Block());
2694 readState.curRegion->push_back(readState.curBlocks.back());
2698 valueScopes.back().push(readState);
2701 readState.curBlock = readState.curRegion->begin();
2702 return parseBlockHeader(reader, readState);
2706BytecodeReader::Impl::parseBlockHeader(EncodingReader &reader,
2707 RegionReadState &readState) {
2709 if (
failed(reader.parseVarIntWithFlag(readState.numOpsRemaining, hasArgs)))
2713 if (hasArgs &&
failed(parseBlockArguments(reader, &*readState.curBlock)))
2720 uint8_t hasUseListOrders = 0;
2721 if (hasArgs &&
failed(reader.parseByte(hasUseListOrders)))
2724 if (!hasUseListOrders)
2727 Block &blk = *readState.curBlock;
2728 auto argIdxToUseListMap =
2730 if (
failed(argIdxToUseListMap) || argIdxToUseListMap->empty())
2734 if (argIdxToUseListMap->contains(idx))
2736 argIdxToUseListMap->at(idx));
2742LogicalResult BytecodeReader::Impl::parseBlockArguments(EncodingReader &reader,
2746 if (
failed(reader.parseVarInt(numArgs)))
2749 SmallVector<Type> argTypes;
2750 SmallVector<Location> argLocs;
2751 argTypes.reserve(numArgs);
2752 argLocs.reserve(numArgs);
2754 Location unknownLoc = UnknownLoc::get(config.
getContext());
2757 LocationAttr argLoc = unknownLoc;
2762 if (
failed(reader.parseVarIntWithFlag(typeIdx, hasLoc)) ||
2763 !(argType = attrTypeReader.resolveType(typeIdx)))
2773 argTypes.push_back(argType);
2774 argLocs.push_back(argLoc);
2784Value BytecodeReader::Impl::parseOperand(EncodingReader &reader) {
2785 std::vector<Value> &values = valueScopes.back().values;
2786 Value *value =
nullptr;
2792 *value = createForwardRef();
2796LogicalResult BytecodeReader::Impl::defineValues(EncodingReader &reader,
2798 ValueScope &valueScope = valueScopes.back();
2799 std::vector<Value> &values = valueScope.values;
2801 unsigned &valueID = valueScope.nextValueIDs.back();
2802 unsigned valueIDEnd = valueID + newValues.size();
2803 if (valueIDEnd > values.size()) {
2804 return reader.emitError(
2805 "value index range was outside of the expected range for "
2806 "the parent region, got [",
2807 valueID,
", ", valueIDEnd,
"), but the maximum index was ",
2812 for (
unsigned i = 0, e = newValues.size(); i != e; ++i, ++valueID) {
2813 Value newValue = newValues[i];
2816 if (Value oldValue = std::exchange(values[valueID], newValue)) {
2817 Operation *forwardRefOp = oldValue.getDefiningOp();
2822 assert(forwardRefOp && forwardRefOp->
getBlock() == &forwardRefOps &&
2823 "value index was already defined?");
2825 oldValue.replaceAllUsesWith(newValue);
2826 forwardRefOp->
moveBefore(&openForwardRefOps, openForwardRefOps.end());
2832Value BytecodeReader::Impl::createForwardRef() {
2835 if (!openForwardRefOps.empty()) {
2836 Operation *op = &openForwardRefOps.back();
2837 op->
moveBefore(&forwardRefOps, forwardRefOps.end());
2841 return forwardRefOps.back().getResult(0);
2851 llvm::MemoryBufferRef buffer,
const ParserConfig &config,
bool lazyLoading,
2852 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) {
2856 impl = std::make_unique<Impl>(sourceFileLoc, config, lazyLoading, buffer,
2862 return impl->read(block, lazyOpsCallback);
2866 return impl->getNumOpsToMaterialize();
2870 return impl->isMaterializable(op);
2875 return impl->materialize(op, lazyOpsCallback);
2880 return impl->finalize(shouldMaterialize);
2884 return buffer.getBuffer().starts_with(
"ML\xefR");
2893 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) {
2899 "input buffer is not an MLIR bytecode file");
2903 buffer, bufferOwnerRef);
2904 return reader.
read(block,
nullptr);
2915 *sourceMgr->getMemoryBuffer(sourceMgr->getMainFileID()), block, config,
static LogicalResult parseDialectGrouping(EncodingReader &reader, MutableArrayRef< std::unique_ptr< BytecodeDialect > > dialects, function_ref< LogicalResult(BytecodeDialect *)> entryCallback)
Parse a single dialect group encoded in the byte stream.
static LogicalResult readBytecodeFileImpl(llvm::MemoryBufferRef buffer, Block *block, const ParserConfig &config, const std::shared_ptr< llvm::SourceMgr > &bufferOwnerRef)
Read the bytecode from the provided memory buffer reference.
static bool isSectionOptional(bytecode::Section::ID sectionID, int version)
Returns true if the given top-level section ID is optional.
static LogicalResult parseResourceGroup(Location fileLoc, bool allowEmpty, EncodingReader &offsetReader, EncodingReader &resourceReader, StringSectionReader &stringReader, T *handler, const std::shared_ptr< llvm::SourceMgr > &bufferOwnerRef, function_ref< StringRef(StringRef)> remapKey={}, function_ref< LogicalResult(StringRef)> processKeyFn={})
static LogicalResult resolveEntry(EncodingReader &reader, RangeT &entries, uint64_t index, T &entry, StringRef entryStr)
Resolve an index into the given entry list.
static LogicalResult parseEntry(EncodingReader &reader, RangeT &entries, T &entry, StringRef entryStr)
Parse and resolve an index into the given entry list.
static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, unsigned nRegions=1)
LogicalResult initialize(unsigned origNumLoops, ArrayRef< ReassociationIndices > foldedIterationDims)
static std::string diag(const llvm::Value &value)
MutableArrayRef< char > getMutableData()
Return a mutable reference to the raw underlying data of this blob.
ArrayRef< char > getData() const
Return the raw underlying data of this blob.
bool isMutable() const
Return if the data of this blob is mutable.
MLIRContext * getContext() const
Return the context this attribute belongs to.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
OpListType & getOperations()
BlockArgListType getArguments()
ArrayRef< std::unique_ptr< AttrTypeBytecodeReader< Type > > > getTypeCallbacks() const
ArrayRef< std::unique_ptr< AttrTypeBytecodeReader< Attribute > > > getAttributeCallbacks() const
Returns the callbacks available to the parser.
This class is used to read a bytecode buffer and translate it into MLIR.
LogicalResult materializeAll()
Materialize all operations.
LogicalResult read(Block *block, llvm::function_ref< bool(Operation *)> lazyOps)
Read the bytecode defined within buffer into the given block.
bool isMaterializable(Operation *op)
Impl(Location fileLoc, const ParserConfig &config, bool lazyLoading, llvm::MemoryBufferRef buffer, const std::shared_ptr< llvm::SourceMgr > &bufferOwnerRef)
LogicalResult finalize(function_ref< bool(Operation *)> shouldMaterialize)
Finalize the lazy-loading by calling back with every op that hasn't been materialized to let the clie...
LogicalResult materialize(Operation *op, llvm::function_ref< bool(Operation *)> lazyOpsCallback)
Materialize the provided operation, invoke the lazyOpsCallback on every newly found lazy operation.
int64_t getNumOpsToMaterialize() const
Return the number of ops that haven't been materialized yet.
LogicalResult materialize(Operation *op, llvm::function_ref< bool(Operation *)> lazyOpsCallback=[](Operation *) { return false;})
Materialize the provide operation.
LogicalResult finalize(function_ref< bool(Operation *)> shouldMaterialize=[](Operation *) { return true;})
Finalize the lazy-loading by calling back with every op that hasn't been materialized to let the clie...
BytecodeReader(llvm::MemoryBufferRef buffer, const ParserConfig &config, bool lazyLoad, const std::shared_ptr< llvm::SourceMgr > &bufferOwnerRef={})
Create a bytecode reader for the given buffer.
int64_t getNumOpsToMaterialize() const
Return the number of ops that haven't been materialized yet.
bool isMaterializable(Operation *op)
Return true if the provided op is materializable.
LogicalResult readTopLevel(Block *block, llvm::function_ref< bool(Operation *)> lazyOps=[](Operation *) { return false;})
Read the operations defined within the given memory buffer, containing MLIR bytecode,...
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
static FileLineColLoc get(StringAttr filename, unsigned line, unsigned column)
This class represents a diagnostic that is inflight and set to be reported.
InFlightDiagnostic & append(Args &&...args) &
Append arguments to the diagnostic.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext * getContext() const
Return the context this location is uniqued in.
MLIRContext is the top-level object for a collection of MLIR operations.
T * getOrLoadDialect()
Get (or create) a dialect for the given derived dialect type.
bool allowsUnregisteredDialects()
Return true if we allow to create operation for unregistered dialects.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
bool isRegistered() const
Return if this operation is registered.
T::Concept * getInterface() const
Returns an instance of the concept object for the given interface if it was registered to this operat...
Operation is the basic unit of execution within MLIR.
void dropAllReferences()
This drops all operand uses from this operation, which is an essential step in breaking cyclic depend...
Block * getBlock()
Returns the operation block that contains this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, PropertyRef properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
result_range getResults()
void erase()
Remove this operation from its parent block and delete it.
unsigned getNumResults()
Return the number of results held by this operation.
This class represents a configuration for the MLIR assembly parser.
MLIRContext * getContext() const
Return the MLIRContext to be used when parsing.
bool shouldVerifyAfterParse() const
Returns if the parser should verify the IR after parsing.
BytecodeReaderConfig & getBytecodeReaderConfig() const
Returns the parsing configurations associated to the bytecode read.
AsmResourceParser * getResourceParser(StringRef name) const
Return the resource parser registered to the given name, or nullptr if no parser with name is registe...
BlockListType::iterator iterator
This diagnostic handler is a simple RAII class that registers and erases a diagnostic handler on a gi...
static AsmResourceBlob allocateWithAlign(ArrayRef< char > data, size_t align, AsmResourceBlob::DeleterFn deleter={}, bool dataIsMutable=false)
Create a new unmanaged resource directly referencing the provided data.
This class provides an abstraction over the different types of ranges over Values.
bool use_empty() const
Returns true if this value has no uses.
void shuffleUseList(ArrayRef< unsigned > indices)
Shuffle the use list order according to the provided indices.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
unsigned getNumUses() const
This method computes the number of uses of this Value.
bool hasOneUse() const
Returns true if this value has exactly one use.
use_iterator use_begin() const
static WalkResult advance()
static WalkResult interrupt()
@ kAttrType
This section contains the attributes and types referenced within an IR module.
@ kAttrTypeOffset
This section contains the offsets for the attribute and types within the AttrType section.
@ kIR
This section contains the list of operations serialized into the bytecode, and their nested regions/o...
@ kResource
This section contains the resources of the bytecode.
@ kResourceOffset
This section contains the offsets of resources within the Resource section.
@ kDialect
This section contains the dialects referenced within an IR module.
@ kString
This section contains strings referenced within the bytecode.
@ kDialectVersions
This section contains the versions of each dialect.
@ kProperties
This section contains the properties for the operations.
@ kNumSections
The total number of section types.
static uint64_t getUseID(OperandT &val, unsigned ownerID)
Get the unique ID of a value use.
@ kUseListOrdering
Use-list ordering started to be encoded in version 3.
@ kAlignmentByte
An arbitrary value used to fill alignment padding.
@ kVersion
The current bytecode version.
@ kLazyLoading
Support for lazy-loading of isolated region was added in version 2.
@ kDialectVersioning
Dialects versioning was added in version 1.
@ kElideUnknownBlockArgLocation
Avoid recording unknown locations on block arguments (compression) started in version 4.
@ kNativePropertiesEncoding
Support for encoding properties natively in bytecode instead of merged with the discardable attribute...
@ kMinSupportedVersion
The minimum supported version of the bytecode.
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
StringRef toString(AsmResourceEntryKind kind)
static LogicalResult readResourceHandle(DialectBytecodeReader &reader, FailureOr< T > &value, Ts &&...params)
Helper for resource handle reading that returns LogicalResult.
bool isBytecode(llvm::MemoryBufferRef buffer)
Returns true if the given buffer starts with the magic bytes that signal MLIR bytecode.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
AsmResourceEntryKind
This enum represents the different kinds of resource values.
LogicalResult readBytecodeFile(llvm::MemoryBufferRef buffer, Block *block, const ParserConfig &config)
Read the operations defined within the given memory buffer, containing MLIR bytecode,...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
llvm::function_ref< Fn > function_ref
SmallVector< Block *, 1 > successors
Successors of this operation and their respective operands.
SmallVector< Value, 4 > operands
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Attribute propertiesAttr
This Attribute is used to opaquely construct the properties of the operation.
SmallVector< Type, 4 > types
Types of the results of this operation.