20#include "llvm/ADT/ArrayRef.h"
21#include "llvm/ADT/ScopeExit.h"
22#include "llvm/ADT/StringExtras.h"
23#include "llvm/ADT/StringRef.h"
24#include "llvm/Support/Endian.h"
25#include "llvm/Support/MemoryBufferRef.h"
26#include "llvm/Support/SourceMgr.h"
36#define DEBUG_TYPE "mlir-bytecode-reader"
48 return "AttrType (2)";
50 return "AttrTypeOffset (3)";
54 return "Resource (5)";
56 return "ResourceOffset (6)";
58 return "DialectVersions (7)";
60 return "Properties (8)";
62 return (
"Unknown (" + Twine(
static_cast<unsigned>(sectionID)) +
")").str();
82 llvm_unreachable(
"unknown section ID");
93 explicit EncodingReader(ArrayRef<uint8_t> contents, Location fileLoc)
94 : buffer(contents), dataIt(buffer.begin()), fileLoc(fileLoc) {}
95 explicit EncodingReader(StringRef contents, Location fileLoc)
96 : EncodingReader({
reinterpret_cast<const uint8_t *
>(contents.data()),
101 bool empty()
const {
return dataIt == buffer.end(); }
104 size_t size()
const {
return buffer.end() - dataIt; }
107 LogicalResult alignTo(
unsigned alignment) {
108 if (!llvm::isPowerOf2_32(alignment))
109 return emitError(
"expected alignment to be a power-of-two");
111 auto isUnaligned = [&](
const uint8_t *ptr) {
112 return ((uintptr_t)ptr & (alignment - 1)) != 0;
119 while (isUnaligned(dataIt)) {
121 if (
failed(parseByte(padding)))
124 return emitError(
"expected alignment byte (0xCB), but got: '0x" +
125 llvm::utohexstr(padding) +
"'");
131 if (LLVM_UNLIKELY(isUnaligned(dataIt))) {
132 return emitError(
"expected data iterator aligned to ", alignment,
133 ", but got pointer: '0x" +
134 llvm::utohexstr((uintptr_t)dataIt) +
"'");
141 template <
typename... Args>
142 InFlightDiagnostic
emitError(Args &&...args)
const {
143 return ::emitError(fileLoc).
append(std::forward<Args>(args)...);
145 InFlightDiagnostic
emitError()
const { return ::emitError(fileLoc); }
148 template <
typename T>
149 LogicalResult parseByte(T &value) {
151 return emitError(
"attempting to parse a byte at the end of the bytecode");
152 value =
static_cast<T
>(*dataIt++);
156 LogicalResult parseBytes(
size_t length, ArrayRef<uint8_t> &
result) {
157 if (length > size()) {
158 return emitError(
"attempting to parse ", length,
" bytes when only ",
161 result = {dataIt, length};
167 LogicalResult parseBytes(
size_t length, uint8_t *
result) {
168 if (length > size()) {
169 return emitError(
"attempting to parse ", length,
" bytes when only ",
172 memcpy(
result, dataIt, length);
179 LogicalResult parseBlobAndAlignment(ArrayRef<uint8_t> &data,
180 uint64_t &alignment) {
182 if (
failed(parseVarInt(alignment)) ||
failed(parseVarInt(dataSize)) ||
183 failed(alignTo(alignment)))
185 return parseBytes(dataSize, data);
195 LogicalResult parseVarInt(uint64_t &
result) {
202 if (LLVM_LIKELY(
result & 1)) {
210 if (LLVM_UNLIKELY(
result == 0)) {
211 llvm::support::ulittle64_t resultLE;
212 if (
failed(parseBytes(
sizeof(resultLE),
213 reinterpret_cast<uint8_t *
>(&resultLE))))
218 return parseMultiByteVarInt(
result);
224 LogicalResult parseSignedVarInt(uint64_t &
result) {
234 LogicalResult parseVarIntWithFlag(uint64_t &
result,
bool &flag) {
243 LogicalResult skipBytes(
size_t length) {
244 if (length > size()) {
245 return emitError(
"attempting to skip ", length,
" bytes when only ",
254 LogicalResult parseNullTerminatedString(StringRef &
result) {
255 const char *startIt = (
const char *)dataIt;
256 const char *nulIt = (
const char *)memchr(startIt, 0, size());
259 "malformed null-terminated string, no null character found");
261 result = StringRef(startIt, nulIt - startIt);
262 dataIt = (
const uint8_t *)nulIt + 1;
267 using ValidateAlignmentFn =
function_ref<LogicalResult(
unsigned alignment)>;
272 ValidateAlignmentFn alignmentValidator,
273 ArrayRef<uint8_t> §ionData) {
274 uint8_t sectionIDAndHasAlignment;
276 if (
failed(parseByte(sectionIDAndHasAlignment)) ||
277 failed(parseVarInt(length)))
284 bool hasAlignment = sectionIDAndHasAlignment & 0b10000000;
289 return emitError(
"invalid section ID: ",
unsigned(sectionID));
295 if (
failed(parseVarInt(alignment)))
330 if (
failed(alignmentValidator(alignment)))
331 return emitError(
"failed to align section ID: ",
unsigned(sectionID));
334 if (
failed(alignTo(alignment)))
339 return parseBytes(
static_cast<size_t>(length), sectionData);
342 Location getLoc()
const {
return fileLoc; }
351 LLVM_ATTRIBUTE_NOINLINE LogicalResult parseMultiByteVarInt(uint64_t &
result) {
357 uint32_t numBytes = llvm::countr_zero<uint32_t>(
result);
358 assert(numBytes > 0 && numBytes <= 7 &&
359 "unexpected number of trailing zeros in varint encoding");
362 llvm::support::ulittle64_t resultLE(
result);
364 parseBytes(numBytes,
reinterpret_cast<uint8_t *
>(&resultLE) + 1)))
369 result = resultLE >> (numBytes + 1);
374 ArrayRef<uint8_t> buffer;
377 const uint8_t *dataIt;
388template <
typename RangeT,
typename T>
389static LogicalResult
resolveEntry(EncodingReader &reader, RangeT &entries,
390 uint64_t
index, T &entry,
391 StringRef entryStr) {
392 if (
index >= entries.size())
393 return reader.emitError(
"invalid ", entryStr,
" index: ",
index);
396 if constexpr (std::is_convertible_v<llvm::detail::ValueOfRange<RangeT>, T>)
397 entry = entries[
index];
399 entry = &entries[
index];
404template <
typename RangeT,
typename T>
405static LogicalResult
parseEntry(EncodingReader &reader, RangeT &entries,
406 T &entry, StringRef entryStr) {
408 if (failed(reader.parseVarInt(entryIdx)))
410 return resolveEntry(reader, entries, entryIdx, entry, entryStr);
420class StringSectionReader {
423 LogicalResult
initialize(Location fileLoc, ArrayRef<uint8_t> sectionData);
427 LogicalResult parseString(EncodingReader &reader, StringRef &
result)
const {
434 LogicalResult parseStringWithFlag(EncodingReader &reader, StringRef &
result,
437 if (
failed(reader.parseVarIntWithFlag(entryIdx, flag)))
439 return parseStringAtIndex(reader, entryIdx,
result);
444 LogicalResult parseStringAtIndex(EncodingReader &reader, uint64_t index,
445 StringRef &
result)
const {
451 SmallVector<StringRef> strings;
455LogicalResult StringSectionReader::initialize(
Location fileLoc,
457 EncodingReader stringReader(sectionData, fileLoc);
461 if (
failed(stringReader.parseVarInt(numStrings)))
463 strings.resize(numStrings);
467 size_t stringDataEndOffset = sectionData.size();
468 for (StringRef &
string : llvm::reverse(strings)) {
470 if (
failed(stringReader.parseVarInt(stringSize)))
472 if (stringDataEndOffset < stringSize) {
473 return stringReader.emitError(
474 "string size exceeds the available data size");
478 size_t stringOffset = stringDataEndOffset - stringSize;
480 reinterpret_cast<const char *
>(sectionData.data() + stringOffset),
482 stringDataEndOffset = stringOffset;
487 if ((sectionData.size() - stringReader.size()) != stringDataEndOffset) {
488 return stringReader.emitError(
"unexpected trailing data between the "
489 "offsets for strings and their data");
502struct BytecodeDialect {
507 LogicalResult
load(
const DialectReader &reader, MLIRContext *ctx);
511 Dialect *getLoadedDialect()
const {
513 "expected `load` to be invoked before `getLoadedDialect`");
520 std::optional<Dialect *> dialect;
525 const BytecodeDialectInterface *
interface =
nullptr;
531 ArrayRef<uint8_t> versionBuffer;
534 std::unique_ptr<DialectVersion> loadedVersion;
538struct BytecodeOperationName {
539 BytecodeOperationName(BytecodeDialect *dialect, StringRef name,
540 std::optional<bool> wasRegistered)
541 : dialect(dialect), name(name), wasRegistered(wasRegistered) {}
545 std::optional<OperationName> opName;
548 BytecodeDialect *dialect;
555 std::optional<bool> wasRegistered;
561 EncodingReader &reader,
563 function_ref<LogicalResult(BytecodeDialect *)> entryCallback) {
565 std::unique_ptr<BytecodeDialect> *dialect;
566 if (failed(
parseEntry(reader, dialects, dialect,
"dialect")))
569 if (failed(reader.parseVarInt(numEntries)))
572 for (uint64_t i = 0; i < numEntries; ++i)
573 if (failed(entryCallback(dialect->get())))
584class ResourceSectionReader {
589 MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
590 StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData,
591 ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader,
592 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef);
595 LogicalResult parseResourceHandle(EncodingReader &reader,
596 AsmDialectResourceHandle &
result)
const {
602 SmallVector<AsmDialectResourceHandle> dialectResources;
603 llvm::StringMap<std::string> dialectResourceHandleRenamingMap;
606class ParsedResourceEntry :
public AsmParsedResourceEntry {
609 EncodingReader &reader, StringSectionReader &stringReader,
610 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef)
611 : key(key), kind(kind), reader(reader), stringReader(stringReader),
612 bufferOwnerRef(bufferOwnerRef) {}
613 ~ParsedResourceEntry()
override =
default;
615 StringRef getKey() const final {
return key; }
617 InFlightDiagnostic
emitError() const final {
return reader.emitError(); }
621 FailureOr<bool> parseAsBool() const final {
622 if (kind != AsmResourceEntryKind::Bool)
623 return emitError() <<
"expected a bool resource entry, but found a "
624 <<
toString(kind) <<
" entry instead";
627 if (
failed(reader.parseByte(value)))
631 FailureOr<std::string> parseAsString() const final {
632 if (kind != AsmResourceEntryKind::String)
633 return emitError() <<
"expected a string resource entry, but found a "
634 <<
toString(kind) <<
" entry instead";
637 if (
failed(stringReader.parseString(reader,
string)))
642 FailureOr<AsmResourceBlob>
643 parseAsBlob(BlobAllocatorFn allocator)
const final {
644 if (kind != AsmResourceEntryKind::Blob)
645 return emitError() <<
"expected a blob resource entry, but found a "
646 <<
toString(kind) <<
" entry instead";
648 ArrayRef<uint8_t> data;
650 if (
failed(reader.parseBlobAndAlignment(data, alignment)))
655 if (bufferOwnerRef) {
656 ArrayRef<char> charData(
reinterpret_cast<const char *
>(data.data()),
664 [bufferOwnerRef = bufferOwnerRef](
void *,
size_t,
size_t) {});
669 AsmResourceBlob blob = allocator(data.size(), alignment);
670 assert(llvm::isAddrAligned(llvm::Align(alignment), blob.
getData().data()) &&
672 "blob allocator did not return a properly aligned address");
680 EncodingReader &reader;
681 StringSectionReader &stringReader;
682 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef;
689 EncodingReader &offsetReader, EncodingReader &resourceReader,
690 StringSectionReader &stringReader, T *handler,
691 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef,
693 function_ref<LogicalResult(StringRef)> processKeyFn = {}) {
694 uint64_t numResources;
695 if (
failed(offsetReader.parseVarInt(numResources)))
698 for (uint64_t i = 0; i < numResources; ++i) {
701 uint64_t resourceOffset;
702 ArrayRef<uint8_t> data;
703 if (
failed(stringReader.parseString(offsetReader, key)) ||
704 failed(offsetReader.parseVarInt(resourceOffset)) ||
705 failed(offsetReader.parseByte(kind)) ||
706 failed(resourceReader.parseBytes(resourceOffset, data)))
710 if ((processKeyFn &&
failed(processKeyFn(key))))
715 if (allowEmpty && data.empty())
723 EncodingReader entryReader(data, fileLoc);
725 ParsedResourceEntry entry(key, kind, entryReader, stringReader,
727 if (
failed(handler->parseResource(entry)))
729 if (!entryReader.empty()) {
730 return entryReader.emitError(
731 "unexpected trailing bytes in resource entry '", key,
"'");
737LogicalResult ResourceSectionReader::initialize(
738 Location fileLoc,
const ParserConfig &
config,
739 MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
740 StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData,
741 ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader,
742 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) {
743 EncodingReader resourceReader(sectionData, fileLoc);
744 EncodingReader offsetReader(offsetSectionData, fileLoc);
747 uint64_t numExternalResourceGroups;
748 if (
failed(offsetReader.parseVarInt(numExternalResourceGroups)))
753 auto parseGroup = [&](
auto *handler,
bool allowEmpty =
false,
755 auto resolveKey = [&](StringRef key) -> StringRef {
756 auto it = dialectResourceHandleRenamingMap.find(key);
757 if (it == dialectResourceHandleRenamingMap.end())
763 stringReader, handler, bufferOwnerRef, resolveKey,
768 for (uint64_t i = 0; i < numExternalResourceGroups; ++i) {
770 if (
failed(stringReader.parseString(offsetReader, key)))
775 AsmResourceParser *handler =
config.getResourceParser(key);
777 emitWarning(fileLoc) <<
"ignoring unknown external resources for '" << key
781 if (
failed(parseGroup(handler)))
787 while (!offsetReader.empty()) {
788 std::unique_ptr<BytecodeDialect> *dialect;
790 failed((*dialect)->load(dialectReader, ctx)))
792 Dialect *loadedDialect = (*dialect)->getLoadedDialect();
793 if (!loadedDialect) {
794 return resourceReader.emitError()
795 <<
"dialect '" << (*dialect)->name <<
"' is unknown";
797 const auto *handler = dyn_cast<OpAsmDialectInterface>(loadedDialect);
799 return resourceReader.emitError()
800 <<
"unexpected resources for dialect '" << (*dialect)->name <<
"'";
804 auto processResourceKeyFn = [&](StringRef key) -> LogicalResult {
805 FailureOr<AsmDialectResourceHandle> handle =
806 handler->declareResource(key);
808 return resourceReader.emitError()
809 <<
"unknown 'resource' key '" << key <<
"' for dialect '"
810 << (*dialect)->name <<
"'";
812 dialectResourceHandleRenamingMap[key] = handler->getResourceKey(*handle);
813 dialectResources.push_back(*handle);
819 if (
failed(parseGroup(handler,
true, processResourceKeyFn)))
851class AttrTypeReader {
853 template <
typename T>
858 BytecodeDialect *dialect =
nullptr;
861 bool hasCustomEncoding =
false;
863 ArrayRef<uint8_t> data;
865 using AttrEntry = Entry<Attribute>;
866 using TypeEntry = Entry<Type>;
869 AttrTypeReader(
const StringSectionReader &stringReader,
870 const ResourceSectionReader &resourceReader,
871 const llvm::StringMap<BytecodeDialect *> &dialectsMap,
872 uint64_t &bytecodeVersion, Location fileLoc,
873 const ParserConfig &
config)
874 : stringReader(stringReader), resourceReader(resourceReader),
875 dialectsMap(dialectsMap), fileLoc(fileLoc),
876 bytecodeVersion(bytecodeVersion), parserConfig(
config) {}
880 initialize(MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
881 ArrayRef<uint8_t> sectionData,
882 ArrayRef<uint8_t> offsetSectionData);
884 LogicalResult readAttribute(uint64_t index, Attribute &
result,
885 uint64_t depth = 0) {
886 return readEntry(attributes, index,
result,
"attribute", depth);
889 LogicalResult readType(uint64_t index, Type &
result, uint64_t depth = 0) {
890 return readEntry(types, index,
result,
"type", depth);
895 Attribute resolveAttribute(
size_t index, uint64_t depth = 0) {
896 return resolveEntry(attributes, index,
"Attribute", depth);
898 Type resolveType(
size_t index, uint64_t depth = 0) {
902 Attribute getAttributeOrSentinel(
size_t index) {
903 if (index >= attributes.size())
905 return attributes[index].entry;
907 Type getTypeOrSentinel(
size_t index) {
908 if (index >= types.size())
910 return types[index].entry;
916 if (
failed(reader.parseVarInt(attrIdx)))
918 result = resolveAttribute(attrIdx);
921 LogicalResult parseOptionalAttribute(EncodingReader &reader,
925 if (
failed(reader.parseVarIntWithFlag(attrIdx, flag)))
929 result = resolveAttribute(attrIdx);
935 if (
failed(reader.parseVarInt(typeIdx)))
937 result = resolveType(typeIdx);
941 template <
typename T>
943 Attribute baseResult;
946 if ((
result = dyn_cast<T>(baseResult)))
948 return reader.emitError(
"expected attribute of type: ",
949 llvm::getTypeName<T>(),
", but got: ", baseResult);
953 enum class EntryKind { Attribute, Type };
956 void addDeferredParsing(uint64_t index, EntryKind kind) {
957 deferredWorklist.emplace_back(index, kind);
961 bool isResolving()
const {
return resolving; }
965 template <
typename T>
966 T
resolveEntry(SmallVectorImpl<Entry<T>> &entries, uint64_t index,
967 StringRef entryType, uint64_t depth = 0);
971 template <
typename T>
972 LogicalResult readEntry(SmallVectorImpl<Entry<T>> &entries, uint64_t index,
973 T &
result, StringRef entryType, uint64_t depth);
977 template <
typename T>
978 LogicalResult parseCustomEntry(Entry<T> &entry, EncodingReader &reader,
979 StringRef entryType, uint64_t index,
984 template <
typename T>
985 LogicalResult parseAsmEntry(T &
result, EncodingReader &reader,
986 StringRef entryType);
990 const StringSectionReader &stringReader;
994 const ResourceSectionReader &resourceReader;
998 const llvm::StringMap<BytecodeDialect *> &dialectsMap;
1001 SmallVector<AttrEntry> attributes;
1002 SmallVector<TypeEntry> types;
1008 uint64_t &bytecodeVersion;
1011 const ParserConfig &parserConfig;
1017 std::vector<std::pair<uint64_t, EntryKind>> deferredWorklist;
1020 bool resolving =
false;
1023class DialectReader :
public DialectBytecodeReader {
1025 DialectReader(AttrTypeReader &attrTypeReader,
1026 const StringSectionReader &stringReader,
1027 const ResourceSectionReader &resourceReader,
1028 const llvm::StringMap<BytecodeDialect *> &dialectsMap,
1029 EncodingReader &reader, uint64_t &bytecodeVersion,
1031 : attrTypeReader(attrTypeReader), stringReader(stringReader),
1032 resourceReader(resourceReader), dialectsMap(dialectsMap),
1033 reader(reader), bytecodeVersion(bytecodeVersion), depth(depth) {}
1035 InFlightDiagnostic
emitError(
const Twine &msg)
const override {
1036 return reader.emitError(msg);
1039 FailureOr<const DialectVersion *>
1040 getDialectVersion(StringRef dialectName)
const override {
1042 auto dialectEntry = dialectsMap.find(dialectName);
1043 if (dialectEntry == dialectsMap.end())
1048 if (
failed(dialectEntry->getValue()->load(*
this, getLoc().
getContext())) ||
1049 dialectEntry->getValue()->loadedVersion ==
nullptr)
1051 return dialectEntry->getValue()->loadedVersion.get();
1054 MLIRContext *
getContext()
const override {
return getLoc().getContext(); }
1056 uint64_t getBytecodeVersion()
const override {
return bytecodeVersion; }
1058 DialectReader withEncodingReader(EncodingReader &encReader)
const {
1059 return DialectReader(attrTypeReader, stringReader, resourceReader,
1060 dialectsMap, encReader, bytecodeVersion);
1063 Location getLoc()
const {
return reader.getLoc(); }
1071 static constexpr uint64_t maxAttrTypeDepth = 5;
1073 LogicalResult readAttribute(Attribute &
result)
override {
1075 if (
failed(reader.parseVarInt(index)))
1081 if (!attrTypeReader.isResolving()) {
1082 if (Attribute attr = attrTypeReader.resolveAttribute(index)) {
1089 if (depth > maxAttrTypeDepth) {
1090 if (Attribute attr = attrTypeReader.getAttributeOrSentinel(index)) {
1094 attrTypeReader.addDeferredParsing(index,
1095 AttrTypeReader::EntryKind::Attribute);
1098 return attrTypeReader.readAttribute(index,
result, depth + 1);
1100 LogicalResult readOptionalAttribute(Attribute &
result)
override {
1101 return attrTypeReader.parseOptionalAttribute(reader,
result);
1103 LogicalResult readType(Type &
result)
override {
1105 if (
failed(reader.parseVarInt(index)))
1111 if (!attrTypeReader.isResolving()) {
1112 if (Type type = attrTypeReader.resolveType(index)) {
1119 if (depth > maxAttrTypeDepth) {
1120 if (Type type = attrTypeReader.getTypeOrSentinel(index)) {
1124 attrTypeReader.addDeferredParsing(index, AttrTypeReader::EntryKind::Type);
1127 return attrTypeReader.readType(index,
result, depth + 1);
1131 AsmDialectResourceHandle handle;
1132 if (
failed(resourceReader.parseResourceHandle(reader, handle)))
1141 LogicalResult readVarInt(uint64_t &
result)
override {
1142 return reader.parseVarInt(
result);
1145 LogicalResult readSignedVarInt(int64_t &
result)
override {
1146 uint64_t unsignedResult;
1147 if (
failed(reader.parseSignedVarInt(unsignedResult)))
1149 result =
static_cast<int64_t
>(unsignedResult);
1153 FailureOr<APInt> readAPIntWithKnownWidth(
unsigned bitWidth)
override {
1155 if (bitWidth <= 8) {
1157 if (
failed(reader.parseByte(value)))
1159 return APInt(bitWidth, value);
1163 if (bitWidth <= 64) {
1165 if (
failed(reader.parseSignedVarInt(value)))
1167 return APInt(bitWidth, value);
1172 uint64_t numActiveWords;
1173 if (
failed(reader.parseVarInt(numActiveWords)))
1175 SmallVector<uint64_t, 4> words(numActiveWords);
1176 for (uint64_t i = 0; i < numActiveWords; ++i)
1177 if (
failed(reader.parseSignedVarInt(words[i])))
1179 return APInt(bitWidth, words);
1183 readAPFloatWithKnownSemantics(
const llvm::fltSemantics &semantics)
override {
1184 FailureOr<APInt> intVal =
1185 readAPIntWithKnownWidth(APFloat::getSizeInBits(semantics));
1188 return APFloat(semantics, *intVal);
1191 LogicalResult readString(StringRef &
result)
override {
1192 return stringReader.parseString(reader,
result);
1195 LogicalResult readBlob(ArrayRef<char> &
result)
override {
1197 ArrayRef<uint8_t> data;
1198 if (
failed(reader.parseVarInt(dataSize)) ||
1199 failed(reader.parseBytes(dataSize, data)))
1201 result = llvm::ArrayRef(
reinterpret_cast<const char *
>(data.data()),
1206 LogicalResult readBool(
bool &
result)
override {
1207 return reader.parseByte(
result);
1211 AttrTypeReader &attrTypeReader;
1212 const StringSectionReader &stringReader;
1213 const ResourceSectionReader &resourceReader;
1214 const llvm::StringMap<BytecodeDialect *> &dialectsMap;
1215 EncodingReader &reader;
1216 uint64_t &bytecodeVersion;
1221class PropertiesSectionReader {
1224 LogicalResult
initialize(Location fileLoc, ArrayRef<uint8_t> sectionData) {
1225 if (sectionData.empty())
1227 EncodingReader propReader(sectionData, fileLoc);
1229 if (
failed(propReader.parseVarInt(count)))
1232 if (
failed(propReader.parseBytes(propReader.size(), propertiesBuffers)))
1235 EncodingReader offsetsReader(propertiesBuffers, fileLoc);
1236 offsetTable.reserve(count);
1237 for (
auto idx : llvm::seq<int64_t>(0, count)) {
1239 offsetTable.push_back(propertiesBuffers.size() - offsetsReader.size());
1240 ArrayRef<uint8_t> rawProperties;
1242 if (
failed(offsetsReader.parseVarInt(dataSize)) ||
1243 failed(offsetsReader.parseBytes(dataSize, rawProperties)))
1246 if (!offsetsReader.empty())
1247 return offsetsReader.emitError()
1248 <<
"Broken properties section: didn't exhaust the offsets table";
1252 LogicalResult read(Location fileLoc, DialectReader &dialectReader,
1253 OperationName *opName, OperationState &opState)
const {
1254 uint64_t propertiesIdx;
1255 if (
failed(dialectReader.readVarInt(propertiesIdx)))
1257 if (propertiesIdx >= offsetTable.size())
1258 return dialectReader.emitError(
"Properties idx out-of-bound for ")
1260 size_t propertiesOffset = offsetTable[propertiesIdx];
1261 if (propertiesIdx >= propertiesBuffers.size())
1262 return dialectReader.emitError(
"Properties offset out-of-bound for ")
1266 ArrayRef<char> rawProperties;
1270 EncodingReader reader(propertiesBuffers.drop_front(propertiesOffset),
1274 dialectReader.withEncodingReader(reader).readBlob(rawProperties)))
1278 EncodingReader reader(
1279 StringRef(rawProperties.begin(), rawProperties.size()), fileLoc);
1280 DialectReader propReader = dialectReader.withEncodingReader(reader);
1282 auto *iface = opName->
getInterface<BytecodeOpInterface>();
1284 return iface->readProperties(propReader, opState);
1286 return propReader.emitError(
1287 "has properties but missing BytecodeOpInterface for ")
1295 ArrayRef<uint8_t> propertiesBuffers;
1298 SmallVector<int64_t> offsetTable;
1302LogicalResult AttrTypeReader::initialize(
1303 MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
1304 ArrayRef<uint8_t> sectionData, ArrayRef<uint8_t> offsetSectionData) {
1305 EncodingReader offsetReader(offsetSectionData, fileLoc);
1308 uint64_t numAttributes, numTypes;
1309 if (
failed(offsetReader.parseVarInt(numAttributes)) ||
1310 failed(offsetReader.parseVarInt(numTypes)))
1312 attributes.resize(numAttributes);
1313 types.resize(numTypes);
1317 uint64_t currentOffset = 0;
1318 auto parseEntries = [&](
auto &&range) {
1319 size_t currentIndex = 0, endIndex = range.size();
1322 auto parseEntryFn = [&](BytecodeDialect *dialect) -> LogicalResult {
1323 auto &entry = range[currentIndex++];
1326 if (
failed(offsetReader.parseVarIntWithFlag(entrySize,
1327 entry.hasCustomEncoding)))
1331 if (currentOffset + entrySize > sectionData.size()) {
1332 return offsetReader.emitError(
1333 "Attribute or Type entry offset points past the end of section");
1336 entry.data = sectionData.slice(currentOffset, entrySize);
1337 entry.dialect = dialect;
1338 currentOffset += entrySize;
1341 while (currentIndex != endIndex)
1348 if (
failed(parseEntries(attributes)) ||
failed(parseEntries(types)))
1352 if (!offsetReader.empty()) {
1353 return offsetReader.emitError(
1354 "unexpected trailing data in the Attribute/Type offset section");
1360template <
typename T>
1361T AttrTypeReader::resolveEntry(SmallVectorImpl<Entry<T>> &entries,
1362 uint64_t index, StringRef entryType,
1364 bool oldResolving = resolving;
1366 llvm::scope_exit restoreResolving([&]() { resolving = oldResolving; });
1368 if (index >= entries.size()) {
1369 emitError(fileLoc) <<
"invalid " << entryType <<
" index: " << index;
1375 assert(deferredWorklist.empty());
1377 if (succeeded(readEntry(entries, index,
result, entryType, depth))) {
1378 assert(deferredWorklist.empty());
1381 if (deferredWorklist.empty()) {
1391 std::deque<std::pair<uint64_t, EntryKind>> worklist;
1392 llvm::DenseSet<std::pair<uint64_t, EntryKind>> inWorklist;
1394 EntryKind entryKind =
1395 std::is_same_v<T, Type> ? EntryKind::Type : EntryKind::Attribute;
1397 static_assert((std::is_same_v<T, Type> || std::is_same_v<T, Attribute>) &&
1398 "Only support resolving Attributes and Types");
1400 auto addToWorklistFront = [&](std::pair<uint64_t, EntryKind> entry) {
1401 if (inWorklist.insert(entry).second)
1402 worklist.push_front(entry);
1406 worklist.emplace_back(index, entryKind);
1407 inWorklist.insert({index, entryKind});
1408 for (
auto entry : llvm::reverse(deferredWorklist))
1409 addToWorklistFront(entry);
1411 while (!worklist.empty()) {
1412 auto [currentIndex, entryKind] = worklist.front();
1413 worklist.pop_front();
1416 deferredWorklist.clear();
1418 if (entryKind == EntryKind::Type) {
1420 if (succeeded(readType(currentIndex,
result, depth))) {
1421 inWorklist.erase({currentIndex, entryKind});
1425 assert(entryKind == EntryKind::Attribute &&
"Unexpected entry kind");
1427 if (succeeded(readAttribute(currentIndex,
result, depth))) {
1428 inWorklist.erase({currentIndex, entryKind});
1433 if (deferredWorklist.empty()) {
1439 worklist.emplace_back(currentIndex, entryKind);
1442 for (
auto entry : llvm::reverse(deferredWorklist))
1443 addToWorklistFront(entry);
1445 deferredWorklist.clear();
1447 return entries[index].entry;
1450template <
typename T>
1451LogicalResult AttrTypeReader::readEntry(SmallVectorImpl<Entry<T>> &entries,
1452 uint64_t index, T &
result,
1453 StringRef entryType, uint64_t depth) {
1454 if (index >= entries.size())
1455 return emitError(fileLoc) <<
"invalid " << entryType <<
" index: " << index;
1458 Entry<T> &entry = entries[index];
1465 EncodingReader reader(entry.data, fileLoc);
1466 LogicalResult parseResult =
1467 entry.hasCustomEncoding
1468 ? parseCustomEntry(entry, reader, entryType, index, depth)
1469 : parseAsmEntry(entry.entry, reader, entryType);
1473 if (!reader.empty())
1474 return reader.emitError(
"unexpected trailing bytes after " + entryType +
1481template <
typename T>
1482LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry,
1483 EncodingReader &reader,
1484 StringRef entryType,
1485 uint64_t index, uint64_t depth) {
1486 DialectReader dialectReader(*
this, stringReader, resourceReader, dialectsMap,
1487 reader, bytecodeVersion, depth);
1491 if constexpr (std::is_same_v<T, Type>) {
1493 for (
const auto &callback :
1496 callback->read(dialectReader, entry.dialect->name, entry.entry)))
1504 reader = EncodingReader(entry.data, reader.getLoc());
1508 for (
const auto &callback :
1511 callback->read(dialectReader, entry.dialect->name, entry.entry)))
1519 reader = EncodingReader(entry.data, reader.getLoc());
1524 if (!entry.dialect->interface) {
1525 return reader.emitError(
"dialect '", entry.dialect->name,
1526 "' does not implement the bytecode interface");
1529 if constexpr (std::is_same_v<T, Type>)
1530 entry.entry = entry.dialect->interface->readType(dialectReader);
1532 entry.entry = entry.dialect->interface->readAttribute(dialectReader);
1534 return success(!!entry.entry);
1537template <
typename T>
1538LogicalResult AttrTypeReader::parseAsmEntry(T &
result, EncodingReader &reader,
1539 StringRef entryType) {
1541 if (
failed(reader.parseNullTerminatedString(asmStr)))
1546 MLIRContext *context = fileLoc->
getContext();
1547 if constexpr (std::is_same_v<T, Type>)
1557 if (numRead != asmStr.size()) {
1558 return reader.emitError(
"trailing characters found after ", entryType,
1559 " assembly format: ", asmStr.drop_front(numRead));
1570 struct RegionReadState;
1571 using LazyLoadableOpsInfo =
1572 std::list<std::pair<Operation *, RegionReadState>>;
1573 using LazyLoadableOpsMap =
1578 llvm::MemoryBufferRef buffer,
1579 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef)
1580 : config(config), fileLoc(fileLoc), lazyLoading(lazyLoading),
1581 attrTypeReader(stringReader, resourceReader, dialectsMap, version,
1585 forwardRefOpState(UnknownLoc::
get(config.getContext()),
1586 "builtin.unrealized_conversion_cast",
ValueRange(),
1587 NoneType::
get(config.getContext())),
1588 buffer(buffer), bufferOwnerRef(bufferOwnerRef) {}
1591 LogicalResult read(
Block *block,
1604 this->lazyOpsCallback = lazyOpsCallback;
1605 llvm::scope_exit resetlazyOpsCallback(
1606 [&] { this->lazyOpsCallback =
nullptr; });
1607 auto it = lazyLoadableOpsMap.find(op);
1608 assert(it != lazyLoadableOpsMap.end() &&
1609 "materialize called on non-materializable op");
1615 while (!lazyLoadableOpsMap.empty()) {
1616 if (failed(
materialize(lazyLoadableOpsMap.begin())))
1627 while (!lazyLoadableOps.empty()) {
1628 Operation *op = lazyLoadableOps.begin()->first;
1629 if (shouldMaterialize(op)) {
1630 if (failed(
materialize(lazyLoadableOpsMap.find(op))))
1636 lazyLoadableOps.pop_front();
1637 lazyLoadableOpsMap.erase(op);
1643 LogicalResult
materialize(LazyLoadableOpsMap::iterator it) {
1644 assert(it != lazyLoadableOpsMap.end() &&
1645 "materialize called on non-materializable op");
1646 valueScopes.emplace_back();
1647 std::vector<RegionReadState> regionStack;
1648 regionStack.push_back(std::move(it->getSecond()->second));
1649 lazyLoadableOps.erase(it->getSecond());
1650 lazyLoadableOpsMap.erase(it);
1652 while (!regionStack.empty())
1653 if (failed(
parseRegions(regionStack, regionStack.back())))
1658 LogicalResult checkSectionAlignment(
1669 const bool isGloballyAligned =
1670 ((uintptr_t)buffer.getBufferStart() & (alignment - 1)) == 0;
1672 if (!isGloballyAligned)
1673 return emitError(
"expected section alignment ")
1674 << alignment <<
" but bytecode buffer 0x"
1675 << Twine::utohexstr((uint64_t)buffer.getBufferStart())
1676 <<
" is not aligned";
1685 LogicalResult parseVersion(EncodingReader &reader);
1690 LogicalResult parseDialectSection(ArrayRef<uint8_t> sectionData);
1695 FailureOr<OperationName> parseOpName(EncodingReader &reader,
1696 std::optional<bool> &wasRegistered);
1702 template <
typename T>
1704 return attrTypeReader.parseAttribute(reader,
result);
1707 return attrTypeReader.parseType(reader,
result);
1714 parseResourceSection(EncodingReader &reader,
1715 std::optional<ArrayRef<uint8_t>> resourceData,
1716 std::optional<ArrayRef<uint8_t>> resourceOffsetData);
1723 struct RegionReadState {
1724 RegionReadState(Operation *op, EncodingReader *reader,
1725 bool isIsolatedFromAbove)
1726 : RegionReadState(op->getRegions(), reader, isIsolatedFromAbove) {}
1727 RegionReadState(MutableArrayRef<Region> regions, EncodingReader *reader,
1728 bool isIsolatedFromAbove)
1729 : curRegion(regions.begin()), endRegion(regions.end()), reader(reader),
1730 isIsolatedFromAbove(isIsolatedFromAbove) {}
1733 MutableArrayRef<Region>::iterator curRegion, endRegion;
1738 EncodingReader *reader;
1739 std::unique_ptr<EncodingReader> owningReader;
1742 unsigned numValues = 0;
1745 SmallVector<Block *> curBlocks;
1750 uint64_t numOpsRemaining = 0;
1753 bool isIsolatedFromAbove =
false;
1756 LogicalResult parseIRSection(ArrayRef<uint8_t> sectionData,
Block *block);
1757 LogicalResult
parseRegions(std::vector<RegionReadState> ®ionStack,
1758 RegionReadState &readState);
1759 FailureOr<Operation *> parseOpWithoutRegions(EncodingReader &reader,
1760 RegionReadState &readState,
1761 bool &isIsolatedFromAbove);
1763 LogicalResult parseRegion(RegionReadState &readState);
1764 LogicalResult parseBlockHeader(EncodingReader &reader,
1765 RegionReadState &readState);
1766 LogicalResult parseBlockArguments(EncodingReader &reader,
Block *block);
1773 Value parseOperand(EncodingReader &reader);
1776 LogicalResult defineValues(EncodingReader &reader,
ValueRange values);
1779 Value createForwardRef();
1787 struct UseListOrderStorage {
1788 UseListOrderStorage(
bool isIndexPairEncoding,
1789 SmallVector<unsigned, 4> &&
indices)
1791 isIndexPairEncoding(isIndexPairEncoding) {};
1794 SmallVector<unsigned, 4>
indices;
1798 bool isIndexPairEncoding;
1806 using UseListMapT = DenseMap<unsigned, UseListOrderStorage>;
1807 FailureOr<UseListMapT> parseUseListOrderForRange(EncodingReader &reader,
1808 uint64_t rangeSize);
1811 LogicalResult sortUseListOrder(Value value);
1815 LogicalResult processUseLists(Operation *topLevelOp);
1825 void push(RegionReadState &readState) {
1826 nextValueIDs.push_back(values.size());
1827 values.resize(values.size() + readState.numValues);
1832 void pop(RegionReadState &readState) {
1833 values.resize(values.size() - readState.numValues);
1834 nextValueIDs.pop_back();
1838 std::vector<Value> values;
1842 SmallVector<unsigned, 4> nextValueIDs;
1846 const ParserConfig &
config;
1857 LazyLoadableOpsInfo lazyLoadableOps;
1858 LazyLoadableOpsMap lazyLoadableOpsMap;
1859 llvm::function_ref<bool(Operation *)> lazyOpsCallback;
1862 AttrTypeReader attrTypeReader;
1865 uint64_t version = 0;
1871 SmallVector<std::unique_ptr<BytecodeDialect>> dialects;
1872 llvm::StringMap<BytecodeDialect *> dialectsMap;
1873 SmallVector<BytecodeOperationName> opNames;
1876 ResourceSectionReader resourceReader;
1880 DenseMap<void *, UseListOrderStorage> valueToUseListMap;
1883 StringSectionReader stringReader;
1886 PropertiesSectionReader propertiesReader;
1889 std::vector<ValueScope> valueScopes;
1896 Block forwardRefOps;
1900 Block openForwardRefOps;
1903 OperationState forwardRefOpState;
1906 llvm::MemoryBufferRef buffer;
1910 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef;
1915 EncodingReader reader(buffer.getBuffer(), fileLoc);
1916 this->lazyOpsCallback = lazyOpsCallback;
1917 llvm::scope_exit resetlazyOpsCallback(
1918 [&] { this->lazyOpsCallback =
nullptr; });
1921 if (failed(reader.skipBytes(StringRef(
"ML\xefR").size())))
1924 if (failed(parseVersion(reader)) ||
1925 failed(reader.parseNullTerminatedString(producer)))
1931 diag.attachNote() <<
"in bytecode version " << version
1932 <<
" produced by: " << producer;
1936 const auto checkSectionAlignment = [&](
unsigned alignment) {
1937 return this->checkSectionAlignment(
1938 alignment, [&](
const auto &msg) {
return reader.emitError(msg); });
1942 std::optional<ArrayRef<uint8_t>>
1944 while (!reader.empty()) {
1949 reader.parseSection(sectionID, checkSectionAlignment, sectionData)))
1953 if (sectionDatas[sectionID]) {
1954 return reader.emitError(
"duplicate top-level section: ",
1957 sectionDatas[sectionID] = sectionData;
1963 return reader.emitError(
"missing data for top-level section: ",
1969 if (failed(stringReader.initialize(
1975 failed(propertiesReader.initialize(
1984 if (failed(parseResourceSection(
1990 if (failed(attrTypeReader.initialize(
1999LogicalResult BytecodeReader::Impl::parseVersion(EncodingReader &reader) {
2000 if (failed(reader.parseVarInt(version)))
2006 if (version < minSupportedVersion) {
2007 return reader.emitError(
"bytecode version ", version,
2008 " is older than the current version of ",
2009 currentVersion,
", and upgrade is not supported");
2011 if (version > currentVersion) {
2012 return reader.emitError(
"bytecode version ", version,
2013 " is newer than the current version ",
2018 lazyLoading =
false;
2026LogicalResult BytecodeDialect::load(
const DialectReader &reader,
2032 return reader.emitError(
"dialect '")
2034 <<
"' is unknown. If this is intended, please call "
2035 "allowUnregisteredDialects() on the MLIRContext, or use "
2036 "-allow-unregistered-dialect with the MLIR tool used.";
2038 dialect = loadedDialect;
2043 interface = dyn_cast<BytecodeDialectInterface>(loadedDialect);
2044 if (!versionBuffer.empty()) {
2046 return reader.emitError(
"dialect '")
2048 <<
"' does not implement the bytecode interface, "
2049 "but found a version entry";
2050 EncodingReader encReader(versionBuffer, reader.getLoc());
2051 DialectReader versionReader = reader.withEncodingReader(encReader);
2052 loadedVersion = interface->readVersion(versionReader);
2060BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) {
2061 EncodingReader sectionReader(sectionData, fileLoc);
2064 uint64_t numDialects;
2065 if (
failed(sectionReader.parseVarInt(numDialects)))
2067 dialects.resize(numDialects);
2069 const auto checkSectionAlignment = [&](
unsigned alignment) {
2070 return this->checkSectionAlignment(alignment, [&](
const auto &msg) {
2071 return sectionReader.emitError(msg);
2076 for (uint64_t i = 0; i < numDialects; ++i) {
2077 dialects[i] = std::make_unique<BytecodeDialect>();
2081 if (
failed(stringReader.parseString(sectionReader, dialects[i]->name)))
2087 uint64_t dialectNameIdx;
2088 bool versionAvailable;
2089 if (
failed(sectionReader.parseVarIntWithFlag(dialectNameIdx,
2092 if (
failed(stringReader.parseStringAtIndex(sectionReader, dialectNameIdx,
2093 dialects[i]->name)))
2095 if (versionAvailable) {
2097 if (
failed(sectionReader.parseSection(sectionID, checkSectionAlignment,
2098 dialects[i]->versionBuffer)))
2101 emitError(fileLoc,
"expected dialect version section");
2105 dialectsMap[dialects[i]->name] = dialects[i].get();
2109 auto parseOpName = [&](BytecodeDialect *dialect) {
2111 std::optional<bool> wasRegistered;
2115 if (
failed(stringReader.parseString(sectionReader, opName)))
2118 bool wasRegisteredFlag;
2119 if (
failed(stringReader.parseStringWithFlag(sectionReader, opName,
2120 wasRegisteredFlag)))
2122 wasRegistered = wasRegisteredFlag;
2124 opNames.emplace_back(dialect, opName, wasRegistered);
2131 if (
failed(sectionReader.parseVarInt(numOps)))
2133 opNames.reserve(numOps);
2135 while (!sectionReader.empty())
2141FailureOr<OperationName>
2142BytecodeReader::Impl::parseOpName(EncodingReader &reader,
2143 std::optional<bool> &wasRegistered) {
2144 BytecodeOperationName *opName =
nullptr;
2147 wasRegistered = opName->wasRegistered;
2150 if (!opName->opName) {
2155 if (opName->name.empty()) {
2156 opName->opName.emplace(opName->dialect->name,
getContext());
2159 DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
2160 dialectsMap, reader, version);
2163 opName->opName.emplace((opName->dialect->name +
"." + opName->name).str(),
2167 return *opName->opName;
2174LogicalResult BytecodeReader::Impl::parseResourceSection(
2175 EncodingReader &reader, std::optional<ArrayRef<uint8_t>> resourceData,
2176 std::optional<ArrayRef<uint8_t>> resourceOffsetData) {
2178 if (resourceData.has_value() != resourceOffsetData.has_value()) {
2179 if (resourceOffsetData)
2180 return emitError(fileLoc,
"unexpected resource offset section when "
2181 "resource section is not present");
2184 "expected resource offset section when resource section is present");
2192 DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
2193 dialectsMap, reader, version);
2194 return resourceReader.initialize(fileLoc,
config, dialects, stringReader,
2195 *resourceData, *resourceOffsetData,
2196 dialectReader, bufferOwnerRef);
2203FailureOr<BytecodeReader::Impl::UseListMapT>
2204BytecodeReader::Impl::parseUseListOrderForRange(EncodingReader &reader,
2205 uint64_t numResults) {
2206 BytecodeReader::Impl::UseListMapT map;
2207 uint64_t numValuesToRead = 1;
2208 if (numResults > 1 &&
failed(reader.parseVarInt(numValuesToRead)))
2211 for (
size_t valueIdx = 0; valueIdx < numValuesToRead; valueIdx++) {
2212 uint64_t resultIdx = 0;
2213 if (numResults > 1 &&
failed(reader.parseVarInt(resultIdx)))
2217 bool indexPairEncoding;
2218 if (
failed(reader.parseVarIntWithFlag(numValues, indexPairEncoding)))
2221 SmallVector<unsigned, 4> useListOrders;
2222 for (
size_t idx = 0; idx < numValues; idx++) {
2224 if (
failed(reader.parseVarInt(index)))
2226 useListOrders.push_back(index);
2230 map.try_emplace(resultIdx, UseListOrderStorage(indexPairEncoding,
2231 std::move(useListOrders)));
2242LogicalResult BytecodeReader::Impl::sortUseListOrder(Value value) {
2247 bool hasIncomingOrder =
2252 bool alreadySorted =
true;
2256 llvm::SmallVector<std::pair<unsigned, uint64_t>> currentOrder = {{0, prevID}};
2257 for (
auto item : llvm::drop_begin(llvm::enumerate(value.
getUses()))) {
2259 item.value(), operationIDs.at(item.value().getOwner()));
2260 alreadySorted &= prevID > currentID;
2261 currentOrder.push_back({item.index(), currentID});
2267 if (alreadySorted && !hasIncomingOrder)
2274 currentOrder.begin(), currentOrder.end(),
2275 [](
auto elem1,
auto elem2) { return elem1.second > elem2.second; });
2277 if (!hasIncomingOrder) {
2281 SmallVector<unsigned> shuffle(llvm::make_first_range(currentOrder));
2287 UseListOrderStorage customOrder =
2289 SmallVector<unsigned, 4> shuffle = std::move(customOrder.indices);
2295 if (customOrder.isIndexPairEncoding) {
2297 if (shuffle.size() & 1)
2300 SmallVector<unsigned, 4> newShuffle(numUses);
2302 std::iota(newShuffle.begin(), newShuffle.end(), idx);
2303 for (idx = 0; idx < shuffle.size(); idx += 2)
2304 newShuffle[shuffle[idx]] = shuffle[idx + 1];
2306 shuffle = std::move(newShuffle);
2313 uint64_t accumulator = 0;
2314 for (
const auto &elem : shuffle) {
2315 if (!set.insert(elem).second)
2317 accumulator += elem;
2319 if (numUses != shuffle.size() ||
2320 accumulator != (((numUses - 1) * numUses) >> 1))
2325 shuffle = SmallVector<unsigned, 4>(llvm::map_range(
2326 currentOrder, [&](
auto item) {
return shuffle[item.first]; }));
2331LogicalResult BytecodeReader::Impl::processUseLists(Operation *topLevelOp) {
2335 unsigned operationID = 0;
2337 [&](Operation *op) { operationIDs.try_emplace(op, operationID++); });
2339 auto blockWalk = topLevelOp->
walk([
this](
Block *block) {
2341 if (
failed(sortUseListOrder(arg)))
2346 auto resultWalk = topLevelOp->
walk([
this](Operation *op) {
2353 return failure(blockWalk.wasInterrupted() || resultWalk.wasInterrupted());
2361BytecodeReader::Impl::parseIRSection(ArrayRef<uint8_t> sectionData,
2363 EncodingReader reader(sectionData, fileLoc);
2366 std::vector<RegionReadState> regionStack;
2369 OwningOpRef<ModuleOp> moduleOp = ModuleOp::create(fileLoc);
2370 regionStack.emplace_back(*moduleOp, &reader,
true);
2371 regionStack.back().curBlocks.push_back(moduleOp->getBody());
2372 regionStack.back().curBlock = regionStack.back().curRegion->begin();
2373 if (
failed(parseBlockHeader(reader, regionStack.back())))
2375 valueScopes.emplace_back();
2376 valueScopes.back().push(regionStack.back());
2379 while (!regionStack.empty())
2382 if (!forwardRefOps.empty()) {
2383 return reader.emitError(
2384 "not all forward unresolved forward operand references");
2388 if (
failed(processUseLists(*moduleOp)))
2389 return reader.emitError(
2390 "parsed use-list orders were invalid and could not be applied");
2393 for (
const std::unique_ptr<BytecodeDialect> &byteCodeDialect : dialects) {
2396 if (!byteCodeDialect->loadedVersion)
2398 if (byteCodeDialect->interface &&
2399 failed(byteCodeDialect->interface->upgradeFromVersion(
2400 *moduleOp, *byteCodeDialect->loadedVersion)))
2409 auto &parsedOps = moduleOp->getBody()->getOperations();
2411 destOps.splice(destOps.end(), parsedOps, parsedOps.begin(), parsedOps.end());
2416BytecodeReader::Impl::parseRegions(std::vector<RegionReadState> ®ionStack,
2417 RegionReadState &readState) {
2418 const auto checkSectionAlignment = [&](
unsigned alignment) {
2419 return this->checkSectionAlignment(
2420 alignment, [&](
const auto &msg) {
return emitError(fileLoc, msg); });
2426 for (; readState.curRegion != readState.endRegion; ++readState.curRegion) {
2432 if (
failed(parseRegion(readState)))
2436 if (readState.curRegion->empty())
2441 EncodingReader &reader = *readState.reader;
2443 while (readState.numOpsRemaining--) {
2446 bool isIsolatedFromAbove =
false;
2447 FailureOr<Operation *> op =
2448 parseOpWithoutRegions(reader, readState, isIsolatedFromAbove);
2456 if ((*op)->getNumRegions()) {
2457 RegionReadState childState(*op, &reader, isIsolatedFromAbove);
2462 ArrayRef<uint8_t> sectionData;
2463 if (
failed(reader.parseSection(sectionID, checkSectionAlignment,
2467 return emitError(fileLoc,
"expected IR section for region");
2468 childState.owningReader =
2469 std::make_unique<EncodingReader>(sectionData, fileLoc);
2470 childState.reader = childState.owningReader.get();
2474 if (lazyLoading && (!lazyOpsCallback || !lazyOpsCallback(*op))) {
2475 lazyLoadableOps.emplace_back(*op, std::move(childState));
2476 lazyLoadableOpsMap.try_emplace(*op,
2477 std::prev(lazyLoadableOps.end()));
2481 regionStack.push_back(std::move(childState));
2484 if (isIsolatedFromAbove)
2485 valueScopes.emplace_back();
2491 if (++readState.curBlock == readState.curRegion->end())
2493 if (
failed(parseBlockHeader(reader, readState)))
2498 readState.curBlock = {};
2499 valueScopes.back().pop(readState);
2504 if (readState.isIsolatedFromAbove) {
2505 assert(!valueScopes.empty() &&
"Expect a valueScope after reading region");
2506 valueScopes.pop_back();
2508 assert(!regionStack.empty() &&
"Expect a regionStack after reading region");
2509 regionStack.pop_back();
2513FailureOr<Operation *>
2514BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader,
2515 RegionReadState &readState,
2516 bool &isIsolatedFromAbove) {
2518 std::optional<bool> wasRegistered;
2519 FailureOr<OperationName> opName = parseOpName(reader, wasRegistered);
2526 if (
failed(reader.parseByte(opMask)))
2536 OperationState opState(opLoc, *opName);
2540 DictionaryAttr dictAttr;
2551 "Unexpected missing `wasRegistered` opname flag at "
2552 "bytecode version ")
2553 << version <<
" with properties.";
2557 if (wasRegistered) {
2558 DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
2559 dialectsMap, reader, version);
2561 propertiesReader.read(fileLoc, dialectReader, &*opName, opState)))
2573 uint64_t numResults;
2574 if (
failed(reader.parseVarInt(numResults)))
2576 opState.
types.resize(numResults);
2577 for (
int i = 0, e = numResults; i < e; ++i)
2584 uint64_t numOperands;
2585 if (
failed(reader.parseVarInt(numOperands)))
2587 opState.
operands.resize(numOperands);
2588 for (
int i = 0, e = numOperands; i < e; ++i)
2589 if (!(opState.
operands[i] = parseOperand(reader)))
2596 if (
failed(reader.parseVarInt(numSuccs)))
2599 for (
int i = 0, e = numSuccs; i < e; ++i) {
2608 std::optional<UseListMapT> resultIdxToUseListMap = std::nullopt;
2611 size_t numResults = opState.
types.size();
2612 auto parseResult = parseUseListOrderForRange(reader, numResults);
2615 resultIdxToUseListMap = std::move(*parseResult);
2620 uint64_t numRegions;
2621 if (
failed(reader.parseVarIntWithFlag(numRegions, isIsolatedFromAbove)))
2624 opState.
regions.reserve(numRegions);
2625 for (
int i = 0, e = numRegions; i < e; ++i)
2626 opState.
regions.push_back(std::make_unique<Region>());
2631 readState.curBlock->push_back(op);
2642 if (resultIdxToUseListMap.has_value()) {
2644 if (resultIdxToUseListMap->contains(idx)) {
2646 resultIdxToUseListMap->at(idx));
2653LogicalResult BytecodeReader::Impl::parseRegion(RegionReadState &readState) {
2654 EncodingReader &reader = *readState.reader;
2658 if (
failed(reader.parseVarInt(numBlocks)))
2667 if (
failed(reader.parseVarInt(numValues)))
2669 readState.numValues = numValues;
2673 readState.curBlocks.clear();
2674 readState.curBlocks.reserve(numBlocks);
2675 for (uint64_t i = 0; i < numBlocks; ++i) {
2676 readState.curBlocks.push_back(
new Block());
2677 readState.curRegion->push_back(readState.curBlocks.back());
2681 valueScopes.back().push(readState);
2684 readState.curBlock = readState.curRegion->begin();
2685 return parseBlockHeader(reader, readState);
2689BytecodeReader::Impl::parseBlockHeader(EncodingReader &reader,
2690 RegionReadState &readState) {
2692 if (
failed(reader.parseVarIntWithFlag(readState.numOpsRemaining, hasArgs)))
2696 if (hasArgs &&
failed(parseBlockArguments(reader, &*readState.curBlock)))
2703 uint8_t hasUseListOrders = 0;
2704 if (hasArgs &&
failed(reader.parseByte(hasUseListOrders)))
2707 if (!hasUseListOrders)
2710 Block &blk = *readState.curBlock;
2711 auto argIdxToUseListMap =
2713 if (
failed(argIdxToUseListMap) || argIdxToUseListMap->empty())
2717 if (argIdxToUseListMap->contains(idx))
2719 argIdxToUseListMap->at(idx));
2725LogicalResult BytecodeReader::Impl::parseBlockArguments(EncodingReader &reader,
2729 if (
failed(reader.parseVarInt(numArgs)))
2732 SmallVector<Type> argTypes;
2733 SmallVector<Location> argLocs;
2734 argTypes.reserve(numArgs);
2735 argLocs.reserve(numArgs);
2737 Location unknownLoc = UnknownLoc::get(
config.getContext());
2740 LocationAttr argLoc = unknownLoc;
2745 if (
failed(reader.parseVarIntWithFlag(typeIdx, hasLoc)) ||
2746 !(argType = attrTypeReader.resolveType(typeIdx)))
2756 argTypes.push_back(argType);
2757 argLocs.push_back(argLoc);
2767Value BytecodeReader::Impl::parseOperand(EncodingReader &reader) {
2768 std::vector<Value> &values = valueScopes.back().values;
2769 Value *value =
nullptr;
2775 *value = createForwardRef();
2779LogicalResult BytecodeReader::Impl::defineValues(EncodingReader &reader,
2781 ValueScope &valueScope = valueScopes.back();
2782 std::vector<Value> &values = valueScope.values;
2784 unsigned &valueID = valueScope.nextValueIDs.back();
2785 unsigned valueIDEnd = valueID + newValues.size();
2786 if (valueIDEnd > values.size()) {
2787 return reader.emitError(
2788 "value index range was outside of the expected range for "
2789 "the parent region, got [",
2790 valueID,
", ", valueIDEnd,
"), but the maximum index was ",
2795 for (
unsigned i = 0, e = newValues.size(); i != e; ++i, ++valueID) {
2796 Value newValue = newValues[i];
2799 if (Value oldValue = std::exchange(values[valueID], newValue)) {
2800 Operation *forwardRefOp = oldValue.getDefiningOp();
2805 assert(forwardRefOp && forwardRefOp->
getBlock() == &forwardRefOps &&
2806 "value index was already defined?");
2808 oldValue.replaceAllUsesWith(newValue);
2809 forwardRefOp->
moveBefore(&openForwardRefOps, openForwardRefOps.end());
2815Value BytecodeReader::Impl::createForwardRef() {
2818 if (!openForwardRefOps.empty()) {
2819 Operation *op = &openForwardRefOps.back();
2820 op->
moveBefore(&forwardRefOps, forwardRefOps.end());
2824 return forwardRefOps.back().getResult(0);
2835 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) {
2839 impl = std::make_unique<Impl>(sourceFileLoc,
config, lazyLoading, buffer,
2845 return impl->read(block, lazyOpsCallback);
2849 return impl->getNumOpsToMaterialize();
2853 return impl->isMaterializable(op);
2858 return impl->materialize(op, lazyOpsCallback);
2863 return impl->finalize(shouldMaterialize);
2867 return buffer.getBuffer().starts_with(
"ML\xefR");
2876 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) {
2882 "input buffer is not an MLIR bytecode file");
2886 buffer, bufferOwnerRef);
2887 return reader.
read(block,
nullptr);
2898 *sourceMgr->getMemoryBuffer(sourceMgr->getMainFileID()), block,
config,
static LogicalResult parseDialectGrouping(EncodingReader &reader, MutableArrayRef< std::unique_ptr< BytecodeDialect > > dialects, function_ref< LogicalResult(BytecodeDialect *)> entryCallback)
Parse a single dialect group encoded in the byte stream.
static LogicalResult readBytecodeFileImpl(llvm::MemoryBufferRef buffer, Block *block, const ParserConfig &config, const std::shared_ptr< llvm::SourceMgr > &bufferOwnerRef)
Read the bytecode from the provided memory buffer reference.
static bool isSectionOptional(bytecode::Section::ID sectionID, int version)
Returns true if the given top-level section ID is optional.
static LogicalResult parseResourceGroup(Location fileLoc, bool allowEmpty, EncodingReader &offsetReader, EncodingReader &resourceReader, StringSectionReader &stringReader, T *handler, const std::shared_ptr< llvm::SourceMgr > &bufferOwnerRef, function_ref< StringRef(StringRef)> remapKey={}, function_ref< LogicalResult(StringRef)> processKeyFn={})
static LogicalResult resolveEntry(EncodingReader &reader, RangeT &entries, uint64_t index, T &entry, StringRef entryStr)
Resolve an index into the given entry list.
static LogicalResult parseEntry(EncodingReader &reader, RangeT &entries, T &entry, StringRef entryStr)
Parse and resolve an index into the given entry list.
static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, unsigned nRegions=1)
LogicalResult initialize(unsigned origNumLoops, ArrayRef< ReassociationIndices > foldedIterationDims)
static std::string diag(const llvm::Value &value)
MutableArrayRef< char > getMutableData()
Return a mutable reference to the raw underlying data of this blob.
ArrayRef< char > getData() const
Return the raw underlying data of this blob.
bool isMutable() const
Return if the data of this blob is mutable.
MLIRContext * getContext() const
Return the context this attribute belongs to.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
OpListType & getOperations()
BlockArgListType getArguments()
ArrayRef< std::unique_ptr< AttrTypeBytecodeReader< Type > > > getTypeCallbacks() const
ArrayRef< std::unique_ptr< AttrTypeBytecodeReader< Attribute > > > getAttributeCallbacks() const
Returns the callbacks available to the parser.
This class is used to read a bytecode buffer and translate it into MLIR.
LogicalResult materializeAll()
Materialize all operations.
LogicalResult read(Block *block, llvm::function_ref< bool(Operation *)> lazyOps)
Read the bytecode defined within buffer into the given block.
bool isMaterializable(Operation *op)
Impl(Location fileLoc, const ParserConfig &config, bool lazyLoading, llvm::MemoryBufferRef buffer, const std::shared_ptr< llvm::SourceMgr > &bufferOwnerRef)
LogicalResult finalize(function_ref< bool(Operation *)> shouldMaterialize)
Finalize the lazy-loading by calling back with every op that hasn't been materialized to let the clie...
LogicalResult materialize(Operation *op, llvm::function_ref< bool(Operation *)> lazyOpsCallback)
Materialize the provided operation, invoke the lazyOpsCallback on every newly found lazy operation.
int64_t getNumOpsToMaterialize() const
Return the number of ops that haven't been materialized yet.
LogicalResult materialize(Operation *op, llvm::function_ref< bool(Operation *)> lazyOpsCallback=[](Operation *) { return false;})
Materialize the provide operation.
LogicalResult finalize(function_ref< bool(Operation *)> shouldMaterialize=[](Operation *) { return true;})
Finalize the lazy-loading by calling back with every op that hasn't been materialized to let the clie...
BytecodeReader(llvm::MemoryBufferRef buffer, const ParserConfig &config, bool lazyLoad, const std::shared_ptr< llvm::SourceMgr > &bufferOwnerRef={})
Create a bytecode reader for the given buffer.
int64_t getNumOpsToMaterialize() const
Return the number of ops that haven't been materialized yet.
bool isMaterializable(Operation *op)
Return true if the provided op is materializable.
LogicalResult readTopLevel(Block *block, llvm::function_ref< bool(Operation *)> lazyOps=[](Operation *) { return false;})
Read the operations defined within the given memory buffer, containing MLIR bytecode,...
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
static FileLineColLoc get(StringAttr filename, unsigned line, unsigned column)
This class represents a diagnostic that is inflight and set to be reported.
InFlightDiagnostic & append(Args &&...args) &
Append arguments to the diagnostic.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext * getContext() const
Return the context this location is uniqued in.
MLIRContext is the top-level object for a collection of MLIR operations.
T * getOrLoadDialect()
Get (or create) a dialect for the given derived dialect type.
bool allowsUnregisteredDialects()
Return true if we allow to create operation for unregistered dialects.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
bool isRegistered() const
Return if this operation is registered.
T::Concept * getInterface() const
Returns an instance of the concept object for the given interface if it was registered to this operat...
Operation is the basic unit of execution within MLIR.
void dropAllReferences()
This drops all operand uses from this operation, which is an essential step in breaking cyclic depend...
Block * getBlock()
Returns the operation block that contains this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
result_range getResults()
void erase()
Remove this operation from its parent block and delete it.
unsigned getNumResults()
Return the number of results held by this operation.
This class represents a configuration for the MLIR assembly parser.
BytecodeReaderConfig & getBytecodeReaderConfig() const
Returns the parsing configurations associated to the bytecode read.
BlockListType::iterator iterator
This diagnostic handler is a simple RAII class that registers and erases a diagnostic handler on a gi...
static AsmResourceBlob allocateWithAlign(ArrayRef< char > data, size_t align, AsmResourceBlob::DeleterFn deleter={}, bool dataIsMutable=false)
Create a new unmanaged resource directly referencing the provided data.
This class provides an abstraction over the different types of ranges over Values.
bool use_empty() const
Returns true if this value has no uses.
void shuffleUseList(ArrayRef< unsigned > indices)
Shuffle the use list order according to the provided indices.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
unsigned getNumUses() const
This method computes the number of uses of this Value.
bool hasOneUse() const
Returns true if this value has exactly one use.
use_iterator use_begin() const
static WalkResult advance()
static WalkResult interrupt()
@ kAttrType
This section contains the attributes and types referenced within an IR module.
@ kAttrTypeOffset
This section contains the offsets for the attribute and types within the AttrType section.
@ kIR
This section contains the list of operations serialized into the bytecode, and their nested regions/o...
@ kResource
This section contains the resources of the bytecode.
@ kResourceOffset
This section contains the offsets of resources within the Resource section.
@ kDialect
This section contains the dialects referenced within an IR module.
@ kString
This section contains strings referenced within the bytecode.
@ kDialectVersions
This section contains the versions of each dialect.
@ kProperties
This section contains the properties for the operations.
@ kNumSections
The total number of section types.
static uint64_t getUseID(OperandT &val, unsigned ownerID)
Get the unique ID of a value use.
@ kUseListOrdering
Use-list ordering started to be encoded in version 3.
@ kAlignmentByte
An arbitrary value used to fill alignment padding.
@ kVersion
The current bytecode version.
@ kLazyLoading
Support for lazy-loading of isolated region was added in version 2.
@ kDialectVersioning
Dialects versioning was added in version 1.
@ kElideUnknownBlockArgLocation
Avoid recording unknown locations on block arguments (compression) started in version 4.
@ kNativePropertiesEncoding
Support for encoding properties natively in bytecode instead of merged with the discardable attribute...
@ kMinSupportedVersion
The minimum supported version of the bytecode.
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
StringRef toString(AsmResourceEntryKind kind)
static LogicalResult readResourceHandle(DialectBytecodeReader &reader, FailureOr< T > &value, Ts &&...params)
Helper for resource handle reading that returns LogicalResult.
bool isBytecode(llvm::MemoryBufferRef buffer)
Returns true if the given buffer starts with the magic bytes that signal MLIR bytecode.
const FrozenRewritePatternSet GreedyRewriteConfig config
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
AsmResourceEntryKind
This enum represents the different kinds of resource values.
LogicalResult readBytecodeFile(llvm::MemoryBufferRef buffer, Block *block, const ParserConfig &config)
Read the operations defined within the given memory buffer, containing MLIR bytecode,...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
llvm::function_ref< Fn > function_ref
SmallVector< Block *, 1 > successors
Successors of this operation and their respective operands.
SmallVector< Value, 4 > operands
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Attribute propertiesAttr
This Attribute is used to opaquely construct the properties of the operation.
SmallVector< Type, 4 > types
Types of the results of this operation.