MLIR  22.0.0git
BytecodeReader.cpp
Go to the documentation of this file.
1 //===- BytecodeReader.cpp - MLIR Bytecode Reader --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 #include "mlir/Bytecode/Encoding.h"
14 #include "mlir/IR/BuiltinOps.h"
15 #include "mlir/IR/Diagnostics.h"
17 #include "mlir/IR/Verifier.h"
18 #include "mlir/IR/Visitors.h"
19 #include "mlir/Support/LLVM.h"
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/ScopeExit.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "llvm/Support/Endian.h"
25 #include "llvm/Support/MemoryBufferRef.h"
26 #include "llvm/Support/SourceMgr.h"
27 
28 #include <cstddef>
29 #include <cstdint>
30 #include <list>
31 #include <memory>
32 #include <numeric>
33 #include <optional>
34 
35 #define DEBUG_TYPE "mlir-bytecode-reader"
36 
37 using namespace mlir;
38 
39 /// Stringify the given section ID.
40 static std::string toString(bytecode::Section::ID sectionID) {
41  switch (sectionID) {
43  return "String (0)";
45  return "Dialect (1)";
47  return "AttrType (2)";
49  return "AttrTypeOffset (3)";
51  return "IR (4)";
53  return "Resource (5)";
55  return "ResourceOffset (6)";
57  return "DialectVersions (7)";
59  return "Properties (8)";
60  default:
61  return ("Unknown (" + Twine(static_cast<unsigned>(sectionID)) + ")").str();
62  }
63 }
64 
65 /// Returns true if the given top-level section ID is optional.
66 static bool isSectionOptional(bytecode::Section::ID sectionID, int version) {
67  switch (sectionID) {
73  return false;
77  return true;
79  return version < bytecode::kNativePropertiesEncoding;
80  default:
81  llvm_unreachable("unknown section ID");
82  }
83 }
84 
85 //===----------------------------------------------------------------------===//
86 // EncodingReader
87 //===----------------------------------------------------------------------===//
88 
89 namespace {
90 class EncodingReader {
91 public:
92  explicit EncodingReader(ArrayRef<uint8_t> contents, Location fileLoc)
93  : buffer(contents), dataIt(buffer.begin()), fileLoc(fileLoc) {}
94  explicit EncodingReader(StringRef contents, Location fileLoc)
95  : EncodingReader({reinterpret_cast<const uint8_t *>(contents.data()),
96  contents.size()},
97  fileLoc) {}
98 
99  /// Returns true if the entire section has been read.
100  bool empty() const { return dataIt == buffer.end(); }
101 
102  /// Returns the remaining size of the bytecode.
103  size_t size() const { return buffer.end() - dataIt; }
104 
105  /// Align the current reader position to the specified alignment.
106  LogicalResult alignTo(unsigned alignment) {
107  if (!llvm::isPowerOf2_32(alignment))
108  return emitError("expected alignment to be a power-of-two");
109 
110  auto isUnaligned = [&](const uint8_t *ptr) {
111  return ((uintptr_t)ptr & (alignment - 1)) != 0;
112  };
113 
114  // Shift the reader position to the next alignment boundary.
115  // Note: this assumes the pointer alignment matches the alignment of the
116  // data from the start of the buffer. In other words, this code is only
117  // valid if `dataIt` is offsetting into an already aligned buffer.
118  while (isUnaligned(dataIt)) {
119  uint8_t padding;
120  if (failed(parseByte(padding)))
121  return failure();
122  if (padding != bytecode::kAlignmentByte) {
123  return emitError("expected alignment byte (0xCB), but got: '0x" +
124  llvm::utohexstr(padding) + "'");
125  }
126  }
127 
128  // Ensure the data iterator is now aligned. This case is unlikely because we
129  // *just* went through the effort to align the data iterator.
130  if (LLVM_UNLIKELY(isUnaligned(dataIt))) {
131  return emitError("expected data iterator aligned to ", alignment,
132  ", but got pointer: '0x" +
133  llvm::utohexstr((uintptr_t)dataIt) + "'");
134  }
135 
136  return success();
137  }
138 
139  /// Emit an error using the given arguments.
140  template <typename... Args>
141  InFlightDiagnostic emitError(Args &&...args) const {
142  return ::emitError(fileLoc).append(std::forward<Args>(args)...);
143  }
144  InFlightDiagnostic emitError() const { return ::emitError(fileLoc); }
145 
146  /// Parse a single byte from the stream.
147  template <typename T>
148  LogicalResult parseByte(T &value) {
149  if (empty())
150  return emitError("attempting to parse a byte at the end of the bytecode");
151  value = static_cast<T>(*dataIt++);
152  return success();
153  }
154  /// Parse a range of bytes of 'length' into the given result.
155  LogicalResult parseBytes(size_t length, ArrayRef<uint8_t> &result) {
156  if (length > size()) {
157  return emitError("attempting to parse ", length, " bytes when only ",
158  size(), " remain");
159  }
160  result = {dataIt, length};
161  dataIt += length;
162  return success();
163  }
164  /// Parse a range of bytes of 'length' into the given result, which can be
165  /// assumed to be large enough to hold `length`.
166  LogicalResult parseBytes(size_t length, uint8_t *result) {
167  if (length > size()) {
168  return emitError("attempting to parse ", length, " bytes when only ",
169  size(), " remain");
170  }
171  memcpy(result, dataIt, length);
172  dataIt += length;
173  return success();
174  }
175 
176  /// Parse an aligned blob of data, where the alignment was encoded alongside
177  /// the data.
178  LogicalResult parseBlobAndAlignment(ArrayRef<uint8_t> &data,
179  uint64_t &alignment) {
180  uint64_t dataSize;
181  if (failed(parseVarInt(alignment)) || failed(parseVarInt(dataSize)) ||
182  failed(alignTo(alignment)))
183  return failure();
184  return parseBytes(dataSize, data);
185  }
186 
187  /// Parse a variable length encoded integer from the byte stream. The first
188  /// encoded byte contains a prefix in the low bits indicating the encoded
189  /// length of the value. This length prefix is a bit sequence of '0's followed
190  /// by a '1'. The number of '0' bits indicate the number of _additional_ bytes
191  /// (not including the prefix byte). All remaining bits in the first byte,
192  /// along with all of the bits in additional bytes, provide the value of the
193  /// integer encoded in little-endian order.
194  LogicalResult parseVarInt(uint64_t &result) {
195  // Parse the first byte of the encoding, which contains the length prefix.
196  if (failed(parseByte(result)))
197  return failure();
198 
199  // Handle the overwhelmingly common case where the value is stored in a
200  // single byte. In this case, the first bit is the `1` marker bit.
201  if (LLVM_LIKELY(result & 1)) {
202  result >>= 1;
203  return success();
204  }
205 
206  // Handle the overwhelming uncommon case where the value required all 8
207  // bytes (i.e. a really really big number). In this case, the marker byte is
208  // all zeros: `00000000`.
209  if (LLVM_UNLIKELY(result == 0)) {
210  llvm::support::ulittle64_t resultLE;
211  if (failed(parseBytes(sizeof(resultLE),
212  reinterpret_cast<uint8_t *>(&resultLE))))
213  return failure();
214  result = resultLE;
215  return success();
216  }
217  return parseMultiByteVarInt(result);
218  }
219 
220  /// Parse a signed variable length encoded integer from the byte stream. A
221  /// signed varint is encoded as a normal varint with zigzag encoding applied,
222  /// i.e. the low bit of the value is used to indicate the sign.
223  LogicalResult parseSignedVarInt(uint64_t &result) {
224  if (failed(parseVarInt(result)))
225  return failure();
226  // Essentially (but using unsigned): (x >> 1) ^ -(x & 1)
227  result = (result >> 1) ^ (~(result & 1) + 1);
228  return success();
229  }
230 
231  /// Parse a variable length encoded integer whose low bit is used to encode an
232  /// unrelated flag, i.e: `(integerValue << 1) | (flag ? 1 : 0)`.
233  LogicalResult parseVarIntWithFlag(uint64_t &result, bool &flag) {
234  if (failed(parseVarInt(result)))
235  return failure();
236  flag = result & 1;
237  result >>= 1;
238  return success();
239  }
240 
241  /// Skip the first `length` bytes within the reader.
242  LogicalResult skipBytes(size_t length) {
243  if (length > size()) {
244  return emitError("attempting to skip ", length, " bytes when only ",
245  size(), " remain");
246  }
247  dataIt += length;
248  return success();
249  }
250 
251  /// Parse a null-terminated string into `result` (without including the NUL
252  /// terminator).
253  LogicalResult parseNullTerminatedString(StringRef &result) {
254  const char *startIt = (const char *)dataIt;
255  const char *nulIt = (const char *)memchr(startIt, 0, size());
256  if (!nulIt)
257  return emitError(
258  "malformed null-terminated string, no null character found");
259 
260  result = StringRef(startIt, nulIt - startIt);
261  dataIt = (const uint8_t *)nulIt + 1;
262  return success();
263  }
264 
265  /// Validate that the alignment requested in the section is valid.
266  using ValidateAlignmentFn = function_ref<LogicalResult(unsigned alignment)>;
267 
268  /// Parse a section header, placing the kind of section in `sectionID` and the
269  /// contents of the section in `sectionData`.
270  LogicalResult parseSection(bytecode::Section::ID &sectionID,
271  ValidateAlignmentFn alignmentValidator,
272  ArrayRef<uint8_t> &sectionData) {
273  uint8_t sectionIDAndHasAlignment;
274  uint64_t length;
275  if (failed(parseByte(sectionIDAndHasAlignment)) ||
276  failed(parseVarInt(length)))
277  return failure();
278 
279  // Extract the section ID and whether the section is aligned. The high bit
280  // of the ID is the alignment flag.
281  sectionID = static_cast<bytecode::Section::ID>(sectionIDAndHasAlignment &
282  0b01111111);
283  bool hasAlignment = sectionIDAndHasAlignment & 0b10000000;
284 
285  // Check that the section is actually valid before trying to process its
286  // data.
287  if (sectionID >= bytecode::Section::kNumSections)
288  return emitError("invalid section ID: ", unsigned(sectionID));
289 
290  // Process the section alignment if present.
291  if (hasAlignment) {
292  // Read the requested alignment from the bytecode parser.
293  uint64_t alignment;
294  if (failed(parseVarInt(alignment)))
295  return failure();
296 
297  // Check that the requested alignment must not exceed the alignment of
298  // the root buffer itself. Otherwise we cannot guarantee that pointers
299  // derived from this buffer will actually satisfy the requested alignment
300  // globally.
301  //
302  // Consider a bytecode buffer that is guaranteed to be 8k aligned, but not
303  // 16k aligned (e.g. absolute address 40960. If a section inside this
304  // buffer declares a 16k alignment requirement, two problems can arise:
305  //
306  // (a) If we "align forward" the current pointer to the next
307  // 16k boundary, the amount of padding we skip depends on the
308  // buffer's starting address. For example:
309  //
310  // buffer_start = 40960
311  // next 16k boundary = 49152
312  // bytes skipped = 49152 - 40960 = 8192
313  //
314  // This leaves behind variable padding that could be misinterpreted
315  // as part of the next section.
316  //
317  // (b) If we align relative to the buffer start, we may
318  // obtain addresses that are multiples of "buffer_start +
319  // section_alignment" rather than truly globally aligned
320  // addresses. For example:
321  //
322  // buffer_start = 40960 (5×8k, 8k aligned but not 16k)
323  // offset = 16384 (first multiple of 16k)
324  // section_ptr = 40960 + 16384 = 57344
325  //
326  // 57344 is 8k aligned but not 16k aligned.
327  // Any consumer expecting true 16k alignment would see this as a
328  // violation.
329  if (failed(alignmentValidator(alignment)))
330  return emitError("failed to align section ID: ", unsigned(sectionID));
331 
332  // Align the buffer.
333  if (failed(alignTo(alignment)))
334  return failure();
335  }
336 
337  // Parse the actual section data.
338  return parseBytes(static_cast<size_t>(length), sectionData);
339  }
340 
341  Location getLoc() const { return fileLoc; }
342 
343 private:
344  /// Parse a variable length encoded integer from the byte stream. This method
345  /// is a fallback when the number of bytes used to encode the value is greater
346  /// than 1, but less than the max (9). The provided `result` value can be
347  /// assumed to already contain the first byte of the value.
348  /// NOTE: This method is marked noinline to avoid pessimizing the common case
349  /// of single byte encoding.
350  LLVM_ATTRIBUTE_NOINLINE LogicalResult parseMultiByteVarInt(uint64_t &result) {
351  // Count the number of trailing zeros in the marker byte, this indicates the
352  // number of trailing bytes that are part of the value. We use `uint32_t`
353  // here because we only care about the first byte, and so that be actually
354  // get ctz intrinsic calls when possible (the `uint8_t` overload uses a loop
355  // implementation).
356  uint32_t numBytes = llvm::countr_zero<uint32_t>(result);
357  assert(numBytes > 0 && numBytes <= 7 &&
358  "unexpected number of trailing zeros in varint encoding");
359 
360  // Parse in the remaining bytes of the value.
361  llvm::support::ulittle64_t resultLE(result);
362  if (failed(
363  parseBytes(numBytes, reinterpret_cast<uint8_t *>(&resultLE) + 1)))
364  return failure();
365 
366  // Shift out the low-order bits that were used to mark how the value was
367  // encoded.
368  result = resultLE >> (numBytes + 1);
369  return success();
370  }
371 
372  /// The bytecode buffer.
373  ArrayRef<uint8_t> buffer;
374 
375  /// The current iterator within the 'buffer'.
376  const uint8_t *dataIt;
377 
378  /// A location for the bytecode used to report errors.
379  Location fileLoc;
380 };
381 } // namespace
382 
383 /// Resolve an index into the given entry list. `entry` may either be a
384 /// reference, in which case it is assigned to the corresponding value in
385 /// `entries`, or a pointer, in which case it is assigned to the address of the
386 /// element in `entries`.
387 template <typename RangeT, typename T>
388 static LogicalResult resolveEntry(EncodingReader &reader, RangeT &entries,
389  uint64_t index, T &entry,
390  StringRef entryStr) {
391  if (index >= entries.size())
392  return reader.emitError("invalid ", entryStr, " index: ", index);
393 
394  // If the provided entry is a pointer, resolve to the address of the entry.
395  if constexpr (std::is_convertible_v<llvm::detail::ValueOfRange<RangeT>, T>)
396  entry = entries[index];
397  else
398  entry = &entries[index];
399  return success();
400 }
401 
402 /// Parse and resolve an index into the given entry list.
403 template <typename RangeT, typename T>
404 static LogicalResult parseEntry(EncodingReader &reader, RangeT &entries,
405  T &entry, StringRef entryStr) {
406  uint64_t entryIdx;
407  if (failed(reader.parseVarInt(entryIdx)))
408  return failure();
409  return resolveEntry(reader, entries, entryIdx, entry, entryStr);
410 }
411 
412 //===----------------------------------------------------------------------===//
413 // StringSectionReader
414 //===----------------------------------------------------------------------===//
415 
416 namespace {
417 /// This class is used to read references to the string section from the
418 /// bytecode.
419 class StringSectionReader {
420 public:
421  /// Initialize the string section reader with the given section data.
422  LogicalResult initialize(Location fileLoc, ArrayRef<uint8_t> sectionData);
423 
424  /// Parse a shared string from the string section. The shared string is
425  /// encoded using an index to a corresponding string in the string section.
426  LogicalResult parseString(EncodingReader &reader, StringRef &result) const {
427  return parseEntry(reader, strings, result, "string");
428  }
429 
430  /// Parse a shared string from the string section. The shared string is
431  /// encoded using an index to a corresponding string in the string section.
432  /// This variant parses a flag compressed with the index.
433  LogicalResult parseStringWithFlag(EncodingReader &reader, StringRef &result,
434  bool &flag) const {
435  uint64_t entryIdx;
436  if (failed(reader.parseVarIntWithFlag(entryIdx, flag)))
437  return failure();
438  return parseStringAtIndex(reader, entryIdx, result);
439  }
440 
441  /// Parse a shared string from the string section. The shared string is
442  /// encoded using an index to a corresponding string in the string section.
443  LogicalResult parseStringAtIndex(EncodingReader &reader, uint64_t index,
444  StringRef &result) const {
445  return resolveEntry(reader, strings, index, result, "string");
446  }
447 
448 private:
449  /// The table of strings referenced within the bytecode file.
450  SmallVector<StringRef> strings;
451 };
452 } // namespace
453 
454 LogicalResult StringSectionReader::initialize(Location fileLoc,
455  ArrayRef<uint8_t> sectionData) {
456  EncodingReader stringReader(sectionData, fileLoc);
457 
458  // Parse the number of strings in the section.
459  uint64_t numStrings;
460  if (failed(stringReader.parseVarInt(numStrings)))
461  return failure();
462  strings.resize(numStrings);
463 
464  // Parse each of the strings. The sizes of the strings are encoded in reverse
465  // order, so that's the order we populate the table.
466  size_t stringDataEndOffset = sectionData.size();
467  for (StringRef &string : llvm::reverse(strings)) {
468  uint64_t stringSize;
469  if (failed(stringReader.parseVarInt(stringSize)))
470  return failure();
471  if (stringDataEndOffset < stringSize) {
472  return stringReader.emitError(
473  "string size exceeds the available data size");
474  }
475 
476  // Extract the string from the data, dropping the null character.
477  size_t stringOffset = stringDataEndOffset - stringSize;
478  string = StringRef(
479  reinterpret_cast<const char *>(sectionData.data() + stringOffset),
480  stringSize - 1);
481  stringDataEndOffset = stringOffset;
482  }
483 
484  // Check that the only remaining data was for the strings, i.e. the reader
485  // should be at the same offset as the first string.
486  if ((sectionData.size() - stringReader.size()) != stringDataEndOffset) {
487  return stringReader.emitError("unexpected trailing data between the "
488  "offsets for strings and their data");
489  }
490  return success();
491 }
492 
493 //===----------------------------------------------------------------------===//
494 // BytecodeDialect
495 //===----------------------------------------------------------------------===//
496 
497 namespace {
498 class DialectReader;
499 
500 /// This struct represents a dialect entry within the bytecode.
501 struct BytecodeDialect {
502  /// Load the dialect into the provided context if it hasn't been loaded yet.
503  /// Returns failure if the dialect couldn't be loaded *and* the provided
504  /// context does not allow unregistered dialects. The provided reader is used
505  /// for error emission if necessary.
506  LogicalResult load(const DialectReader &reader, MLIRContext *ctx);
507 
508  /// Return the loaded dialect, or nullptr if the dialect is unknown. This can
509  /// only be called after `load`.
510  Dialect *getLoadedDialect() const {
511  assert(dialect &&
512  "expected `load` to be invoked before `getLoadedDialect`");
513  return *dialect;
514  }
515 
516  /// The loaded dialect entry. This field is std::nullopt if we haven't
517  /// attempted to load, nullptr if we failed to load, otherwise the loaded
518  /// dialect.
519  std::optional<Dialect *> dialect;
520 
521  /// The bytecode interface of the dialect, or nullptr if the dialect does not
522  /// implement the bytecode interface. This field should only be checked if the
523  /// `dialect` field is not std::nullopt.
524  const BytecodeDialectInterface *interface = nullptr;
525 
526  /// The name of the dialect.
527  StringRef name;
528 
529  /// A buffer containing the encoding of the dialect version parsed.
530  ArrayRef<uint8_t> versionBuffer;
531 
532  /// Lazy loaded dialect version from the handle above.
533  std::unique_ptr<DialectVersion> loadedVersion;
534 };
535 
536 /// This struct represents an operation name entry within the bytecode.
537 struct BytecodeOperationName {
538  BytecodeOperationName(BytecodeDialect *dialect, StringRef name,
539  std::optional<bool> wasRegistered)
540  : dialect(dialect), name(name), wasRegistered(wasRegistered) {}
541 
542  /// The loaded operation name, or std::nullopt if it hasn't been processed
543  /// yet.
544  std::optional<OperationName> opName;
545 
546  /// The dialect that owns this operation name.
547  BytecodeDialect *dialect;
548 
549  /// The name of the operation, without the dialect prefix.
550  StringRef name;
551 
552  /// Whether this operation was registered when the bytecode was produced.
553  /// This flag is populated when bytecode version >=kNativePropertiesEncoding.
554  std::optional<bool> wasRegistered;
555 };
556 } // namespace
557 
558 /// Parse a single dialect group encoded in the byte stream.
559 static LogicalResult parseDialectGrouping(
560  EncodingReader &reader,
561  MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
562  function_ref<LogicalResult(BytecodeDialect *)> entryCallback) {
563  // Parse the dialect and the number of entries in the group.
564  std::unique_ptr<BytecodeDialect> *dialect;
565  if (failed(parseEntry(reader, dialects, dialect, "dialect")))
566  return failure();
567  uint64_t numEntries;
568  if (failed(reader.parseVarInt(numEntries)))
569  return failure();
570 
571  for (uint64_t i = 0; i < numEntries; ++i)
572  if (failed(entryCallback(dialect->get())))
573  return failure();
574  return success();
575 }
576 
577 //===----------------------------------------------------------------------===//
578 // ResourceSectionReader
579 //===----------------------------------------------------------------------===//
580 
581 namespace {
582 /// This class is used to read the resource section from the bytecode.
583 class ResourceSectionReader {
584 public:
585  /// Initialize the resource section reader with the given section data.
586  LogicalResult
587  initialize(Location fileLoc, const ParserConfig &config,
588  MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
589  StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData,
590  ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader,
591  const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef);
592 
593  /// Parse a dialect resource handle from the resource section.
594  LogicalResult parseResourceHandle(EncodingReader &reader,
595  AsmDialectResourceHandle &result) const {
596  return parseEntry(reader, dialectResources, result, "resource handle");
597  }
598 
599 private:
600  /// The table of dialect resources within the bytecode file.
601  SmallVector<AsmDialectResourceHandle> dialectResources;
602  llvm::StringMap<std::string> dialectResourceHandleRenamingMap;
603 };
604 
605 class ParsedResourceEntry : public AsmParsedResourceEntry {
606 public:
607  ParsedResourceEntry(StringRef key, AsmResourceEntryKind kind,
608  EncodingReader &reader, StringSectionReader &stringReader,
609  const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef)
610  : key(key), kind(kind), reader(reader), stringReader(stringReader),
611  bufferOwnerRef(bufferOwnerRef) {}
612  ~ParsedResourceEntry() override = default;
613 
614  StringRef getKey() const final { return key; }
615 
616  InFlightDiagnostic emitError() const final { return reader.emitError(); }
617 
618  AsmResourceEntryKind getKind() const final { return kind; }
619 
620  FailureOr<bool> parseAsBool() const final {
622  return emitError() << "expected a bool resource entry, but found a "
623  << toString(kind) << " entry instead";
624 
625  bool value;
626  if (failed(reader.parseByte(value)))
627  return failure();
628  return value;
629  }
630  FailureOr<std::string> parseAsString() const final {
632  return emitError() << "expected a string resource entry, but found a "
633  << toString(kind) << " entry instead";
634 
635  StringRef string;
636  if (failed(stringReader.parseString(reader, string)))
637  return failure();
638  return string.str();
639  }
640 
641  FailureOr<AsmResourceBlob>
642  parseAsBlob(BlobAllocatorFn allocator) const final {
644  return emitError() << "expected a blob resource entry, but found a "
645  << toString(kind) << " entry instead";
646 
647  ArrayRef<uint8_t> data;
648  uint64_t alignment;
649  if (failed(reader.parseBlobAndAlignment(data, alignment)))
650  return failure();
651 
652  // If we have an extendable reference to the buffer owner, we don't need to
653  // allocate a new buffer for the data, and can use the data directly.
654  if (bufferOwnerRef) {
655  ArrayRef<char> charData(reinterpret_cast<const char *>(data.data()),
656  data.size());
657 
658  // Allocate an unmanager buffer which captures a reference to the owner.
659  // For now we just mark this as immutable, but in the future we should
660  // explore marking this as mutable when desired.
662  charData, alignment,
663  [bufferOwnerRef = bufferOwnerRef](void *, size_t, size_t) {});
664  }
665 
666  // Allocate memory for the blob using the provided allocator and copy the
667  // data into it.
668  AsmResourceBlob blob = allocator(data.size(), alignment);
669  assert(llvm::isAddrAligned(llvm::Align(alignment), blob.getData().data()) &&
670  blob.isMutable() &&
671  "blob allocator did not return a properly aligned address");
672  memcpy(blob.getMutableData().data(), data.data(), data.size());
673  return blob;
674  }
675 
676 private:
677  StringRef key;
679  EncodingReader &reader;
680  StringSectionReader &stringReader;
681  const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef;
682 };
683 } // namespace
684 
685 template <typename T>
686 static LogicalResult
687 parseResourceGroup(Location fileLoc, bool allowEmpty,
688  EncodingReader &offsetReader, EncodingReader &resourceReader,
689  StringSectionReader &stringReader, T *handler,
690  const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef,
691  function_ref<StringRef(StringRef)> remapKey = {},
692  function_ref<LogicalResult(StringRef)> processKeyFn = {}) {
693  uint64_t numResources;
694  if (failed(offsetReader.parseVarInt(numResources)))
695  return failure();
696 
697  for (uint64_t i = 0; i < numResources; ++i) {
698  StringRef key;
700  uint64_t resourceOffset;
701  ArrayRef<uint8_t> data;
702  if (failed(stringReader.parseString(offsetReader, key)) ||
703  failed(offsetReader.parseVarInt(resourceOffset)) ||
704  failed(offsetReader.parseByte(kind)) ||
705  failed(resourceReader.parseBytes(resourceOffset, data)))
706  return failure();
707 
708  // Process the resource key.
709  if ((processKeyFn && failed(processKeyFn(key))))
710  return failure();
711 
712  // If the resource data is empty and we allow it, don't error out when
713  // parsing below, just skip it.
714  if (allowEmpty && data.empty())
715  continue;
716 
717  // Ignore the entry if we don't have a valid handler.
718  if (!handler)
719  continue;
720 
721  // Otherwise, parse the resource value.
722  EncodingReader entryReader(data, fileLoc);
723  key = remapKey(key);
724  ParsedResourceEntry entry(key, kind, entryReader, stringReader,
725  bufferOwnerRef);
726  if (failed(handler->parseResource(entry)))
727  return failure();
728  if (!entryReader.empty()) {
729  return entryReader.emitError(
730  "unexpected trailing bytes in resource entry '", key, "'");
731  }
732  }
733  return success();
734 }
735 
736 LogicalResult ResourceSectionReader::initialize(
737  Location fileLoc, const ParserConfig &config,
738  MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
739  StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData,
740  ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader,
741  const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) {
742  EncodingReader resourceReader(sectionData, fileLoc);
743  EncodingReader offsetReader(offsetSectionData, fileLoc);
744 
745  // Read the number of external resource providers.
746  uint64_t numExternalResourceGroups;
747  if (failed(offsetReader.parseVarInt(numExternalResourceGroups)))
748  return failure();
749 
750  // Utility functor that dispatches to `parseResourceGroup`, but implicitly
751  // provides most of the arguments.
752  auto parseGroup = [&](auto *handler, bool allowEmpty = false,
753  function_ref<LogicalResult(StringRef)> keyFn = {}) {
754  auto resolveKey = [&](StringRef key) -> StringRef {
755  auto it = dialectResourceHandleRenamingMap.find(key);
756  if (it == dialectResourceHandleRenamingMap.end())
757  return key;
758  return it->second;
759  };
760 
761  return parseResourceGroup(fileLoc, allowEmpty, offsetReader, resourceReader,
762  stringReader, handler, bufferOwnerRef, resolveKey,
763  keyFn);
764  };
765 
766  // Read the external resources from the bytecode.
767  for (uint64_t i = 0; i < numExternalResourceGroups; ++i) {
768  StringRef key;
769  if (failed(stringReader.parseString(offsetReader, key)))
770  return failure();
771 
772  // Get the handler for these resources.
773  // TODO: Should we require handling external resources in some scenarios?
774  AsmResourceParser *handler = config.getResourceParser(key);
775  if (!handler) {
776  emitWarning(fileLoc) << "ignoring unknown external resources for '" << key
777  << "'";
778  }
779 
780  if (failed(parseGroup(handler)))
781  return failure();
782  }
783 
784  // Read the dialect resources from the bytecode.
785  MLIRContext *ctx = fileLoc->getContext();
786  while (!offsetReader.empty()) {
787  std::unique_ptr<BytecodeDialect> *dialect;
788  if (failed(parseEntry(offsetReader, dialects, dialect, "dialect")) ||
789  failed((*dialect)->load(dialectReader, ctx)))
790  return failure();
791  Dialect *loadedDialect = (*dialect)->getLoadedDialect();
792  if (!loadedDialect) {
793  return resourceReader.emitError()
794  << "dialect '" << (*dialect)->name << "' is unknown";
795  }
796  const auto *handler = dyn_cast<OpAsmDialectInterface>(loadedDialect);
797  if (!handler) {
798  return resourceReader.emitError()
799  << "unexpected resources for dialect '" << (*dialect)->name << "'";
800  }
801 
802  // Ensure that each resource is declared before being processed.
803  auto processResourceKeyFn = [&](StringRef key) -> LogicalResult {
804  FailureOr<AsmDialectResourceHandle> handle =
805  handler->declareResource(key);
806  if (failed(handle)) {
807  return resourceReader.emitError()
808  << "unknown 'resource' key '" << key << "' for dialect '"
809  << (*dialect)->name << "'";
810  }
811  dialectResourceHandleRenamingMap[key] = handler->getResourceKey(*handle);
812  dialectResources.push_back(*handle);
813  return success();
814  };
815 
816  // Parse the resources for this dialect. We allow empty resources because we
817  // just treat these as declarations.
818  if (failed(parseGroup(handler, /*allowEmpty=*/true, processResourceKeyFn)))
819  return failure();
820  }
821 
822  return success();
823 }
824 
825 //===----------------------------------------------------------------------===//
826 // Attribute/Type Reader
827 //===----------------------------------------------------------------------===//
828 
829 namespace {
830 /// This class provides support for reading attribute and type entries from the
831 /// bytecode. Attribute and Type entries are read lazily on demand, so we use
832 /// this reader to manage when to actually parse them from the bytecode.
833 class AttrTypeReader {
834  /// This class represents a single attribute or type entry.
835  template <typename T>
836  struct Entry {
837  /// The entry, or null if it hasn't been resolved yet.
838  T entry = {};
839  /// The parent dialect of this entry.
840  BytecodeDialect *dialect = nullptr;
841  /// A flag indicating if the entry was encoded using a custom encoding,
842  /// instead of using the textual assembly format.
843  bool hasCustomEncoding = false;
844  /// The raw data of this entry in the bytecode.
845  ArrayRef<uint8_t> data;
846  };
847  using AttrEntry = Entry<Attribute>;
848  using TypeEntry = Entry<Type>;
849 
850 public:
851  AttrTypeReader(const StringSectionReader &stringReader,
852  const ResourceSectionReader &resourceReader,
853  const llvm::StringMap<BytecodeDialect *> &dialectsMap,
854  uint64_t &bytecodeVersion, Location fileLoc,
855  const ParserConfig &config)
856  : stringReader(stringReader), resourceReader(resourceReader),
857  dialectsMap(dialectsMap), fileLoc(fileLoc),
858  bytecodeVersion(bytecodeVersion), parserConfig(config) {}
859 
860  /// Initialize the attribute and type information within the reader.
861  LogicalResult
862  initialize(MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
863  ArrayRef<uint8_t> sectionData,
864  ArrayRef<uint8_t> offsetSectionData);
865 
866  /// Resolve the attribute or type at the given index. Returns nullptr on
867  /// failure.
868  Attribute resolveAttribute(size_t index) {
869  return resolveEntry(attributes, index, "Attribute");
870  }
871  Type resolveType(size_t index) { return resolveEntry(types, index, "Type"); }
872 
873  /// Parse a reference to an attribute or type using the given reader.
874  LogicalResult parseAttribute(EncodingReader &reader, Attribute &result) {
875  uint64_t attrIdx;
876  if (failed(reader.parseVarInt(attrIdx)))
877  return failure();
878  result = resolveAttribute(attrIdx);
879  return success(!!result);
880  }
881  LogicalResult parseOptionalAttribute(EncodingReader &reader,
882  Attribute &result) {
883  uint64_t attrIdx;
884  bool flag;
885  if (failed(reader.parseVarIntWithFlag(attrIdx, flag)))
886  return failure();
887  if (!flag)
888  return success();
889  result = resolveAttribute(attrIdx);
890  return success(!!result);
891  }
892 
893  LogicalResult parseType(EncodingReader &reader, Type &result) {
894  uint64_t typeIdx;
895  if (failed(reader.parseVarInt(typeIdx)))
896  return failure();
897  result = resolveType(typeIdx);
898  return success(!!result);
899  }
900 
901  template <typename T>
902  LogicalResult parseAttribute(EncodingReader &reader, T &result) {
903  Attribute baseResult;
904  if (failed(parseAttribute(reader, baseResult)))
905  return failure();
906  if ((result = dyn_cast<T>(baseResult)))
907  return success();
908  return reader.emitError("expected attribute of type: ",
909  llvm::getTypeName<T>(), ", but got: ", baseResult);
910  }
911 
912 private:
913  /// Resolve the given entry at `index`.
914  template <typename T>
915  T resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index,
916  StringRef entryType);
917 
918  /// Parse an entry using the given reader that was encoded using the textual
919  /// assembly format.
920  template <typename T>
921  LogicalResult parseAsmEntry(T &result, EncodingReader &reader,
922  StringRef entryType);
923 
924  /// Parse an entry using the given reader that was encoded using a custom
925  /// bytecode format.
926  template <typename T>
927  LogicalResult parseCustomEntry(Entry<T> &entry, EncodingReader &reader,
928  StringRef entryType);
929 
930  /// The string section reader used to resolve string references when parsing
931  /// custom encoded attribute/type entries.
932  const StringSectionReader &stringReader;
933 
934  /// The resource section reader used to resolve resource references when
935  /// parsing custom encoded attribute/type entries.
936  const ResourceSectionReader &resourceReader;
937 
938  /// The map of the loaded dialects used to retrieve dialect information, such
939  /// as the dialect version.
940  const llvm::StringMap<BytecodeDialect *> &dialectsMap;
941 
942  /// The set of attribute and type entries.
943  SmallVector<AttrEntry> attributes;
945 
946  /// A location used for error emission.
947  Location fileLoc;
948 
949  /// Current bytecode version being used.
950  uint64_t &bytecodeVersion;
951 
952  /// Reference to the parser configuration.
953  const ParserConfig &parserConfig;
954 };
955 
956 class DialectReader : public DialectBytecodeReader {
957 public:
958  DialectReader(AttrTypeReader &attrTypeReader,
959  const StringSectionReader &stringReader,
960  const ResourceSectionReader &resourceReader,
961  const llvm::StringMap<BytecodeDialect *> &dialectsMap,
962  EncodingReader &reader, uint64_t &bytecodeVersion)
963  : attrTypeReader(attrTypeReader), stringReader(stringReader),
964  resourceReader(resourceReader), dialectsMap(dialectsMap),
965  reader(reader), bytecodeVersion(bytecodeVersion) {}
966 
967  InFlightDiagnostic emitError(const Twine &msg) const override {
968  return reader.emitError(msg);
969  }
970 
971  FailureOr<const DialectVersion *>
972  getDialectVersion(StringRef dialectName) const override {
973  // First check if the dialect is available in the map.
974  auto dialectEntry = dialectsMap.find(dialectName);
975  if (dialectEntry == dialectsMap.end())
976  return failure();
977  // If the dialect was found, try to load it. This will trigger reading the
978  // bytecode version from the version buffer if it wasn't already processed.
979  // Return failure if either of those two actions could not be completed.
980  if (failed(dialectEntry->getValue()->load(*this, getLoc().getContext())) ||
981  dialectEntry->getValue()->loadedVersion == nullptr)
982  return failure();
983  return dialectEntry->getValue()->loadedVersion.get();
984  }
985 
986  MLIRContext *getContext() const override { return getLoc().getContext(); }
987 
988  uint64_t getBytecodeVersion() const override { return bytecodeVersion; }
989 
990  DialectReader withEncodingReader(EncodingReader &encReader) const {
991  return DialectReader(attrTypeReader, stringReader, resourceReader,
992  dialectsMap, encReader, bytecodeVersion);
993  }
994 
995  Location getLoc() const { return reader.getLoc(); }
996 
997  //===--------------------------------------------------------------------===//
998  // IR
999  //===--------------------------------------------------------------------===//
1000 
1001  LogicalResult readAttribute(Attribute &result) override {
1002  return attrTypeReader.parseAttribute(reader, result);
1003  }
1004  LogicalResult readOptionalAttribute(Attribute &result) override {
1005  return attrTypeReader.parseOptionalAttribute(reader, result);
1006  }
1007  LogicalResult readType(Type &result) override {
1008  return attrTypeReader.parseType(reader, result);
1009  }
1010 
1011  FailureOr<AsmDialectResourceHandle> readResourceHandle() override {
1012  AsmDialectResourceHandle handle;
1013  if (failed(resourceReader.parseResourceHandle(reader, handle)))
1014  return failure();
1015  return handle;
1016  }
1017 
1018  //===--------------------------------------------------------------------===//
1019  // Primitives
1020  //===--------------------------------------------------------------------===//
1021 
1022  LogicalResult readVarInt(uint64_t &result) override {
1023  return reader.parseVarInt(result);
1024  }
1025 
1026  LogicalResult readSignedVarInt(int64_t &result) override {
1027  uint64_t unsignedResult;
1028  if (failed(reader.parseSignedVarInt(unsignedResult)))
1029  return failure();
1030  result = static_cast<int64_t>(unsignedResult);
1031  return success();
1032  }
1033 
1034  FailureOr<APInt> readAPIntWithKnownWidth(unsigned bitWidth) override {
1035  // Small values are encoded using a single byte.
1036  if (bitWidth <= 8) {
1037  uint8_t value;
1038  if (failed(reader.parseByte(value)))
1039  return failure();
1040  return APInt(bitWidth, value);
1041  }
1042 
1043  // Large values up to 64 bits are encoded using a single varint.
1044  if (bitWidth <= 64) {
1045  uint64_t value;
1046  if (failed(reader.parseSignedVarInt(value)))
1047  return failure();
1048  return APInt(bitWidth, value);
1049  }
1050 
1051  // Otherwise, for really big values we encode the array of active words in
1052  // the value.
1053  uint64_t numActiveWords;
1054  if (failed(reader.parseVarInt(numActiveWords)))
1055  return failure();
1056  SmallVector<uint64_t, 4> words(numActiveWords);
1057  for (uint64_t i = 0; i < numActiveWords; ++i)
1058  if (failed(reader.parseSignedVarInt(words[i])))
1059  return failure();
1060  return APInt(bitWidth, words);
1061  }
1062 
1063  FailureOr<APFloat>
1064  readAPFloatWithKnownSemantics(const llvm::fltSemantics &semantics) override {
1065  FailureOr<APInt> intVal =
1066  readAPIntWithKnownWidth(APFloat::getSizeInBits(semantics));
1067  if (failed(intVal))
1068  return failure();
1069  return APFloat(semantics, *intVal);
1070  }
1071 
1072  LogicalResult readString(StringRef &result) override {
1073  return stringReader.parseString(reader, result);
1074  }
1075 
1076  LogicalResult readBlob(ArrayRef<char> &result) override {
1077  uint64_t dataSize;
1078  ArrayRef<uint8_t> data;
1079  if (failed(reader.parseVarInt(dataSize)) ||
1080  failed(reader.parseBytes(dataSize, data)))
1081  return failure();
1082  result = llvm::ArrayRef(reinterpret_cast<const char *>(data.data()),
1083  data.size());
1084  return success();
1085  }
1086 
1087  LogicalResult readBool(bool &result) override {
1088  return reader.parseByte(result);
1089  }
1090 
1091 private:
1092  AttrTypeReader &attrTypeReader;
1093  const StringSectionReader &stringReader;
1094  const ResourceSectionReader &resourceReader;
1095  const llvm::StringMap<BytecodeDialect *> &dialectsMap;
1096  EncodingReader &reader;
1097  uint64_t &bytecodeVersion;
1098 };
1099 
1100 /// Wraps the properties section and handles reading properties out of it.
1101 class PropertiesSectionReader {
1102 public:
1103  /// Initialize the properties section reader with the given section data.
1104  LogicalResult initialize(Location fileLoc, ArrayRef<uint8_t> sectionData) {
1105  if (sectionData.empty())
1106  return success();
1107  EncodingReader propReader(sectionData, fileLoc);
1108  uint64_t count;
1109  if (failed(propReader.parseVarInt(count)))
1110  return failure();
1111  // Parse the raw properties buffer.
1112  if (failed(propReader.parseBytes(propReader.size(), propertiesBuffers)))
1113  return failure();
1114 
1115  EncodingReader offsetsReader(propertiesBuffers, fileLoc);
1116  offsetTable.reserve(count);
1117  for (auto idx : llvm::seq<int64_t>(0, count)) {
1118  (void)idx;
1119  offsetTable.push_back(propertiesBuffers.size() - offsetsReader.size());
1120  ArrayRef<uint8_t> rawProperties;
1121  uint64_t dataSize;
1122  if (failed(offsetsReader.parseVarInt(dataSize)) ||
1123  failed(offsetsReader.parseBytes(dataSize, rawProperties)))
1124  return failure();
1125  }
1126  if (!offsetsReader.empty())
1127  return offsetsReader.emitError()
1128  << "Broken properties section: didn't exhaust the offsets table";
1129  return success();
1130  }
1131 
1132  LogicalResult read(Location fileLoc, DialectReader &dialectReader,
1133  OperationName *opName, OperationState &opState) const {
1134  uint64_t propertiesIdx;
1135  if (failed(dialectReader.readVarInt(propertiesIdx)))
1136  return failure();
1137  if (propertiesIdx >= offsetTable.size())
1138  return dialectReader.emitError("Properties idx out-of-bound for ")
1139  << opName->getStringRef();
1140  size_t propertiesOffset = offsetTable[propertiesIdx];
1141  if (propertiesIdx >= propertiesBuffers.size())
1142  return dialectReader.emitError("Properties offset out-of-bound for ")
1143  << opName->getStringRef();
1144 
1145  // Acquire the sub-buffer that represent the requested properties.
1146  ArrayRef<char> rawProperties;
1147  {
1148  // "Seek" to the requested offset by getting a new reader with the right
1149  // sub-buffer.
1150  EncodingReader reader(propertiesBuffers.drop_front(propertiesOffset),
1151  fileLoc);
1152  // Properties are stored as a sequence of {size + raw_data}.
1153  if (failed(
1154  dialectReader.withEncodingReader(reader).readBlob(rawProperties)))
1155  return failure();
1156  }
1157  // Setup a new reader to read from the `rawProperties` sub-buffer.
1158  EncodingReader reader(
1159  StringRef(rawProperties.begin(), rawProperties.size()), fileLoc);
1160  DialectReader propReader = dialectReader.withEncodingReader(reader);
1161 
1162  auto *iface = opName->getInterface<BytecodeOpInterface>();
1163  if (iface)
1164  return iface->readProperties(propReader, opState);
1165  if (opName->isRegistered())
1166  return propReader.emitError(
1167  "has properties but missing BytecodeOpInterface for ")
1168  << opName->getStringRef();
1169  // Unregistered op are storing properties as an attribute.
1170  return propReader.readAttribute(opState.propertiesAttr);
1171  }
1172 
1173 private:
1174  /// The properties buffer referenced within the bytecode file.
1175  ArrayRef<uint8_t> propertiesBuffers;
1176 
1177  /// Table of offset in the buffer above.
1178  SmallVector<int64_t> offsetTable;
1179 };
1180 } // namespace
1181 
1182 LogicalResult AttrTypeReader::initialize(
1183  MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
1184  ArrayRef<uint8_t> sectionData, ArrayRef<uint8_t> offsetSectionData) {
1185  EncodingReader offsetReader(offsetSectionData, fileLoc);
1186 
1187  // Parse the number of attribute and type entries.
1188  uint64_t numAttributes, numTypes;
1189  if (failed(offsetReader.parseVarInt(numAttributes)) ||
1190  failed(offsetReader.parseVarInt(numTypes)))
1191  return failure();
1192  attributes.resize(numAttributes);
1193  types.resize(numTypes);
1194 
1195  // A functor used to accumulate the offsets for the entries in the given
1196  // range.
1197  uint64_t currentOffset = 0;
1198  auto parseEntries = [&](auto &&range) {
1199  size_t currentIndex = 0, endIndex = range.size();
1200 
1201  // Parse an individual entry.
1202  auto parseEntryFn = [&](BytecodeDialect *dialect) -> LogicalResult {
1203  auto &entry = range[currentIndex++];
1204 
1205  uint64_t entrySize;
1206  if (failed(offsetReader.parseVarIntWithFlag(entrySize,
1207  entry.hasCustomEncoding)))
1208  return failure();
1209 
1210  // Verify that the offset is actually valid.
1211  if (currentOffset + entrySize > sectionData.size()) {
1212  return offsetReader.emitError(
1213  "Attribute or Type entry offset points past the end of section");
1214  }
1215 
1216  entry.data = sectionData.slice(currentOffset, entrySize);
1217  entry.dialect = dialect;
1218  currentOffset += entrySize;
1219  return success();
1220  };
1221  while (currentIndex != endIndex)
1222  if (failed(parseDialectGrouping(offsetReader, dialects, parseEntryFn)))
1223  return failure();
1224  return success();
1225  };
1226 
1227  // Process each of the attributes, and then the types.
1228  if (failed(parseEntries(attributes)) || failed(parseEntries(types)))
1229  return failure();
1230 
1231  // Ensure that we read everything from the section.
1232  if (!offsetReader.empty()) {
1233  return offsetReader.emitError(
1234  "unexpected trailing data in the Attribute/Type offset section");
1235  }
1236 
1237  return success();
1238 }
1239 
1240 template <typename T>
1241 T AttrTypeReader::resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index,
1242  StringRef entryType) {
1243  if (index >= entries.size()) {
1244  emitError(fileLoc) << "invalid " << entryType << " index: " << index;
1245  return {};
1246  }
1247 
1248  // If the entry has already been resolved, there is nothing left to do.
1249  Entry<T> &entry = entries[index];
1250  if (entry.entry)
1251  return entry.entry;
1252 
1253  // Parse the entry.
1254  EncodingReader reader(entry.data, fileLoc);
1255 
1256  // Parse based on how the entry was encoded.
1257  if (entry.hasCustomEncoding) {
1258  if (failed(parseCustomEntry(entry, reader, entryType)))
1259  return T();
1260  } else if (failed(parseAsmEntry(entry.entry, reader, entryType))) {
1261  return T();
1262  }
1263 
1264  if (!reader.empty()) {
1265  reader.emitError("unexpected trailing bytes after " + entryType + " entry");
1266  return T();
1267  }
1268  return entry.entry;
1269 }
1270 
1271 template <typename T>
1272 LogicalResult AttrTypeReader::parseAsmEntry(T &result, EncodingReader &reader,
1273  StringRef entryType) {
1274  StringRef asmStr;
1275  if (failed(reader.parseNullTerminatedString(asmStr)))
1276  return failure();
1277 
1278  // Invoke the MLIR assembly parser to parse the entry text.
1279  size_t numRead = 0;
1280  MLIRContext *context = fileLoc->getContext();
1281  if constexpr (std::is_same_v<T, Type>)
1282  result =
1283  ::parseType(asmStr, context, &numRead, /*isKnownNullTerminated=*/true);
1284  else
1285  result = ::parseAttribute(asmStr, context, Type(), &numRead,
1286  /*isKnownNullTerminated=*/true);
1287  if (!result)
1288  return failure();
1289 
1290  // Ensure there weren't dangling characters after the entry.
1291  if (numRead != asmStr.size()) {
1292  return reader.emitError("trailing characters found after ", entryType,
1293  " assembly format: ", asmStr.drop_front(numRead));
1294  }
1295  return success();
1296 }
1297 
1298 template <typename T>
1299 LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry,
1300  EncodingReader &reader,
1301  StringRef entryType) {
1302  DialectReader dialectReader(*this, stringReader, resourceReader, dialectsMap,
1303  reader, bytecodeVersion);
1304  if (failed(entry.dialect->load(dialectReader, fileLoc.getContext())))
1305  return failure();
1306 
1307  if constexpr (std::is_same_v<T, Type>) {
1308  // Try parsing with callbacks first if available.
1309  for (const auto &callback :
1310  parserConfig.getBytecodeReaderConfig().getTypeCallbacks()) {
1311  if (failed(
1312  callback->read(dialectReader, entry.dialect->name, entry.entry)))
1313  return failure();
1314  // Early return if parsing was successful.
1315  if (!!entry.entry)
1316  return success();
1317 
1318  // Reset the reader if we failed to parse, so we can fall through the
1319  // other parsing functions.
1320  reader = EncodingReader(entry.data, reader.getLoc());
1321  }
1322  } else {
1323  // Try parsing with callbacks first if available.
1324  for (const auto &callback :
1325  parserConfig.getBytecodeReaderConfig().getAttributeCallbacks()) {
1326  if (failed(
1327  callback->read(dialectReader, entry.dialect->name, entry.entry)))
1328  return failure();
1329  // Early return if parsing was successful.
1330  if (!!entry.entry)
1331  return success();
1332 
1333  // Reset the reader if we failed to parse, so we can fall through the
1334  // other parsing functions.
1335  reader = EncodingReader(entry.data, reader.getLoc());
1336  }
1337  }
1338 
1339  // Ensure that the dialect implements the bytecode interface.
1340  if (!entry.dialect->interface) {
1341  return reader.emitError("dialect '", entry.dialect->name,
1342  "' does not implement the bytecode interface");
1343  }
1344 
1345  if constexpr (std::is_same_v<T, Type>)
1346  entry.entry = entry.dialect->interface->readType(dialectReader);
1347  else
1348  entry.entry = entry.dialect->interface->readAttribute(dialectReader);
1349 
1350  return success(!!entry.entry);
1351 }
1352 
1353 //===----------------------------------------------------------------------===//
1354 // Bytecode Reader
1355 //===----------------------------------------------------------------------===//
1356 
1357 /// This class is used to read a bytecode buffer and translate it into MLIR.
1359  struct RegionReadState;
1360  using LazyLoadableOpsInfo =
1361  std::list<std::pair<Operation *, RegionReadState>>;
1362  using LazyLoadableOpsMap =
1364 
1365 public:
1366  Impl(Location fileLoc, const ParserConfig &config, bool lazyLoading,
1367  llvm::MemoryBufferRef buffer,
1368  const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef)
1369  : config(config), fileLoc(fileLoc), lazyLoading(lazyLoading),
1370  attrTypeReader(stringReader, resourceReader, dialectsMap, version,
1371  fileLoc, config),
1372  // Use the builtin unrealized conversion cast operation to represent
1373  // forward references to values that aren't yet defined.
1374  forwardRefOpState(UnknownLoc::get(config.getContext()),
1375  "builtin.unrealized_conversion_cast", ValueRange(),
1376  NoneType::get(config.getContext())),
1377  buffer(buffer), bufferOwnerRef(bufferOwnerRef) {}
1378 
1379  /// Read the bytecode defined within `buffer` into the given block.
1380  LogicalResult read(Block *block,
1381  llvm::function_ref<bool(Operation *)> lazyOps);
1382 
1383  /// Return the number of ops that haven't been materialized yet.
1384  int64_t getNumOpsToMaterialize() const { return lazyLoadableOpsMap.size(); }
1385 
1386  bool isMaterializable(Operation *op) { return lazyLoadableOpsMap.count(op); }
1387 
1388  /// Materialize the provided operation, invoke the lazyOpsCallback on every
1389  /// newly found lazy operation.
1390  LogicalResult
1392  llvm::function_ref<bool(Operation *)> lazyOpsCallback) {
1393  this->lazyOpsCallback = lazyOpsCallback;
1394  auto resetlazyOpsCallback =
1395  llvm::make_scope_exit([&] { this->lazyOpsCallback = nullptr; });
1396  auto it = lazyLoadableOpsMap.find(op);
1397  assert(it != lazyLoadableOpsMap.end() &&
1398  "materialize called on non-materializable op");
1399  return materialize(it);
1400  }
1401 
1402  /// Materialize all operations.
1403  LogicalResult materializeAll() {
1404  while (!lazyLoadableOpsMap.empty()) {
1405  if (failed(materialize(lazyLoadableOpsMap.begin())))
1406  return failure();
1407  }
1408  return success();
1409  }
1410 
1411  /// Finalize the lazy-loading by calling back with every op that hasn't been
1412  /// materialized to let the client decide if the op should be deleted or
1413  /// materialized. The op is materialized if the callback returns true, deleted
1414  /// otherwise.
1415  LogicalResult finalize(function_ref<bool(Operation *)> shouldMaterialize) {
1416  while (!lazyLoadableOps.empty()) {
1417  Operation *op = lazyLoadableOps.begin()->first;
1418  if (shouldMaterialize(op)) {
1419  if (failed(materialize(lazyLoadableOpsMap.find(op))))
1420  return failure();
1421  continue;
1422  }
1423  op->dropAllReferences();
1424  op->erase();
1425  lazyLoadableOps.pop_front();
1426  lazyLoadableOpsMap.erase(op);
1427  }
1428  return success();
1429  }
1430 
1431 private:
1432  LogicalResult materialize(LazyLoadableOpsMap::iterator it) {
1433  assert(it != lazyLoadableOpsMap.end() &&
1434  "materialize called on non-materializable op");
1435  valueScopes.emplace_back();
1436  std::vector<RegionReadState> regionStack;
1437  regionStack.push_back(std::move(it->getSecond()->second));
1438  lazyLoadableOps.erase(it->getSecond());
1439  lazyLoadableOpsMap.erase(it);
1440 
1441  while (!regionStack.empty())
1442  if (failed(parseRegions(regionStack, regionStack.back())))
1443  return failure();
1444  return success();
1445  }
1446 
1447  LogicalResult checkSectionAlignment(
1448  unsigned alignment,
1449  function_ref<InFlightDiagnostic(const Twine &error)> emitError) {
1450  // Check that the bytecode buffer meets the requested section alignment.
1451  //
1452  // If it does not, the virtual address of the item in the section will
1453  // not be aligned to the requested alignment.
1454  //
1455  // The typical case where this is necessary is the resource blob
1456  // optimization in `parseAsBlob` where we reference the weights from the
1457  // provided buffer instead of copying them to a new allocation.
1458  const bool isGloballyAligned =
1459  ((uintptr_t)buffer.getBufferStart() & (alignment - 1)) == 0;
1460 
1461  if (!isGloballyAligned)
1462  return emitError("expected section alignment ")
1463  << alignment << " but bytecode buffer 0x"
1464  << Twine::utohexstr((uint64_t)buffer.getBufferStart())
1465  << " is not aligned";
1466 
1467  return success();
1468  };
1469 
1470  /// Return the context for this config.
1471  MLIRContext *getContext() const { return config.getContext(); }
1472 
1473  /// Parse the bytecode version.
1474  LogicalResult parseVersion(EncodingReader &reader);
1475 
1476  //===--------------------------------------------------------------------===//
1477  // Dialect Section
1478 
1479  LogicalResult parseDialectSection(ArrayRef<uint8_t> sectionData);
1480 
1481  /// Parse an operation name reference using the given reader, and set the
1482  /// `wasRegistered` flag that indicates if the bytecode was produced by a
1483  /// context where opName was registered.
1484  FailureOr<OperationName> parseOpName(EncodingReader &reader,
1485  std::optional<bool> &wasRegistered);
1486 
1487  //===--------------------------------------------------------------------===//
1488  // Attribute/Type Section
1489 
1490  /// Parse an attribute or type using the given reader.
1491  template <typename T>
1492  LogicalResult parseAttribute(EncodingReader &reader, T &result) {
1493  return attrTypeReader.parseAttribute(reader, result);
1494  }
1495  LogicalResult parseType(EncodingReader &reader, Type &result) {
1496  return attrTypeReader.parseType(reader, result);
1497  }
1498 
1499  //===--------------------------------------------------------------------===//
1500  // Resource Section
1501 
1502  LogicalResult
1503  parseResourceSection(EncodingReader &reader,
1504  std::optional<ArrayRef<uint8_t>> resourceData,
1505  std::optional<ArrayRef<uint8_t>> resourceOffsetData);
1506 
1507  //===--------------------------------------------------------------------===//
1508  // IR Section
1509 
1510  /// This struct represents the current read state of a range of regions. This
1511  /// struct is used to enable iterative parsing of regions.
1512  struct RegionReadState {
1513  RegionReadState(Operation *op, EncodingReader *reader,
1514  bool isIsolatedFromAbove)
1515  : RegionReadState(op->getRegions(), reader, isIsolatedFromAbove) {}
1516  RegionReadState(MutableArrayRef<Region> regions, EncodingReader *reader,
1517  bool isIsolatedFromAbove)
1518  : curRegion(regions.begin()), endRegion(regions.end()), reader(reader),
1519  isIsolatedFromAbove(isIsolatedFromAbove) {}
1520 
1521  /// The current regions being read.
1522  MutableArrayRef<Region>::iterator curRegion, endRegion;
1523  /// This is the reader to use for this region, this pointer is pointing to
1524  /// the parent region reader unless the current region is IsolatedFromAbove,
1525  /// in which case the pointer is pointing to the `owningReader` which is a
1526  /// section dedicated to the current region.
1527  EncodingReader *reader;
1528  std::unique_ptr<EncodingReader> owningReader;
1529 
1530  /// The number of values defined immediately within this region.
1531  unsigned numValues = 0;
1532 
1533  /// The current blocks of the region being read.
1534  SmallVector<Block *> curBlocks;
1535  Region::iterator curBlock = {};
1536 
1537  /// The number of operations remaining to be read from the current block
1538  /// being read.
1539  uint64_t numOpsRemaining = 0;
1540 
1541  /// A flag indicating if the regions being read are isolated from above.
1542  bool isIsolatedFromAbove = false;
1543  };
1544 
1545  LogicalResult parseIRSection(ArrayRef<uint8_t> sectionData, Block *block);
1546  LogicalResult parseRegions(std::vector<RegionReadState> &regionStack,
1547  RegionReadState &readState);
1548  FailureOr<Operation *> parseOpWithoutRegions(EncodingReader &reader,
1549  RegionReadState &readState,
1550  bool &isIsolatedFromAbove);
1551 
1552  LogicalResult parseRegion(RegionReadState &readState);
1553  LogicalResult parseBlockHeader(EncodingReader &reader,
1554  RegionReadState &readState);
1555  LogicalResult parseBlockArguments(EncodingReader &reader, Block *block);
1556 
1557  //===--------------------------------------------------------------------===//
1558  // Value Processing
1559 
1560  /// Parse an operand reference using the given reader. Returns nullptr in the
1561  /// case of failure.
1562  Value parseOperand(EncodingReader &reader);
1563 
1564  /// Sequentially define the given value range.
1565  LogicalResult defineValues(EncodingReader &reader, ValueRange values);
1566 
1567  /// Create a value to use for a forward reference.
1568  Value createForwardRef();
1569 
1570  //===--------------------------------------------------------------------===//
1571  // Use-list order helpers
1572 
1573  /// This struct is a simple storage that contains information required to
1574  /// reorder the use-list of a value with respect to the pre-order traversal
1575  /// ordering.
1576  struct UseListOrderStorage {
1577  UseListOrderStorage(bool isIndexPairEncoding,
1578  SmallVector<unsigned, 4> &&indices)
1579  : indices(std::move(indices)),
1580  isIndexPairEncoding(isIndexPairEncoding) {};
1581  /// The vector containing the information required to reorder the
1582  /// use-list of a value.
1583  SmallVector<unsigned, 4> indices;
1584 
1585  /// Whether indices represent a pair of type `(src, dst)` or it is a direct
1586  /// indexing, such as `dst = order[src]`.
1587  bool isIndexPairEncoding;
1588  };
1589 
1590  /// Parse use-list order from bytecode for a range of values if available. The
1591  /// range is expected to be either a block argument or an op result range. On
1592  /// success, return a map of the position in the range and the use-list order
1593  /// encoding. The function assumes to know the size of the range it is
1594  /// processing.
1595  using UseListMapT = DenseMap<unsigned, UseListOrderStorage>;
1596  FailureOr<UseListMapT> parseUseListOrderForRange(EncodingReader &reader,
1597  uint64_t rangeSize);
1598 
1599  /// Shuffle the use-chain according to the order parsed.
1600  LogicalResult sortUseListOrder(Value value);
1601 
1602  /// Recursively visit all the values defined within topLevelOp and sort the
1603  /// use-list orders according to the indices parsed.
1604  LogicalResult processUseLists(Operation *topLevelOp);
1605 
1606  //===--------------------------------------------------------------------===//
1607  // Fields
1608 
1609  /// This class represents a single value scope, in which a value scope is
1610  /// delimited by isolated from above regions.
1611  struct ValueScope {
1612  /// Push a new region state onto this scope, reserving enough values for
1613  /// those defined within the current region of the provided state.
1614  void push(RegionReadState &readState) {
1615  nextValueIDs.push_back(values.size());
1616  values.resize(values.size() + readState.numValues);
1617  }
1618 
1619  /// Pop the values defined for the current region within the provided region
1620  /// state.
1621  void pop(RegionReadState &readState) {
1622  values.resize(values.size() - readState.numValues);
1623  nextValueIDs.pop_back();
1624  }
1625 
1626  /// The set of values defined in this scope.
1627  std::vector<Value> values;
1628 
1629  /// The ID for the next defined value for each region current being
1630  /// processed in this scope.
1631  SmallVector<unsigned, 4> nextValueIDs;
1632  };
1633 
1634  /// The configuration of the parser.
1635  const ParserConfig &config;
1636 
1637  /// A location to use when emitting errors.
1638  Location fileLoc;
1639 
1640  /// Flag that indicates if lazyloading is enabled.
1641  bool lazyLoading;
1642 
1643  /// Keep track of operations that have been lazy loaded (their regions haven't
1644  /// been materialized), along with the `RegionReadState` that allows to
1645  /// lazy-load the regions nested under the operation.
1646  LazyLoadableOpsInfo lazyLoadableOps;
1647  LazyLoadableOpsMap lazyLoadableOpsMap;
1648  llvm::function_ref<bool(Operation *)> lazyOpsCallback;
1649 
1650  /// The reader used to process attribute and types within the bytecode.
1651  AttrTypeReader attrTypeReader;
1652 
1653  /// The version of the bytecode being read.
1654  uint64_t version = 0;
1655 
1656  /// The producer of the bytecode being read.
1657  StringRef producer;
1658 
1659  /// The table of IR units referenced within the bytecode file.
1661  llvm::StringMap<BytecodeDialect *> dialectsMap;
1663 
1664  /// The reader used to process resources within the bytecode.
1665  ResourceSectionReader resourceReader;
1666 
1667  /// Worklist of values with custom use-list orders to process before the end
1668  /// of the parsing.
1669  DenseMap<void *, UseListOrderStorage> valueToUseListMap;
1670 
1671  /// The table of strings referenced within the bytecode file.
1672  StringSectionReader stringReader;
1673 
1674  /// The table of properties referenced by the operation in the bytecode file.
1675  PropertiesSectionReader propertiesReader;
1676 
1677  /// The current set of available IR value scopes.
1678  std::vector<ValueScope> valueScopes;
1679 
1680  /// The global pre-order operation ordering.
1681  DenseMap<Operation *, unsigned> operationIDs;
1682 
1683  /// A block containing the set of operations defined to create forward
1684  /// references.
1685  Block forwardRefOps;
1686 
1687  /// A block containing previously created, and no longer used, forward
1688  /// reference operations.
1689  Block openForwardRefOps;
1690 
1691  /// An operation state used when instantiating forward references.
1692  OperationState forwardRefOpState;
1693 
1694  /// Reference to the input buffer.
1695  llvm::MemoryBufferRef buffer;
1696 
1697  /// The optional owning source manager, which when present may be used to
1698  /// extend the lifetime of the input buffer.
1699  const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef;
1700 };
1701 
1703  Block *block, llvm::function_ref<bool(Operation *)> lazyOpsCallback) {
1704  EncodingReader reader(buffer.getBuffer(), fileLoc);
1705  this->lazyOpsCallback = lazyOpsCallback;
1706  auto resetlazyOpsCallback =
1707  llvm::make_scope_exit([&] { this->lazyOpsCallback = nullptr; });
1708 
1709  // Skip over the bytecode header, this should have already been checked.
1710  if (failed(reader.skipBytes(StringRef("ML\xefR").size())))
1711  return failure();
1712  // Parse the bytecode version and producer.
1713  if (failed(parseVersion(reader)) ||
1714  failed(reader.parseNullTerminatedString(producer)))
1715  return failure();
1716 
1717  // Add a diagnostic handler that attaches a note that includes the original
1718  // producer of the bytecode.
1719  ScopedDiagnosticHandler diagHandler(getContext(), [&](Diagnostic &diag) {
1720  diag.attachNote() << "in bytecode version " << version
1721  << " produced by: " << producer;
1722  return failure();
1723  });
1724 
1725  const auto checkSectionAlignment = [&](unsigned alignment) {
1726  return this->checkSectionAlignment(
1727  alignment, [&](const auto &msg) { return reader.emitError(msg); });
1728  };
1729 
1730  // Parse the raw data for each of the top-level sections of the bytecode.
1731  std::optional<ArrayRef<uint8_t>>
1732  sectionDatas[bytecode::Section::kNumSections];
1733  while (!reader.empty()) {
1734  // Read the next section from the bytecode.
1735  bytecode::Section::ID sectionID;
1736  ArrayRef<uint8_t> sectionData;
1737  if (failed(
1738  reader.parseSection(sectionID, checkSectionAlignment, sectionData)))
1739  return failure();
1740 
1741  // Check for duplicate sections, we only expect one instance of each.
1742  if (sectionDatas[sectionID]) {
1743  return reader.emitError("duplicate top-level section: ",
1744  ::toString(sectionID));
1745  }
1746  sectionDatas[sectionID] = sectionData;
1747  }
1748  // Check that all of the required sections were found.
1749  for (int i = 0; i < bytecode::Section::kNumSections; ++i) {
1750  bytecode::Section::ID sectionID = static_cast<bytecode::Section::ID>(i);
1751  if (!sectionDatas[i] && !isSectionOptional(sectionID, version)) {
1752  return reader.emitError("missing data for top-level section: ",
1753  ::toString(sectionID));
1754  }
1755  }
1756 
1757  // Process the string section first.
1758  if (failed(stringReader.initialize(
1759  fileLoc, *sectionDatas[bytecode::Section::kString])))
1760  return failure();
1761 
1762  // Process the properties section.
1763  if (sectionDatas[bytecode::Section::kProperties] &&
1764  failed(propertiesReader.initialize(
1765  fileLoc, *sectionDatas[bytecode::Section::kProperties])))
1766  return failure();
1767 
1768  // Process the dialect section.
1769  if (failed(parseDialectSection(*sectionDatas[bytecode::Section::kDialect])))
1770  return failure();
1771 
1772  // Process the resource section if present.
1773  if (failed(parseResourceSection(
1774  reader, sectionDatas[bytecode::Section::kResource],
1775  sectionDatas[bytecode::Section::kResourceOffset])))
1776  return failure();
1777 
1778  // Process the attribute and type section.
1779  if (failed(attrTypeReader.initialize(
1780  dialects, *sectionDatas[bytecode::Section::kAttrType],
1781  *sectionDatas[bytecode::Section::kAttrTypeOffset])))
1782  return failure();
1783 
1784  // Finally, process the IR section.
1785  return parseIRSection(*sectionDatas[bytecode::Section::kIR], block);
1786 }
1787 
1788 LogicalResult BytecodeReader::Impl::parseVersion(EncodingReader &reader) {
1789  if (failed(reader.parseVarInt(version)))
1790  return failure();
1791 
1792  // Validate the bytecode version.
1793  uint64_t currentVersion = bytecode::kVersion;
1794  uint64_t minSupportedVersion = bytecode::kMinSupportedVersion;
1795  if (version < minSupportedVersion) {
1796  return reader.emitError("bytecode version ", version,
1797  " is older than the current version of ",
1798  currentVersion, ", and upgrade is not supported");
1799  }
1800  if (version > currentVersion) {
1801  return reader.emitError("bytecode version ", version,
1802  " is newer than the current version ",
1803  currentVersion);
1804  }
1805  // Override any request to lazy-load if the bytecode version is too old.
1806  if (version < bytecode::kLazyLoading)
1807  lazyLoading = false;
1808  return success();
1809 }
1810 
1811 //===----------------------------------------------------------------------===//
1812 // Dialect Section
1813 //===----------------------------------------------------------------------===//
1814 
1815 LogicalResult BytecodeDialect::load(const DialectReader &reader,
1816  MLIRContext *ctx) {
1817  if (dialect)
1818  return success();
1819  Dialect *loadedDialect = ctx->getOrLoadDialect(name);
1820  if (!loadedDialect && !ctx->allowsUnregisteredDialects()) {
1821  return reader.emitError("dialect '")
1822  << name
1823  << "' is unknown. If this is intended, please call "
1824  "allowUnregisteredDialects() on the MLIRContext, or use "
1825  "-allow-unregistered-dialect with the MLIR tool used.";
1826  }
1827  dialect = loadedDialect;
1828 
1829  // If the dialect was actually loaded, check to see if it has a bytecode
1830  // interface.
1831  if (loadedDialect)
1832  interface = dyn_cast<BytecodeDialectInterface>(loadedDialect);
1833  if (!versionBuffer.empty()) {
1834  if (!interface)
1835  return reader.emitError("dialect '")
1836  << name
1837  << "' does not implement the bytecode interface, "
1838  "but found a version entry";
1839  EncodingReader encReader(versionBuffer, reader.getLoc());
1840  DialectReader versionReader = reader.withEncodingReader(encReader);
1841  loadedVersion = interface->readVersion(versionReader);
1842  if (!loadedVersion)
1843  return failure();
1844  }
1845  return success();
1846 }
1847 
1848 LogicalResult
1849 BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) {
1850  EncodingReader sectionReader(sectionData, fileLoc);
1851 
1852  // Parse the number of dialects in the section.
1853  uint64_t numDialects;
1854  if (failed(sectionReader.parseVarInt(numDialects)))
1855  return failure();
1856  dialects.resize(numDialects);
1857 
1858  const auto checkSectionAlignment = [&](unsigned alignment) {
1859  return this->checkSectionAlignment(alignment, [&](const auto &msg) {
1860  return sectionReader.emitError(msg);
1861  });
1862  };
1863 
1864  // Parse each of the dialects.
1865  for (uint64_t i = 0; i < numDialects; ++i) {
1866  dialects[i] = std::make_unique<BytecodeDialect>();
1867  /// Before version kDialectVersioning, there wasn't any versioning available
1868  /// for dialects, and the entryIdx represent the string itself.
1869  if (version < bytecode::kDialectVersioning) {
1870  if (failed(stringReader.parseString(sectionReader, dialects[i]->name)))
1871  return failure();
1872  continue;
1873  }
1874 
1875  // Parse ID representing dialect and version.
1876  uint64_t dialectNameIdx;
1877  bool versionAvailable;
1878  if (failed(sectionReader.parseVarIntWithFlag(dialectNameIdx,
1879  versionAvailable)))
1880  return failure();
1881  if (failed(stringReader.parseStringAtIndex(sectionReader, dialectNameIdx,
1882  dialects[i]->name)))
1883  return failure();
1884  if (versionAvailable) {
1885  bytecode::Section::ID sectionID;
1886  if (failed(sectionReader.parseSection(sectionID, checkSectionAlignment,
1887  dialects[i]->versionBuffer)))
1888  return failure();
1889  if (sectionID != bytecode::Section::kDialectVersions) {
1890  emitError(fileLoc, "expected dialect version section");
1891  return failure();
1892  }
1893  }
1894  dialectsMap[dialects[i]->name] = dialects[i].get();
1895  }
1896 
1897  // Parse the operation names, which are grouped by dialect.
1898  auto parseOpName = [&](BytecodeDialect *dialect) {
1899  StringRef opName;
1900  std::optional<bool> wasRegistered;
1901  // Prior to version kNativePropertiesEncoding, the information about wheter
1902  // an op was registered or not wasn't encoded.
1903  if (version < bytecode::kNativePropertiesEncoding) {
1904  if (failed(stringReader.parseString(sectionReader, opName)))
1905  return failure();
1906  } else {
1907  bool wasRegisteredFlag;
1908  if (failed(stringReader.parseStringWithFlag(sectionReader, opName,
1909  wasRegisteredFlag)))
1910  return failure();
1911  wasRegistered = wasRegisteredFlag;
1912  }
1913  opNames.emplace_back(dialect, opName, wasRegistered);
1914  return success();
1915  };
1916  // Avoid re-allocation in bytecode version >=kElideUnknownBlockArgLocation
1917  // where the number of ops are known.
1918  if (version >= bytecode::kElideUnknownBlockArgLocation) {
1919  uint64_t numOps;
1920  if (failed(sectionReader.parseVarInt(numOps)))
1921  return failure();
1922  opNames.reserve(numOps);
1923  }
1924  while (!sectionReader.empty())
1925  if (failed(parseDialectGrouping(sectionReader, dialects, parseOpName)))
1926  return failure();
1927  return success();
1928 }
1929 
1930 FailureOr<OperationName>
1931 BytecodeReader::Impl::parseOpName(EncodingReader &reader,
1932  std::optional<bool> &wasRegistered) {
1933  BytecodeOperationName *opName = nullptr;
1934  if (failed(parseEntry(reader, opNames, opName, "operation name")))
1935  return failure();
1936  wasRegistered = opName->wasRegistered;
1937  // Check to see if this operation name has already been resolved. If we
1938  // haven't, load the dialect and build the operation name.
1939  if (!opName->opName) {
1940  // If the opName is empty, this is because we use to accept names such as
1941  // `foo` without any `.` separator. We shouldn't tolerate this in textual
1942  // format anymore but for now we'll be backward compatible. This can only
1943  // happen with unregistered dialects.
1944  if (opName->name.empty()) {
1945  opName->opName.emplace(opName->dialect->name, getContext());
1946  } else {
1947  // Load the dialect and its version.
1948  DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
1949  dialectsMap, reader, version);
1950  if (failed(opName->dialect->load(dialectReader, getContext())))
1951  return failure();
1952  opName->opName.emplace((opName->dialect->name + "." + opName->name).str(),
1953  getContext());
1954  }
1955  }
1956  return *opName->opName;
1957 }
1958 
1959 //===----------------------------------------------------------------------===//
1960 // Resource Section
1961 //===----------------------------------------------------------------------===//
1962 
1963 LogicalResult BytecodeReader::Impl::parseResourceSection(
1964  EncodingReader &reader, std::optional<ArrayRef<uint8_t>> resourceData,
1965  std::optional<ArrayRef<uint8_t>> resourceOffsetData) {
1966  // Ensure both sections are either present or not.
1967  if (resourceData.has_value() != resourceOffsetData.has_value()) {
1968  if (resourceOffsetData)
1969  return emitError(fileLoc, "unexpected resource offset section when "
1970  "resource section is not present");
1971  return emitError(
1972  fileLoc,
1973  "expected resource offset section when resource section is present");
1974  }
1975 
1976  // If the resource sections are absent, there is nothing to do.
1977  if (!resourceData)
1978  return success();
1979 
1980  // Initialize the resource reader with the resource sections.
1981  DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
1982  dialectsMap, reader, version);
1983  return resourceReader.initialize(fileLoc, config, dialects, stringReader,
1984  *resourceData, *resourceOffsetData,
1985  dialectReader, bufferOwnerRef);
1986 }
1987 
1988 //===----------------------------------------------------------------------===//
1989 // UseListOrder Helpers
1990 //===----------------------------------------------------------------------===//
1991 
1992 FailureOr<BytecodeReader::Impl::UseListMapT>
1993 BytecodeReader::Impl::parseUseListOrderForRange(EncodingReader &reader,
1994  uint64_t numResults) {
1996  uint64_t numValuesToRead = 1;
1997  if (numResults > 1 && failed(reader.parseVarInt(numValuesToRead)))
1998  return failure();
1999 
2000  for (size_t valueIdx = 0; valueIdx < numValuesToRead; valueIdx++) {
2001  uint64_t resultIdx = 0;
2002  if (numResults > 1 && failed(reader.parseVarInt(resultIdx)))
2003  return failure();
2004 
2005  uint64_t numValues;
2006  bool indexPairEncoding;
2007  if (failed(reader.parseVarIntWithFlag(numValues, indexPairEncoding)))
2008  return failure();
2009 
2010  SmallVector<unsigned, 4> useListOrders;
2011  for (size_t idx = 0; idx < numValues; idx++) {
2012  uint64_t index;
2013  if (failed(reader.parseVarInt(index)))
2014  return failure();
2015  useListOrders.push_back(index);
2016  }
2017 
2018  // Store in a map the result index
2019  map.try_emplace(resultIdx, UseListOrderStorage(indexPairEncoding,
2020  std::move(useListOrders)));
2021  }
2022 
2023  return map;
2024 }
2025 
2026 /// Sorts each use according to the order specified in the use-list parsed. If
2027 /// the custom use-list is not found, this means that the order needs to be
2028 /// consistent with the reverse pre-order walk of the IR. If multiple uses lie
2029 /// on the same operation, the order will follow the reverse operand number
2030 /// ordering.
2031 LogicalResult BytecodeReader::Impl::sortUseListOrder(Value value) {
2032  // Early return for trivial use-lists.
2033  if (value.use_empty() || value.hasOneUse())
2034  return success();
2035 
2036  bool hasIncomingOrder =
2037  valueToUseListMap.contains(value.getAsOpaquePointer());
2038 
2039  // Compute the current order of the use-list with respect to the global
2040  // ordering. Detect if the order is already sorted while doing so.
2041  bool alreadySorted = true;
2042  auto &firstUse = *value.use_begin();
2043  uint64_t prevID =
2044  bytecode::getUseID(firstUse, operationIDs.at(firstUse.getOwner()));
2045  llvm::SmallVector<std::pair<unsigned, uint64_t>> currentOrder = {{0, prevID}};
2046  for (auto item : llvm::drop_begin(llvm::enumerate(value.getUses()))) {
2047  uint64_t currentID = bytecode::getUseID(
2048  item.value(), operationIDs.at(item.value().getOwner()));
2049  alreadySorted &= prevID > currentID;
2050  currentOrder.push_back({item.index(), currentID});
2051  prevID = currentID;
2052  }
2053 
2054  // If the order is already sorted, and there wasn't a custom order to apply
2055  // from the bytecode file, we are done.
2056  if (alreadySorted && !hasIncomingOrder)
2057  return success();
2058 
2059  // If not already sorted, sort the indices of the current order by descending
2060  // useIDs.
2061  if (!alreadySorted)
2062  std::sort(
2063  currentOrder.begin(), currentOrder.end(),
2064  [](auto elem1, auto elem2) { return elem1.second > elem2.second; });
2065 
2066  if (!hasIncomingOrder) {
2067  // If the bytecode file did not contain any custom use-list order, it means
2068  // that the order was descending useID. Hence, shuffle by the first index
2069  // of the `currentOrder` pair.
2070  SmallVector<unsigned> shuffle(llvm::make_first_range(currentOrder));
2071  value.shuffleUseList(shuffle);
2072  return success();
2073  }
2074 
2075  // Pull the custom order info from the map.
2076  UseListOrderStorage customOrder =
2077  valueToUseListMap.at(value.getAsOpaquePointer());
2078  SmallVector<unsigned, 4> shuffle = std::move(customOrder.indices);
2079  uint64_t numUses = value.getNumUses();
2080 
2081  // If the encoding was a pair of indices `(src, dst)` for every permutation,
2082  // reconstruct the shuffle vector for every use. Initialize the shuffle vector
2083  // as identity, and then apply the mapping encoded in the indices.
2084  if (customOrder.isIndexPairEncoding) {
2085  // Return failure if the number of indices was not representing pairs.
2086  if (shuffle.size() & 1)
2087  return failure();
2088 
2089  SmallVector<unsigned, 4> newShuffle(numUses);
2090  size_t idx = 0;
2091  std::iota(newShuffle.begin(), newShuffle.end(), idx);
2092  for (idx = 0; idx < shuffle.size(); idx += 2)
2093  newShuffle[shuffle[idx]] = shuffle[idx + 1];
2094 
2095  shuffle = std::move(newShuffle);
2096  }
2097 
2098  // Make sure that the indices represent a valid mapping. That is, the sum of
2099  // all the values needs to be equal to (numUses - 1) * numUses / 2, and no
2100  // duplicates are allowed in the list.
2101  DenseSet<unsigned> set;
2102  uint64_t accumulator = 0;
2103  for (const auto &elem : shuffle) {
2104  if (!set.insert(elem).second)
2105  return failure();
2106  accumulator += elem;
2107  }
2108  if (numUses != shuffle.size() ||
2109  accumulator != (((numUses - 1) * numUses) >> 1))
2110  return failure();
2111 
2112  // Apply the current ordering map onto the shuffle vector to get the final
2113  // use-list sorting indices before shuffling.
2114  shuffle = SmallVector<unsigned, 4>(llvm::map_range(
2115  currentOrder, [&](auto item) { return shuffle[item.first]; }));
2116  value.shuffleUseList(shuffle);
2117  return success();
2118 }
2119 
2120 LogicalResult BytecodeReader::Impl::processUseLists(Operation *topLevelOp) {
2121  // Precompute operation IDs according to the pre-order walk of the IR. We
2122  // can't do this while parsing since parseRegions ordering is not strictly
2123  // equal to the pre-order walk.
2124  unsigned operationID = 0;
2125  topLevelOp->walk<mlir::WalkOrder::PreOrder>(
2126  [&](Operation *op) { operationIDs.try_emplace(op, operationID++); });
2127 
2128  auto blockWalk = topLevelOp->walk([this](Block *block) {
2129  for (auto arg : block->getArguments())
2130  if (failed(sortUseListOrder(arg)))
2131  return WalkResult::interrupt();
2132  return WalkResult::advance();
2133  });
2134 
2135  auto resultWalk = topLevelOp->walk([this](Operation *op) {
2136  for (auto result : op->getResults())
2137  if (failed(sortUseListOrder(result)))
2138  return WalkResult::interrupt();
2139  return WalkResult::advance();
2140  });
2141 
2142  return failure(blockWalk.wasInterrupted() || resultWalk.wasInterrupted());
2143 }
2144 
2145 //===----------------------------------------------------------------------===//
2146 // IR Section
2147 //===----------------------------------------------------------------------===//
2148 
2149 LogicalResult
2150 BytecodeReader::Impl::parseIRSection(ArrayRef<uint8_t> sectionData,
2151  Block *block) {
2152  EncodingReader reader(sectionData, fileLoc);
2153 
2154  // A stack of operation regions currently being read from the bytecode.
2155  std::vector<RegionReadState> regionStack;
2156 
2157  // Parse the top-level block using a temporary module operation.
2158  OwningOpRef<ModuleOp> moduleOp = ModuleOp::create(fileLoc);
2159  regionStack.emplace_back(*moduleOp, &reader, /*isIsolatedFromAbove=*/true);
2160  regionStack.back().curBlocks.push_back(moduleOp->getBody());
2161  regionStack.back().curBlock = regionStack.back().curRegion->begin();
2162  if (failed(parseBlockHeader(reader, regionStack.back())))
2163  return failure();
2164  valueScopes.emplace_back();
2165  valueScopes.back().push(regionStack.back());
2166 
2167  // Iteratively parse regions until everything has been resolved.
2168  while (!regionStack.empty())
2169  if (failed(parseRegions(regionStack, regionStack.back())))
2170  return failure();
2171  if (!forwardRefOps.empty()) {
2172  return reader.emitError(
2173  "not all forward unresolved forward operand references");
2174  }
2175 
2176  // Sort use-lists according to what specified in bytecode.
2177  if (failed(processUseLists(*moduleOp)))
2178  return reader.emitError(
2179  "parsed use-list orders were invalid and could not be applied");
2180 
2181  // Resolve dialect version.
2182  for (const std::unique_ptr<BytecodeDialect> &byteCodeDialect : dialects) {
2183  // Parsing is complete, give an opportunity to each dialect to visit the
2184  // IR and perform upgrades.
2185  if (!byteCodeDialect->loadedVersion)
2186  continue;
2187  if (byteCodeDialect->interface &&
2188  failed(byteCodeDialect->interface->upgradeFromVersion(
2189  *moduleOp, *byteCodeDialect->loadedVersion)))
2190  return failure();
2191  }
2192 
2193  // Verify that the parsed operations are valid.
2194  if (config.shouldVerifyAfterParse() && failed(verify(*moduleOp)))
2195  return failure();
2196 
2197  // Splice the parsed operations over to the provided top-level block.
2198  auto &parsedOps = moduleOp->getBody()->getOperations();
2199  auto &destOps = block->getOperations();
2200  destOps.splice(destOps.end(), parsedOps, parsedOps.begin(), parsedOps.end());
2201  return success();
2202 }
2203 
2204 LogicalResult
2205 BytecodeReader::Impl::parseRegions(std::vector<RegionReadState> &regionStack,
2206  RegionReadState &readState) {
2207  const auto checkSectionAlignment = [&](unsigned alignment) {
2208  return this->checkSectionAlignment(
2209  alignment, [&](const auto &msg) { return emitError(fileLoc, msg); });
2210  };
2211 
2212  // Process regions, blocks, and operations until the end or if a nested
2213  // region is encountered. In this case we push a new state in regionStack and
2214  // return, the processing of the current region will resume afterward.
2215  for (; readState.curRegion != readState.endRegion; ++readState.curRegion) {
2216  // If the current block hasn't been setup yet, parse the header for this
2217  // region. The current block is already setup when this function was
2218  // interrupted to recurse down in a nested region and we resume the current
2219  // block after processing the nested region.
2220  if (readState.curBlock == Region::iterator()) {
2221  if (failed(parseRegion(readState)))
2222  return failure();
2223 
2224  // If the region is empty, there is nothing to more to do.
2225  if (readState.curRegion->empty())
2226  continue;
2227  }
2228 
2229  // Parse the blocks within the region.
2230  EncodingReader &reader = *readState.reader;
2231  do {
2232  while (readState.numOpsRemaining--) {
2233  // Read in the next operation. We don't read its regions directly, we
2234  // handle those afterwards as necessary.
2235  bool isIsolatedFromAbove = false;
2236  FailureOr<Operation *> op =
2237  parseOpWithoutRegions(reader, readState, isIsolatedFromAbove);
2238  if (failed(op))
2239  return failure();
2240 
2241  // If the op has regions, add it to the stack for processing and return:
2242  // we stop the processing of the current region and resume it after the
2243  // inner one is completed. Unless LazyLoading is activated in which case
2244  // nested region parsing is delayed.
2245  if ((*op)->getNumRegions()) {
2246  RegionReadState childState(*op, &reader, isIsolatedFromAbove);
2247 
2248  // Isolated regions are encoded as a section in version 2 and above.
2249  if (version >= bytecode::kLazyLoading && isIsolatedFromAbove) {
2250  bytecode::Section::ID sectionID;
2251  ArrayRef<uint8_t> sectionData;
2252  if (failed(reader.parseSection(sectionID, checkSectionAlignment,
2253  sectionData)))
2254  return failure();
2255  if (sectionID != bytecode::Section::kIR)
2256  return emitError(fileLoc, "expected IR section for region");
2257  childState.owningReader =
2258  std::make_unique<EncodingReader>(sectionData, fileLoc);
2259  childState.reader = childState.owningReader.get();
2260 
2261  // If the user has a callback set, they have the opportunity to
2262  // control lazyloading as we go.
2263  if (lazyLoading && (!lazyOpsCallback || !lazyOpsCallback(*op))) {
2264  lazyLoadableOps.emplace_back(*op, std::move(childState));
2265  lazyLoadableOpsMap.try_emplace(*op,
2266  std::prev(lazyLoadableOps.end()));
2267  continue;
2268  }
2269  }
2270  regionStack.push_back(std::move(childState));
2271 
2272  // If the op is isolated from above, push a new value scope.
2273  if (isIsolatedFromAbove)
2274  valueScopes.emplace_back();
2275  return success();
2276  }
2277  }
2278 
2279  // Move to the next block of the region.
2280  if (++readState.curBlock == readState.curRegion->end())
2281  break;
2282  if (failed(parseBlockHeader(reader, readState)))
2283  return failure();
2284  } while (true);
2285 
2286  // Reset the current block and any values reserved for this region.
2287  readState.curBlock = {};
2288  valueScopes.back().pop(readState);
2289  }
2290 
2291  // When the regions have been fully parsed, pop them off of the read stack. If
2292  // the regions were isolated from above, we also pop the last value scope.
2293  if (readState.isIsolatedFromAbove) {
2294  assert(!valueScopes.empty() && "Expect a valueScope after reading region");
2295  valueScopes.pop_back();
2296  }
2297  assert(!regionStack.empty() && "Expect a regionStack after reading region");
2298  regionStack.pop_back();
2299  return success();
2300 }
2301 
2302 FailureOr<Operation *>
2303 BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader,
2304  RegionReadState &readState,
2305  bool &isIsolatedFromAbove) {
2306  // Parse the name of the operation.
2307  std::optional<bool> wasRegistered;
2308  FailureOr<OperationName> opName = parseOpName(reader, wasRegistered);
2309  if (failed(opName))
2310  return failure();
2311 
2312  // Parse the operation mask, which indicates which components of the operation
2313  // are present.
2314  uint8_t opMask;
2315  if (failed(reader.parseByte(opMask)))
2316  return failure();
2317 
2318  /// Parse the location.
2319  LocationAttr opLoc;
2320  if (failed(parseAttribute(reader, opLoc)))
2321  return failure();
2322 
2323  // With the location and name resolved, we can start building the operation
2324  // state.
2325  OperationState opState(opLoc, *opName);
2326 
2327  // Parse the attributes of the operation.
2328  if (opMask & bytecode::OpEncodingMask::kHasAttrs) {
2329  DictionaryAttr dictAttr;
2330  if (failed(parseAttribute(reader, dictAttr)))
2331  return failure();
2332  opState.attributes = dictAttr;
2333  }
2334 
2336  // kHasProperties wasn't emitted in older bytecode, we should never get
2337  // there without also having the `wasRegistered` flag available.
2338  if (!wasRegistered)
2339  return emitError(fileLoc,
2340  "Unexpected missing `wasRegistered` opname flag at "
2341  "bytecode version ")
2342  << version << " with properties.";
2343  // When an operation is emitted without being registered, the properties are
2344  // stored as an attribute. Otherwise the op must implement the bytecode
2345  // interface and control the serialization.
2346  if (wasRegistered) {
2347  DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
2348  dialectsMap, reader, version);
2349  if (failed(
2350  propertiesReader.read(fileLoc, dialectReader, &*opName, opState)))
2351  return failure();
2352  } else {
2353  // If the operation wasn't registered when it was emitted, the properties
2354  // was serialized as an attribute.
2355  if (failed(parseAttribute(reader, opState.propertiesAttr)))
2356  return failure();
2357  }
2358  }
2359 
2360  /// Parse the results of the operation.
2362  uint64_t numResults;
2363  if (failed(reader.parseVarInt(numResults)))
2364  return failure();
2365  opState.types.resize(numResults);
2366  for (int i = 0, e = numResults; i < e; ++i)
2367  if (failed(parseType(reader, opState.types[i])))
2368  return failure();
2369  }
2370 
2371  /// Parse the operands of the operation.
2373  uint64_t numOperands;
2374  if (failed(reader.parseVarInt(numOperands)))
2375  return failure();
2376  opState.operands.resize(numOperands);
2377  for (int i = 0, e = numOperands; i < e; ++i)
2378  if (!(opState.operands[i] = parseOperand(reader)))
2379  return failure();
2380  }
2381 
2382  /// Parse the successors of the operation.
2384  uint64_t numSuccs;
2385  if (failed(reader.parseVarInt(numSuccs)))
2386  return failure();
2387  opState.successors.resize(numSuccs);
2388  for (int i = 0, e = numSuccs; i < e; ++i) {
2389  if (failed(parseEntry(reader, readState.curBlocks, opState.successors[i],
2390  "successor")))
2391  return failure();
2392  }
2393  }
2394 
2395  /// Parse the use-list orders for the results of the operation. Use-list
2396  /// orders are available since version 3 of the bytecode.
2397  std::optional<UseListMapT> resultIdxToUseListMap = std::nullopt;
2398  if (version >= bytecode::kUseListOrdering &&
2400  size_t numResults = opState.types.size();
2401  auto parseResult = parseUseListOrderForRange(reader, numResults);
2402  if (failed(parseResult))
2403  return failure();
2404  resultIdxToUseListMap = std::move(*parseResult);
2405  }
2406 
2407  /// Parse the regions of the operation.
2409  uint64_t numRegions;
2410  if (failed(reader.parseVarIntWithFlag(numRegions, isIsolatedFromAbove)))
2411  return failure();
2412 
2413  opState.regions.reserve(numRegions);
2414  for (int i = 0, e = numRegions; i < e; ++i)
2415  opState.regions.push_back(std::make_unique<Region>());
2416  }
2417 
2418  // Create the operation at the back of the current block.
2419  Operation *op = Operation::create(opState);
2420  readState.curBlock->push_back(op);
2421 
2422  // If the operation had results, update the value references. We don't need to
2423  // do this if the current value scope is empty. That is, the op was not
2424  // encoded within a parent region.
2425  if (readState.numValues && op->getNumResults() &&
2426  failed(defineValues(reader, op->getResults())))
2427  return failure();
2428 
2429  /// Store a map for every value that received a custom use-list order from the
2430  /// bytecode file.
2431  if (resultIdxToUseListMap.has_value()) {
2432  for (size_t idx = 0; idx < op->getNumResults(); idx++) {
2433  if (resultIdxToUseListMap->contains(idx)) {
2434  valueToUseListMap.try_emplace(op->getResult(idx).getAsOpaquePointer(),
2435  resultIdxToUseListMap->at(idx));
2436  }
2437  }
2438  }
2439  return op;
2440 }
2441 
2442 LogicalResult BytecodeReader::Impl::parseRegion(RegionReadState &readState) {
2443  EncodingReader &reader = *readState.reader;
2444 
2445  // Parse the number of blocks in the region.
2446  uint64_t numBlocks;
2447  if (failed(reader.parseVarInt(numBlocks)))
2448  return failure();
2449 
2450  // If the region is empty, there is nothing else to do.
2451  if (numBlocks == 0)
2452  return success();
2453 
2454  // Parse the number of values defined in this region.
2455  uint64_t numValues;
2456  if (failed(reader.parseVarInt(numValues)))
2457  return failure();
2458  readState.numValues = numValues;
2459 
2460  // Create the blocks within this region. We do this before processing so that
2461  // we can rely on the blocks existing when creating operations.
2462  readState.curBlocks.clear();
2463  readState.curBlocks.reserve(numBlocks);
2464  for (uint64_t i = 0; i < numBlocks; ++i) {
2465  readState.curBlocks.push_back(new Block());
2466  readState.curRegion->push_back(readState.curBlocks.back());
2467  }
2468 
2469  // Prepare the current value scope for this region.
2470  valueScopes.back().push(readState);
2471 
2472  // Parse the entry block of the region.
2473  readState.curBlock = readState.curRegion->begin();
2474  return parseBlockHeader(reader, readState);
2475 }
2476 
2477 LogicalResult
2478 BytecodeReader::Impl::parseBlockHeader(EncodingReader &reader,
2479  RegionReadState &readState) {
2480  bool hasArgs;
2481  if (failed(reader.parseVarIntWithFlag(readState.numOpsRemaining, hasArgs)))
2482  return failure();
2483 
2484  // Parse the arguments of the block.
2485  if (hasArgs && failed(parseBlockArguments(reader, &*readState.curBlock)))
2486  return failure();
2487 
2488  // Uselist orders are available since version 3 of the bytecode.
2489  if (version < bytecode::kUseListOrdering)
2490  return success();
2491 
2492  uint8_t hasUseListOrders = 0;
2493  if (hasArgs && failed(reader.parseByte(hasUseListOrders)))
2494  return failure();
2495 
2496  if (!hasUseListOrders)
2497  return success();
2498 
2499  Block &blk = *readState.curBlock;
2500  auto argIdxToUseListMap =
2501  parseUseListOrderForRange(reader, blk.getNumArguments());
2502  if (failed(argIdxToUseListMap) || argIdxToUseListMap->empty())
2503  return failure();
2504 
2505  for (size_t idx = 0; idx < blk.getNumArguments(); idx++)
2506  if (argIdxToUseListMap->contains(idx))
2507  valueToUseListMap.try_emplace(blk.getArgument(idx).getAsOpaquePointer(),
2508  argIdxToUseListMap->at(idx));
2509 
2510  // We don't parse the operations of the block here, that's done elsewhere.
2511  return success();
2512 }
2513 
2514 LogicalResult BytecodeReader::Impl::parseBlockArguments(EncodingReader &reader,
2515  Block *block) {
2516  // Parse the value ID for the first argument, and the number of arguments.
2517  uint64_t numArgs;
2518  if (failed(reader.parseVarInt(numArgs)))
2519  return failure();
2520 
2521  SmallVector<Type> argTypes;
2522  SmallVector<Location> argLocs;
2523  argTypes.reserve(numArgs);
2524  argLocs.reserve(numArgs);
2525 
2526  Location unknownLoc = UnknownLoc::get(config.getContext());
2527  while (numArgs--) {
2528  Type argType;
2529  LocationAttr argLoc = unknownLoc;
2530  if (version >= bytecode::kElideUnknownBlockArgLocation) {
2531  // Parse the type with hasLoc flag to determine if it has type.
2532  uint64_t typeIdx;
2533  bool hasLoc;
2534  if (failed(reader.parseVarIntWithFlag(typeIdx, hasLoc)) ||
2535  !(argType = attrTypeReader.resolveType(typeIdx)))
2536  return failure();
2537  if (hasLoc && failed(parseAttribute(reader, argLoc)))
2538  return failure();
2539  } else {
2540  // All args has type and location.
2541  if (failed(parseType(reader, argType)) ||
2542  failed(parseAttribute(reader, argLoc)))
2543  return failure();
2544  }
2545  argTypes.push_back(argType);
2546  argLocs.push_back(argLoc);
2547  }
2548  block->addArguments(argTypes, argLocs);
2549  return defineValues(reader, block->getArguments());
2550 }
2551 
2552 //===----------------------------------------------------------------------===//
2553 // Value Processing
2554 //===----------------------------------------------------------------------===//
2555 
2556 Value BytecodeReader::Impl::parseOperand(EncodingReader &reader) {
2557  std::vector<Value> &values = valueScopes.back().values;
2558  Value *value = nullptr;
2559  if (failed(parseEntry(reader, values, value, "value")))
2560  return Value();
2561 
2562  // Create a new forward reference if necessary.
2563  if (!*value)
2564  *value = createForwardRef();
2565  return *value;
2566 }
2567 
2568 LogicalResult BytecodeReader::Impl::defineValues(EncodingReader &reader,
2569  ValueRange newValues) {
2570  ValueScope &valueScope = valueScopes.back();
2571  std::vector<Value> &values = valueScope.values;
2572 
2573  unsigned &valueID = valueScope.nextValueIDs.back();
2574  unsigned valueIDEnd = valueID + newValues.size();
2575  if (valueIDEnd > values.size()) {
2576  return reader.emitError(
2577  "value index range was outside of the expected range for "
2578  "the parent region, got [",
2579  valueID, ", ", valueIDEnd, "), but the maximum index was ",
2580  values.size() - 1);
2581  }
2582 
2583  // Assign the values and update any forward references.
2584  for (unsigned i = 0, e = newValues.size(); i != e; ++i, ++valueID) {
2585  Value newValue = newValues[i];
2586 
2587  // Check to see if a definition for this value already exists.
2588  if (Value oldValue = std::exchange(values[valueID], newValue)) {
2589  Operation *forwardRefOp = oldValue.getDefiningOp();
2590 
2591  // Assert that this is a forward reference operation. Given how we compute
2592  // definition ids (incrementally as we parse), it shouldn't be possible
2593  // for the value to be defined any other way.
2594  assert(forwardRefOp && forwardRefOp->getBlock() == &forwardRefOps &&
2595  "value index was already defined?");
2596 
2597  oldValue.replaceAllUsesWith(newValue);
2598  forwardRefOp->moveBefore(&openForwardRefOps, openForwardRefOps.end());
2599  }
2600  }
2601  return success();
2602 }
2603 
2604 Value BytecodeReader::Impl::createForwardRef() {
2605  // Check for an available existing operation to use. Otherwise, create a new
2606  // fake operation to use for the reference.
2607  if (!openForwardRefOps.empty()) {
2608  Operation *op = &openForwardRefOps.back();
2609  op->moveBefore(&forwardRefOps, forwardRefOps.end());
2610  } else {
2611  forwardRefOps.push_back(Operation::create(forwardRefOpState));
2612  }
2613  return forwardRefOps.back().getResult(0);
2614 }
2615 
2616 //===----------------------------------------------------------------------===//
2617 // Entry Points
2618 //===----------------------------------------------------------------------===//
2619 
2621 
2623  llvm::MemoryBufferRef buffer, const ParserConfig &config, bool lazyLoading,
2624  const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) {
2625  Location sourceFileLoc =
2626  FileLineColLoc::get(config.getContext(), buffer.getBufferIdentifier(),
2627  /*line=*/0, /*column=*/0);
2628  impl = std::make_unique<Impl>(sourceFileLoc, config, lazyLoading, buffer,
2629  bufferOwnerRef);
2630 }
2631 
2633  Block *block, llvm::function_ref<bool(Operation *)> lazyOpsCallback) {
2634  return impl->read(block, lazyOpsCallback);
2635 }
2636 
2638  return impl->getNumOpsToMaterialize();
2639 }
2640 
2642  return impl->isMaterializable(op);
2643 }
2644 
2646  Operation *op, llvm::function_ref<bool(Operation *)> lazyOpsCallback) {
2647  return impl->materialize(op, lazyOpsCallback);
2648 }
2649 
2650 LogicalResult
2651 BytecodeReader::finalize(function_ref<bool(Operation *)> shouldMaterialize) {
2652  return impl->finalize(shouldMaterialize);
2653 }
2654 
2655 bool mlir::isBytecode(llvm::MemoryBufferRef buffer) {
2656  return buffer.getBuffer().starts_with("ML\xefR");
2657 }
2658 
2659 /// Read the bytecode from the provided memory buffer reference.
2660 /// `bufferOwnerRef` if provided is the owning source manager for the buffer,
2661 /// and may be used to extend the lifetime of the buffer.
2662 static LogicalResult
2663 readBytecodeFileImpl(llvm::MemoryBufferRef buffer, Block *block,
2664  const ParserConfig &config,
2665  const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) {
2666  Location sourceFileLoc =
2667  FileLineColLoc::get(config.getContext(), buffer.getBufferIdentifier(),
2668  /*line=*/0, /*column=*/0);
2669  if (!isBytecode(buffer)) {
2670  return emitError(sourceFileLoc,
2671  "input buffer is not an MLIR bytecode file");
2672  }
2673 
2674  BytecodeReader::Impl reader(sourceFileLoc, config, /*lazyLoading=*/false,
2675  buffer, bufferOwnerRef);
2676  return reader.read(block, /*lazyOpsCallback=*/nullptr);
2677 }
2678 
2679 LogicalResult mlir::readBytecodeFile(llvm::MemoryBufferRef buffer, Block *block,
2680  const ParserConfig &config) {
2681  return readBytecodeFileImpl(buffer, block, config, /*bufferOwnerRef=*/{});
2682 }
2683 LogicalResult
2684 mlir::readBytecodeFile(const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
2685  Block *block, const ParserConfig &config) {
2686  return readBytecodeFileImpl(
2687  *sourceMgr->getMemoryBuffer(sourceMgr->getMainFileID()), block, config,
2688  sourceMgr);
2689 }
static LogicalResult readBytecodeFileImpl(llvm::MemoryBufferRef buffer, Block *block, const ParserConfig &config, const std::shared_ptr< llvm::SourceMgr > &bufferOwnerRef)
Read the bytecode from the provided memory buffer reference.
static bool isSectionOptional(bytecode::Section::ID sectionID, int version)
Returns true if the given top-level section ID is optional.
static LogicalResult parseResourceGroup(Location fileLoc, bool allowEmpty, EncodingReader &offsetReader, EncodingReader &resourceReader, StringSectionReader &stringReader, T *handler, const std::shared_ptr< llvm::SourceMgr > &bufferOwnerRef, function_ref< StringRef(StringRef)> remapKey={}, function_ref< LogicalResult(StringRef)> processKeyFn={})
static LogicalResult parseDialectGrouping(EncodingReader &reader, MutableArrayRef< std::unique_ptr< BytecodeDialect >> dialects, function_ref< LogicalResult(BytecodeDialect *)> entryCallback)
Parse a single dialect group encoded in the byte stream.
static LogicalResult resolveEntry(EncodingReader &reader, RangeT &entries, uint64_t index, T &entry, StringRef entryStr)
Resolve an index into the given entry list.
static LogicalResult parseEntry(EncodingReader &reader, RangeT &entries, T &entry, StringRef entryStr)
Parse and resolve an index into the given entry list.
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1243::ArityGroupAndKind::Kind kind
static std::string diag(const llvm::Value &value)
static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, unsigned nRegions=1)
Definition: OpenACC.cpp:811
This class represents an opaque handle to a dialect resource entry.
This class represents a single parsed resource entry.
Definition: AsmState.h:291
This class represents a processed binary blob of data.
Definition: AsmState.h:91
MutableArrayRef< char > getMutableData()
Return a mutable reference to the raw underlying data of this blob.
Definition: AsmState.h:157
ArrayRef< char > getData() const
Return the raw underlying data of this blob.
Definition: AsmState.h:145
bool isMutable() const
Return if the data of this blob is mutable.
Definition: AsmState.h:164
This class represents an instance of a resource parser.
Definition: AsmState.h:339
Attributes are known-constant values of operations.
Definition: Attributes.h:25
MLIRContext * getContext() const
Return the context this attribute belongs to.
Definition: Attributes.cpp:37
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition: Block.cpp:160
OpListType & getOperations()
Definition: Block.h:137
BlockArgListType getArguments()
Definition: Block.h:87
This class is used to read a bytecode buffer and translate it into MLIR.
LogicalResult materializeAll()
Materialize all operations.
LogicalResult read(Block *block, llvm::function_ref< bool(Operation *)> lazyOps)
Read the bytecode defined within buffer into the given block.
bool isMaterializable(Operation *op)
Impl(Location fileLoc, const ParserConfig &config, bool lazyLoading, llvm::MemoryBufferRef buffer, const std::shared_ptr< llvm::SourceMgr > &bufferOwnerRef)
LogicalResult finalize(function_ref< bool(Operation *)> shouldMaterialize)
Finalize the lazy-loading by calling back with every op that hasn't been materialized to let the clie...
LogicalResult materialize(Operation *op, llvm::function_ref< bool(Operation *)> lazyOpsCallback)
Materialize the provided operation, invoke the lazyOpsCallback on every newly found lazy operation.
int64_t getNumOpsToMaterialize() const
Return the number of ops that haven't been materialized yet.
LogicalResult materialize(Operation *op, llvm::function_ref< bool(Operation *)> lazyOpsCallback=[](Operation *) { return false;})
Materialize the provide operation.
LogicalResult finalize(function_ref< bool(Operation *)> shouldMaterialize=[](Operation *) { return true;})
Finalize the lazy-loading by calling back with every op that hasn't been materialized to let the clie...
BytecodeReader(llvm::MemoryBufferRef buffer, const ParserConfig &config, bool lazyLoad, const std::shared_ptr< llvm::SourceMgr > &bufferOwnerRef={})
Create a bytecode reader for the given buffer.
int64_t getNumOpsToMaterialize() const
Return the number of ops that haven't been materialized yet.
bool isMaterializable(Operation *op)
Return true if the provided op is materializable.
LogicalResult readTopLevel(Block *block, llvm::function_ref< bool(Operation *)> lazyOps=[](Operation *) { return false;})
Read the operations defined within the given memory buffer, containing MLIR bytecode,...
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
This class defines a virtual interface for reading a bytecode stream, providing hooks into the byteco...
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
static FileLineColLoc get(StringAttr filename, unsigned line, unsigned column)
Definition: Location.cpp:157
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
InFlightDiagnostic & append(Args &&...args) &
Append arguments to the diagnostic.
Definition: Diagnostics.h:340
Location objects represent source locations information in MLIR.
Definition: Location.h:32
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext * getContext() const
Return the context this location is uniqued in.
Definition: Location.h:86
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
bool allowsUnregisteredDialects()
Return true if we allow to create operation for unregistered dialects.
T * getOrLoadDialect()
Get (or create) a dialect for the given derived dialect type.
Definition: MLIRContext.h:100
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
T::Concept * getInterface() const
Returns an instance of the concept object for the given interface if it was registered to this operat...
bool isRegistered() const
Return if this operation is registered.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void dropAllReferences()
This drops all operand uses from this operation, which is an essential step in breaking cyclic depend...
Definition: Operation.cpp:585
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition: Operation.cpp:66
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
Definition: Operation.cpp:554
result_range getResults()
Definition: Operation.h:415
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:538
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class represents a configuration for the MLIR assembly parser.
Definition: AsmState.h:469
BlockListType::iterator iterator
Definition: Region.h:52
This diagnostic handler is a simple RAII class that registers and erases a diagnostic handler on a gi...
Definition: Diagnostics.h:522
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
static AsmResourceBlob allocateWithAlign(ArrayRef< char > data, size_t align, AsmResourceBlob::DeleterFn deleter={}, bool dataIsMutable=false)
Create a new unmanaged resource directly referencing the provided data.
Definition: AsmState.h:228
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:208
void shuffleUseList(ArrayRef< unsigned > indices)
Shuffle the use list order according to the provided indices.
Definition: Value.cpp:106
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:188
void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
Definition: Value.h:233
unsigned getNumUses() const
This method computes the number of uses of this Value.
Definition: Value.cpp:52
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:197
use_iterator use_begin() const
Definition: Value.h:184
static WalkResult advance()
Definition: WalkResult.h:47
static WalkResult interrupt()
Definition: WalkResult.h:46
@ kAttrType
This section contains the attributes and types referenced within an IR module.
Definition: Encoding.h:73
@ kAttrTypeOffset
This section contains the offsets for the attribute and types within the AttrType section.
Definition: Encoding.h:77
@ kIR
This section contains the list of operations serialized into the bytecode, and their nested regions/o...
Definition: Encoding.h:81
@ kResource
This section contains the resources of the bytecode.
Definition: Encoding.h:84
@ kResourceOffset
This section contains the offsets of resources within the Resource section.
Definition: Encoding.h:88
@ kDialect
This section contains the dialects referenced within an IR module.
Definition: Encoding.h:69
@ kString
This section contains strings referenced within the bytecode.
Definition: Encoding.h:66
@ kDialectVersions
This section contains the versions of each dialect.
Definition: Encoding.h:91
@ kProperties
This section contains the properties for the operations.
Definition: Encoding.h:94
@ kNumSections
The total number of section types.
Definition: Encoding.h:97
static uint64_t getUseID(OperandT &val, unsigned ownerID)
Get the unique ID of a value use.
Definition: Encoding.h:127
@ kUseListOrdering
Use-list ordering started to be encoded in version 3.
Definition: Encoding.h:38
@ kAlignmentByte
An arbitrary value used to fill alignment padding.
Definition: Encoding.h:56
@ kVersion
The current bytecode version.
Definition: Encoding.h:53
@ kLazyLoading
Support for lazy-loading of isolated region was added in version 2.
Definition: Encoding.h:35
@ kDialectVersioning
Dialects versioning was added in version 1.
Definition: Encoding.h:32
@ kElideUnknownBlockArgLocation
Avoid recording unknown locations on block arguments (compression) started in version 4.
Definition: Encoding.h:42
@ kNativePropertiesEncoding
Support for encoding properties natively in bytecode instead of merged with the discardable attribute...
Definition: Encoding.h:46
@ kMinSupportedVersion
The minimum supported version of the bytecode.
Definition: Encoding.h:29
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
StringRef toString(AsmResourceEntryKind kind)
static LogicalResult readResourceHandle(DialectBytecodeReader &reader, FailureOr< T > &value, Ts &&...params)
Helper for resource handle reading that returns LogicalResult.
bool isBytecode(llvm::MemoryBufferRef buffer)
Returns true if the given buffer starts with the magic bytes that signal MLIR bytecode.
const FrozenRewritePatternSet GreedyRewriteConfig config
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.
AsmResourceEntryKind
This enum represents the different kinds of resource values.
Definition: AsmState.h:280
@ String
A string value.
@ Bool
A boolean value.
@ Blob
A blob of data with an accompanying alignment.
LogicalResult readBytecodeFile(llvm::MemoryBufferRef buffer, Block *block, const ParserConfig &config)
Read the operations defined within the given memory buffer, containing MLIR bytecode,...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Block *, 1 > successors
Successors of this operation and their respective operands.
SmallVector< Value, 4 > operands
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
Attribute propertiesAttr
This Attribute is used to opaquely construct the properties of the operation.
SmallVector< Type, 4 > types
Types of the results of this operation.