MLIR 22.0.0git
BytecodeReader.cpp
Go to the documentation of this file.
1//===- BytecodeReader.cpp - MLIR Bytecode Reader --------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
14#include "mlir/IR/BuiltinOps.h"
15#include "mlir/IR/Diagnostics.h"
17#include "mlir/IR/Verifier.h"
18#include "mlir/IR/Visitors.h"
19#include "mlir/Support/LLVM.h"
20#include "llvm/ADT/ArrayRef.h"
21#include "llvm/ADT/ScopeExit.h"
22#include "llvm/ADT/StringExtras.h"
23#include "llvm/ADT/StringRef.h"
24#include "llvm/Support/Endian.h"
25#include "llvm/Support/MemoryBufferRef.h"
26#include "llvm/Support/SourceMgr.h"
27
28#include <cstddef>
29#include <cstdint>
30#include <deque>
31#include <list>
32#include <memory>
33#include <numeric>
34#include <optional>
35
36#define DEBUG_TYPE "mlir-bytecode-reader"
37
38using namespace mlir;
39
40/// Stringify the given section ID.
41static std::string toString(bytecode::Section::ID sectionID) {
42 switch (sectionID) {
44 return "String (0)";
46 return "Dialect (1)";
48 return "AttrType (2)";
50 return "AttrTypeOffset (3)";
52 return "IR (4)";
54 return "Resource (5)";
56 return "ResourceOffset (6)";
58 return "DialectVersions (7)";
60 return "Properties (8)";
61 default:
62 return ("Unknown (" + Twine(static_cast<unsigned>(sectionID)) + ")").str();
63 }
64}
65
66/// Returns true if the given top-level section ID is optional.
67static bool isSectionOptional(bytecode::Section::ID sectionID, int version) {
68 switch (sectionID) {
74 return false;
78 return true;
81 default:
82 llvm_unreachable("unknown section ID");
83 }
84}
85
86//===----------------------------------------------------------------------===//
87// EncodingReader
88//===----------------------------------------------------------------------===//
89
90namespace {
91class EncodingReader {
92public:
93 explicit EncodingReader(ArrayRef<uint8_t> contents, Location fileLoc)
94 : buffer(contents), dataIt(buffer.begin()), fileLoc(fileLoc) {}
95 explicit EncodingReader(StringRef contents, Location fileLoc)
96 : EncodingReader({reinterpret_cast<const uint8_t *>(contents.data()),
97 contents.size()},
98 fileLoc) {}
99
100 /// Returns true if the entire section has been read.
101 bool empty() const { return dataIt == buffer.end(); }
102
103 /// Returns the remaining size of the bytecode.
104 size_t size() const { return buffer.end() - dataIt; }
105
106 /// Align the current reader position to the specified alignment.
107 LogicalResult alignTo(unsigned alignment) {
108 if (!llvm::isPowerOf2_32(alignment))
109 return emitError("expected alignment to be a power-of-two");
110
111 auto isUnaligned = [&](const uint8_t *ptr) {
112 return ((uintptr_t)ptr & (alignment - 1)) != 0;
113 };
114
115 // Shift the reader position to the next alignment boundary.
116 // Note: this assumes the pointer alignment matches the alignment of the
117 // data from the start of the buffer. In other words, this code is only
118 // valid if `dataIt` is offsetting into an already aligned buffer.
119 while (isUnaligned(dataIt)) {
120 uint8_t padding;
121 if (failed(parseByte(padding)))
122 return failure();
123 if (padding != bytecode::kAlignmentByte) {
124 return emitError("expected alignment byte (0xCB), but got: '0x" +
125 llvm::utohexstr(padding) + "'");
126 }
127 }
128
129 // Ensure the data iterator is now aligned. This case is unlikely because we
130 // *just* went through the effort to align the data iterator.
131 if (LLVM_UNLIKELY(isUnaligned(dataIt))) {
132 return emitError("expected data iterator aligned to ", alignment,
133 ", but got pointer: '0x" +
134 llvm::utohexstr((uintptr_t)dataIt) + "'");
135 }
136
137 return success();
138 }
139
140 /// Emit an error using the given arguments.
141 template <typename... Args>
142 InFlightDiagnostic emitError(Args &&...args) const {
143 return ::emitError(fileLoc).append(std::forward<Args>(args)...);
144 }
145 InFlightDiagnostic emitError() const { return ::emitError(fileLoc); }
146
147 /// Parse a single byte from the stream.
148 template <typename T>
149 LogicalResult parseByte(T &value) {
150 if (empty())
151 return emitError("attempting to parse a byte at the end of the bytecode");
152 value = static_cast<T>(*dataIt++);
153 return success();
154 }
155 /// Parse a range of bytes of 'length' into the given result.
156 LogicalResult parseBytes(size_t length, ArrayRef<uint8_t> &result) {
157 if (length > size()) {
158 return emitError("attempting to parse ", length, " bytes when only ",
159 size(), " remain");
160 }
161 result = {dataIt, length};
162 dataIt += length;
163 return success();
164 }
165 /// Parse a range of bytes of 'length' into the given result, which can be
166 /// assumed to be large enough to hold `length`.
167 LogicalResult parseBytes(size_t length, uint8_t *result) {
168 if (length > size()) {
169 return emitError("attempting to parse ", length, " bytes when only ",
170 size(), " remain");
171 }
172 memcpy(result, dataIt, length);
173 dataIt += length;
174 return success();
175 }
176
177 /// Parse an aligned blob of data, where the alignment was encoded alongside
178 /// the data.
179 LogicalResult parseBlobAndAlignment(ArrayRef<uint8_t> &data,
180 uint64_t &alignment) {
181 uint64_t dataSize;
182 if (failed(parseVarInt(alignment)) || failed(parseVarInt(dataSize)) ||
183 failed(alignTo(alignment)))
184 return failure();
185 return parseBytes(dataSize, data);
186 }
187
188 /// Parse a variable length encoded integer from the byte stream. The first
189 /// encoded byte contains a prefix in the low bits indicating the encoded
190 /// length of the value. This length prefix is a bit sequence of '0's followed
191 /// by a '1'. The number of '0' bits indicate the number of _additional_ bytes
192 /// (not including the prefix byte). All remaining bits in the first byte,
193 /// along with all of the bits in additional bytes, provide the value of the
194 /// integer encoded in little-endian order.
195 LogicalResult parseVarInt(uint64_t &result) {
196 // Parse the first byte of the encoding, which contains the length prefix.
197 if (failed(parseByte(result)))
198 return failure();
199
200 // Handle the overwhelmingly common case where the value is stored in a
201 // single byte. In this case, the first bit is the `1` marker bit.
202 if (LLVM_LIKELY(result & 1)) {
203 result >>= 1;
204 return success();
205 }
206
207 // Handle the overwhelming uncommon case where the value required all 8
208 // bytes (i.e. a really really big number). In this case, the marker byte is
209 // all zeros: `00000000`.
210 if (LLVM_UNLIKELY(result == 0)) {
211 llvm::support::ulittle64_t resultLE;
212 if (failed(parseBytes(sizeof(resultLE),
213 reinterpret_cast<uint8_t *>(&resultLE))))
214 return failure();
215 result = resultLE;
216 return success();
217 }
218 return parseMultiByteVarInt(result);
219 }
220
221 /// Parse a signed variable length encoded integer from the byte stream. A
222 /// signed varint is encoded as a normal varint with zigzag encoding applied,
223 /// i.e. the low bit of the value is used to indicate the sign.
224 LogicalResult parseSignedVarInt(uint64_t &result) {
225 if (failed(parseVarInt(result)))
226 return failure();
227 // Essentially (but using unsigned): (x >> 1) ^ -(x & 1)
228 result = (result >> 1) ^ (~(result & 1) + 1);
229 return success();
230 }
231
232 /// Parse a variable length encoded integer whose low bit is used to encode an
233 /// unrelated flag, i.e: `(integerValue << 1) | (flag ? 1 : 0)`.
234 LogicalResult parseVarIntWithFlag(uint64_t &result, bool &flag) {
235 if (failed(parseVarInt(result)))
236 return failure();
237 flag = result & 1;
238 result >>= 1;
239 return success();
240 }
241
242 /// Skip the first `length` bytes within the reader.
243 LogicalResult skipBytes(size_t length) {
244 if (length > size()) {
245 return emitError("attempting to skip ", length, " bytes when only ",
246 size(), " remain");
247 }
248 dataIt += length;
249 return success();
250 }
251
252 /// Parse a null-terminated string into `result` (without including the NUL
253 /// terminator).
254 LogicalResult parseNullTerminatedString(StringRef &result) {
255 const char *startIt = (const char *)dataIt;
256 const char *nulIt = (const char *)memchr(startIt, 0, size());
257 if (!nulIt)
258 return emitError(
259 "malformed null-terminated string, no null character found");
260
261 result = StringRef(startIt, nulIt - startIt);
262 dataIt = (const uint8_t *)nulIt + 1;
263 return success();
264 }
265
266 /// Validate that the alignment requested in the section is valid.
267 using ValidateAlignmentFn = function_ref<LogicalResult(unsigned alignment)>;
268
269 /// Parse a section header, placing the kind of section in `sectionID` and the
270 /// contents of the section in `sectionData`.
271 LogicalResult parseSection(bytecode::Section::ID &sectionID,
272 ValidateAlignmentFn alignmentValidator,
273 ArrayRef<uint8_t> &sectionData) {
274 uint8_t sectionIDAndHasAlignment;
275 uint64_t length;
276 if (failed(parseByte(sectionIDAndHasAlignment)) ||
277 failed(parseVarInt(length)))
278 return failure();
279
280 // Extract the section ID and whether the section is aligned. The high bit
281 // of the ID is the alignment flag.
282 sectionID = static_cast<bytecode::Section::ID>(sectionIDAndHasAlignment &
283 0b01111111);
284 bool hasAlignment = sectionIDAndHasAlignment & 0b10000000;
285
286 // Check that the section is actually valid before trying to process its
287 // data.
288 if (sectionID >= bytecode::Section::kNumSections)
289 return emitError("invalid section ID: ", unsigned(sectionID));
290
291 // Process the section alignment if present.
292 if (hasAlignment) {
293 // Read the requested alignment from the bytecode parser.
294 uint64_t alignment;
295 if (failed(parseVarInt(alignment)))
296 return failure();
297
298 // Check that the requested alignment must not exceed the alignment of
299 // the root buffer itself. Otherwise we cannot guarantee that pointers
300 // derived from this buffer will actually satisfy the requested alignment
301 // globally.
302 //
303 // Consider a bytecode buffer that is guaranteed to be 8k aligned, but not
304 // 16k aligned (e.g. absolute address 40960. If a section inside this
305 // buffer declares a 16k alignment requirement, two problems can arise:
306 //
307 // (a) If we "align forward" the current pointer to the next
308 // 16k boundary, the amount of padding we skip depends on the
309 // buffer's starting address. For example:
310 //
311 // buffer_start = 40960
312 // next 16k boundary = 49152
313 // bytes skipped = 49152 - 40960 = 8192
314 //
315 // This leaves behind variable padding that could be misinterpreted
316 // as part of the next section.
317 //
318 // (b) If we align relative to the buffer start, we may
319 // obtain addresses that are multiples of "buffer_start +
320 // section_alignment" rather than truly globally aligned
321 // addresses. For example:
322 //
323 // buffer_start = 40960 (5×8k, 8k aligned but not 16k)
324 // offset = 16384 (first multiple of 16k)
325 // section_ptr = 40960 + 16384 = 57344
326 //
327 // 57344 is 8k aligned but not 16k aligned.
328 // Any consumer expecting true 16k alignment would see this as a
329 // violation.
330 if (failed(alignmentValidator(alignment)))
331 return emitError("failed to align section ID: ", unsigned(sectionID));
332
333 // Align the buffer.
334 if (failed(alignTo(alignment)))
335 return failure();
336 }
337
338 // Parse the actual section data.
339 return parseBytes(static_cast<size_t>(length), sectionData);
340 }
341
342 Location getLoc() const { return fileLoc; }
343
344private:
345 /// Parse a variable length encoded integer from the byte stream. This method
346 /// is a fallback when the number of bytes used to encode the value is greater
347 /// than 1, but less than the max (9). The provided `result` value can be
348 /// assumed to already contain the first byte of the value.
349 /// NOTE: This method is marked noinline to avoid pessimizing the common case
350 /// of single byte encoding.
351 LLVM_ATTRIBUTE_NOINLINE LogicalResult parseMultiByteVarInt(uint64_t &result) {
352 // Count the number of trailing zeros in the marker byte, this indicates the
353 // number of trailing bytes that are part of the value. We use `uint32_t`
354 // here because we only care about the first byte, and so that be actually
355 // get ctz intrinsic calls when possible (the `uint8_t` overload uses a loop
356 // implementation).
357 uint32_t numBytes = llvm::countr_zero<uint32_t>(result);
358 assert(numBytes > 0 && numBytes <= 7 &&
359 "unexpected number of trailing zeros in varint encoding");
360
361 // Parse in the remaining bytes of the value.
362 llvm::support::ulittle64_t resultLE(result);
363 if (failed(
364 parseBytes(numBytes, reinterpret_cast<uint8_t *>(&resultLE) + 1)))
365 return failure();
366
367 // Shift out the low-order bits that were used to mark how the value was
368 // encoded.
369 result = resultLE >> (numBytes + 1);
370 return success();
371 }
372
373 /// The bytecode buffer.
374 ArrayRef<uint8_t> buffer;
375
376 /// The current iterator within the 'buffer'.
377 const uint8_t *dataIt;
378
379 /// A location for the bytecode used to report errors.
380 Location fileLoc;
381};
382} // namespace
383
384/// Resolve an index into the given entry list. `entry` may either be a
385/// reference, in which case it is assigned to the corresponding value in
386/// `entries`, or a pointer, in which case it is assigned to the address of the
387/// element in `entries`.
388template <typename RangeT, typename T>
389static LogicalResult resolveEntry(EncodingReader &reader, RangeT &entries,
390 uint64_t index, T &entry,
391 StringRef entryStr) {
392 if (index >= entries.size())
393 return reader.emitError("invalid ", entryStr, " index: ", index);
394
395 // If the provided entry is a pointer, resolve to the address of the entry.
396 if constexpr (std::is_convertible_v<llvm::detail::ValueOfRange<RangeT>, T>)
397 entry = entries[index];
398 else
399 entry = &entries[index];
400 return success();
401}
402
403/// Parse and resolve an index into the given entry list.
404template <typename RangeT, typename T>
405static LogicalResult parseEntry(EncodingReader &reader, RangeT &entries,
406 T &entry, StringRef entryStr) {
407 uint64_t entryIdx;
408 if (failed(reader.parseVarInt(entryIdx)))
409 return failure();
410 return resolveEntry(reader, entries, entryIdx, entry, entryStr);
411}
412
413//===----------------------------------------------------------------------===//
414// StringSectionReader
415//===----------------------------------------------------------------------===//
416
417namespace {
418/// This class is used to read references to the string section from the
419/// bytecode.
420class StringSectionReader {
421public:
422 /// Initialize the string section reader with the given section data.
423 LogicalResult initialize(Location fileLoc, ArrayRef<uint8_t> sectionData);
424
425 /// Parse a shared string from the string section. The shared string is
426 /// encoded using an index to a corresponding string in the string section.
427 LogicalResult parseString(EncodingReader &reader, StringRef &result) const {
428 return parseEntry(reader, strings, result, "string");
429 }
430
431 /// Parse a shared string from the string section. The shared string is
432 /// encoded using an index to a corresponding string in the string section.
433 /// This variant parses a flag compressed with the index.
434 LogicalResult parseStringWithFlag(EncodingReader &reader, StringRef &result,
435 bool &flag) const {
436 uint64_t entryIdx;
437 if (failed(reader.parseVarIntWithFlag(entryIdx, flag)))
438 return failure();
439 return parseStringAtIndex(reader, entryIdx, result);
440 }
441
442 /// Parse a shared string from the string section. The shared string is
443 /// encoded using an index to a corresponding string in the string section.
444 LogicalResult parseStringAtIndex(EncodingReader &reader, uint64_t index,
445 StringRef &result) const {
446 return resolveEntry(reader, strings, index, result, "string");
447 }
448
449private:
450 /// The table of strings referenced within the bytecode file.
451 SmallVector<StringRef> strings;
452};
453} // namespace
454
455LogicalResult StringSectionReader::initialize(Location fileLoc,
456 ArrayRef<uint8_t> sectionData) {
457 EncodingReader stringReader(sectionData, fileLoc);
458
459 // Parse the number of strings in the section.
460 uint64_t numStrings;
461 if (failed(stringReader.parseVarInt(numStrings)))
462 return failure();
463 strings.resize(numStrings);
464
465 // Parse each of the strings. The sizes of the strings are encoded in reverse
466 // order, so that's the order we populate the table.
467 size_t stringDataEndOffset = sectionData.size();
468 for (StringRef &string : llvm::reverse(strings)) {
469 uint64_t stringSize;
470 if (failed(stringReader.parseVarInt(stringSize)))
471 return failure();
472 if (stringDataEndOffset < stringSize) {
473 return stringReader.emitError(
474 "string size exceeds the available data size");
475 }
476
477 // Extract the string from the data, dropping the null character.
478 size_t stringOffset = stringDataEndOffset - stringSize;
479 string = StringRef(
480 reinterpret_cast<const char *>(sectionData.data() + stringOffset),
481 stringSize - 1);
482 stringDataEndOffset = stringOffset;
483 }
484
485 // Check that the only remaining data was for the strings, i.e. the reader
486 // should be at the same offset as the first string.
487 if ((sectionData.size() - stringReader.size()) != stringDataEndOffset) {
488 return stringReader.emitError("unexpected trailing data between the "
489 "offsets for strings and their data");
490 }
491 return success();
492}
493
494//===----------------------------------------------------------------------===//
495// BytecodeDialect
496//===----------------------------------------------------------------------===//
497
498namespace {
499class DialectReader;
500
501/// This struct represents a dialect entry within the bytecode.
502struct BytecodeDialect {
503 /// Load the dialect into the provided context if it hasn't been loaded yet.
504 /// Returns failure if the dialect couldn't be loaded *and* the provided
505 /// context does not allow unregistered dialects. The provided reader is used
506 /// for error emission if necessary.
507 LogicalResult load(const DialectReader &reader, MLIRContext *ctx);
508
509 /// Return the loaded dialect, or nullptr if the dialect is unknown. This can
510 /// only be called after `load`.
511 Dialect *getLoadedDialect() const {
512 assert(dialect &&
513 "expected `load` to be invoked before `getLoadedDialect`");
514 return *dialect;
515 }
516
517 /// The loaded dialect entry. This field is std::nullopt if we haven't
518 /// attempted to load, nullptr if we failed to load, otherwise the loaded
519 /// dialect.
520 std::optional<Dialect *> dialect;
521
522 /// The bytecode interface of the dialect, or nullptr if the dialect does not
523 /// implement the bytecode interface. This field should only be checked if the
524 /// `dialect` field is not std::nullopt.
525 const BytecodeDialectInterface *interface = nullptr;
526
527 /// The name of the dialect.
528 StringRef name;
529
530 /// A buffer containing the encoding of the dialect version parsed.
531 ArrayRef<uint8_t> versionBuffer;
532
533 /// Lazy loaded dialect version from the handle above.
534 std::unique_ptr<DialectVersion> loadedVersion;
535};
536
537/// This struct represents an operation name entry within the bytecode.
538struct BytecodeOperationName {
539 BytecodeOperationName(BytecodeDialect *dialect, StringRef name,
540 std::optional<bool> wasRegistered)
541 : dialect(dialect), name(name), wasRegistered(wasRegistered) {}
542
543 /// The loaded operation name, or std::nullopt if it hasn't been processed
544 /// yet.
545 std::optional<OperationName> opName;
546
547 /// The dialect that owns this operation name.
548 BytecodeDialect *dialect;
549
550 /// The name of the operation, without the dialect prefix.
551 StringRef name;
552
553 /// Whether this operation was registered when the bytecode was produced.
554 /// This flag is populated when bytecode version >=kNativePropertiesEncoding.
555 std::optional<bool> wasRegistered;
556};
557} // namespace
558
559/// Parse a single dialect group encoded in the byte stream.
560static LogicalResult parseDialectGrouping(
561 EncodingReader &reader,
562 MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
563 function_ref<LogicalResult(BytecodeDialect *)> entryCallback) {
564 // Parse the dialect and the number of entries in the group.
565 std::unique_ptr<BytecodeDialect> *dialect;
566 if (failed(parseEntry(reader, dialects, dialect, "dialect")))
567 return failure();
568 uint64_t numEntries;
569 if (failed(reader.parseVarInt(numEntries)))
570 return failure();
571
572 for (uint64_t i = 0; i < numEntries; ++i)
573 if (failed(entryCallback(dialect->get())))
574 return failure();
575 return success();
576}
577
578//===----------------------------------------------------------------------===//
579// ResourceSectionReader
580//===----------------------------------------------------------------------===//
581
582namespace {
583/// This class is used to read the resource section from the bytecode.
584class ResourceSectionReader {
585public:
586 /// Initialize the resource section reader with the given section data.
587 LogicalResult
588 initialize(Location fileLoc, const ParserConfig &config,
589 MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
590 StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData,
591 ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader,
592 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef);
593
594 /// Parse a dialect resource handle from the resource section.
595 LogicalResult parseResourceHandle(EncodingReader &reader,
596 AsmDialectResourceHandle &result) const {
597 return parseEntry(reader, dialectResources, result, "resource handle");
598 }
599
600private:
601 /// The table of dialect resources within the bytecode file.
602 SmallVector<AsmDialectResourceHandle> dialectResources;
603 llvm::StringMap<std::string> dialectResourceHandleRenamingMap;
604};
605
606class ParsedResourceEntry : public AsmParsedResourceEntry {
607public:
608 ParsedResourceEntry(StringRef key, AsmResourceEntryKind kind,
609 EncodingReader &reader, StringSectionReader &stringReader,
610 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef)
611 : key(key), kind(kind), reader(reader), stringReader(stringReader),
612 bufferOwnerRef(bufferOwnerRef) {}
613 ~ParsedResourceEntry() override = default;
614
615 StringRef getKey() const final { return key; }
616
617 InFlightDiagnostic emitError() const final { return reader.emitError(); }
618
619 AsmResourceEntryKind getKind() const final { return kind; }
620
621 FailureOr<bool> parseAsBool() const final {
622 if (kind != AsmResourceEntryKind::Bool)
623 return emitError() << "expected a bool resource entry, but found a "
624 << toString(kind) << " entry instead";
625
626 bool value;
627 if (failed(reader.parseByte(value)))
628 return failure();
629 return value;
630 }
631 FailureOr<std::string> parseAsString() const final {
632 if (kind != AsmResourceEntryKind::String)
633 return emitError() << "expected a string resource entry, but found a "
634 << toString(kind) << " entry instead";
635
636 StringRef string;
637 if (failed(stringReader.parseString(reader, string)))
638 return failure();
639 return string.str();
640 }
641
642 FailureOr<AsmResourceBlob>
643 parseAsBlob(BlobAllocatorFn allocator) const final {
644 if (kind != AsmResourceEntryKind::Blob)
645 return emitError() << "expected a blob resource entry, but found a "
646 << toString(kind) << " entry instead";
647
648 ArrayRef<uint8_t> data;
649 uint64_t alignment;
650 if (failed(reader.parseBlobAndAlignment(data, alignment)))
651 return failure();
652
653 // If we have an extendable reference to the buffer owner, we don't need to
654 // allocate a new buffer for the data, and can use the data directly.
655 if (bufferOwnerRef) {
656 ArrayRef<char> charData(reinterpret_cast<const char *>(data.data()),
657 data.size());
658
659 // Allocate an unmanager buffer which captures a reference to the owner.
660 // For now we just mark this as immutable, but in the future we should
661 // explore marking this as mutable when desired.
663 charData, alignment,
664 [bufferOwnerRef = bufferOwnerRef](void *, size_t, size_t) {});
665 }
666
667 // Allocate memory for the blob using the provided allocator and copy the
668 // data into it.
669 AsmResourceBlob blob = allocator(data.size(), alignment);
670 assert(llvm::isAddrAligned(llvm::Align(alignment), blob.getData().data()) &&
671 blob.isMutable() &&
672 "blob allocator did not return a properly aligned address");
673 memcpy(blob.getMutableData().data(), data.data(), data.size());
674 return blob;
675 }
676
677private:
678 StringRef key;
680 EncodingReader &reader;
681 StringSectionReader &stringReader;
682 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef;
683};
684} // namespace
685
686template <typename T>
687static LogicalResult
688parseResourceGroup(Location fileLoc, bool allowEmpty,
689 EncodingReader &offsetReader, EncodingReader &resourceReader,
690 StringSectionReader &stringReader, T *handler,
691 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef,
692 function_ref<StringRef(StringRef)> remapKey = {},
693 function_ref<LogicalResult(StringRef)> processKeyFn = {}) {
694 uint64_t numResources;
695 if (failed(offsetReader.parseVarInt(numResources)))
696 return failure();
697
698 for (uint64_t i = 0; i < numResources; ++i) {
699 StringRef key;
701 uint64_t resourceOffset;
702 ArrayRef<uint8_t> data;
703 if (failed(stringReader.parseString(offsetReader, key)) ||
704 failed(offsetReader.parseVarInt(resourceOffset)) ||
705 failed(offsetReader.parseByte(kind)) ||
706 failed(resourceReader.parseBytes(resourceOffset, data)))
707 return failure();
708
709 // Process the resource key.
710 if ((processKeyFn && failed(processKeyFn(key))))
711 return failure();
712
713 // If the resource data is empty and we allow it, don't error out when
714 // parsing below, just skip it.
715 if (allowEmpty && data.empty())
716 continue;
717
718 // Ignore the entry if we don't have a valid handler.
719 if (!handler)
720 continue;
721
722 // Otherwise, parse the resource value.
723 EncodingReader entryReader(data, fileLoc);
724 key = remapKey(key);
725 ParsedResourceEntry entry(key, kind, entryReader, stringReader,
726 bufferOwnerRef);
727 if (failed(handler->parseResource(entry)))
728 return failure();
729 if (!entryReader.empty()) {
730 return entryReader.emitError(
731 "unexpected trailing bytes in resource entry '", key, "'");
732 }
733 }
734 return success();
735}
736
737LogicalResult ResourceSectionReader::initialize(
738 Location fileLoc, const ParserConfig &config,
739 MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
740 StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData,
741 ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader,
742 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) {
743 EncodingReader resourceReader(sectionData, fileLoc);
744 EncodingReader offsetReader(offsetSectionData, fileLoc);
745
746 // Read the number of external resource providers.
747 uint64_t numExternalResourceGroups;
748 if (failed(offsetReader.parseVarInt(numExternalResourceGroups)))
749 return failure();
750
751 // Utility functor that dispatches to `parseResourceGroup`, but implicitly
752 // provides most of the arguments.
753 auto parseGroup = [&](auto *handler, bool allowEmpty = false,
754 function_ref<LogicalResult(StringRef)> keyFn = {}) {
755 auto resolveKey = [&](StringRef key) -> StringRef {
756 auto it = dialectResourceHandleRenamingMap.find(key);
757 if (it == dialectResourceHandleRenamingMap.end())
758 return key;
759 return it->second;
760 };
761
762 return parseResourceGroup(fileLoc, allowEmpty, offsetReader, resourceReader,
763 stringReader, handler, bufferOwnerRef, resolveKey,
764 keyFn);
765 };
766
767 // Read the external resources from the bytecode.
768 for (uint64_t i = 0; i < numExternalResourceGroups; ++i) {
769 StringRef key;
770 if (failed(stringReader.parseString(offsetReader, key)))
771 return failure();
772
773 // Get the handler for these resources.
774 // TODO: Should we require handling external resources in some scenarios?
775 AsmResourceParser *handler = config.getResourceParser(key);
776 if (!handler) {
777 emitWarning(fileLoc) << "ignoring unknown external resources for '" << key
778 << "'";
779 }
780
781 if (failed(parseGroup(handler)))
782 return failure();
783 }
784
785 // Read the dialect resources from the bytecode.
786 MLIRContext *ctx = fileLoc->getContext();
787 while (!offsetReader.empty()) {
788 std::unique_ptr<BytecodeDialect> *dialect;
789 if (failed(parseEntry(offsetReader, dialects, dialect, "dialect")) ||
790 failed((*dialect)->load(dialectReader, ctx)))
791 return failure();
792 Dialect *loadedDialect = (*dialect)->getLoadedDialect();
793 if (!loadedDialect) {
794 return resourceReader.emitError()
795 << "dialect '" << (*dialect)->name << "' is unknown";
796 }
797 const auto *handler = dyn_cast<OpAsmDialectInterface>(loadedDialect);
798 if (!handler) {
799 return resourceReader.emitError()
800 << "unexpected resources for dialect '" << (*dialect)->name << "'";
801 }
802
803 // Ensure that each resource is declared before being processed.
804 auto processResourceKeyFn = [&](StringRef key) -> LogicalResult {
805 FailureOr<AsmDialectResourceHandle> handle =
806 handler->declareResource(key);
807 if (failed(handle)) {
808 return resourceReader.emitError()
809 << "unknown 'resource' key '" << key << "' for dialect '"
810 << (*dialect)->name << "'";
811 }
812 dialectResourceHandleRenamingMap[key] = handler->getResourceKey(*handle);
813 dialectResources.push_back(*handle);
814 return success();
815 };
816
817 // Parse the resources for this dialect. We allow empty resources because we
818 // just treat these as declarations.
819 if (failed(parseGroup(handler, /*allowEmpty=*/true, processResourceKeyFn)))
820 return failure();
821 }
822
823 return success();
824}
825
826//===----------------------------------------------------------------------===//
827// Attribute/Type Reader
828//===----------------------------------------------------------------------===//
829
830namespace {
831/// This class provides support for reading attribute and type entries from the
832/// bytecode. Attribute and Type entries are read lazily on demand, so we use
833/// this reader to manage when to actually parse them from the bytecode.
834///
835/// The parsing of attributes & types are generally recursive, this can lead to
836/// stack overflows for deeply nested structures, so we track a few extra pieces
837/// of information to avoid this:
838///
839/// - `depth`: The current depth while parsing nested attributes. We defer on
840/// parsing deeply nested attributes to avoid potential stack overflows. The
841/// deferred parsing is achieved by reporting a failure when parsing a nested
842/// attribute/type and registering the index of the encountered attribute/type
843/// in the deferred parsing worklist. Hence, a failure with deffered entry
844/// does not constitute a failure, it also requires that folks return on
845/// first failure rather than attempting additional parses.
846/// - `deferredWorklist`: A list of attribute/type indices that we could not
847/// parse due to hitting the depth limit. The worklist is used to capture the
848/// indices of attributes/types that need to be parsed/reparsed when we hit
849/// the depth limit. This enables moving the tracking of what needs to be
850/// parsed to the heap.
851class AttrTypeReader {
852 /// This class represents a single attribute or type entry.
853 template <typename T>
854 struct Entry {
855 /// The entry, or null if it hasn't been resolved yet.
856 T entry = {};
857 /// The parent dialect of this entry.
858 BytecodeDialect *dialect = nullptr;
859 /// A flag indicating if the entry was encoded using a custom encoding,
860 /// instead of using the textual assembly format.
861 bool hasCustomEncoding = false;
862 /// The raw data of this entry in the bytecode.
863 ArrayRef<uint8_t> data;
864 };
865 using AttrEntry = Entry<Attribute>;
866 using TypeEntry = Entry<Type>;
867
868public:
869 AttrTypeReader(const StringSectionReader &stringReader,
870 const ResourceSectionReader &resourceReader,
871 const llvm::StringMap<BytecodeDialect *> &dialectsMap,
872 uint64_t &bytecodeVersion, Location fileLoc,
873 const ParserConfig &config)
874 : stringReader(stringReader), resourceReader(resourceReader),
875 dialectsMap(dialectsMap), fileLoc(fileLoc),
876 bytecodeVersion(bytecodeVersion), parserConfig(config) {}
877
878 /// Initialize the attribute and type information within the reader.
879 LogicalResult
880 initialize(MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
881 ArrayRef<uint8_t> sectionData,
882 ArrayRef<uint8_t> offsetSectionData);
883
884 LogicalResult readAttribute(uint64_t index, Attribute &result,
885 uint64_t depth = 0) {
886 return readEntry(attributes, index, result, "attribute", depth);
887 }
888
889 LogicalResult readType(uint64_t index, Type &result, uint64_t depth = 0) {
890 return readEntry(types, index, result, "type", depth);
891 }
892
893 /// Resolve the attribute or type at the given index. Returns nullptr on
894 /// failure.
895 Attribute resolveAttribute(size_t index, uint64_t depth = 0) {
896 return resolveEntry(attributes, index, "Attribute", depth);
897 }
898 Type resolveType(size_t index, uint64_t depth = 0) {
899 return resolveEntry(types, index, "Type", depth);
900 }
901
902 Attribute getAttributeOrSentinel(size_t index) {
903 if (index >= attributes.size())
904 return nullptr;
905 return attributes[index].entry;
906 }
907 Type getTypeOrSentinel(size_t index) {
908 if (index >= types.size())
909 return nullptr;
910 return types[index].entry;
911 }
912
913 /// Parse a reference to an attribute or type using the given reader.
914 LogicalResult parseAttribute(EncodingReader &reader, Attribute &result) {
915 uint64_t attrIdx;
916 if (failed(reader.parseVarInt(attrIdx)))
917 return failure();
918 result = resolveAttribute(attrIdx);
919 return success(!!result);
920 }
921 LogicalResult parseOptionalAttribute(EncodingReader &reader,
922 Attribute &result) {
923 uint64_t attrIdx;
924 bool flag;
925 if (failed(reader.parseVarIntWithFlag(attrIdx, flag)))
926 return failure();
927 if (!flag)
928 return success();
929 result = resolveAttribute(attrIdx);
930 return success(!!result);
931 }
932
933 LogicalResult parseType(EncodingReader &reader, Type &result) {
934 uint64_t typeIdx;
935 if (failed(reader.parseVarInt(typeIdx)))
936 return failure();
937 result = resolveType(typeIdx);
938 return success(!!result);
939 }
940
941 template <typename T>
942 LogicalResult parseAttribute(EncodingReader &reader, T &result) {
943 Attribute baseResult;
944 if (failed(parseAttribute(reader, baseResult)))
945 return failure();
946 if ((result = dyn_cast<T>(baseResult)))
947 return success();
948 return reader.emitError("expected attribute of type: ",
949 llvm::getTypeName<T>(), ", but got: ", baseResult);
950 }
951
952 /// Add an index to the deferred worklist for re-parsing.
953 void addDeferredParsing(uint64_t index) { deferredWorklist.push_back(index); }
954
955private:
956 /// Resolve the given entry at `index`.
957 template <typename T>
958 T resolveEntry(SmallVectorImpl<Entry<T>> &entries, uint64_t index,
959 StringRef entryType, uint64_t depth = 0);
960
961 /// Read the entry at the given index, returning failure if the entry is not
962 /// yet resolved.
963 template <typename T>
964 LogicalResult readEntry(SmallVectorImpl<Entry<T>> &entries, uint64_t index,
965 T &result, StringRef entryType, uint64_t depth);
966
967 /// Parse an entry using the given reader that was encoded using a custom
968 /// bytecode format.
969 template <typename T>
970 LogicalResult parseCustomEntry(Entry<T> &entry, EncodingReader &reader,
971 StringRef entryType, uint64_t index,
972 uint64_t depth);
973
974 /// Parse an entry using the given reader that was encoded using the textual
975 /// assembly format.
976 template <typename T>
977 LogicalResult parseAsmEntry(T &result, EncodingReader &reader,
978 StringRef entryType);
979
980 /// The string section reader used to resolve string references when parsing
981 /// custom encoded attribute/type entries.
982 const StringSectionReader &stringReader;
983
984 /// The resource section reader used to resolve resource references when
985 /// parsing custom encoded attribute/type entries.
986 const ResourceSectionReader &resourceReader;
987
988 /// The map of the loaded dialects used to retrieve dialect information, such
989 /// as the dialect version.
990 const llvm::StringMap<BytecodeDialect *> &dialectsMap;
991
992 /// The set of attribute and type entries.
993 SmallVector<AttrEntry> attributes;
994 SmallVector<TypeEntry> types;
995
996 /// A location used for error emission.
997 Location fileLoc;
998
999 /// Current bytecode version being used.
1000 uint64_t &bytecodeVersion;
1001
1002 /// Reference to the parser configuration.
1003 const ParserConfig &parserConfig;
1004
1005 /// Worklist for deferred attribute/type parsing. This is used to handle
1006 /// deeply nested structures like CallSiteLoc iteratively.
1007 std::vector<uint64_t> deferredWorklist;
1008};
1009
1010class DialectReader : public DialectBytecodeReader {
1011public:
1012 DialectReader(AttrTypeReader &attrTypeReader,
1013 const StringSectionReader &stringReader,
1014 const ResourceSectionReader &resourceReader,
1015 const llvm::StringMap<BytecodeDialect *> &dialectsMap,
1016 EncodingReader &reader, uint64_t &bytecodeVersion,
1017 uint64_t depth = 0)
1018 : attrTypeReader(attrTypeReader), stringReader(stringReader),
1019 resourceReader(resourceReader), dialectsMap(dialectsMap),
1020 reader(reader), bytecodeVersion(bytecodeVersion), depth(depth) {}
1021
1022 InFlightDiagnostic emitError(const Twine &msg) const override {
1023 return reader.emitError(msg);
1024 }
1025
1026 FailureOr<const DialectVersion *>
1027 getDialectVersion(StringRef dialectName) const override {
1028 // First check if the dialect is available in the map.
1029 auto dialectEntry = dialectsMap.find(dialectName);
1030 if (dialectEntry == dialectsMap.end())
1031 return failure();
1032 // If the dialect was found, try to load it. This will trigger reading the
1033 // bytecode version from the version buffer if it wasn't already processed.
1034 // Return failure if either of those two actions could not be completed.
1035 if (failed(dialectEntry->getValue()->load(*this, getLoc().getContext())) ||
1036 dialectEntry->getValue()->loadedVersion == nullptr)
1037 return failure();
1038 return dialectEntry->getValue()->loadedVersion.get();
1039 }
1040
1041 MLIRContext *getContext() const override { return getLoc().getContext(); }
1042
1043 uint64_t getBytecodeVersion() const override { return bytecodeVersion; }
1044
1045 DialectReader withEncodingReader(EncodingReader &encReader) const {
1046 return DialectReader(attrTypeReader, stringReader, resourceReader,
1047 dialectsMap, encReader, bytecodeVersion);
1048 }
1049
1050 Location getLoc() const { return reader.getLoc(); }
1051
1052 //===--------------------------------------------------------------------===//
1053 // IR
1054 //===--------------------------------------------------------------------===//
1055
1056 /// The maximum depth to eagerly parse nested attributes/types before
1057 /// deferring.
1058 static constexpr uint64_t maxAttrTypeDepth = 5;
1059
1060 LogicalResult readAttribute(Attribute &result) override {
1061 uint64_t index;
1062 if (failed(reader.parseVarInt(index)))
1063 return failure();
1064 if (depth > maxAttrTypeDepth) {
1065 if (Attribute attr = attrTypeReader.getAttributeOrSentinel(index)) {
1066 result = attr;
1067 return success();
1068 }
1069 attrTypeReader.addDeferredParsing(index);
1070 return failure();
1071 }
1072 return attrTypeReader.readAttribute(index, result, depth + 1);
1073 }
1074 LogicalResult readOptionalAttribute(Attribute &result) override {
1075 return attrTypeReader.parseOptionalAttribute(reader, result);
1076 }
1077 LogicalResult readType(Type &result) override {
1078 uint64_t index;
1079 if (failed(reader.parseVarInt(index)))
1080 return failure();
1081 if (depth > maxAttrTypeDepth) {
1082 if (Type type = attrTypeReader.getTypeOrSentinel(index)) {
1083 result = type;
1084 return success();
1085 }
1086 attrTypeReader.addDeferredParsing(index);
1087 return failure();
1088 }
1089 return attrTypeReader.readType(index, result, depth + 1);
1090 }
1091
1092 FailureOr<AsmDialectResourceHandle> readResourceHandle() override {
1093 AsmDialectResourceHandle handle;
1094 if (failed(resourceReader.parseResourceHandle(reader, handle)))
1095 return failure();
1096 return handle;
1097 }
1098
1099 //===--------------------------------------------------------------------===//
1100 // Primitives
1101 //===--------------------------------------------------------------------===//
1102
1103 LogicalResult readVarInt(uint64_t &result) override {
1104 return reader.parseVarInt(result);
1105 }
1106
1107 LogicalResult readSignedVarInt(int64_t &result) override {
1108 uint64_t unsignedResult;
1109 if (failed(reader.parseSignedVarInt(unsignedResult)))
1110 return failure();
1111 result = static_cast<int64_t>(unsignedResult);
1112 return success();
1113 }
1114
1115 FailureOr<APInt> readAPIntWithKnownWidth(unsigned bitWidth) override {
1116 // Small values are encoded using a single byte.
1117 if (bitWidth <= 8) {
1118 uint8_t value;
1119 if (failed(reader.parseByte(value)))
1120 return failure();
1121 return APInt(bitWidth, value);
1122 }
1123
1124 // Large values up to 64 bits are encoded using a single varint.
1125 if (bitWidth <= 64) {
1126 uint64_t value;
1127 if (failed(reader.parseSignedVarInt(value)))
1128 return failure();
1129 return APInt(bitWidth, value);
1130 }
1131
1132 // Otherwise, for really big values we encode the array of active words in
1133 // the value.
1134 uint64_t numActiveWords;
1135 if (failed(reader.parseVarInt(numActiveWords)))
1136 return failure();
1137 SmallVector<uint64_t, 4> words(numActiveWords);
1138 for (uint64_t i = 0; i < numActiveWords; ++i)
1139 if (failed(reader.parseSignedVarInt(words[i])))
1140 return failure();
1141 return APInt(bitWidth, words);
1142 }
1143
1144 FailureOr<APFloat>
1145 readAPFloatWithKnownSemantics(const llvm::fltSemantics &semantics) override {
1146 FailureOr<APInt> intVal =
1147 readAPIntWithKnownWidth(APFloat::getSizeInBits(semantics));
1148 if (failed(intVal))
1149 return failure();
1150 return APFloat(semantics, *intVal);
1151 }
1152
1153 LogicalResult readString(StringRef &result) override {
1154 return stringReader.parseString(reader, result);
1155 }
1156
1157 LogicalResult readBlob(ArrayRef<char> &result) override {
1158 uint64_t dataSize;
1159 ArrayRef<uint8_t> data;
1160 if (failed(reader.parseVarInt(dataSize)) ||
1161 failed(reader.parseBytes(dataSize, data)))
1162 return failure();
1163 result = llvm::ArrayRef(reinterpret_cast<const char *>(data.data()),
1164 data.size());
1165 return success();
1166 }
1167
1168 LogicalResult readBool(bool &result) override {
1169 return reader.parseByte(result);
1170 }
1171
1172private:
1173 AttrTypeReader &attrTypeReader;
1174 const StringSectionReader &stringReader;
1175 const ResourceSectionReader &resourceReader;
1176 const llvm::StringMap<BytecodeDialect *> &dialectsMap;
1177 EncodingReader &reader;
1178 uint64_t &bytecodeVersion;
1179 uint64_t depth;
1180};
1181
1182/// Wraps the properties section and handles reading properties out of it.
1183class PropertiesSectionReader {
1184public:
1185 /// Initialize the properties section reader with the given section data.
1186 LogicalResult initialize(Location fileLoc, ArrayRef<uint8_t> sectionData) {
1187 if (sectionData.empty())
1188 return success();
1189 EncodingReader propReader(sectionData, fileLoc);
1190 uint64_t count;
1191 if (failed(propReader.parseVarInt(count)))
1192 return failure();
1193 // Parse the raw properties buffer.
1194 if (failed(propReader.parseBytes(propReader.size(), propertiesBuffers)))
1195 return failure();
1196
1197 EncodingReader offsetsReader(propertiesBuffers, fileLoc);
1198 offsetTable.reserve(count);
1199 for (auto idx : llvm::seq<int64_t>(0, count)) {
1200 (void)idx;
1201 offsetTable.push_back(propertiesBuffers.size() - offsetsReader.size());
1202 ArrayRef<uint8_t> rawProperties;
1203 uint64_t dataSize;
1204 if (failed(offsetsReader.parseVarInt(dataSize)) ||
1205 failed(offsetsReader.parseBytes(dataSize, rawProperties)))
1206 return failure();
1207 }
1208 if (!offsetsReader.empty())
1209 return offsetsReader.emitError()
1210 << "Broken properties section: didn't exhaust the offsets table";
1211 return success();
1212 }
1213
1214 LogicalResult read(Location fileLoc, DialectReader &dialectReader,
1215 OperationName *opName, OperationState &opState) const {
1216 uint64_t propertiesIdx;
1217 if (failed(dialectReader.readVarInt(propertiesIdx)))
1218 return failure();
1219 if (propertiesIdx >= offsetTable.size())
1220 return dialectReader.emitError("Properties idx out-of-bound for ")
1221 << opName->getStringRef();
1222 size_t propertiesOffset = offsetTable[propertiesIdx];
1223 if (propertiesIdx >= propertiesBuffers.size())
1224 return dialectReader.emitError("Properties offset out-of-bound for ")
1225 << opName->getStringRef();
1226
1227 // Acquire the sub-buffer that represent the requested properties.
1228 ArrayRef<char> rawProperties;
1229 {
1230 // "Seek" to the requested offset by getting a new reader with the right
1231 // sub-buffer.
1232 EncodingReader reader(propertiesBuffers.drop_front(propertiesOffset),
1233 fileLoc);
1234 // Properties are stored as a sequence of {size + raw_data}.
1235 if (failed(
1236 dialectReader.withEncodingReader(reader).readBlob(rawProperties)))
1237 return failure();
1238 }
1239 // Setup a new reader to read from the `rawProperties` sub-buffer.
1240 EncodingReader reader(
1241 StringRef(rawProperties.begin(), rawProperties.size()), fileLoc);
1242 DialectReader propReader = dialectReader.withEncodingReader(reader);
1243
1244 auto *iface = opName->getInterface<BytecodeOpInterface>();
1245 if (iface)
1246 return iface->readProperties(propReader, opState);
1247 if (opName->isRegistered())
1248 return propReader.emitError(
1249 "has properties but missing BytecodeOpInterface for ")
1250 << opName->getStringRef();
1251 // Unregistered op are storing properties as an attribute.
1252 return propReader.readAttribute(opState.propertiesAttr);
1253 }
1254
1255private:
1256 /// The properties buffer referenced within the bytecode file.
1257 ArrayRef<uint8_t> propertiesBuffers;
1258
1259 /// Table of offset in the buffer above.
1260 SmallVector<int64_t> offsetTable;
1261};
1262} // namespace
1263
1264LogicalResult AttrTypeReader::initialize(
1265 MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
1266 ArrayRef<uint8_t> sectionData, ArrayRef<uint8_t> offsetSectionData) {
1267 EncodingReader offsetReader(offsetSectionData, fileLoc);
1268
1269 // Parse the number of attribute and type entries.
1270 uint64_t numAttributes, numTypes;
1271 if (failed(offsetReader.parseVarInt(numAttributes)) ||
1272 failed(offsetReader.parseVarInt(numTypes)))
1273 return failure();
1274 attributes.resize(numAttributes);
1275 types.resize(numTypes);
1276
1277 // A functor used to accumulate the offsets for the entries in the given
1278 // range.
1279 uint64_t currentOffset = 0;
1280 auto parseEntries = [&](auto &&range) {
1281 size_t currentIndex = 0, endIndex = range.size();
1282
1283 // Parse an individual entry.
1284 auto parseEntryFn = [&](BytecodeDialect *dialect) -> LogicalResult {
1285 auto &entry = range[currentIndex++];
1286
1287 uint64_t entrySize;
1288 if (failed(offsetReader.parseVarIntWithFlag(entrySize,
1289 entry.hasCustomEncoding)))
1290 return failure();
1291
1292 // Verify that the offset is actually valid.
1293 if (currentOffset + entrySize > sectionData.size()) {
1294 return offsetReader.emitError(
1295 "Attribute or Type entry offset points past the end of section");
1296 }
1297
1298 entry.data = sectionData.slice(currentOffset, entrySize);
1299 entry.dialect = dialect;
1300 currentOffset += entrySize;
1301 return success();
1302 };
1303 while (currentIndex != endIndex)
1304 if (failed(parseDialectGrouping(offsetReader, dialects, parseEntryFn)))
1305 return failure();
1306 return success();
1307 };
1308
1309 // Process each of the attributes, and then the types.
1310 if (failed(parseEntries(attributes)) || failed(parseEntries(types)))
1311 return failure();
1312
1313 // Ensure that we read everything from the section.
1314 if (!offsetReader.empty()) {
1315 return offsetReader.emitError(
1316 "unexpected trailing data in the Attribute/Type offset section");
1317 }
1318
1319 return success();
1320}
1321
1322template <typename T>
1323T AttrTypeReader::resolveEntry(SmallVectorImpl<Entry<T>> &entries,
1324 uint64_t index, StringRef entryType,
1325 uint64_t depth) {
1326 if (index >= entries.size()) {
1327 emitError(fileLoc) << "invalid " << entryType << " index: " << index;
1328 return {};
1329 }
1330
1331 // Fast path: Try direct parsing without worklist overhead. This handles the
1332 // common case where there are no deferred dependencies.
1333 assert(deferredWorklist.empty());
1334 T result;
1335 if (succeeded(readEntry(entries, index, result, entryType, depth))) {
1336 assert(deferredWorklist.empty());
1337 return result;
1338 }
1339 if (deferredWorklist.empty()) {
1340 // Failed with no deferred entries is error.
1341 return T();
1342 }
1343
1344 // Slow path: Use worklist to handle deferred dependencies. Use a deque to
1345 // iteratively resolve entries with dependencies.
1346 // - Pop from front to process
1347 // - Push new dependencies to front (depth-first)
1348 // - Move failed entries to back (retry after dependencies)
1349 std::deque<size_t> worklist;
1350 llvm::DenseSet<size_t> inWorklist;
1351
1352 // Add the original index and any dependencies from the fast path attempt.
1353 worklist.push_back(index);
1354 inWorklist.insert(index);
1355 for (uint64_t idx : llvm::reverse(deferredWorklist)) {
1356 if (inWorklist.insert(idx).second)
1357 worklist.push_front(idx);
1358 }
1359
1360 while (!worklist.empty()) {
1361 size_t currentIndex = worklist.front();
1362 worklist.pop_front();
1363
1364 // Clear the deferred worklist before parsing to capture any new entries.
1365 deferredWorklist.clear();
1366
1367 T result;
1368 if (succeeded(readEntry(entries, currentIndex, result, entryType, depth))) {
1369 inWorklist.erase(currentIndex);
1370 continue;
1371 }
1372
1373 if (deferredWorklist.empty()) {
1374 // Parsing failed with no deferred entries which implies an error.
1375 return T();
1376 }
1377
1378 // Move this entry to the back to retry after dependencies.
1379 worklist.push_back(currentIndex);
1380
1381 // Add dependencies to the front (in reverse so they maintain order).
1382 for (uint64_t idx : llvm::reverse(deferredWorklist)) {
1383 if (inWorklist.insert(idx).second)
1384 worklist.push_front(idx);
1385 }
1386 deferredWorklist.clear();
1387 }
1388 return entries[index].entry;
1389}
1390
1391template <typename T>
1392LogicalResult AttrTypeReader::readEntry(SmallVectorImpl<Entry<T>> &entries,
1393 uint64_t index, T &result,
1394 StringRef entryType, uint64_t depth) {
1395 if (index >= entries.size())
1396 return emitError(fileLoc) << "invalid " << entryType << " index: " << index;
1397
1398 // If the entry has already been resolved, return it.
1399 Entry<T> &entry = entries[index];
1400 if (entry.entry) {
1401 result = entry.entry;
1402 return success();
1403 }
1404
1405 // If the entry hasn't been resolved, try to parse it.
1406 EncodingReader reader(entry.data, fileLoc);
1407 LogicalResult parseResult =
1408 entry.hasCustomEncoding
1409 ? parseCustomEntry(entry, reader, entryType, index, depth)
1410 : parseAsmEntry(entry.entry, reader, entryType);
1411 if (failed(parseResult))
1412 return failure();
1413
1414 if (!reader.empty())
1415 return reader.emitError("unexpected trailing bytes after " + entryType +
1416 " entry");
1417
1418 result = entry.entry;
1419 return success();
1420}
1421
1422template <typename T>
1423LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry,
1424 EncodingReader &reader,
1425 StringRef entryType,
1426 uint64_t index, uint64_t depth) {
1427 DialectReader dialectReader(*this, stringReader, resourceReader, dialectsMap,
1428 reader, bytecodeVersion, depth);
1429 if (failed(entry.dialect->load(dialectReader, fileLoc.getContext())))
1430 return failure();
1431
1432 if constexpr (std::is_same_v<T, Type>) {
1433 // Try parsing with callbacks first if available.
1434 for (const auto &callback :
1435 parserConfig.getBytecodeReaderConfig().getTypeCallbacks()) {
1436 if (failed(
1437 callback->read(dialectReader, entry.dialect->name, entry.entry)))
1438 return failure();
1439 // Early return if parsing was successful.
1440 if (!!entry.entry)
1441 return success();
1442
1443 // Reset the reader if we failed to parse, so we can fall through the
1444 // other parsing functions.
1445 reader = EncodingReader(entry.data, reader.getLoc());
1446 }
1447 } else {
1448 // Try parsing with callbacks first if available.
1449 for (const auto &callback :
1451 if (failed(
1452 callback->read(dialectReader, entry.dialect->name, entry.entry)))
1453 return failure();
1454 // Early return if parsing was successful.
1455 if (!!entry.entry)
1456 return success();
1457
1458 // Reset the reader if we failed to parse, so we can fall through the
1459 // other parsing functions.
1460 reader = EncodingReader(entry.data, reader.getLoc());
1461 }
1462 }
1463
1464 // Ensure that the dialect implements the bytecode interface.
1465 if (!entry.dialect->interface) {
1466 return reader.emitError("dialect '", entry.dialect->name,
1467 "' does not implement the bytecode interface");
1468 }
1469
1470 if constexpr (std::is_same_v<T, Type>)
1471 entry.entry = entry.dialect->interface->readType(dialectReader);
1472 else
1473 entry.entry = entry.dialect->interface->readAttribute(dialectReader);
1474
1475 return success(!!entry.entry);
1476}
1477
1478template <typename T>
1479LogicalResult AttrTypeReader::parseAsmEntry(T &result, EncodingReader &reader,
1480 StringRef entryType) {
1481 StringRef asmStr;
1482 if (failed(reader.parseNullTerminatedString(asmStr)))
1483 return failure();
1484
1485 // Invoke the MLIR assembly parser to parse the entry text.
1486 size_t numRead = 0;
1487 MLIRContext *context = fileLoc->getContext();
1488 if constexpr (std::is_same_v<T, Type>)
1489 result =
1490 ::parseType(asmStr, context, &numRead, /*isKnownNullTerminated=*/true);
1491 else
1492 result = ::parseAttribute(asmStr, context, Type(), &numRead,
1493 /*isKnownNullTerminated=*/true);
1494 if (!result)
1495 return failure();
1496
1497 // Ensure there weren't dangling characters after the entry.
1498 if (numRead != asmStr.size()) {
1499 return reader.emitError("trailing characters found after ", entryType,
1500 " assembly format: ", asmStr.drop_front(numRead));
1501 }
1502 return success();
1503}
1504
1505//===----------------------------------------------------------------------===//
1506// Bytecode Reader
1507//===----------------------------------------------------------------------===//
1508
1509/// This class is used to read a bytecode buffer and translate it into MLIR.
1511 struct RegionReadState;
1512 using LazyLoadableOpsInfo =
1513 std::list<std::pair<Operation *, RegionReadState>>;
1514 using LazyLoadableOpsMap =
1516
1517public:
1518 Impl(Location fileLoc, const ParserConfig &config, bool lazyLoading,
1519 llvm::MemoryBufferRef buffer,
1520 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef)
1521 : config(config), fileLoc(fileLoc), lazyLoading(lazyLoading),
1522 attrTypeReader(stringReader, resourceReader, dialectsMap, version,
1523 fileLoc, config),
1524 // Use the builtin unrealized conversion cast operation to represent
1525 // forward references to values that aren't yet defined.
1526 forwardRefOpState(UnknownLoc::get(config.getContext()),
1527 "builtin.unrealized_conversion_cast", ValueRange(),
1528 NoneType::get(config.getContext())),
1529 buffer(buffer), bufferOwnerRef(bufferOwnerRef) {}
1530
1531 /// Read the bytecode defined within `buffer` into the given block.
1532 LogicalResult read(Block *block,
1533 llvm::function_ref<bool(Operation *)> lazyOps);
1534
1535 /// Return the number of ops that haven't been materialized yet.
1536 int64_t getNumOpsToMaterialize() const { return lazyLoadableOpsMap.size(); }
1537
1538 bool isMaterializable(Operation *op) { return lazyLoadableOpsMap.count(op); }
1539
1540 /// Materialize the provided operation, invoke the lazyOpsCallback on every
1541 /// newly found lazy operation.
1542 LogicalResult
1544 llvm::function_ref<bool(Operation *)> lazyOpsCallback) {
1545 this->lazyOpsCallback = lazyOpsCallback;
1546 auto resetlazyOpsCallback =
1547 llvm::make_scope_exit([&] { this->lazyOpsCallback = nullptr; });
1548 auto it = lazyLoadableOpsMap.find(op);
1549 assert(it != lazyLoadableOpsMap.end() &&
1550 "materialize called on non-materializable op");
1551 return materialize(it);
1552 }
1553
1554 /// Materialize all operations.
1555 LogicalResult materializeAll() {
1556 while (!lazyLoadableOpsMap.empty()) {
1557 if (failed(materialize(lazyLoadableOpsMap.begin())))
1558 return failure();
1559 }
1560 return success();
1561 }
1562
1563 /// Finalize the lazy-loading by calling back with every op that hasn't been
1564 /// materialized to let the client decide if the op should be deleted or
1565 /// materialized. The op is materialized if the callback returns true, deleted
1566 /// otherwise.
1567 LogicalResult finalize(function_ref<bool(Operation *)> shouldMaterialize) {
1568 while (!lazyLoadableOps.empty()) {
1569 Operation *op = lazyLoadableOps.begin()->first;
1570 if (shouldMaterialize(op)) {
1571 if (failed(materialize(lazyLoadableOpsMap.find(op))))
1572 return failure();
1573 continue;
1574 }
1575 op->dropAllReferences();
1576 op->erase();
1577 lazyLoadableOps.pop_front();
1578 lazyLoadableOpsMap.erase(op);
1579 }
1580 return success();
1581 }
1582
1583private:
1584 LogicalResult materialize(LazyLoadableOpsMap::iterator it) {
1585 assert(it != lazyLoadableOpsMap.end() &&
1586 "materialize called on non-materializable op");
1587 valueScopes.emplace_back();
1588 std::vector<RegionReadState> regionStack;
1589 regionStack.push_back(std::move(it->getSecond()->second));
1590 lazyLoadableOps.erase(it->getSecond());
1591 lazyLoadableOpsMap.erase(it);
1592
1593 while (!regionStack.empty())
1594 if (failed(parseRegions(regionStack, regionStack.back())))
1595 return failure();
1596 return success();
1597 }
1598
1599 LogicalResult checkSectionAlignment(
1600 unsigned alignment,
1601 function_ref<InFlightDiagnostic(const Twine &error)> emitError) {
1602 // Check that the bytecode buffer meets the requested section alignment.
1603 //
1604 // If it does not, the virtual address of the item in the section will
1605 // not be aligned to the requested alignment.
1606 //
1607 // The typical case where this is necessary is the resource blob
1608 // optimization in `parseAsBlob` where we reference the weights from the
1609 // provided buffer instead of copying them to a new allocation.
1610 const bool isGloballyAligned =
1611 ((uintptr_t)buffer.getBufferStart() & (alignment - 1)) == 0;
1612
1613 if (!isGloballyAligned)
1614 return emitError("expected section alignment ")
1615 << alignment << " but bytecode buffer 0x"
1616 << Twine::utohexstr((uint64_t)buffer.getBufferStart())
1617 << " is not aligned";
1618
1619 return success();
1620 };
1621
1622 /// Return the context for this config.
1623 MLIRContext *getContext() const { return config.getContext(); }
1624
1625 /// Parse the bytecode version.
1626 LogicalResult parseVersion(EncodingReader &reader);
1627
1628 //===--------------------------------------------------------------------===//
1629 // Dialect Section
1630
1631 LogicalResult parseDialectSection(ArrayRef<uint8_t> sectionData);
1632
1633 /// Parse an operation name reference using the given reader, and set the
1634 /// `wasRegistered` flag that indicates if the bytecode was produced by a
1635 /// context where opName was registered.
1636 FailureOr<OperationName> parseOpName(EncodingReader &reader,
1637 std::optional<bool> &wasRegistered);
1638
1639 //===--------------------------------------------------------------------===//
1640 // Attribute/Type Section
1641
1642 /// Parse an attribute or type using the given reader.
1643 template <typename T>
1644 LogicalResult parseAttribute(EncodingReader &reader, T &result) {
1645 return attrTypeReader.parseAttribute(reader, result);
1646 }
1647 LogicalResult parseType(EncodingReader &reader, Type &result) {
1648 return attrTypeReader.parseType(reader, result);
1649 }
1650
1651 //===--------------------------------------------------------------------===//
1652 // Resource Section
1653
1654 LogicalResult
1655 parseResourceSection(EncodingReader &reader,
1656 std::optional<ArrayRef<uint8_t>> resourceData,
1657 std::optional<ArrayRef<uint8_t>> resourceOffsetData);
1658
1659 //===--------------------------------------------------------------------===//
1660 // IR Section
1661
1662 /// This struct represents the current read state of a range of regions. This
1663 /// struct is used to enable iterative parsing of regions.
1664 struct RegionReadState {
1665 RegionReadState(Operation *op, EncodingReader *reader,
1666 bool isIsolatedFromAbove)
1667 : RegionReadState(op->getRegions(), reader, isIsolatedFromAbove) {}
1668 RegionReadState(MutableArrayRef<Region> regions, EncodingReader *reader,
1669 bool isIsolatedFromAbove)
1670 : curRegion(regions.begin()), endRegion(regions.end()), reader(reader),
1671 isIsolatedFromAbove(isIsolatedFromAbove) {}
1672
1673 /// The current regions being read.
1674 MutableArrayRef<Region>::iterator curRegion, endRegion;
1675 /// This is the reader to use for this region, this pointer is pointing to
1676 /// the parent region reader unless the current region is IsolatedFromAbove,
1677 /// in which case the pointer is pointing to the `owningReader` which is a
1678 /// section dedicated to the current region.
1679 EncodingReader *reader;
1680 std::unique_ptr<EncodingReader> owningReader;
1681
1682 /// The number of values defined immediately within this region.
1683 unsigned numValues = 0;
1684
1685 /// The current blocks of the region being read.
1686 SmallVector<Block *> curBlocks;
1687 Region::iterator curBlock = {};
1688
1689 /// The number of operations remaining to be read from the current block
1690 /// being read.
1691 uint64_t numOpsRemaining = 0;
1692
1693 /// A flag indicating if the regions being read are isolated from above.
1694 bool isIsolatedFromAbove = false;
1695 };
1696
1697 LogicalResult parseIRSection(ArrayRef<uint8_t> sectionData, Block *block);
1698 LogicalResult parseRegions(std::vector<RegionReadState> &regionStack,
1699 RegionReadState &readState);
1700 FailureOr<Operation *> parseOpWithoutRegions(EncodingReader &reader,
1701 RegionReadState &readState,
1702 bool &isIsolatedFromAbove);
1703
1704 LogicalResult parseRegion(RegionReadState &readState);
1705 LogicalResult parseBlockHeader(EncodingReader &reader,
1706 RegionReadState &readState);
1707 LogicalResult parseBlockArguments(EncodingReader &reader, Block *block);
1708
1709 //===--------------------------------------------------------------------===//
1710 // Value Processing
1711
1712 /// Parse an operand reference using the given reader. Returns nullptr in the
1713 /// case of failure.
1714 Value parseOperand(EncodingReader &reader);
1715
1716 /// Sequentially define the given value range.
1717 LogicalResult defineValues(EncodingReader &reader, ValueRange values);
1718
1719 /// Create a value to use for a forward reference.
1720 Value createForwardRef();
1721
1722 //===--------------------------------------------------------------------===//
1723 // Use-list order helpers
1724
1725 /// This struct is a simple storage that contains information required to
1726 /// reorder the use-list of a value with respect to the pre-order traversal
1727 /// ordering.
1728 struct UseListOrderStorage {
1729 UseListOrderStorage(bool isIndexPairEncoding,
1730 SmallVector<unsigned, 4> &&indices)
1731 : indices(std::move(indices)),
1732 isIndexPairEncoding(isIndexPairEncoding) {};
1733 /// The vector containing the information required to reorder the
1734 /// use-list of a value.
1735 SmallVector<unsigned, 4> indices;
1736
1737 /// Whether indices represent a pair of type `(src, dst)` or it is a direct
1738 /// indexing, such as `dst = order[src]`.
1739 bool isIndexPairEncoding;
1740 };
1741
1742 /// Parse use-list order from bytecode for a range of values if available. The
1743 /// range is expected to be either a block argument or an op result range. On
1744 /// success, return a map of the position in the range and the use-list order
1745 /// encoding. The function assumes to know the size of the range it is
1746 /// processing.
1747 using UseListMapT = DenseMap<unsigned, UseListOrderStorage>;
1748 FailureOr<UseListMapT> parseUseListOrderForRange(EncodingReader &reader,
1749 uint64_t rangeSize);
1750
1751 /// Shuffle the use-chain according to the order parsed.
1752 LogicalResult sortUseListOrder(Value value);
1753
1754 /// Recursively visit all the values defined within topLevelOp and sort the
1755 /// use-list orders according to the indices parsed.
1756 LogicalResult processUseLists(Operation *topLevelOp);
1757
1758 //===--------------------------------------------------------------------===//
1759 // Fields
1760
1761 /// This class represents a single value scope, in which a value scope is
1762 /// delimited by isolated from above regions.
1763 struct ValueScope {
1764 /// Push a new region state onto this scope, reserving enough values for
1765 /// those defined within the current region of the provided state.
1766 void push(RegionReadState &readState) {
1767 nextValueIDs.push_back(values.size());
1768 values.resize(values.size() + readState.numValues);
1769 }
1770
1771 /// Pop the values defined for the current region within the provided region
1772 /// state.
1773 void pop(RegionReadState &readState) {
1774 values.resize(values.size() - readState.numValues);
1775 nextValueIDs.pop_back();
1776 }
1777
1778 /// The set of values defined in this scope.
1779 std::vector<Value> values;
1780
1781 /// The ID for the next defined value for each region current being
1782 /// processed in this scope.
1783 SmallVector<unsigned, 4> nextValueIDs;
1784 };
1785
1786 /// The configuration of the parser.
1787 const ParserConfig &config;
1788
1789 /// A location to use when emitting errors.
1790 Location fileLoc;
1791
1792 /// Flag that indicates if lazyloading is enabled.
1793 bool lazyLoading;
1794
1795 /// Keep track of operations that have been lazy loaded (their regions haven't
1796 /// been materialized), along with the `RegionReadState` that allows to
1797 /// lazy-load the regions nested under the operation.
1798 LazyLoadableOpsInfo lazyLoadableOps;
1799 LazyLoadableOpsMap lazyLoadableOpsMap;
1800 llvm::function_ref<bool(Operation *)> lazyOpsCallback;
1801
1802 /// The reader used to process attribute and types within the bytecode.
1803 AttrTypeReader attrTypeReader;
1804
1805 /// The version of the bytecode being read.
1806 uint64_t version = 0;
1807
1808 /// The producer of the bytecode being read.
1809 StringRef producer;
1810
1811 /// The table of IR units referenced within the bytecode file.
1812 SmallVector<std::unique_ptr<BytecodeDialect>> dialects;
1813 llvm::StringMap<BytecodeDialect *> dialectsMap;
1814 SmallVector<BytecodeOperationName> opNames;
1815
1816 /// The reader used to process resources within the bytecode.
1817 ResourceSectionReader resourceReader;
1818
1819 /// Worklist of values with custom use-list orders to process before the end
1820 /// of the parsing.
1821 DenseMap<void *, UseListOrderStorage> valueToUseListMap;
1822
1823 /// The table of strings referenced within the bytecode file.
1824 StringSectionReader stringReader;
1825
1826 /// The table of properties referenced by the operation in the bytecode file.
1827 PropertiesSectionReader propertiesReader;
1828
1829 /// The current set of available IR value scopes.
1830 std::vector<ValueScope> valueScopes;
1831
1832 /// The global pre-order operation ordering.
1834
1835 /// A block containing the set of operations defined to create forward
1836 /// references.
1837 Block forwardRefOps;
1838
1839 /// A block containing previously created, and no longer used, forward
1840 /// reference operations.
1841 Block openForwardRefOps;
1842
1843 /// An operation state used when instantiating forward references.
1844 OperationState forwardRefOpState;
1845
1846 /// Reference to the input buffer.
1847 llvm::MemoryBufferRef buffer;
1848
1849 /// The optional owning source manager, which when present may be used to
1850 /// extend the lifetime of the input buffer.
1851 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef;
1852};
1853
1855 Block *block, llvm::function_ref<bool(Operation *)> lazyOpsCallback) {
1856 EncodingReader reader(buffer.getBuffer(), fileLoc);
1857 this->lazyOpsCallback = lazyOpsCallback;
1858 auto resetlazyOpsCallback =
1859 llvm::make_scope_exit([&] { this->lazyOpsCallback = nullptr; });
1860
1861 // Skip over the bytecode header, this should have already been checked.
1862 if (failed(reader.skipBytes(StringRef("ML\xefR").size())))
1863 return failure();
1864 // Parse the bytecode version and producer.
1865 if (failed(parseVersion(reader)) ||
1866 failed(reader.parseNullTerminatedString(producer)))
1867 return failure();
1868
1869 // Add a diagnostic handler that attaches a note that includes the original
1870 // producer of the bytecode.
1871 ScopedDiagnosticHandler diagHandler(getContext(), [&](Diagnostic &diag) {
1872 diag.attachNote() << "in bytecode version " << version
1873 << " produced by: " << producer;
1874 return failure();
1875 });
1876
1877 const auto checkSectionAlignment = [&](unsigned alignment) {
1878 return this->checkSectionAlignment(
1879 alignment, [&](const auto &msg) { return reader.emitError(msg); });
1880 };
1881
1882 // Parse the raw data for each of the top-level sections of the bytecode.
1883 std::optional<ArrayRef<uint8_t>>
1884 sectionDatas[bytecode::Section::kNumSections];
1885 while (!reader.empty()) {
1886 // Read the next section from the bytecode.
1887 bytecode::Section::ID sectionID;
1888 ArrayRef<uint8_t> sectionData;
1889 if (failed(
1890 reader.parseSection(sectionID, checkSectionAlignment, sectionData)))
1891 return failure();
1892
1893 // Check for duplicate sections, we only expect one instance of each.
1894 if (sectionDatas[sectionID]) {
1895 return reader.emitError("duplicate top-level section: ",
1896 ::toString(sectionID));
1897 }
1898 sectionDatas[sectionID] = sectionData;
1899 }
1900 // Check that all of the required sections were found.
1901 for (int i = 0; i < bytecode::Section::kNumSections; ++i) {
1902 bytecode::Section::ID sectionID = static_cast<bytecode::Section::ID>(i);
1903 if (!sectionDatas[i] && !isSectionOptional(sectionID, version)) {
1904 return reader.emitError("missing data for top-level section: ",
1905 ::toString(sectionID));
1906 }
1907 }
1908
1909 // Process the string section first.
1910 if (failed(stringReader.initialize(
1911 fileLoc, *sectionDatas[bytecode::Section::kString])))
1912 return failure();
1913
1914 // Process the properties section.
1915 if (sectionDatas[bytecode::Section::kProperties] &&
1916 failed(propertiesReader.initialize(
1917 fileLoc, *sectionDatas[bytecode::Section::kProperties])))
1918 return failure();
1919
1920 // Process the dialect section.
1921 if (failed(parseDialectSection(*sectionDatas[bytecode::Section::kDialect])))
1922 return failure();
1923
1924 // Process the resource section if present.
1925 if (failed(parseResourceSection(
1926 reader, sectionDatas[bytecode::Section::kResource],
1927 sectionDatas[bytecode::Section::kResourceOffset])))
1928 return failure();
1929
1930 // Process the attribute and type section.
1931 if (failed(attrTypeReader.initialize(
1932 dialects, *sectionDatas[bytecode::Section::kAttrType],
1933 *sectionDatas[bytecode::Section::kAttrTypeOffset])))
1934 return failure();
1935
1936 // Finally, process the IR section.
1937 return parseIRSection(*sectionDatas[bytecode::Section::kIR], block);
1938}
1939
1940LogicalResult BytecodeReader::Impl::parseVersion(EncodingReader &reader) {
1941 if (failed(reader.parseVarInt(version)))
1942 return failure();
1943
1944 // Validate the bytecode version.
1945 uint64_t currentVersion = bytecode::kVersion;
1946 uint64_t minSupportedVersion = bytecode::kMinSupportedVersion;
1947 if (version < minSupportedVersion) {
1948 return reader.emitError("bytecode version ", version,
1949 " is older than the current version of ",
1950 currentVersion, ", and upgrade is not supported");
1951 }
1952 if (version > currentVersion) {
1953 return reader.emitError("bytecode version ", version,
1954 " is newer than the current version ",
1955 currentVersion);
1956 }
1957 // Override any request to lazy-load if the bytecode version is too old.
1958 if (version < bytecode::kLazyLoading)
1959 lazyLoading = false;
1960 return success();
1961}
1962
1963//===----------------------------------------------------------------------===//
1964// Dialect Section
1965//===----------------------------------------------------------------------===//
1966
1967LogicalResult BytecodeDialect::load(const DialectReader &reader,
1968 MLIRContext *ctx) {
1969 if (dialect)
1970 return success();
1971 Dialect *loadedDialect = ctx->getOrLoadDialect(name);
1972 if (!loadedDialect && !ctx->allowsUnregisteredDialects()) {
1973 return reader.emitError("dialect '")
1974 << name
1975 << "' is unknown. If this is intended, please call "
1976 "allowUnregisteredDialects() on the MLIRContext, or use "
1977 "-allow-unregistered-dialect with the MLIR tool used.";
1978 }
1979 dialect = loadedDialect;
1980
1981 // If the dialect was actually loaded, check to see if it has a bytecode
1982 // interface.
1983 if (loadedDialect)
1984 interface = dyn_cast<BytecodeDialectInterface>(loadedDialect);
1985 if (!versionBuffer.empty()) {
1986 if (!interface)
1987 return reader.emitError("dialect '")
1988 << name
1989 << "' does not implement the bytecode interface, "
1990 "but found a version entry";
1991 EncodingReader encReader(versionBuffer, reader.getLoc());
1992 DialectReader versionReader = reader.withEncodingReader(encReader);
1993 loadedVersion = interface->readVersion(versionReader);
1994 if (!loadedVersion)
1995 return failure();
1996 }
1997 return success();
1998}
1999
2000LogicalResult
2001BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) {
2002 EncodingReader sectionReader(sectionData, fileLoc);
2003
2004 // Parse the number of dialects in the section.
2005 uint64_t numDialects;
2006 if (failed(sectionReader.parseVarInt(numDialects)))
2007 return failure();
2008 dialects.resize(numDialects);
2009
2010 const auto checkSectionAlignment = [&](unsigned alignment) {
2011 return this->checkSectionAlignment(alignment, [&](const auto &msg) {
2012 return sectionReader.emitError(msg);
2013 });
2014 };
2015
2016 // Parse each of the dialects.
2017 for (uint64_t i = 0; i < numDialects; ++i) {
2018 dialects[i] = std::make_unique<BytecodeDialect>();
2019 /// Before version kDialectVersioning, there wasn't any versioning available
2020 /// for dialects, and the entryIdx represent the string itself.
2021 if (version < bytecode::kDialectVersioning) {
2022 if (failed(stringReader.parseString(sectionReader, dialects[i]->name)))
2023 return failure();
2024 continue;
2025 }
2026
2027 // Parse ID representing dialect and version.
2028 uint64_t dialectNameIdx;
2029 bool versionAvailable;
2030 if (failed(sectionReader.parseVarIntWithFlag(dialectNameIdx,
2031 versionAvailable)))
2032 return failure();
2033 if (failed(stringReader.parseStringAtIndex(sectionReader, dialectNameIdx,
2034 dialects[i]->name)))
2035 return failure();
2036 if (versionAvailable) {
2037 bytecode::Section::ID sectionID;
2038 if (failed(sectionReader.parseSection(sectionID, checkSectionAlignment,
2039 dialects[i]->versionBuffer)))
2040 return failure();
2041 if (sectionID != bytecode::Section::kDialectVersions) {
2042 emitError(fileLoc, "expected dialect version section");
2043 return failure();
2044 }
2045 }
2046 dialectsMap[dialects[i]->name] = dialects[i].get();
2047 }
2048
2049 // Parse the operation names, which are grouped by dialect.
2050 auto parseOpName = [&](BytecodeDialect *dialect) {
2051 StringRef opName;
2052 std::optional<bool> wasRegistered;
2053 // Prior to version kNativePropertiesEncoding, the information about wheter
2054 // an op was registered or not wasn't encoded.
2056 if (failed(stringReader.parseString(sectionReader, opName)))
2057 return failure();
2058 } else {
2059 bool wasRegisteredFlag;
2060 if (failed(stringReader.parseStringWithFlag(sectionReader, opName,
2061 wasRegisteredFlag)))
2062 return failure();
2063 wasRegistered = wasRegisteredFlag;
2064 }
2065 opNames.emplace_back(dialect, opName, wasRegistered);
2066 return success();
2067 };
2068 // Avoid re-allocation in bytecode version >=kElideUnknownBlockArgLocation
2069 // where the number of ops are known.
2071 uint64_t numOps;
2072 if (failed(sectionReader.parseVarInt(numOps)))
2073 return failure();
2074 opNames.reserve(numOps);
2075 }
2076 while (!sectionReader.empty())
2077 if (failed(parseDialectGrouping(sectionReader, dialects, parseOpName)))
2078 return failure();
2079 return success();
2080}
2081
2082FailureOr<OperationName>
2083BytecodeReader::Impl::parseOpName(EncodingReader &reader,
2084 std::optional<bool> &wasRegistered) {
2085 BytecodeOperationName *opName = nullptr;
2086 if (failed(parseEntry(reader, opNames, opName, "operation name")))
2087 return failure();
2088 wasRegistered = opName->wasRegistered;
2089 // Check to see if this operation name has already been resolved. If we
2090 // haven't, load the dialect and build the operation name.
2091 if (!opName->opName) {
2092 // If the opName is empty, this is because we use to accept names such as
2093 // `foo` without any `.` separator. We shouldn't tolerate this in textual
2094 // format anymore but for now we'll be backward compatible. This can only
2095 // happen with unregistered dialects.
2096 if (opName->name.empty()) {
2097 opName->opName.emplace(opName->dialect->name, getContext());
2098 } else {
2099 // Load the dialect and its version.
2100 DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
2101 dialectsMap, reader, version);
2102 if (failed(opName->dialect->load(dialectReader, getContext())))
2103 return failure();
2104 opName->opName.emplace((opName->dialect->name + "." + opName->name).str(),
2105 getContext());
2106 }
2107 }
2108 return *opName->opName;
2109}
2110
2111//===----------------------------------------------------------------------===//
2112// Resource Section
2113//===----------------------------------------------------------------------===//
2114
2115LogicalResult BytecodeReader::Impl::parseResourceSection(
2116 EncodingReader &reader, std::optional<ArrayRef<uint8_t>> resourceData,
2117 std::optional<ArrayRef<uint8_t>> resourceOffsetData) {
2118 // Ensure both sections are either present or not.
2119 if (resourceData.has_value() != resourceOffsetData.has_value()) {
2120 if (resourceOffsetData)
2121 return emitError(fileLoc, "unexpected resource offset section when "
2122 "resource section is not present");
2123 return emitError(
2124 fileLoc,
2125 "expected resource offset section when resource section is present");
2126 }
2127
2128 // If the resource sections are absent, there is nothing to do.
2129 if (!resourceData)
2130 return success();
2131
2132 // Initialize the resource reader with the resource sections.
2133 DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
2134 dialectsMap, reader, version);
2135 return resourceReader.initialize(fileLoc, config, dialects, stringReader,
2136 *resourceData, *resourceOffsetData,
2137 dialectReader, bufferOwnerRef);
2138}
2139
2140//===----------------------------------------------------------------------===//
2141// UseListOrder Helpers
2142//===----------------------------------------------------------------------===//
2143
2144FailureOr<BytecodeReader::Impl::UseListMapT>
2145BytecodeReader::Impl::parseUseListOrderForRange(EncodingReader &reader,
2146 uint64_t numResults) {
2147 BytecodeReader::Impl::UseListMapT map;
2148 uint64_t numValuesToRead = 1;
2149 if (numResults > 1 && failed(reader.parseVarInt(numValuesToRead)))
2150 return failure();
2151
2152 for (size_t valueIdx = 0; valueIdx < numValuesToRead; valueIdx++) {
2153 uint64_t resultIdx = 0;
2154 if (numResults > 1 && failed(reader.parseVarInt(resultIdx)))
2155 return failure();
2156
2157 uint64_t numValues;
2158 bool indexPairEncoding;
2159 if (failed(reader.parseVarIntWithFlag(numValues, indexPairEncoding)))
2160 return failure();
2161
2162 SmallVector<unsigned, 4> useListOrders;
2163 for (size_t idx = 0; idx < numValues; idx++) {
2164 uint64_t index;
2165 if (failed(reader.parseVarInt(index)))
2166 return failure();
2167 useListOrders.push_back(index);
2168 }
2169
2170 // Store in a map the result index
2171 map.try_emplace(resultIdx, UseListOrderStorage(indexPairEncoding,
2172 std::move(useListOrders)));
2173 }
2174
2175 return map;
2176}
2177
2178/// Sorts each use according to the order specified in the use-list parsed. If
2179/// the custom use-list is not found, this means that the order needs to be
2180/// consistent with the reverse pre-order walk of the IR. If multiple uses lie
2181/// on the same operation, the order will follow the reverse operand number
2182/// ordering.
2183LogicalResult BytecodeReader::Impl::sortUseListOrder(Value value) {
2184 // Early return for trivial use-lists.
2185 if (value.use_empty() || value.hasOneUse())
2186 return success();
2187
2188 bool hasIncomingOrder =
2189 valueToUseListMap.contains(value.getAsOpaquePointer());
2190
2191 // Compute the current order of the use-list with respect to the global
2192 // ordering. Detect if the order is already sorted while doing so.
2193 bool alreadySorted = true;
2194 auto &firstUse = *value.use_begin();
2195 uint64_t prevID =
2196 bytecode::getUseID(firstUse, operationIDs.at(firstUse.getOwner()));
2197 llvm::SmallVector<std::pair<unsigned, uint64_t>> currentOrder = {{0, prevID}};
2198 for (auto item : llvm::drop_begin(llvm::enumerate(value.getUses()))) {
2199 uint64_t currentID = bytecode::getUseID(
2200 item.value(), operationIDs.at(item.value().getOwner()));
2201 alreadySorted &= prevID > currentID;
2202 currentOrder.push_back({item.index(), currentID});
2203 prevID = currentID;
2204 }
2205
2206 // If the order is already sorted, and there wasn't a custom order to apply
2207 // from the bytecode file, we are done.
2208 if (alreadySorted && !hasIncomingOrder)
2209 return success();
2210
2211 // If not already sorted, sort the indices of the current order by descending
2212 // useIDs.
2213 if (!alreadySorted)
2214 std::sort(
2215 currentOrder.begin(), currentOrder.end(),
2216 [](auto elem1, auto elem2) { return elem1.second > elem2.second; });
2217
2218 if (!hasIncomingOrder) {
2219 // If the bytecode file did not contain any custom use-list order, it means
2220 // that the order was descending useID. Hence, shuffle by the first index
2221 // of the `currentOrder` pair.
2222 SmallVector<unsigned> shuffle(llvm::make_first_range(currentOrder));
2223 value.shuffleUseList(shuffle);
2224 return success();
2225 }
2226
2227 // Pull the custom order info from the map.
2228 UseListOrderStorage customOrder =
2229 valueToUseListMap.at(value.getAsOpaquePointer());
2230 SmallVector<unsigned, 4> shuffle = std::move(customOrder.indices);
2231 uint64_t numUses = value.getNumUses();
2232
2233 // If the encoding was a pair of indices `(src, dst)` for every permutation,
2234 // reconstruct the shuffle vector for every use. Initialize the shuffle vector
2235 // as identity, and then apply the mapping encoded in the indices.
2236 if (customOrder.isIndexPairEncoding) {
2237 // Return failure if the number of indices was not representing pairs.
2238 if (shuffle.size() & 1)
2239 return failure();
2240
2241 SmallVector<unsigned, 4> newShuffle(numUses);
2242 size_t idx = 0;
2243 std::iota(newShuffle.begin(), newShuffle.end(), idx);
2244 for (idx = 0; idx < shuffle.size(); idx += 2)
2245 newShuffle[shuffle[idx]] = shuffle[idx + 1];
2246
2247 shuffle = std::move(newShuffle);
2248 }
2249
2250 // Make sure that the indices represent a valid mapping. That is, the sum of
2251 // all the values needs to be equal to (numUses - 1) * numUses / 2, and no
2252 // duplicates are allowed in the list.
2254 uint64_t accumulator = 0;
2255 for (const auto &elem : shuffle) {
2256 if (!set.insert(elem).second)
2257 return failure();
2258 accumulator += elem;
2259 }
2260 if (numUses != shuffle.size() ||
2261 accumulator != (((numUses - 1) * numUses) >> 1))
2262 return failure();
2263
2264 // Apply the current ordering map onto the shuffle vector to get the final
2265 // use-list sorting indices before shuffling.
2266 shuffle = SmallVector<unsigned, 4>(llvm::map_range(
2267 currentOrder, [&](auto item) { return shuffle[item.first]; }));
2268 value.shuffleUseList(shuffle);
2269 return success();
2270}
2271
2272LogicalResult BytecodeReader::Impl::processUseLists(Operation *topLevelOp) {
2273 // Precompute operation IDs according to the pre-order walk of the IR. We
2274 // can't do this while parsing since parseRegions ordering is not strictly
2275 // equal to the pre-order walk.
2276 unsigned operationID = 0;
2277 topLevelOp->walk<mlir::WalkOrder::PreOrder>(
2278 [&](Operation *op) { operationIDs.try_emplace(op, operationID++); });
2279
2280 auto blockWalk = topLevelOp->walk([this](Block *block) {
2281 for (auto arg : block->getArguments())
2282 if (failed(sortUseListOrder(arg)))
2283 return WalkResult::interrupt();
2284 return WalkResult::advance();
2285 });
2286
2287 auto resultWalk = topLevelOp->walk([this](Operation *op) {
2288 for (auto result : op->getResults())
2289 if (failed(sortUseListOrder(result)))
2290 return WalkResult::interrupt();
2291 return WalkResult::advance();
2292 });
2293
2294 return failure(blockWalk.wasInterrupted() || resultWalk.wasInterrupted());
2295}
2296
2297//===----------------------------------------------------------------------===//
2298// IR Section
2299//===----------------------------------------------------------------------===//
2300
2301LogicalResult
2302BytecodeReader::Impl::parseIRSection(ArrayRef<uint8_t> sectionData,
2303 Block *block) {
2304 EncodingReader reader(sectionData, fileLoc);
2305
2306 // A stack of operation regions currently being read from the bytecode.
2307 std::vector<RegionReadState> regionStack;
2308
2309 // Parse the top-level block using a temporary module operation.
2310 OwningOpRef<ModuleOp> moduleOp = ModuleOp::create(fileLoc);
2311 regionStack.emplace_back(*moduleOp, &reader, /*isIsolatedFromAbove=*/true);
2312 regionStack.back().curBlocks.push_back(moduleOp->getBody());
2313 regionStack.back().curBlock = regionStack.back().curRegion->begin();
2314 if (failed(parseBlockHeader(reader, regionStack.back())))
2315 return failure();
2316 valueScopes.emplace_back();
2317 valueScopes.back().push(regionStack.back());
2318
2319 // Iteratively parse regions until everything has been resolved.
2320 while (!regionStack.empty())
2321 if (failed(parseRegions(regionStack, regionStack.back())))
2322 return failure();
2323 if (!forwardRefOps.empty()) {
2324 return reader.emitError(
2325 "not all forward unresolved forward operand references");
2326 }
2327
2328 // Sort use-lists according to what specified in bytecode.
2329 if (failed(processUseLists(*moduleOp)))
2330 return reader.emitError(
2331 "parsed use-list orders were invalid and could not be applied");
2332
2333 // Resolve dialect version.
2334 for (const std::unique_ptr<BytecodeDialect> &byteCodeDialect : dialects) {
2335 // Parsing is complete, give an opportunity to each dialect to visit the
2336 // IR and perform upgrades.
2337 if (!byteCodeDialect->loadedVersion)
2338 continue;
2339 if (byteCodeDialect->interface &&
2340 failed(byteCodeDialect->interface->upgradeFromVersion(
2341 *moduleOp, *byteCodeDialect->loadedVersion)))
2342 return failure();
2343 }
2344
2345 // Verify that the parsed operations are valid.
2346 if (config.shouldVerifyAfterParse() && failed(verify(*moduleOp)))
2347 return failure();
2348
2349 // Splice the parsed operations over to the provided top-level block.
2350 auto &parsedOps = moduleOp->getBody()->getOperations();
2351 auto &destOps = block->getOperations();
2352 destOps.splice(destOps.end(), parsedOps, parsedOps.begin(), parsedOps.end());
2353 return success();
2354}
2355
2356LogicalResult
2357BytecodeReader::Impl::parseRegions(std::vector<RegionReadState> &regionStack,
2358 RegionReadState &readState) {
2359 const auto checkSectionAlignment = [&](unsigned alignment) {
2360 return this->checkSectionAlignment(
2361 alignment, [&](const auto &msg) { return emitError(fileLoc, msg); });
2362 };
2363
2364 // Process regions, blocks, and operations until the end or if a nested
2365 // region is encountered. In this case we push a new state in regionStack and
2366 // return, the processing of the current region will resume afterward.
2367 for (; readState.curRegion != readState.endRegion; ++readState.curRegion) {
2368 // If the current block hasn't been setup yet, parse the header for this
2369 // region. The current block is already setup when this function was
2370 // interrupted to recurse down in a nested region and we resume the current
2371 // block after processing the nested region.
2372 if (readState.curBlock == Region::iterator()) {
2373 if (failed(parseRegion(readState)))
2374 return failure();
2375
2376 // If the region is empty, there is nothing to more to do.
2377 if (readState.curRegion->empty())
2378 continue;
2379 }
2380
2381 // Parse the blocks within the region.
2382 EncodingReader &reader = *readState.reader;
2383 do {
2384 while (readState.numOpsRemaining--) {
2385 // Read in the next operation. We don't read its regions directly, we
2386 // handle those afterwards as necessary.
2387 bool isIsolatedFromAbove = false;
2388 FailureOr<Operation *> op =
2389 parseOpWithoutRegions(reader, readState, isIsolatedFromAbove);
2390 if (failed(op))
2391 return failure();
2392
2393 // If the op has regions, add it to the stack for processing and return:
2394 // we stop the processing of the current region and resume it after the
2395 // inner one is completed. Unless LazyLoading is activated in which case
2396 // nested region parsing is delayed.
2397 if ((*op)->getNumRegions()) {
2398 RegionReadState childState(*op, &reader, isIsolatedFromAbove);
2399
2400 // Isolated regions are encoded as a section in version 2 and above.
2401 if (version >= bytecode::kLazyLoading && isIsolatedFromAbove) {
2402 bytecode::Section::ID sectionID;
2403 ArrayRef<uint8_t> sectionData;
2404 if (failed(reader.parseSection(sectionID, checkSectionAlignment,
2405 sectionData)))
2406 return failure();
2407 if (sectionID != bytecode::Section::kIR)
2408 return emitError(fileLoc, "expected IR section for region");
2409 childState.owningReader =
2410 std::make_unique<EncodingReader>(sectionData, fileLoc);
2411 childState.reader = childState.owningReader.get();
2412
2413 // If the user has a callback set, they have the opportunity to
2414 // control lazyloading as we go.
2415 if (lazyLoading && (!lazyOpsCallback || !lazyOpsCallback(*op))) {
2416 lazyLoadableOps.emplace_back(*op, std::move(childState));
2417 lazyLoadableOpsMap.try_emplace(*op,
2418 std::prev(lazyLoadableOps.end()));
2419 continue;
2420 }
2421 }
2422 regionStack.push_back(std::move(childState));
2423
2424 // If the op is isolated from above, push a new value scope.
2425 if (isIsolatedFromAbove)
2426 valueScopes.emplace_back();
2427 return success();
2428 }
2429 }
2430
2431 // Move to the next block of the region.
2432 if (++readState.curBlock == readState.curRegion->end())
2433 break;
2434 if (failed(parseBlockHeader(reader, readState)))
2435 return failure();
2436 } while (true);
2437
2438 // Reset the current block and any values reserved for this region.
2439 readState.curBlock = {};
2440 valueScopes.back().pop(readState);
2441 }
2442
2443 // When the regions have been fully parsed, pop them off of the read stack. If
2444 // the regions were isolated from above, we also pop the last value scope.
2445 if (readState.isIsolatedFromAbove) {
2446 assert(!valueScopes.empty() && "Expect a valueScope after reading region");
2447 valueScopes.pop_back();
2448 }
2449 assert(!regionStack.empty() && "Expect a regionStack after reading region");
2450 regionStack.pop_back();
2451 return success();
2452}
2453
2454FailureOr<Operation *>
2455BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader,
2456 RegionReadState &readState,
2457 bool &isIsolatedFromAbove) {
2458 // Parse the name of the operation.
2459 std::optional<bool> wasRegistered;
2460 FailureOr<OperationName> opName = parseOpName(reader, wasRegistered);
2461 if (failed(opName))
2462 return failure();
2463
2464 // Parse the operation mask, which indicates which components of the operation
2465 // are present.
2466 uint8_t opMask;
2467 if (failed(reader.parseByte(opMask)))
2468 return failure();
2469
2470 /// Parse the location.
2471 LocationAttr opLoc;
2472 if (failed(parseAttribute(reader, opLoc)))
2473 return failure();
2474
2475 // With the location and name resolved, we can start building the operation
2476 // state.
2477 OperationState opState(opLoc, *opName);
2478
2479 // Parse the attributes of the operation.
2481 DictionaryAttr dictAttr;
2482 if (failed(parseAttribute(reader, dictAttr)))
2483 return failure();
2484 opState.attributes = dictAttr;
2485 }
2486
2488 // kHasProperties wasn't emitted in older bytecode, we should never get
2489 // there without also having the `wasRegistered` flag available.
2490 if (!wasRegistered)
2491 return emitError(fileLoc,
2492 "Unexpected missing `wasRegistered` opname flag at "
2493 "bytecode version ")
2494 << version << " with properties.";
2495 // When an operation is emitted without being registered, the properties are
2496 // stored as an attribute. Otherwise the op must implement the bytecode
2497 // interface and control the serialization.
2498 if (wasRegistered) {
2499 DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
2500 dialectsMap, reader, version);
2501 if (failed(
2502 propertiesReader.read(fileLoc, dialectReader, &*opName, opState)))
2503 return failure();
2504 } else {
2505 // If the operation wasn't registered when it was emitted, the properties
2506 // was serialized as an attribute.
2507 if (failed(parseAttribute(reader, opState.propertiesAttr)))
2508 return failure();
2509 }
2510 }
2511
2512 /// Parse the results of the operation.
2514 uint64_t numResults;
2515 if (failed(reader.parseVarInt(numResults)))
2516 return failure();
2517 opState.types.resize(numResults);
2518 for (int i = 0, e = numResults; i < e; ++i)
2519 if (failed(parseType(reader, opState.types[i])))
2520 return failure();
2521 }
2522
2523 /// Parse the operands of the operation.
2525 uint64_t numOperands;
2526 if (failed(reader.parseVarInt(numOperands)))
2527 return failure();
2528 opState.operands.resize(numOperands);
2529 for (int i = 0, e = numOperands; i < e; ++i)
2530 if (!(opState.operands[i] = parseOperand(reader)))
2531 return failure();
2532 }
2533
2534 /// Parse the successors of the operation.
2536 uint64_t numSuccs;
2537 if (failed(reader.parseVarInt(numSuccs)))
2538 return failure();
2539 opState.successors.resize(numSuccs);
2540 for (int i = 0, e = numSuccs; i < e; ++i) {
2541 if (failed(parseEntry(reader, readState.curBlocks, opState.successors[i],
2542 "successor")))
2543 return failure();
2544 }
2545 }
2546
2547 /// Parse the use-list orders for the results of the operation. Use-list
2548 /// orders are available since version 3 of the bytecode.
2549 std::optional<UseListMapT> resultIdxToUseListMap = std::nullopt;
2550 if (version >= bytecode::kUseListOrdering &&
2552 size_t numResults = opState.types.size();
2553 auto parseResult = parseUseListOrderForRange(reader, numResults);
2554 if (failed(parseResult))
2555 return failure();
2556 resultIdxToUseListMap = std::move(*parseResult);
2557 }
2558
2559 /// Parse the regions of the operation.
2561 uint64_t numRegions;
2562 if (failed(reader.parseVarIntWithFlag(numRegions, isIsolatedFromAbove)))
2563 return failure();
2564
2565 opState.regions.reserve(numRegions);
2566 for (int i = 0, e = numRegions; i < e; ++i)
2567 opState.regions.push_back(std::make_unique<Region>());
2568 }
2569
2570 // Create the operation at the back of the current block.
2571 Operation *op = Operation::create(opState);
2572 readState.curBlock->push_back(op);
2573
2574 // If the operation had results, update the value references. We don't need to
2575 // do this if the current value scope is empty. That is, the op was not
2576 // encoded within a parent region.
2577 if (readState.numValues && op->getNumResults() &&
2578 failed(defineValues(reader, op->getResults())))
2579 return failure();
2580
2581 /// Store a map for every value that received a custom use-list order from the
2582 /// bytecode file.
2583 if (resultIdxToUseListMap.has_value()) {
2584 for (size_t idx = 0; idx < op->getNumResults(); idx++) {
2585 if (resultIdxToUseListMap->contains(idx)) {
2586 valueToUseListMap.try_emplace(op->getResult(idx).getAsOpaquePointer(),
2587 resultIdxToUseListMap->at(idx));
2588 }
2589 }
2590 }
2591 return op;
2592}
2593
2594LogicalResult BytecodeReader::Impl::parseRegion(RegionReadState &readState) {
2595 EncodingReader &reader = *readState.reader;
2596
2597 // Parse the number of blocks in the region.
2598 uint64_t numBlocks;
2599 if (failed(reader.parseVarInt(numBlocks)))
2600 return failure();
2601
2602 // If the region is empty, there is nothing else to do.
2603 if (numBlocks == 0)
2604 return success();
2605
2606 // Parse the number of values defined in this region.
2607 uint64_t numValues;
2608 if (failed(reader.parseVarInt(numValues)))
2609 return failure();
2610 readState.numValues = numValues;
2611
2612 // Create the blocks within this region. We do this before processing so that
2613 // we can rely on the blocks existing when creating operations.
2614 readState.curBlocks.clear();
2615 readState.curBlocks.reserve(numBlocks);
2616 for (uint64_t i = 0; i < numBlocks; ++i) {
2617 readState.curBlocks.push_back(new Block());
2618 readState.curRegion->push_back(readState.curBlocks.back());
2619 }
2620
2621 // Prepare the current value scope for this region.
2622 valueScopes.back().push(readState);
2623
2624 // Parse the entry block of the region.
2625 readState.curBlock = readState.curRegion->begin();
2626 return parseBlockHeader(reader, readState);
2627}
2628
2629LogicalResult
2630BytecodeReader::Impl::parseBlockHeader(EncodingReader &reader,
2631 RegionReadState &readState) {
2632 bool hasArgs;
2633 if (failed(reader.parseVarIntWithFlag(readState.numOpsRemaining, hasArgs)))
2634 return failure();
2635
2636 // Parse the arguments of the block.
2637 if (hasArgs && failed(parseBlockArguments(reader, &*readState.curBlock)))
2638 return failure();
2639
2640 // Uselist orders are available since version 3 of the bytecode.
2641 if (version < bytecode::kUseListOrdering)
2642 return success();
2643
2644 uint8_t hasUseListOrders = 0;
2645 if (hasArgs && failed(reader.parseByte(hasUseListOrders)))
2646 return failure();
2647
2648 if (!hasUseListOrders)
2649 return success();
2650
2651 Block &blk = *readState.curBlock;
2652 auto argIdxToUseListMap =
2653 parseUseListOrderForRange(reader, blk.getNumArguments());
2654 if (failed(argIdxToUseListMap) || argIdxToUseListMap->empty())
2655 return failure();
2656
2657 for (size_t idx = 0; idx < blk.getNumArguments(); idx++)
2658 if (argIdxToUseListMap->contains(idx))
2659 valueToUseListMap.try_emplace(blk.getArgument(idx).getAsOpaquePointer(),
2660 argIdxToUseListMap->at(idx));
2661
2662 // We don't parse the operations of the block here, that's done elsewhere.
2663 return success();
2664}
2665
2666LogicalResult BytecodeReader::Impl::parseBlockArguments(EncodingReader &reader,
2667 Block *block) {
2668 // Parse the value ID for the first argument, and the number of arguments.
2669 uint64_t numArgs;
2670 if (failed(reader.parseVarInt(numArgs)))
2671 return failure();
2672
2673 SmallVector<Type> argTypes;
2674 SmallVector<Location> argLocs;
2675 argTypes.reserve(numArgs);
2676 argLocs.reserve(numArgs);
2677
2678 Location unknownLoc = UnknownLoc::get(config.getContext());
2679 while (numArgs--) {
2680 Type argType;
2681 LocationAttr argLoc = unknownLoc;
2683 // Parse the type with hasLoc flag to determine if it has type.
2684 uint64_t typeIdx;
2685 bool hasLoc;
2686 if (failed(reader.parseVarIntWithFlag(typeIdx, hasLoc)) ||
2687 !(argType = attrTypeReader.resolveType(typeIdx)))
2688 return failure();
2689 if (hasLoc && failed(parseAttribute(reader, argLoc)))
2690 return failure();
2691 } else {
2692 // All args has type and location.
2693 if (failed(parseType(reader, argType)) ||
2694 failed(parseAttribute(reader, argLoc)))
2695 return failure();
2696 }
2697 argTypes.push_back(argType);
2698 argLocs.push_back(argLoc);
2699 }
2700 block->addArguments(argTypes, argLocs);
2701 return defineValues(reader, block->getArguments());
2702}
2703
2704//===----------------------------------------------------------------------===//
2705// Value Processing
2706//===----------------------------------------------------------------------===//
2707
2708Value BytecodeReader::Impl::parseOperand(EncodingReader &reader) {
2709 std::vector<Value> &values = valueScopes.back().values;
2710 Value *value = nullptr;
2711 if (failed(parseEntry(reader, values, value, "value")))
2712 return Value();
2713
2714 // Create a new forward reference if necessary.
2715 if (!*value)
2716 *value = createForwardRef();
2717 return *value;
2718}
2719
2720LogicalResult BytecodeReader::Impl::defineValues(EncodingReader &reader,
2721 ValueRange newValues) {
2722 ValueScope &valueScope = valueScopes.back();
2723 std::vector<Value> &values = valueScope.values;
2724
2725 unsigned &valueID = valueScope.nextValueIDs.back();
2726 unsigned valueIDEnd = valueID + newValues.size();
2727 if (valueIDEnd > values.size()) {
2728 return reader.emitError(
2729 "value index range was outside of the expected range for "
2730 "the parent region, got [",
2731 valueID, ", ", valueIDEnd, "), but the maximum index was ",
2732 values.size() - 1);
2733 }
2734
2735 // Assign the values and update any forward references.
2736 for (unsigned i = 0, e = newValues.size(); i != e; ++i, ++valueID) {
2737 Value newValue = newValues[i];
2738
2739 // Check to see if a definition for this value already exists.
2740 if (Value oldValue = std::exchange(values[valueID], newValue)) {
2741 Operation *forwardRefOp = oldValue.getDefiningOp();
2742
2743 // Assert that this is a forward reference operation. Given how we compute
2744 // definition ids (incrementally as we parse), it shouldn't be possible
2745 // for the value to be defined any other way.
2746 assert(forwardRefOp && forwardRefOp->getBlock() == &forwardRefOps &&
2747 "value index was already defined?");
2748
2749 oldValue.replaceAllUsesWith(newValue);
2750 forwardRefOp->moveBefore(&openForwardRefOps, openForwardRefOps.end());
2751 }
2752 }
2753 return success();
2754}
2755
2756Value BytecodeReader::Impl::createForwardRef() {
2757 // Check for an available existing operation to use. Otherwise, create a new
2758 // fake operation to use for the reference.
2759 if (!openForwardRefOps.empty()) {
2760 Operation *op = &openForwardRefOps.back();
2761 op->moveBefore(&forwardRefOps, forwardRefOps.end());
2762 } else {
2763 forwardRefOps.push_back(Operation::create(forwardRefOpState));
2764 }
2765 return forwardRefOps.back().getResult(0);
2766}
2767
2768//===----------------------------------------------------------------------===//
2769// Entry Points
2770//===----------------------------------------------------------------------===//
2771
2773
2775 llvm::MemoryBufferRef buffer, const ParserConfig &config, bool lazyLoading,
2776 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) {
2777 Location sourceFileLoc =
2778 FileLineColLoc::get(config.getContext(), buffer.getBufferIdentifier(),
2779 /*line=*/0, /*column=*/0);
2780 impl = std::make_unique<Impl>(sourceFileLoc, config, lazyLoading, buffer,
2781 bufferOwnerRef);
2782}
2783
2785 Block *block, llvm::function_ref<bool(Operation *)> lazyOpsCallback) {
2786 return impl->read(block, lazyOpsCallback);
2787}
2788
2790 return impl->getNumOpsToMaterialize();
2791}
2792
2794 return impl->isMaterializable(op);
2795}
2796
2798 Operation *op, llvm::function_ref<bool(Operation *)> lazyOpsCallback) {
2799 return impl->materialize(op, lazyOpsCallback);
2800}
2801
2802LogicalResult
2804 return impl->finalize(shouldMaterialize);
2805}
2806
2807bool mlir::isBytecode(llvm::MemoryBufferRef buffer) {
2808 return buffer.getBuffer().starts_with("ML\xefR");
2809}
2810
2811/// Read the bytecode from the provided memory buffer reference.
2812/// `bufferOwnerRef` if provided is the owning source manager for the buffer,
2813/// and may be used to extend the lifetime of the buffer.
2814static LogicalResult
2815readBytecodeFileImpl(llvm::MemoryBufferRef buffer, Block *block,
2816 const ParserConfig &config,
2817 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) {
2818 Location sourceFileLoc =
2819 FileLineColLoc::get(config.getContext(), buffer.getBufferIdentifier(),
2820 /*line=*/0, /*column=*/0);
2821 if (!isBytecode(buffer)) {
2822 return emitError(sourceFileLoc,
2823 "input buffer is not an MLIR bytecode file");
2824 }
2825
2826 BytecodeReader::Impl reader(sourceFileLoc, config, /*lazyLoading=*/false,
2827 buffer, bufferOwnerRef);
2828 return reader.read(block, /*lazyOpsCallback=*/nullptr);
2829}
2830
2831LogicalResult mlir::readBytecodeFile(llvm::MemoryBufferRef buffer, Block *block,
2832 const ParserConfig &config) {
2833 return readBytecodeFileImpl(buffer, block, config, /*bufferOwnerRef=*/{});
2834}
2835LogicalResult
2836mlir::readBytecodeFile(const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
2837 Block *block, const ParserConfig &config) {
2838 return readBytecodeFileImpl(
2839 *sourceMgr->getMemoryBuffer(sourceMgr->getMainFileID()), block, config,
2840 sourceMgr);
2841}
return success()
static LogicalResult parseDialectGrouping(EncodingReader &reader, MutableArrayRef< std::unique_ptr< BytecodeDialect > > dialects, function_ref< LogicalResult(BytecodeDialect *)> entryCallback)
Parse a single dialect group encoded in the byte stream.
static LogicalResult readBytecodeFileImpl(llvm::MemoryBufferRef buffer, Block *block, const ParserConfig &config, const std::shared_ptr< llvm::SourceMgr > &bufferOwnerRef)
Read the bytecode from the provided memory buffer reference.
static bool isSectionOptional(bytecode::Section::ID sectionID, int version)
Returns true if the given top-level section ID is optional.
static LogicalResult parseResourceGroup(Location fileLoc, bool allowEmpty, EncodingReader &offsetReader, EncodingReader &resourceReader, StringSectionReader &stringReader, T *handler, const std::shared_ptr< llvm::SourceMgr > &bufferOwnerRef, function_ref< StringRef(StringRef)> remapKey={}, function_ref< LogicalResult(StringRef)> processKeyFn={})
static LogicalResult resolveEntry(EncodingReader &reader, RangeT &entries, uint64_t index, T &entry, StringRef entryStr)
Resolve an index into the given entry list.
static LogicalResult parseEntry(EncodingReader &reader, RangeT &entries, T &entry, StringRef entryStr)
Parse and resolve an index into the given entry list.
LogicalResult initialize(unsigned origNumLoops, ArrayRef< ReassociationIndices > foldedIterationDims)
b getContext())
auto load
static std::string diag(const llvm::Value &value)
static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, unsigned nRegions=1)
Definition OpenACC.cpp:1144
MutableArrayRef< char > getMutableData()
Return a mutable reference to the raw underlying data of this blob.
Definition AsmState.h:157
ArrayRef< char > getData() const
Return the raw underlying data of this blob.
Definition AsmState.h:145
bool isMutable() const
Return if the data of this blob is mutable.
Definition AsmState.h:164
MLIRContext * getContext() const
Return the context this attribute belongs to.
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:129
unsigned getNumArguments()
Definition Block.h:128
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition Block.cpp:160
OpListType & getOperations()
Definition Block.h:137
BlockArgListType getArguments()
Definition Block.h:87
ArrayRef< std::unique_ptr< AttrTypeBytecodeReader< Type > > > getTypeCallbacks() const
ArrayRef< std::unique_ptr< AttrTypeBytecodeReader< Attribute > > > getAttributeCallbacks() const
Returns the callbacks available to the parser.
This class is used to read a bytecode buffer and translate it into MLIR.
LogicalResult materializeAll()
Materialize all operations.
LogicalResult read(Block *block, llvm::function_ref< bool(Operation *)> lazyOps)
Read the bytecode defined within buffer into the given block.
bool isMaterializable(Operation *op)
Impl(Location fileLoc, const ParserConfig &config, bool lazyLoading, llvm::MemoryBufferRef buffer, const std::shared_ptr< llvm::SourceMgr > &bufferOwnerRef)
LogicalResult finalize(function_ref< bool(Operation *)> shouldMaterialize)
Finalize the lazy-loading by calling back with every op that hasn't been materialized to let the clie...
LogicalResult materialize(Operation *op, llvm::function_ref< bool(Operation *)> lazyOpsCallback)
Materialize the provided operation, invoke the lazyOpsCallback on every newly found lazy operation.
int64_t getNumOpsToMaterialize() const
Return the number of ops that haven't been materialized yet.
LogicalResult materialize(Operation *op, llvm::function_ref< bool(Operation *)> lazyOpsCallback=[](Operation *) { return false;})
Materialize the provide operation.
LogicalResult finalize(function_ref< bool(Operation *)> shouldMaterialize=[](Operation *) { return true;})
Finalize the lazy-loading by calling back with every op that hasn't been materialized to let the clie...
BytecodeReader(llvm::MemoryBufferRef buffer, const ParserConfig &config, bool lazyLoad, const std::shared_ptr< llvm::SourceMgr > &bufferOwnerRef={})
Create a bytecode reader for the given buffer.
int64_t getNumOpsToMaterialize() const
Return the number of ops that haven't been materialized yet.
bool isMaterializable(Operation *op)
Return true if the provided op is materializable.
LogicalResult readTopLevel(Block *block, llvm::function_ref< bool(Operation *)> lazyOps=[](Operation *) { return false;})
Read the operations defined within the given memory buffer, containing MLIR bytecode,...
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
static FileLineColLoc get(StringAttr filename, unsigned line, unsigned column)
Definition Location.cpp:157
This class represents a diagnostic that is inflight and set to be reported.
InFlightDiagnostic & append(Args &&...args) &
Append arguments to the diagnostic.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext * getContext() const
Return the context this location is uniqued in.
Definition Location.h:86
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
T * getOrLoadDialect()
Get (or create) a dialect for the given derived dialect type.
bool allowsUnregisteredDialects()
Return true if we allow to create operation for unregistered dialects.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
bool isRegistered() const
Return if this operation is registered.
T::Concept * getInterface() const
Returns an instance of the concept object for the given interface if it was registered to this operat...
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
void dropAllReferences()
This drops all operand uses from this operation, which is an essential step in breaking cyclic depend...
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition Operation.cpp:67
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:797
result_range getResults()
Definition Operation.h:415
void erase()
Remove this operation from its parent block and delete it.
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
This class represents a configuration for the MLIR assembly parser.
Definition AsmState.h:469
BytecodeReaderConfig & getBytecodeReaderConfig() const
Returns the parsing configurations associated to the bytecode read.
Definition AsmState.h:489
BlockListType::iterator iterator
Definition Region.h:52
This diagnostic handler is a simple RAII class that registers and erases a diagnostic handler on a gi...
static AsmResourceBlob allocateWithAlign(ArrayRef< char > data, size_t align, AsmResourceBlob::DeleterFn deleter={}, bool dataIsMutable=false)
Create a new unmanaged resource directly referencing the provided data.
Definition AsmState.h:228
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
bool use_empty() const
Returns true if this value has no uses.
Definition Value.h:208
void shuffleUseList(ArrayRef< unsigned > indices)
Shuffle the use list order according to the provided indices.
Definition Value.cpp:106
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition Value.h:188
void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
Definition Value.h:233
unsigned getNumUses() const
This method computes the number of uses of this Value.
Definition Value.cpp:52
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition Value.h:197
use_iterator use_begin() const
Definition Value.h:184
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
@ kAttrType
This section contains the attributes and types referenced within an IR module.
Definition Encoding.h:73
@ kAttrTypeOffset
This section contains the offsets for the attribute and types within the AttrType section.
Definition Encoding.h:77
@ kIR
This section contains the list of operations serialized into the bytecode, and their nested regions/o...
Definition Encoding.h:81
@ kResource
This section contains the resources of the bytecode.
Definition Encoding.h:84
@ kResourceOffset
This section contains the offsets of resources within the Resource section.
Definition Encoding.h:88
@ kDialect
This section contains the dialects referenced within an IR module.
Definition Encoding.h:69
@ kString
This section contains strings referenced within the bytecode.
Definition Encoding.h:66
@ kDialectVersions
This section contains the versions of each dialect.
Definition Encoding.h:91
@ kProperties
This section contains the properties for the operations.
Definition Encoding.h:94
@ kNumSections
The total number of section types.
Definition Encoding.h:97
static uint64_t getUseID(OperandT &val, unsigned ownerID)
Get the unique ID of a value use.
Definition Encoding.h:127
@ kUseListOrdering
Use-list ordering started to be encoded in version 3.
Definition Encoding.h:38
@ kAlignmentByte
An arbitrary value used to fill alignment padding.
Definition Encoding.h:56
@ kVersion
The current bytecode version.
Definition Encoding.h:53
@ kLazyLoading
Support for lazy-loading of isolated region was added in version 2.
Definition Encoding.h:35
@ kDialectVersioning
Dialects versioning was added in version 1.
Definition Encoding.h:32
@ kElideUnknownBlockArgLocation
Avoid recording unknown locations on block arguments (compression) started in version 4.
Definition Encoding.h:42
@ kNativePropertiesEncoding
Support for encoding properties natively in bytecode instead of merged with the discardable attribute...
Definition Encoding.h:46
@ kMinSupportedVersion
The minimum supported version of the bytecode.
Definition Encoding.h:29
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
StringRef toString(AsmResourceEntryKind kind)
static LogicalResult readResourceHandle(DialectBytecodeReader &reader, FailureOr< T > &value, Ts &&...params)
Helper for resource handle reading that returns LogicalResult.
bool isBytecode(llvm::MemoryBufferRef buffer)
Returns true if the given buffer starts with the magic bytes that signal MLIR bytecode.
const FrozenRewritePatternSet GreedyRewriteConfig config
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:128
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
AsmResourceEntryKind
This enum represents the different kinds of resource values.
Definition AsmState.h:280
LogicalResult readBytecodeFile(llvm::MemoryBufferRef buffer, Block *block, const ParserConfig &config)
Read the operations defined within the given memory buffer, containing MLIR bytecode,...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition Verifier.cpp:423
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
SmallVector< Block *, 1 > successors
Successors of this operation and their respective operands.
SmallVector< Value, 4 > operands
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Attribute propertiesAttr
This Attribute is used to opaquely construct the properties of the operation.
SmallVector< Type, 4 > types
Types of the results of this operation.