MLIR 22.0.0git
BytecodeReader.cpp
Go to the documentation of this file.
1//===- BytecodeReader.cpp - MLIR Bytecode Reader --------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
14#include "mlir/IR/BuiltinOps.h"
15#include "mlir/IR/Diagnostics.h"
17#include "mlir/IR/Verifier.h"
18#include "mlir/IR/Visitors.h"
19#include "mlir/Support/LLVM.h"
20#include "llvm/ADT/ArrayRef.h"
21#include "llvm/ADT/ScopeExit.h"
22#include "llvm/ADT/StringExtras.h"
23#include "llvm/ADT/StringRef.h"
24#include "llvm/Support/Endian.h"
25#include "llvm/Support/MemoryBufferRef.h"
26#include "llvm/Support/SourceMgr.h"
27
28#include <cstddef>
29#include <cstdint>
30#include <list>
31#include <memory>
32#include <numeric>
33#include <optional>
34
35#define DEBUG_TYPE "mlir-bytecode-reader"
36
37using namespace mlir;
38
39/// Stringify the given section ID.
40static std::string toString(bytecode::Section::ID sectionID) {
41 switch (sectionID) {
43 return "String (0)";
45 return "Dialect (1)";
47 return "AttrType (2)";
49 return "AttrTypeOffset (3)";
51 return "IR (4)";
53 return "Resource (5)";
55 return "ResourceOffset (6)";
57 return "DialectVersions (7)";
59 return "Properties (8)";
60 default:
61 return ("Unknown (" + Twine(static_cast<unsigned>(sectionID)) + ")").str();
62 }
63}
64
65/// Returns true if the given top-level section ID is optional.
66static bool isSectionOptional(bytecode::Section::ID sectionID, int version) {
67 switch (sectionID) {
73 return false;
77 return true;
80 default:
81 llvm_unreachable("unknown section ID");
82 }
83}
84
85//===----------------------------------------------------------------------===//
86// EncodingReader
87//===----------------------------------------------------------------------===//
88
89namespace {
90class EncodingReader {
91public:
92 explicit EncodingReader(ArrayRef<uint8_t> contents, Location fileLoc)
93 : buffer(contents), dataIt(buffer.begin()), fileLoc(fileLoc) {}
94 explicit EncodingReader(StringRef contents, Location fileLoc)
95 : EncodingReader({reinterpret_cast<const uint8_t *>(contents.data()),
96 contents.size()},
97 fileLoc) {}
98
99 /// Returns true if the entire section has been read.
100 bool empty() const { return dataIt == buffer.end(); }
101
102 /// Returns the remaining size of the bytecode.
103 size_t size() const { return buffer.end() - dataIt; }
104
105 /// Align the current reader position to the specified alignment.
106 LogicalResult alignTo(unsigned alignment) {
107 if (!llvm::isPowerOf2_32(alignment))
108 return emitError("expected alignment to be a power-of-two");
109
110 auto isUnaligned = [&](const uint8_t *ptr) {
111 return ((uintptr_t)ptr & (alignment - 1)) != 0;
112 };
113
114 // Shift the reader position to the next alignment boundary.
115 // Note: this assumes the pointer alignment matches the alignment of the
116 // data from the start of the buffer. In other words, this code is only
117 // valid if `dataIt` is offsetting into an already aligned buffer.
118 while (isUnaligned(dataIt)) {
119 uint8_t padding;
120 if (failed(parseByte(padding)))
121 return failure();
122 if (padding != bytecode::kAlignmentByte) {
123 return emitError("expected alignment byte (0xCB), but got: '0x" +
124 llvm::utohexstr(padding) + "'");
125 }
126 }
127
128 // Ensure the data iterator is now aligned. This case is unlikely because we
129 // *just* went through the effort to align the data iterator.
130 if (LLVM_UNLIKELY(isUnaligned(dataIt))) {
131 return emitError("expected data iterator aligned to ", alignment,
132 ", but got pointer: '0x" +
133 llvm::utohexstr((uintptr_t)dataIt) + "'");
134 }
135
136 return success();
137 }
138
139 /// Emit an error using the given arguments.
140 template <typename... Args>
141 InFlightDiagnostic emitError(Args &&...args) const {
142 return ::emitError(fileLoc).append(std::forward<Args>(args)...);
143 }
144 InFlightDiagnostic emitError() const { return ::emitError(fileLoc); }
145
146 /// Parse a single byte from the stream.
147 template <typename T>
148 LogicalResult parseByte(T &value) {
149 if (empty())
150 return emitError("attempting to parse a byte at the end of the bytecode");
151 value = static_cast<T>(*dataIt++);
152 return success();
153 }
154 /// Parse a range of bytes of 'length' into the given result.
155 LogicalResult parseBytes(size_t length, ArrayRef<uint8_t> &result) {
156 if (length > size()) {
157 return emitError("attempting to parse ", length, " bytes when only ",
158 size(), " remain");
159 }
160 result = {dataIt, length};
161 dataIt += length;
162 return success();
163 }
164 /// Parse a range of bytes of 'length' into the given result, which can be
165 /// assumed to be large enough to hold `length`.
166 LogicalResult parseBytes(size_t length, uint8_t *result) {
167 if (length > size()) {
168 return emitError("attempting to parse ", length, " bytes when only ",
169 size(), " remain");
170 }
171 memcpy(result, dataIt, length);
172 dataIt += length;
173 return success();
174 }
175
176 /// Parse an aligned blob of data, where the alignment was encoded alongside
177 /// the data.
178 LogicalResult parseBlobAndAlignment(ArrayRef<uint8_t> &data,
179 uint64_t &alignment) {
180 uint64_t dataSize;
181 if (failed(parseVarInt(alignment)) || failed(parseVarInt(dataSize)) ||
182 failed(alignTo(alignment)))
183 return failure();
184 return parseBytes(dataSize, data);
185 }
186
187 /// Parse a variable length encoded integer from the byte stream. The first
188 /// encoded byte contains a prefix in the low bits indicating the encoded
189 /// length of the value. This length prefix is a bit sequence of '0's followed
190 /// by a '1'. The number of '0' bits indicate the number of _additional_ bytes
191 /// (not including the prefix byte). All remaining bits in the first byte,
192 /// along with all of the bits in additional bytes, provide the value of the
193 /// integer encoded in little-endian order.
194 LogicalResult parseVarInt(uint64_t &result) {
195 // Parse the first byte of the encoding, which contains the length prefix.
196 if (failed(parseByte(result)))
197 return failure();
198
199 // Handle the overwhelmingly common case where the value is stored in a
200 // single byte. In this case, the first bit is the `1` marker bit.
201 if (LLVM_LIKELY(result & 1)) {
202 result >>= 1;
203 return success();
204 }
205
206 // Handle the overwhelming uncommon case where the value required all 8
207 // bytes (i.e. a really really big number). In this case, the marker byte is
208 // all zeros: `00000000`.
209 if (LLVM_UNLIKELY(result == 0)) {
210 llvm::support::ulittle64_t resultLE;
211 if (failed(parseBytes(sizeof(resultLE),
212 reinterpret_cast<uint8_t *>(&resultLE))))
213 return failure();
214 result = resultLE;
215 return success();
216 }
217 return parseMultiByteVarInt(result);
218 }
219
220 /// Parse a signed variable length encoded integer from the byte stream. A
221 /// signed varint is encoded as a normal varint with zigzag encoding applied,
222 /// i.e. the low bit of the value is used to indicate the sign.
223 LogicalResult parseSignedVarInt(uint64_t &result) {
224 if (failed(parseVarInt(result)))
225 return failure();
226 // Essentially (but using unsigned): (x >> 1) ^ -(x & 1)
227 result = (result >> 1) ^ (~(result & 1) + 1);
228 return success();
229 }
230
231 /// Parse a variable length encoded integer whose low bit is used to encode an
232 /// unrelated flag, i.e: `(integerValue << 1) | (flag ? 1 : 0)`.
233 LogicalResult parseVarIntWithFlag(uint64_t &result, bool &flag) {
234 if (failed(parseVarInt(result)))
235 return failure();
236 flag = result & 1;
237 result >>= 1;
238 return success();
239 }
240
241 /// Skip the first `length` bytes within the reader.
242 LogicalResult skipBytes(size_t length) {
243 if (length > size()) {
244 return emitError("attempting to skip ", length, " bytes when only ",
245 size(), " remain");
246 }
247 dataIt += length;
248 return success();
249 }
250
251 /// Parse a null-terminated string into `result` (without including the NUL
252 /// terminator).
253 LogicalResult parseNullTerminatedString(StringRef &result) {
254 const char *startIt = (const char *)dataIt;
255 const char *nulIt = (const char *)memchr(startIt, 0, size());
256 if (!nulIt)
257 return emitError(
258 "malformed null-terminated string, no null character found");
259
260 result = StringRef(startIt, nulIt - startIt);
261 dataIt = (const uint8_t *)nulIt + 1;
262 return success();
263 }
264
265 /// Validate that the alignment requested in the section is valid.
266 using ValidateAlignmentFn = function_ref<LogicalResult(unsigned alignment)>;
267
268 /// Parse a section header, placing the kind of section in `sectionID` and the
269 /// contents of the section in `sectionData`.
270 LogicalResult parseSection(bytecode::Section::ID &sectionID,
271 ValidateAlignmentFn alignmentValidator,
272 ArrayRef<uint8_t> &sectionData) {
273 uint8_t sectionIDAndHasAlignment;
274 uint64_t length;
275 if (failed(parseByte(sectionIDAndHasAlignment)) ||
276 failed(parseVarInt(length)))
277 return failure();
278
279 // Extract the section ID and whether the section is aligned. The high bit
280 // of the ID is the alignment flag.
281 sectionID = static_cast<bytecode::Section::ID>(sectionIDAndHasAlignment &
282 0b01111111);
283 bool hasAlignment = sectionIDAndHasAlignment & 0b10000000;
284
285 // Check that the section is actually valid before trying to process its
286 // data.
287 if (sectionID >= bytecode::Section::kNumSections)
288 return emitError("invalid section ID: ", unsigned(sectionID));
289
290 // Process the section alignment if present.
291 if (hasAlignment) {
292 // Read the requested alignment from the bytecode parser.
293 uint64_t alignment;
294 if (failed(parseVarInt(alignment)))
295 return failure();
296
297 // Check that the requested alignment must not exceed the alignment of
298 // the root buffer itself. Otherwise we cannot guarantee that pointers
299 // derived from this buffer will actually satisfy the requested alignment
300 // globally.
301 //
302 // Consider a bytecode buffer that is guaranteed to be 8k aligned, but not
303 // 16k aligned (e.g. absolute address 40960. If a section inside this
304 // buffer declares a 16k alignment requirement, two problems can arise:
305 //
306 // (a) If we "align forward" the current pointer to the next
307 // 16k boundary, the amount of padding we skip depends on the
308 // buffer's starting address. For example:
309 //
310 // buffer_start = 40960
311 // next 16k boundary = 49152
312 // bytes skipped = 49152 - 40960 = 8192
313 //
314 // This leaves behind variable padding that could be misinterpreted
315 // as part of the next section.
316 //
317 // (b) If we align relative to the buffer start, we may
318 // obtain addresses that are multiples of "buffer_start +
319 // section_alignment" rather than truly globally aligned
320 // addresses. For example:
321 //
322 // buffer_start = 40960 (5×8k, 8k aligned but not 16k)
323 // offset = 16384 (first multiple of 16k)
324 // section_ptr = 40960 + 16384 = 57344
325 //
326 // 57344 is 8k aligned but not 16k aligned.
327 // Any consumer expecting true 16k alignment would see this as a
328 // violation.
329 if (failed(alignmentValidator(alignment)))
330 return emitError("failed to align section ID: ", unsigned(sectionID));
331
332 // Align the buffer.
333 if (failed(alignTo(alignment)))
334 return failure();
335 }
336
337 // Parse the actual section data.
338 return parseBytes(static_cast<size_t>(length), sectionData);
339 }
340
341 Location getLoc() const { return fileLoc; }
342
343private:
344 /// Parse a variable length encoded integer from the byte stream. This method
345 /// is a fallback when the number of bytes used to encode the value is greater
346 /// than 1, but less than the max (9). The provided `result` value can be
347 /// assumed to already contain the first byte of the value.
348 /// NOTE: This method is marked noinline to avoid pessimizing the common case
349 /// of single byte encoding.
350 LLVM_ATTRIBUTE_NOINLINE LogicalResult parseMultiByteVarInt(uint64_t &result) {
351 // Count the number of trailing zeros in the marker byte, this indicates the
352 // number of trailing bytes that are part of the value. We use `uint32_t`
353 // here because we only care about the first byte, and so that be actually
354 // get ctz intrinsic calls when possible (the `uint8_t` overload uses a loop
355 // implementation).
356 uint32_t numBytes = llvm::countr_zero<uint32_t>(result);
357 assert(numBytes > 0 && numBytes <= 7 &&
358 "unexpected number of trailing zeros in varint encoding");
359
360 // Parse in the remaining bytes of the value.
361 llvm::support::ulittle64_t resultLE(result);
362 if (failed(
363 parseBytes(numBytes, reinterpret_cast<uint8_t *>(&resultLE) + 1)))
364 return failure();
365
366 // Shift out the low-order bits that were used to mark how the value was
367 // encoded.
368 result = resultLE >> (numBytes + 1);
369 return success();
370 }
371
372 /// The bytecode buffer.
373 ArrayRef<uint8_t> buffer;
374
375 /// The current iterator within the 'buffer'.
376 const uint8_t *dataIt;
377
378 /// A location for the bytecode used to report errors.
379 Location fileLoc;
380};
381} // namespace
382
383/// Resolve an index into the given entry list. `entry` may either be a
384/// reference, in which case it is assigned to the corresponding value in
385/// `entries`, or a pointer, in which case it is assigned to the address of the
386/// element in `entries`.
387template <typename RangeT, typename T>
388static LogicalResult resolveEntry(EncodingReader &reader, RangeT &entries,
389 uint64_t index, T &entry,
390 StringRef entryStr) {
391 if (index >= entries.size())
392 return reader.emitError("invalid ", entryStr, " index: ", index);
393
394 // If the provided entry is a pointer, resolve to the address of the entry.
395 if constexpr (std::is_convertible_v<llvm::detail::ValueOfRange<RangeT>, T>)
396 entry = entries[index];
397 else
398 entry = &entries[index];
399 return success();
400}
401
402/// Parse and resolve an index into the given entry list.
403template <typename RangeT, typename T>
404static LogicalResult parseEntry(EncodingReader &reader, RangeT &entries,
405 T &entry, StringRef entryStr) {
406 uint64_t entryIdx;
407 if (failed(reader.parseVarInt(entryIdx)))
408 return failure();
409 return resolveEntry(reader, entries, entryIdx, entry, entryStr);
410}
411
412//===----------------------------------------------------------------------===//
413// StringSectionReader
414//===----------------------------------------------------------------------===//
415
416namespace {
417/// This class is used to read references to the string section from the
418/// bytecode.
419class StringSectionReader {
420public:
421 /// Initialize the string section reader with the given section data.
422 LogicalResult initialize(Location fileLoc, ArrayRef<uint8_t> sectionData);
423
424 /// Parse a shared string from the string section. The shared string is
425 /// encoded using an index to a corresponding string in the string section.
426 LogicalResult parseString(EncodingReader &reader, StringRef &result) const {
427 return parseEntry(reader, strings, result, "string");
428 }
429
430 /// Parse a shared string from the string section. The shared string is
431 /// encoded using an index to a corresponding string in the string section.
432 /// This variant parses a flag compressed with the index.
433 LogicalResult parseStringWithFlag(EncodingReader &reader, StringRef &result,
434 bool &flag) const {
435 uint64_t entryIdx;
436 if (failed(reader.parseVarIntWithFlag(entryIdx, flag)))
437 return failure();
438 return parseStringAtIndex(reader, entryIdx, result);
439 }
440
441 /// Parse a shared string from the string section. The shared string is
442 /// encoded using an index to a corresponding string in the string section.
443 LogicalResult parseStringAtIndex(EncodingReader &reader, uint64_t index,
444 StringRef &result) const {
445 return resolveEntry(reader, strings, index, result, "string");
446 }
447
448private:
449 /// The table of strings referenced within the bytecode file.
450 SmallVector<StringRef> strings;
451};
452} // namespace
453
454LogicalResult StringSectionReader::initialize(Location fileLoc,
455 ArrayRef<uint8_t> sectionData) {
456 EncodingReader stringReader(sectionData, fileLoc);
457
458 // Parse the number of strings in the section.
459 uint64_t numStrings;
460 if (failed(stringReader.parseVarInt(numStrings)))
461 return failure();
462 strings.resize(numStrings);
463
464 // Parse each of the strings. The sizes of the strings are encoded in reverse
465 // order, so that's the order we populate the table.
466 size_t stringDataEndOffset = sectionData.size();
467 for (StringRef &string : llvm::reverse(strings)) {
468 uint64_t stringSize;
469 if (failed(stringReader.parseVarInt(stringSize)))
470 return failure();
471 if (stringDataEndOffset < stringSize) {
472 return stringReader.emitError(
473 "string size exceeds the available data size");
474 }
475
476 // Extract the string from the data, dropping the null character.
477 size_t stringOffset = stringDataEndOffset - stringSize;
478 string = StringRef(
479 reinterpret_cast<const char *>(sectionData.data() + stringOffset),
480 stringSize - 1);
481 stringDataEndOffset = stringOffset;
482 }
483
484 // Check that the only remaining data was for the strings, i.e. the reader
485 // should be at the same offset as the first string.
486 if ((sectionData.size() - stringReader.size()) != stringDataEndOffset) {
487 return stringReader.emitError("unexpected trailing data between the "
488 "offsets for strings and their data");
489 }
490 return success();
491}
492
493//===----------------------------------------------------------------------===//
494// BytecodeDialect
495//===----------------------------------------------------------------------===//
496
497namespace {
498class DialectReader;
499
500/// This struct represents a dialect entry within the bytecode.
501struct BytecodeDialect {
502 /// Load the dialect into the provided context if it hasn't been loaded yet.
503 /// Returns failure if the dialect couldn't be loaded *and* the provided
504 /// context does not allow unregistered dialects. The provided reader is used
505 /// for error emission if necessary.
506 LogicalResult load(const DialectReader &reader, MLIRContext *ctx);
507
508 /// Return the loaded dialect, or nullptr if the dialect is unknown. This can
509 /// only be called after `load`.
510 Dialect *getLoadedDialect() const {
511 assert(dialect &&
512 "expected `load` to be invoked before `getLoadedDialect`");
513 return *dialect;
514 }
515
516 /// The loaded dialect entry. This field is std::nullopt if we haven't
517 /// attempted to load, nullptr if we failed to load, otherwise the loaded
518 /// dialect.
519 std::optional<Dialect *> dialect;
520
521 /// The bytecode interface of the dialect, or nullptr if the dialect does not
522 /// implement the bytecode interface. This field should only be checked if the
523 /// `dialect` field is not std::nullopt.
524 const BytecodeDialectInterface *interface = nullptr;
525
526 /// The name of the dialect.
527 StringRef name;
528
529 /// A buffer containing the encoding of the dialect version parsed.
530 ArrayRef<uint8_t> versionBuffer;
531
532 /// Lazy loaded dialect version from the handle above.
533 std::unique_ptr<DialectVersion> loadedVersion;
534};
535
536/// This struct represents an operation name entry within the bytecode.
537struct BytecodeOperationName {
538 BytecodeOperationName(BytecodeDialect *dialect, StringRef name,
539 std::optional<bool> wasRegistered)
540 : dialect(dialect), name(name), wasRegistered(wasRegistered) {}
541
542 /// The loaded operation name, or std::nullopt if it hasn't been processed
543 /// yet.
544 std::optional<OperationName> opName;
545
546 /// The dialect that owns this operation name.
547 BytecodeDialect *dialect;
548
549 /// The name of the operation, without the dialect prefix.
550 StringRef name;
551
552 /// Whether this operation was registered when the bytecode was produced.
553 /// This flag is populated when bytecode version >=kNativePropertiesEncoding.
554 std::optional<bool> wasRegistered;
555};
556} // namespace
557
558/// Parse a single dialect group encoded in the byte stream.
559static LogicalResult parseDialectGrouping(
560 EncodingReader &reader,
561 MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
562 function_ref<LogicalResult(BytecodeDialect *)> entryCallback) {
563 // Parse the dialect and the number of entries in the group.
564 std::unique_ptr<BytecodeDialect> *dialect;
565 if (failed(parseEntry(reader, dialects, dialect, "dialect")))
566 return failure();
567 uint64_t numEntries;
568 if (failed(reader.parseVarInt(numEntries)))
569 return failure();
570
571 for (uint64_t i = 0; i < numEntries; ++i)
572 if (failed(entryCallback(dialect->get())))
573 return failure();
574 return success();
575}
576
577//===----------------------------------------------------------------------===//
578// ResourceSectionReader
579//===----------------------------------------------------------------------===//
580
581namespace {
582/// This class is used to read the resource section from the bytecode.
583class ResourceSectionReader {
584public:
585 /// Initialize the resource section reader with the given section data.
586 LogicalResult
587 initialize(Location fileLoc, const ParserConfig &config,
588 MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
589 StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData,
590 ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader,
591 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef);
592
593 /// Parse a dialect resource handle from the resource section.
594 LogicalResult parseResourceHandle(EncodingReader &reader,
595 AsmDialectResourceHandle &result) const {
596 return parseEntry(reader, dialectResources, result, "resource handle");
597 }
598
599private:
600 /// The table of dialect resources within the bytecode file.
601 SmallVector<AsmDialectResourceHandle> dialectResources;
602 llvm::StringMap<std::string> dialectResourceHandleRenamingMap;
603};
604
605class ParsedResourceEntry : public AsmParsedResourceEntry {
606public:
607 ParsedResourceEntry(StringRef key, AsmResourceEntryKind kind,
608 EncodingReader &reader, StringSectionReader &stringReader,
609 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef)
610 : key(key), kind(kind), reader(reader), stringReader(stringReader),
611 bufferOwnerRef(bufferOwnerRef) {}
612 ~ParsedResourceEntry() override = default;
613
614 StringRef getKey() const final { return key; }
615
616 InFlightDiagnostic emitError() const final { return reader.emitError(); }
617
618 AsmResourceEntryKind getKind() const final { return kind; }
619
620 FailureOr<bool> parseAsBool() const final {
621 if (kind != AsmResourceEntryKind::Bool)
622 return emitError() << "expected a bool resource entry, but found a "
623 << toString(kind) << " entry instead";
624
625 bool value;
626 if (failed(reader.parseByte(value)))
627 return failure();
628 return value;
629 }
630 FailureOr<std::string> parseAsString() const final {
631 if (kind != AsmResourceEntryKind::String)
632 return emitError() << "expected a string resource entry, but found a "
633 << toString(kind) << " entry instead";
634
635 StringRef string;
636 if (failed(stringReader.parseString(reader, string)))
637 return failure();
638 return string.str();
639 }
640
641 FailureOr<AsmResourceBlob>
642 parseAsBlob(BlobAllocatorFn allocator) const final {
643 if (kind != AsmResourceEntryKind::Blob)
644 return emitError() << "expected a blob resource entry, but found a "
645 << toString(kind) << " entry instead";
646
647 ArrayRef<uint8_t> data;
648 uint64_t alignment;
649 if (failed(reader.parseBlobAndAlignment(data, alignment)))
650 return failure();
651
652 // If we have an extendable reference to the buffer owner, we don't need to
653 // allocate a new buffer for the data, and can use the data directly.
654 if (bufferOwnerRef) {
655 ArrayRef<char> charData(reinterpret_cast<const char *>(data.data()),
656 data.size());
657
658 // Allocate an unmanager buffer which captures a reference to the owner.
659 // For now we just mark this as immutable, but in the future we should
660 // explore marking this as mutable when desired.
662 charData, alignment,
663 [bufferOwnerRef = bufferOwnerRef](void *, size_t, size_t) {});
664 }
665
666 // Allocate memory for the blob using the provided allocator and copy the
667 // data into it.
668 AsmResourceBlob blob = allocator(data.size(), alignment);
669 assert(llvm::isAddrAligned(llvm::Align(alignment), blob.getData().data()) &&
670 blob.isMutable() &&
671 "blob allocator did not return a properly aligned address");
672 memcpy(blob.getMutableData().data(), data.data(), data.size());
673 return blob;
674 }
675
676private:
677 StringRef key;
679 EncodingReader &reader;
680 StringSectionReader &stringReader;
681 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef;
682};
683} // namespace
684
685template <typename T>
686static LogicalResult
687parseResourceGroup(Location fileLoc, bool allowEmpty,
688 EncodingReader &offsetReader, EncodingReader &resourceReader,
689 StringSectionReader &stringReader, T *handler,
690 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef,
691 function_ref<StringRef(StringRef)> remapKey = {},
692 function_ref<LogicalResult(StringRef)> processKeyFn = {}) {
693 uint64_t numResources;
694 if (failed(offsetReader.parseVarInt(numResources)))
695 return failure();
696
697 for (uint64_t i = 0; i < numResources; ++i) {
698 StringRef key;
700 uint64_t resourceOffset;
701 ArrayRef<uint8_t> data;
702 if (failed(stringReader.parseString(offsetReader, key)) ||
703 failed(offsetReader.parseVarInt(resourceOffset)) ||
704 failed(offsetReader.parseByte(kind)) ||
705 failed(resourceReader.parseBytes(resourceOffset, data)))
706 return failure();
707
708 // Process the resource key.
709 if ((processKeyFn && failed(processKeyFn(key))))
710 return failure();
711
712 // If the resource data is empty and we allow it, don't error out when
713 // parsing below, just skip it.
714 if (allowEmpty && data.empty())
715 continue;
716
717 // Ignore the entry if we don't have a valid handler.
718 if (!handler)
719 continue;
720
721 // Otherwise, parse the resource value.
722 EncodingReader entryReader(data, fileLoc);
723 key = remapKey(key);
724 ParsedResourceEntry entry(key, kind, entryReader, stringReader,
725 bufferOwnerRef);
726 if (failed(handler->parseResource(entry)))
727 return failure();
728 if (!entryReader.empty()) {
729 return entryReader.emitError(
730 "unexpected trailing bytes in resource entry '", key, "'");
731 }
732 }
733 return success();
734}
735
736LogicalResult ResourceSectionReader::initialize(
737 Location fileLoc, const ParserConfig &config,
738 MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
739 StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData,
740 ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader,
741 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) {
742 EncodingReader resourceReader(sectionData, fileLoc);
743 EncodingReader offsetReader(offsetSectionData, fileLoc);
744
745 // Read the number of external resource providers.
746 uint64_t numExternalResourceGroups;
747 if (failed(offsetReader.parseVarInt(numExternalResourceGroups)))
748 return failure();
749
750 // Utility functor that dispatches to `parseResourceGroup`, but implicitly
751 // provides most of the arguments.
752 auto parseGroup = [&](auto *handler, bool allowEmpty = false,
753 function_ref<LogicalResult(StringRef)> keyFn = {}) {
754 auto resolveKey = [&](StringRef key) -> StringRef {
755 auto it = dialectResourceHandleRenamingMap.find(key);
756 if (it == dialectResourceHandleRenamingMap.end())
757 return key;
758 return it->second;
759 };
760
761 return parseResourceGroup(fileLoc, allowEmpty, offsetReader, resourceReader,
762 stringReader, handler, bufferOwnerRef, resolveKey,
763 keyFn);
764 };
765
766 // Read the external resources from the bytecode.
767 for (uint64_t i = 0; i < numExternalResourceGroups; ++i) {
768 StringRef key;
769 if (failed(stringReader.parseString(offsetReader, key)))
770 return failure();
771
772 // Get the handler for these resources.
773 // TODO: Should we require handling external resources in some scenarios?
774 AsmResourceParser *handler = config.getResourceParser(key);
775 if (!handler) {
776 emitWarning(fileLoc) << "ignoring unknown external resources for '" << key
777 << "'";
778 }
779
780 if (failed(parseGroup(handler)))
781 return failure();
782 }
783
784 // Read the dialect resources from the bytecode.
785 MLIRContext *ctx = fileLoc->getContext();
786 while (!offsetReader.empty()) {
787 std::unique_ptr<BytecodeDialect> *dialect;
788 if (failed(parseEntry(offsetReader, dialects, dialect, "dialect")) ||
789 failed((*dialect)->load(dialectReader, ctx)))
790 return failure();
791 Dialect *loadedDialect = (*dialect)->getLoadedDialect();
792 if (!loadedDialect) {
793 return resourceReader.emitError()
794 << "dialect '" << (*dialect)->name << "' is unknown";
795 }
796 const auto *handler = dyn_cast<OpAsmDialectInterface>(loadedDialect);
797 if (!handler) {
798 return resourceReader.emitError()
799 << "unexpected resources for dialect '" << (*dialect)->name << "'";
800 }
801
802 // Ensure that each resource is declared before being processed.
803 auto processResourceKeyFn = [&](StringRef key) -> LogicalResult {
804 FailureOr<AsmDialectResourceHandle> handle =
805 handler->declareResource(key);
806 if (failed(handle)) {
807 return resourceReader.emitError()
808 << "unknown 'resource' key '" << key << "' for dialect '"
809 << (*dialect)->name << "'";
810 }
811 dialectResourceHandleRenamingMap[key] = handler->getResourceKey(*handle);
812 dialectResources.push_back(*handle);
813 return success();
814 };
815
816 // Parse the resources for this dialect. We allow empty resources because we
817 // just treat these as declarations.
818 if (failed(parseGroup(handler, /*allowEmpty=*/true, processResourceKeyFn)))
819 return failure();
820 }
821
822 return success();
823}
824
825//===----------------------------------------------------------------------===//
826// Attribute/Type Reader
827//===----------------------------------------------------------------------===//
828
829namespace {
830/// This class provides support for reading attribute and type entries from the
831/// bytecode. Attribute and Type entries are read lazily on demand, so we use
832/// this reader to manage when to actually parse them from the bytecode.
833class AttrTypeReader {
834 /// This class represents a single attribute or type entry.
835 template <typename T>
836 struct Entry {
837 /// The entry, or null if it hasn't been resolved yet.
838 T entry = {};
839 /// The parent dialect of this entry.
840 BytecodeDialect *dialect = nullptr;
841 /// A flag indicating if the entry was encoded using a custom encoding,
842 /// instead of using the textual assembly format.
843 bool hasCustomEncoding = false;
844 /// The raw data of this entry in the bytecode.
845 ArrayRef<uint8_t> data;
846 };
847 using AttrEntry = Entry<Attribute>;
848 using TypeEntry = Entry<Type>;
849
850public:
851 AttrTypeReader(const StringSectionReader &stringReader,
852 const ResourceSectionReader &resourceReader,
853 const llvm::StringMap<BytecodeDialect *> &dialectsMap,
854 uint64_t &bytecodeVersion, Location fileLoc,
855 const ParserConfig &config)
856 : stringReader(stringReader), resourceReader(resourceReader),
857 dialectsMap(dialectsMap), fileLoc(fileLoc),
858 bytecodeVersion(bytecodeVersion), parserConfig(config) {}
859
860 /// Initialize the attribute and type information within the reader.
861 LogicalResult
862 initialize(MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
863 ArrayRef<uint8_t> sectionData,
864 ArrayRef<uint8_t> offsetSectionData);
865
866 /// Resolve the attribute or type at the given index. Returns nullptr on
867 /// failure.
868 Attribute resolveAttribute(size_t index) {
869 return resolveEntry(attributes, index, "Attribute");
870 }
871 Type resolveType(size_t index) { return resolveEntry(types, index, "Type"); }
872
873 /// Parse a reference to an attribute or type using the given reader.
874 LogicalResult parseAttribute(EncodingReader &reader, Attribute &result) {
875 uint64_t attrIdx;
876 if (failed(reader.parseVarInt(attrIdx)))
877 return failure();
878 result = resolveAttribute(attrIdx);
879 return success(!!result);
880 }
881 LogicalResult parseOptionalAttribute(EncodingReader &reader,
882 Attribute &result) {
883 uint64_t attrIdx;
884 bool flag;
885 if (failed(reader.parseVarIntWithFlag(attrIdx, flag)))
886 return failure();
887 if (!flag)
888 return success();
889 result = resolveAttribute(attrIdx);
890 return success(!!result);
891 }
892
893 LogicalResult parseType(EncodingReader &reader, Type &result) {
894 uint64_t typeIdx;
895 if (failed(reader.parseVarInt(typeIdx)))
896 return failure();
897 result = resolveType(typeIdx);
898 return success(!!result);
899 }
900
901 template <typename T>
902 LogicalResult parseAttribute(EncodingReader &reader, T &result) {
903 Attribute baseResult;
904 if (failed(parseAttribute(reader, baseResult)))
905 return failure();
906 if ((result = dyn_cast<T>(baseResult)))
907 return success();
908 return reader.emitError("expected attribute of type: ",
909 llvm::getTypeName<T>(), ", but got: ", baseResult);
910 }
911
912private:
913 /// Resolve the given entry at `index`.
914 template <typename T>
915 T resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index,
916 StringRef entryType);
917
918 /// Parse an entry using the given reader that was encoded using the textual
919 /// assembly format.
920 template <typename T>
921 LogicalResult parseAsmEntry(T &result, EncodingReader &reader,
922 StringRef entryType);
923
924 /// Parse an entry using the given reader that was encoded using a custom
925 /// bytecode format.
926 template <typename T>
927 LogicalResult parseCustomEntry(Entry<T> &entry, EncodingReader &reader,
928 StringRef entryType);
929
930 /// The string section reader used to resolve string references when parsing
931 /// custom encoded attribute/type entries.
932 const StringSectionReader &stringReader;
933
934 /// The resource section reader used to resolve resource references when
935 /// parsing custom encoded attribute/type entries.
936 const ResourceSectionReader &resourceReader;
937
938 /// The map of the loaded dialects used to retrieve dialect information, such
939 /// as the dialect version.
940 const llvm::StringMap<BytecodeDialect *> &dialectsMap;
941
942 /// The set of attribute and type entries.
943 SmallVector<AttrEntry> attributes;
944 SmallVector<TypeEntry> types;
945
946 /// A location used for error emission.
947 Location fileLoc;
948
949 /// Current bytecode version being used.
950 uint64_t &bytecodeVersion;
951
952 /// Reference to the parser configuration.
953 const ParserConfig &parserConfig;
954};
955
956class DialectReader : public DialectBytecodeReader {
957public:
958 DialectReader(AttrTypeReader &attrTypeReader,
959 const StringSectionReader &stringReader,
960 const ResourceSectionReader &resourceReader,
961 const llvm::StringMap<BytecodeDialect *> &dialectsMap,
962 EncodingReader &reader, uint64_t &bytecodeVersion)
963 : attrTypeReader(attrTypeReader), stringReader(stringReader),
964 resourceReader(resourceReader), dialectsMap(dialectsMap),
965 reader(reader), bytecodeVersion(bytecodeVersion) {}
966
967 InFlightDiagnostic emitError(const Twine &msg) const override {
968 return reader.emitError(msg);
969 }
970
971 FailureOr<const DialectVersion *>
972 getDialectVersion(StringRef dialectName) const override {
973 // First check if the dialect is available in the map.
974 auto dialectEntry = dialectsMap.find(dialectName);
975 if (dialectEntry == dialectsMap.end())
976 return failure();
977 // If the dialect was found, try to load it. This will trigger reading the
978 // bytecode version from the version buffer if it wasn't already processed.
979 // Return failure if either of those two actions could not be completed.
980 if (failed(dialectEntry->getValue()->load(*this, getLoc().getContext())) ||
981 dialectEntry->getValue()->loadedVersion == nullptr)
982 return failure();
983 return dialectEntry->getValue()->loadedVersion.get();
984 }
985
986 MLIRContext *getContext() const override { return getLoc().getContext(); }
987
988 uint64_t getBytecodeVersion() const override { return bytecodeVersion; }
989
990 DialectReader withEncodingReader(EncodingReader &encReader) const {
991 return DialectReader(attrTypeReader, stringReader, resourceReader,
992 dialectsMap, encReader, bytecodeVersion);
993 }
994
995 Location getLoc() const { return reader.getLoc(); }
996
997 //===--------------------------------------------------------------------===//
998 // IR
999 //===--------------------------------------------------------------------===//
1000
1001 LogicalResult readAttribute(Attribute &result) override {
1002 return attrTypeReader.parseAttribute(reader, result);
1003 }
1004 LogicalResult readOptionalAttribute(Attribute &result) override {
1005 return attrTypeReader.parseOptionalAttribute(reader, result);
1006 }
1007 LogicalResult readType(Type &result) override {
1008 return attrTypeReader.parseType(reader, result);
1009 }
1010
1011 FailureOr<AsmDialectResourceHandle> readResourceHandle() override {
1012 AsmDialectResourceHandle handle;
1013 if (failed(resourceReader.parseResourceHandle(reader, handle)))
1014 return failure();
1015 return handle;
1016 }
1017
1018 //===--------------------------------------------------------------------===//
1019 // Primitives
1020 //===--------------------------------------------------------------------===//
1021
1022 LogicalResult readVarInt(uint64_t &result) override {
1023 return reader.parseVarInt(result);
1024 }
1025
1026 LogicalResult readSignedVarInt(int64_t &result) override {
1027 uint64_t unsignedResult;
1028 if (failed(reader.parseSignedVarInt(unsignedResult)))
1029 return failure();
1030 result = static_cast<int64_t>(unsignedResult);
1031 return success();
1032 }
1033
1034 FailureOr<APInt> readAPIntWithKnownWidth(unsigned bitWidth) override {
1035 // Small values are encoded using a single byte.
1036 if (bitWidth <= 8) {
1037 uint8_t value;
1038 if (failed(reader.parseByte(value)))
1039 return failure();
1040 return APInt(bitWidth, value);
1041 }
1042
1043 // Large values up to 64 bits are encoded using a single varint.
1044 if (bitWidth <= 64) {
1045 uint64_t value;
1046 if (failed(reader.parseSignedVarInt(value)))
1047 return failure();
1048 return APInt(bitWidth, value);
1049 }
1050
1051 // Otherwise, for really big values we encode the array of active words in
1052 // the value.
1053 uint64_t numActiveWords;
1054 if (failed(reader.parseVarInt(numActiveWords)))
1055 return failure();
1056 SmallVector<uint64_t, 4> words(numActiveWords);
1057 for (uint64_t i = 0; i < numActiveWords; ++i)
1058 if (failed(reader.parseSignedVarInt(words[i])))
1059 return failure();
1060 return APInt(bitWidth, words);
1061 }
1062
1063 FailureOr<APFloat>
1064 readAPFloatWithKnownSemantics(const llvm::fltSemantics &semantics) override {
1065 FailureOr<APInt> intVal =
1066 readAPIntWithKnownWidth(APFloat::getSizeInBits(semantics));
1067 if (failed(intVal))
1068 return failure();
1069 return APFloat(semantics, *intVal);
1070 }
1071
1072 LogicalResult readString(StringRef &result) override {
1073 return stringReader.parseString(reader, result);
1074 }
1075
1076 LogicalResult readBlob(ArrayRef<char> &result) override {
1077 uint64_t dataSize;
1078 ArrayRef<uint8_t> data;
1079 if (failed(reader.parseVarInt(dataSize)) ||
1080 failed(reader.parseBytes(dataSize, data)))
1081 return failure();
1082 result = llvm::ArrayRef(reinterpret_cast<const char *>(data.data()),
1083 data.size());
1084 return success();
1085 }
1086
1087 LogicalResult readBool(bool &result) override {
1088 return reader.parseByte(result);
1089 }
1090
1091private:
1092 AttrTypeReader &attrTypeReader;
1093 const StringSectionReader &stringReader;
1094 const ResourceSectionReader &resourceReader;
1095 const llvm::StringMap<BytecodeDialect *> &dialectsMap;
1096 EncodingReader &reader;
1097 uint64_t &bytecodeVersion;
1098};
1099
1100/// Wraps the properties section and handles reading properties out of it.
1101class PropertiesSectionReader {
1102public:
1103 /// Initialize the properties section reader with the given section data.
1104 LogicalResult initialize(Location fileLoc, ArrayRef<uint8_t> sectionData) {
1105 if (sectionData.empty())
1106 return success();
1107 EncodingReader propReader(sectionData, fileLoc);
1108 uint64_t count;
1109 if (failed(propReader.parseVarInt(count)))
1110 return failure();
1111 // Parse the raw properties buffer.
1112 if (failed(propReader.parseBytes(propReader.size(), propertiesBuffers)))
1113 return failure();
1114
1115 EncodingReader offsetsReader(propertiesBuffers, fileLoc);
1116 offsetTable.reserve(count);
1117 for (auto idx : llvm::seq<int64_t>(0, count)) {
1118 (void)idx;
1119 offsetTable.push_back(propertiesBuffers.size() - offsetsReader.size());
1120 ArrayRef<uint8_t> rawProperties;
1121 uint64_t dataSize;
1122 if (failed(offsetsReader.parseVarInt(dataSize)) ||
1123 failed(offsetsReader.parseBytes(dataSize, rawProperties)))
1124 return failure();
1125 }
1126 if (!offsetsReader.empty())
1127 return offsetsReader.emitError()
1128 << "Broken properties section: didn't exhaust the offsets table";
1129 return success();
1130 }
1131
1132 LogicalResult read(Location fileLoc, DialectReader &dialectReader,
1133 OperationName *opName, OperationState &opState) const {
1134 uint64_t propertiesIdx;
1135 if (failed(dialectReader.readVarInt(propertiesIdx)))
1136 return failure();
1137 if (propertiesIdx >= offsetTable.size())
1138 return dialectReader.emitError("Properties idx out-of-bound for ")
1139 << opName->getStringRef();
1140 size_t propertiesOffset = offsetTable[propertiesIdx];
1141 if (propertiesIdx >= propertiesBuffers.size())
1142 return dialectReader.emitError("Properties offset out-of-bound for ")
1143 << opName->getStringRef();
1144
1145 // Acquire the sub-buffer that represent the requested properties.
1146 ArrayRef<char> rawProperties;
1147 {
1148 // "Seek" to the requested offset by getting a new reader with the right
1149 // sub-buffer.
1150 EncodingReader reader(propertiesBuffers.drop_front(propertiesOffset),
1151 fileLoc);
1152 // Properties are stored as a sequence of {size + raw_data}.
1153 if (failed(
1154 dialectReader.withEncodingReader(reader).readBlob(rawProperties)))
1155 return failure();
1156 }
1157 // Setup a new reader to read from the `rawProperties` sub-buffer.
1158 EncodingReader reader(
1159 StringRef(rawProperties.begin(), rawProperties.size()), fileLoc);
1160 DialectReader propReader = dialectReader.withEncodingReader(reader);
1161
1162 auto *iface = opName->getInterface<BytecodeOpInterface>();
1163 if (iface)
1164 return iface->readProperties(propReader, opState);
1165 if (opName->isRegistered())
1166 return propReader.emitError(
1167 "has properties but missing BytecodeOpInterface for ")
1168 << opName->getStringRef();
1169 // Unregistered op are storing properties as an attribute.
1170 return propReader.readAttribute(opState.propertiesAttr);
1171 }
1172
1173private:
1174 /// The properties buffer referenced within the bytecode file.
1175 ArrayRef<uint8_t> propertiesBuffers;
1176
1177 /// Table of offset in the buffer above.
1178 SmallVector<int64_t> offsetTable;
1179};
1180} // namespace
1181
1182LogicalResult AttrTypeReader::initialize(
1183 MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
1184 ArrayRef<uint8_t> sectionData, ArrayRef<uint8_t> offsetSectionData) {
1185 EncodingReader offsetReader(offsetSectionData, fileLoc);
1186
1187 // Parse the number of attribute and type entries.
1188 uint64_t numAttributes, numTypes;
1189 if (failed(offsetReader.parseVarInt(numAttributes)) ||
1190 failed(offsetReader.parseVarInt(numTypes)))
1191 return failure();
1192 attributes.resize(numAttributes);
1193 types.resize(numTypes);
1194
1195 // A functor used to accumulate the offsets for the entries in the given
1196 // range.
1197 uint64_t currentOffset = 0;
1198 auto parseEntries = [&](auto &&range) {
1199 size_t currentIndex = 0, endIndex = range.size();
1200
1201 // Parse an individual entry.
1202 auto parseEntryFn = [&](BytecodeDialect *dialect) -> LogicalResult {
1203 auto &entry = range[currentIndex++];
1204
1205 uint64_t entrySize;
1206 if (failed(offsetReader.parseVarIntWithFlag(entrySize,
1207 entry.hasCustomEncoding)))
1208 return failure();
1209
1210 // Verify that the offset is actually valid.
1211 if (currentOffset + entrySize > sectionData.size()) {
1212 return offsetReader.emitError(
1213 "Attribute or Type entry offset points past the end of section");
1214 }
1215
1216 entry.data = sectionData.slice(currentOffset, entrySize);
1217 entry.dialect = dialect;
1218 currentOffset += entrySize;
1219 return success();
1220 };
1221 while (currentIndex != endIndex)
1222 if (failed(parseDialectGrouping(offsetReader, dialects, parseEntryFn)))
1223 return failure();
1224 return success();
1225 };
1226
1227 // Process each of the attributes, and then the types.
1228 if (failed(parseEntries(attributes)) || failed(parseEntries(types)))
1229 return failure();
1230
1231 // Ensure that we read everything from the section.
1232 if (!offsetReader.empty()) {
1233 return offsetReader.emitError(
1234 "unexpected trailing data in the Attribute/Type offset section");
1235 }
1236
1237 return success();
1238}
1239
1240template <typename T>
1241T AttrTypeReader::resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index,
1242 StringRef entryType) {
1243 if (index >= entries.size()) {
1244 emitError(fileLoc) << "invalid " << entryType << " index: " << index;
1245 return {};
1246 }
1247
1248 // If the entry has already been resolved, there is nothing left to do.
1249 Entry<T> &entry = entries[index];
1250 if (entry.entry)
1251 return entry.entry;
1252
1253 // Parse the entry.
1254 EncodingReader reader(entry.data, fileLoc);
1255
1256 // Parse based on how the entry was encoded.
1257 if (entry.hasCustomEncoding) {
1258 if (failed(parseCustomEntry(entry, reader, entryType)))
1259 return T();
1260 } else if (failed(parseAsmEntry(entry.entry, reader, entryType))) {
1261 return T();
1262 }
1263
1264 if (!reader.empty()) {
1265 reader.emitError("unexpected trailing bytes after " + entryType + " entry");
1266 return T();
1267 }
1268 return entry.entry;
1269}
1270
1271template <typename T>
1272LogicalResult AttrTypeReader::parseAsmEntry(T &result, EncodingReader &reader,
1273 StringRef entryType) {
1274 StringRef asmStr;
1275 if (failed(reader.parseNullTerminatedString(asmStr)))
1276 return failure();
1277
1278 // Invoke the MLIR assembly parser to parse the entry text.
1279 size_t numRead = 0;
1280 MLIRContext *context = fileLoc->getContext();
1281 if constexpr (std::is_same_v<T, Type>)
1282 result =
1283 ::parseType(asmStr, context, &numRead, /*isKnownNullTerminated=*/true);
1284 else
1285 result = ::parseAttribute(asmStr, context, Type(), &numRead,
1286 /*isKnownNullTerminated=*/true);
1287 if (!result)
1288 return failure();
1289
1290 // Ensure there weren't dangling characters after the entry.
1291 if (numRead != asmStr.size()) {
1292 return reader.emitError("trailing characters found after ", entryType,
1293 " assembly format: ", asmStr.drop_front(numRead));
1294 }
1295 return success();
1296}
1297
1298template <typename T>
1299LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry,
1300 EncodingReader &reader,
1301 StringRef entryType) {
1302 DialectReader dialectReader(*this, stringReader, resourceReader, dialectsMap,
1303 reader, bytecodeVersion);
1304 if (failed(entry.dialect->load(dialectReader, fileLoc.getContext())))
1305 return failure();
1306
1307 if constexpr (std::is_same_v<T, Type>) {
1308 // Try parsing with callbacks first if available.
1309 for (const auto &callback :
1310 parserConfig.getBytecodeReaderConfig().getTypeCallbacks()) {
1311 if (failed(
1312 callback->read(dialectReader, entry.dialect->name, entry.entry)))
1313 return failure();
1314 // Early return if parsing was successful.
1315 if (!!entry.entry)
1316 return success();
1317
1318 // Reset the reader if we failed to parse, so we can fall through the
1319 // other parsing functions.
1320 reader = EncodingReader(entry.data, reader.getLoc());
1321 }
1322 } else {
1323 // Try parsing with callbacks first if available.
1324 for (const auto &callback :
1326 if (failed(
1327 callback->read(dialectReader, entry.dialect->name, entry.entry)))
1328 return failure();
1329 // Early return if parsing was successful.
1330 if (!!entry.entry)
1331 return success();
1332
1333 // Reset the reader if we failed to parse, so we can fall through the
1334 // other parsing functions.
1335 reader = EncodingReader(entry.data, reader.getLoc());
1336 }
1337 }
1338
1339 // Ensure that the dialect implements the bytecode interface.
1340 if (!entry.dialect->interface) {
1341 return reader.emitError("dialect '", entry.dialect->name,
1342 "' does not implement the bytecode interface");
1343 }
1344
1345 if constexpr (std::is_same_v<T, Type>)
1346 entry.entry = entry.dialect->interface->readType(dialectReader);
1347 else
1348 entry.entry = entry.dialect->interface->readAttribute(dialectReader);
1349
1350 return success(!!entry.entry);
1351}
1352
1353//===----------------------------------------------------------------------===//
1354// Bytecode Reader
1355//===----------------------------------------------------------------------===//
1356
1357/// This class is used to read a bytecode buffer and translate it into MLIR.
1359 struct RegionReadState;
1360 using LazyLoadableOpsInfo =
1361 std::list<std::pair<Operation *, RegionReadState>>;
1362 using LazyLoadableOpsMap =
1364
1365public:
1366 Impl(Location fileLoc, const ParserConfig &config, bool lazyLoading,
1367 llvm::MemoryBufferRef buffer,
1368 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef)
1369 : config(config), fileLoc(fileLoc), lazyLoading(lazyLoading),
1370 attrTypeReader(stringReader, resourceReader, dialectsMap, version,
1371 fileLoc, config),
1372 // Use the builtin unrealized conversion cast operation to represent
1373 // forward references to values that aren't yet defined.
1374 forwardRefOpState(UnknownLoc::get(config.getContext()),
1375 "builtin.unrealized_conversion_cast", ValueRange(),
1376 NoneType::get(config.getContext())),
1377 buffer(buffer), bufferOwnerRef(bufferOwnerRef) {}
1378
1379 /// Read the bytecode defined within `buffer` into the given block.
1380 LogicalResult read(Block *block,
1381 llvm::function_ref<bool(Operation *)> lazyOps);
1382
1383 /// Return the number of ops that haven't been materialized yet.
1384 int64_t getNumOpsToMaterialize() const { return lazyLoadableOpsMap.size(); }
1385
1386 bool isMaterializable(Operation *op) { return lazyLoadableOpsMap.count(op); }
1387
1388 /// Materialize the provided operation, invoke the lazyOpsCallback on every
1389 /// newly found lazy operation.
1390 LogicalResult
1392 llvm::function_ref<bool(Operation *)> lazyOpsCallback) {
1393 this->lazyOpsCallback = lazyOpsCallback;
1394 auto resetlazyOpsCallback =
1395 llvm::make_scope_exit([&] { this->lazyOpsCallback = nullptr; });
1396 auto it = lazyLoadableOpsMap.find(op);
1397 assert(it != lazyLoadableOpsMap.end() &&
1398 "materialize called on non-materializable op");
1399 return materialize(it);
1400 }
1401
1402 /// Materialize all operations.
1403 LogicalResult materializeAll() {
1404 while (!lazyLoadableOpsMap.empty()) {
1405 if (failed(materialize(lazyLoadableOpsMap.begin())))
1406 return failure();
1407 }
1408 return success();
1409 }
1410
1411 /// Finalize the lazy-loading by calling back with every op that hasn't been
1412 /// materialized to let the client decide if the op should be deleted or
1413 /// materialized. The op is materialized if the callback returns true, deleted
1414 /// otherwise.
1415 LogicalResult finalize(function_ref<bool(Operation *)> shouldMaterialize) {
1416 while (!lazyLoadableOps.empty()) {
1417 Operation *op = lazyLoadableOps.begin()->first;
1418 if (shouldMaterialize(op)) {
1419 if (failed(materialize(lazyLoadableOpsMap.find(op))))
1420 return failure();
1421 continue;
1422 }
1423 op->dropAllReferences();
1424 op->erase();
1425 lazyLoadableOps.pop_front();
1426 lazyLoadableOpsMap.erase(op);
1427 }
1428 return success();
1429 }
1430
1431private:
1432 LogicalResult materialize(LazyLoadableOpsMap::iterator it) {
1433 assert(it != lazyLoadableOpsMap.end() &&
1434 "materialize called on non-materializable op");
1435 valueScopes.emplace_back();
1436 std::vector<RegionReadState> regionStack;
1437 regionStack.push_back(std::move(it->getSecond()->second));
1438 lazyLoadableOps.erase(it->getSecond());
1439 lazyLoadableOpsMap.erase(it);
1440
1441 while (!regionStack.empty())
1442 if (failed(parseRegions(regionStack, regionStack.back())))
1443 return failure();
1444 return success();
1445 }
1446
1447 LogicalResult checkSectionAlignment(
1448 unsigned alignment,
1449 function_ref<InFlightDiagnostic(const Twine &error)> emitError) {
1450 // Check that the bytecode buffer meets the requested section alignment.
1451 //
1452 // If it does not, the virtual address of the item in the section will
1453 // not be aligned to the requested alignment.
1454 //
1455 // The typical case where this is necessary is the resource blob
1456 // optimization in `parseAsBlob` where we reference the weights from the
1457 // provided buffer instead of copying them to a new allocation.
1458 const bool isGloballyAligned =
1459 ((uintptr_t)buffer.getBufferStart() & (alignment - 1)) == 0;
1460
1461 if (!isGloballyAligned)
1462 return emitError("expected section alignment ")
1463 << alignment << " but bytecode buffer 0x"
1464 << Twine::utohexstr((uint64_t)buffer.getBufferStart())
1465 << " is not aligned";
1466
1467 return success();
1468 };
1469
1470 /// Return the context for this config.
1471 MLIRContext *getContext() const { return config.getContext(); }
1472
1473 /// Parse the bytecode version.
1474 LogicalResult parseVersion(EncodingReader &reader);
1475
1476 //===--------------------------------------------------------------------===//
1477 // Dialect Section
1478
1479 LogicalResult parseDialectSection(ArrayRef<uint8_t> sectionData);
1480
1481 /// Parse an operation name reference using the given reader, and set the
1482 /// `wasRegistered` flag that indicates if the bytecode was produced by a
1483 /// context where opName was registered.
1484 FailureOr<OperationName> parseOpName(EncodingReader &reader,
1485 std::optional<bool> &wasRegistered);
1486
1487 //===--------------------------------------------------------------------===//
1488 // Attribute/Type Section
1489
1490 /// Parse an attribute or type using the given reader.
1491 template <typename T>
1492 LogicalResult parseAttribute(EncodingReader &reader, T &result) {
1493 return attrTypeReader.parseAttribute(reader, result);
1494 }
1495 LogicalResult parseType(EncodingReader &reader, Type &result) {
1496 return attrTypeReader.parseType(reader, result);
1497 }
1498
1499 //===--------------------------------------------------------------------===//
1500 // Resource Section
1501
1502 LogicalResult
1503 parseResourceSection(EncodingReader &reader,
1504 std::optional<ArrayRef<uint8_t>> resourceData,
1505 std::optional<ArrayRef<uint8_t>> resourceOffsetData);
1506
1507 //===--------------------------------------------------------------------===//
1508 // IR Section
1509
1510 /// This struct represents the current read state of a range of regions. This
1511 /// struct is used to enable iterative parsing of regions.
1512 struct RegionReadState {
1513 RegionReadState(Operation *op, EncodingReader *reader,
1514 bool isIsolatedFromAbove)
1515 : RegionReadState(op->getRegions(), reader, isIsolatedFromAbove) {}
1516 RegionReadState(MutableArrayRef<Region> regions, EncodingReader *reader,
1517 bool isIsolatedFromAbove)
1518 : curRegion(regions.begin()), endRegion(regions.end()), reader(reader),
1519 isIsolatedFromAbove(isIsolatedFromAbove) {}
1520
1521 /// The current regions being read.
1522 MutableArrayRef<Region>::iterator curRegion, endRegion;
1523 /// This is the reader to use for this region, this pointer is pointing to
1524 /// the parent region reader unless the current region is IsolatedFromAbove,
1525 /// in which case the pointer is pointing to the `owningReader` which is a
1526 /// section dedicated to the current region.
1527 EncodingReader *reader;
1528 std::unique_ptr<EncodingReader> owningReader;
1529
1530 /// The number of values defined immediately within this region.
1531 unsigned numValues = 0;
1532
1533 /// The current blocks of the region being read.
1534 SmallVector<Block *> curBlocks;
1535 Region::iterator curBlock = {};
1536
1537 /// The number of operations remaining to be read from the current block
1538 /// being read.
1539 uint64_t numOpsRemaining = 0;
1540
1541 /// A flag indicating if the regions being read are isolated from above.
1542 bool isIsolatedFromAbove = false;
1543 };
1544
1545 LogicalResult parseIRSection(ArrayRef<uint8_t> sectionData, Block *block);
1546 LogicalResult parseRegions(std::vector<RegionReadState> &regionStack,
1547 RegionReadState &readState);
1548 FailureOr<Operation *> parseOpWithoutRegions(EncodingReader &reader,
1549 RegionReadState &readState,
1550 bool &isIsolatedFromAbove);
1551
1552 LogicalResult parseRegion(RegionReadState &readState);
1553 LogicalResult parseBlockHeader(EncodingReader &reader,
1554 RegionReadState &readState);
1555 LogicalResult parseBlockArguments(EncodingReader &reader, Block *block);
1556
1557 //===--------------------------------------------------------------------===//
1558 // Value Processing
1559
1560 /// Parse an operand reference using the given reader. Returns nullptr in the
1561 /// case of failure.
1562 Value parseOperand(EncodingReader &reader);
1563
1564 /// Sequentially define the given value range.
1565 LogicalResult defineValues(EncodingReader &reader, ValueRange values);
1566
1567 /// Create a value to use for a forward reference.
1568 Value createForwardRef();
1569
1570 //===--------------------------------------------------------------------===//
1571 // Use-list order helpers
1572
1573 /// This struct is a simple storage that contains information required to
1574 /// reorder the use-list of a value with respect to the pre-order traversal
1575 /// ordering.
1576 struct UseListOrderStorage {
1577 UseListOrderStorage(bool isIndexPairEncoding,
1578 SmallVector<unsigned, 4> &&indices)
1579 : indices(std::move(indices)),
1580 isIndexPairEncoding(isIndexPairEncoding) {};
1581 /// The vector containing the information required to reorder the
1582 /// use-list of a value.
1583 SmallVector<unsigned, 4> indices;
1584
1585 /// Whether indices represent a pair of type `(src, dst)` or it is a direct
1586 /// indexing, such as `dst = order[src]`.
1587 bool isIndexPairEncoding;
1588 };
1589
1590 /// Parse use-list order from bytecode for a range of values if available. The
1591 /// range is expected to be either a block argument or an op result range. On
1592 /// success, return a map of the position in the range and the use-list order
1593 /// encoding. The function assumes to know the size of the range it is
1594 /// processing.
1595 using UseListMapT = DenseMap<unsigned, UseListOrderStorage>;
1596 FailureOr<UseListMapT> parseUseListOrderForRange(EncodingReader &reader,
1597 uint64_t rangeSize);
1598
1599 /// Shuffle the use-chain according to the order parsed.
1600 LogicalResult sortUseListOrder(Value value);
1601
1602 /// Recursively visit all the values defined within topLevelOp and sort the
1603 /// use-list orders according to the indices parsed.
1604 LogicalResult processUseLists(Operation *topLevelOp);
1605
1606 //===--------------------------------------------------------------------===//
1607 // Fields
1608
1609 /// This class represents a single value scope, in which a value scope is
1610 /// delimited by isolated from above regions.
1611 struct ValueScope {
1612 /// Push a new region state onto this scope, reserving enough values for
1613 /// those defined within the current region of the provided state.
1614 void push(RegionReadState &readState) {
1615 nextValueIDs.push_back(values.size());
1616 values.resize(values.size() + readState.numValues);
1617 }
1618
1619 /// Pop the values defined for the current region within the provided region
1620 /// state.
1621 void pop(RegionReadState &readState) {
1622 values.resize(values.size() - readState.numValues);
1623 nextValueIDs.pop_back();
1624 }
1625
1626 /// The set of values defined in this scope.
1627 std::vector<Value> values;
1628
1629 /// The ID for the next defined value for each region current being
1630 /// processed in this scope.
1631 SmallVector<unsigned, 4> nextValueIDs;
1632 };
1633
1634 /// The configuration of the parser.
1635 const ParserConfig &config;
1636
1637 /// A location to use when emitting errors.
1638 Location fileLoc;
1639
1640 /// Flag that indicates if lazyloading is enabled.
1641 bool lazyLoading;
1642
1643 /// Keep track of operations that have been lazy loaded (their regions haven't
1644 /// been materialized), along with the `RegionReadState` that allows to
1645 /// lazy-load the regions nested under the operation.
1646 LazyLoadableOpsInfo lazyLoadableOps;
1647 LazyLoadableOpsMap lazyLoadableOpsMap;
1648 llvm::function_ref<bool(Operation *)> lazyOpsCallback;
1649
1650 /// The reader used to process attribute and types within the bytecode.
1651 AttrTypeReader attrTypeReader;
1652
1653 /// The version of the bytecode being read.
1654 uint64_t version = 0;
1655
1656 /// The producer of the bytecode being read.
1657 StringRef producer;
1658
1659 /// The table of IR units referenced within the bytecode file.
1660 SmallVector<std::unique_ptr<BytecodeDialect>> dialects;
1661 llvm::StringMap<BytecodeDialect *> dialectsMap;
1662 SmallVector<BytecodeOperationName> opNames;
1663
1664 /// The reader used to process resources within the bytecode.
1665 ResourceSectionReader resourceReader;
1666
1667 /// Worklist of values with custom use-list orders to process before the end
1668 /// of the parsing.
1669 DenseMap<void *, UseListOrderStorage> valueToUseListMap;
1670
1671 /// The table of strings referenced within the bytecode file.
1672 StringSectionReader stringReader;
1673
1674 /// The table of properties referenced by the operation in the bytecode file.
1675 PropertiesSectionReader propertiesReader;
1676
1677 /// The current set of available IR value scopes.
1678 std::vector<ValueScope> valueScopes;
1679
1680 /// The global pre-order operation ordering.
1682
1683 /// A block containing the set of operations defined to create forward
1684 /// references.
1685 Block forwardRefOps;
1686
1687 /// A block containing previously created, and no longer used, forward
1688 /// reference operations.
1689 Block openForwardRefOps;
1690
1691 /// An operation state used when instantiating forward references.
1692 OperationState forwardRefOpState;
1693
1694 /// Reference to the input buffer.
1695 llvm::MemoryBufferRef buffer;
1696
1697 /// The optional owning source manager, which when present may be used to
1698 /// extend the lifetime of the input buffer.
1699 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef;
1700};
1701
1703 Block *block, llvm::function_ref<bool(Operation *)> lazyOpsCallback) {
1704 EncodingReader reader(buffer.getBuffer(), fileLoc);
1705 this->lazyOpsCallback = lazyOpsCallback;
1706 auto resetlazyOpsCallback =
1707 llvm::make_scope_exit([&] { this->lazyOpsCallback = nullptr; });
1708
1709 // Skip over the bytecode header, this should have already been checked.
1710 if (failed(reader.skipBytes(StringRef("ML\xefR").size())))
1711 return failure();
1712 // Parse the bytecode version and producer.
1713 if (failed(parseVersion(reader)) ||
1714 failed(reader.parseNullTerminatedString(producer)))
1715 return failure();
1716
1717 // Add a diagnostic handler that attaches a note that includes the original
1718 // producer of the bytecode.
1719 ScopedDiagnosticHandler diagHandler(getContext(), [&](Diagnostic &diag) {
1720 diag.attachNote() << "in bytecode version " << version
1721 << " produced by: " << producer;
1722 return failure();
1723 });
1724
1725 const auto checkSectionAlignment = [&](unsigned alignment) {
1726 return this->checkSectionAlignment(
1727 alignment, [&](const auto &msg) { return reader.emitError(msg); });
1728 };
1729
1730 // Parse the raw data for each of the top-level sections of the bytecode.
1731 std::optional<ArrayRef<uint8_t>>
1732 sectionDatas[bytecode::Section::kNumSections];
1733 while (!reader.empty()) {
1734 // Read the next section from the bytecode.
1735 bytecode::Section::ID sectionID;
1736 ArrayRef<uint8_t> sectionData;
1737 if (failed(
1738 reader.parseSection(sectionID, checkSectionAlignment, sectionData)))
1739 return failure();
1740
1741 // Check for duplicate sections, we only expect one instance of each.
1742 if (sectionDatas[sectionID]) {
1743 return reader.emitError("duplicate top-level section: ",
1744 ::toString(sectionID));
1745 }
1746 sectionDatas[sectionID] = sectionData;
1747 }
1748 // Check that all of the required sections were found.
1749 for (int i = 0; i < bytecode::Section::kNumSections; ++i) {
1750 bytecode::Section::ID sectionID = static_cast<bytecode::Section::ID>(i);
1751 if (!sectionDatas[i] && !isSectionOptional(sectionID, version)) {
1752 return reader.emitError("missing data for top-level section: ",
1753 ::toString(sectionID));
1754 }
1755 }
1756
1757 // Process the string section first.
1758 if (failed(stringReader.initialize(
1759 fileLoc, *sectionDatas[bytecode::Section::kString])))
1760 return failure();
1761
1762 // Process the properties section.
1763 if (sectionDatas[bytecode::Section::kProperties] &&
1764 failed(propertiesReader.initialize(
1765 fileLoc, *sectionDatas[bytecode::Section::kProperties])))
1766 return failure();
1767
1768 // Process the dialect section.
1769 if (failed(parseDialectSection(*sectionDatas[bytecode::Section::kDialect])))
1770 return failure();
1771
1772 // Process the resource section if present.
1773 if (failed(parseResourceSection(
1774 reader, sectionDatas[bytecode::Section::kResource],
1775 sectionDatas[bytecode::Section::kResourceOffset])))
1776 return failure();
1777
1778 // Process the attribute and type section.
1779 if (failed(attrTypeReader.initialize(
1780 dialects, *sectionDatas[bytecode::Section::kAttrType],
1781 *sectionDatas[bytecode::Section::kAttrTypeOffset])))
1782 return failure();
1783
1784 // Finally, process the IR section.
1785 return parseIRSection(*sectionDatas[bytecode::Section::kIR], block);
1786}
1787
1788LogicalResult BytecodeReader::Impl::parseVersion(EncodingReader &reader) {
1789 if (failed(reader.parseVarInt(version)))
1790 return failure();
1791
1792 // Validate the bytecode version.
1793 uint64_t currentVersion = bytecode::kVersion;
1794 uint64_t minSupportedVersion = bytecode::kMinSupportedVersion;
1795 if (version < minSupportedVersion) {
1796 return reader.emitError("bytecode version ", version,
1797 " is older than the current version of ",
1798 currentVersion, ", and upgrade is not supported");
1799 }
1800 if (version > currentVersion) {
1801 return reader.emitError("bytecode version ", version,
1802 " is newer than the current version ",
1803 currentVersion);
1804 }
1805 // Override any request to lazy-load if the bytecode version is too old.
1806 if (version < bytecode::kLazyLoading)
1807 lazyLoading = false;
1808 return success();
1809}
1810
1811//===----------------------------------------------------------------------===//
1812// Dialect Section
1813//===----------------------------------------------------------------------===//
1814
1815LogicalResult BytecodeDialect::load(const DialectReader &reader,
1816 MLIRContext *ctx) {
1817 if (dialect)
1818 return success();
1819 Dialect *loadedDialect = ctx->getOrLoadDialect(name);
1820 if (!loadedDialect && !ctx->allowsUnregisteredDialects()) {
1821 return reader.emitError("dialect '")
1822 << name
1823 << "' is unknown. If this is intended, please call "
1824 "allowUnregisteredDialects() on the MLIRContext, or use "
1825 "-allow-unregistered-dialect with the MLIR tool used.";
1826 }
1827 dialect = loadedDialect;
1828
1829 // If the dialect was actually loaded, check to see if it has a bytecode
1830 // interface.
1831 if (loadedDialect)
1832 interface = dyn_cast<BytecodeDialectInterface>(loadedDialect);
1833 if (!versionBuffer.empty()) {
1834 if (!interface)
1835 return reader.emitError("dialect '")
1836 << name
1837 << "' does not implement the bytecode interface, "
1838 "but found a version entry";
1839 EncodingReader encReader(versionBuffer, reader.getLoc());
1840 DialectReader versionReader = reader.withEncodingReader(encReader);
1841 loadedVersion = interface->readVersion(versionReader);
1842 if (!loadedVersion)
1843 return failure();
1844 }
1845 return success();
1846}
1847
1848LogicalResult
1849BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) {
1850 EncodingReader sectionReader(sectionData, fileLoc);
1851
1852 // Parse the number of dialects in the section.
1853 uint64_t numDialects;
1854 if (failed(sectionReader.parseVarInt(numDialects)))
1855 return failure();
1856 dialects.resize(numDialects);
1857
1858 const auto checkSectionAlignment = [&](unsigned alignment) {
1859 return this->checkSectionAlignment(alignment, [&](const auto &msg) {
1860 return sectionReader.emitError(msg);
1861 });
1862 };
1863
1864 // Parse each of the dialects.
1865 for (uint64_t i = 0; i < numDialects; ++i) {
1866 dialects[i] = std::make_unique<BytecodeDialect>();
1867 /// Before version kDialectVersioning, there wasn't any versioning available
1868 /// for dialects, and the entryIdx represent the string itself.
1869 if (version < bytecode::kDialectVersioning) {
1870 if (failed(stringReader.parseString(sectionReader, dialects[i]->name)))
1871 return failure();
1872 continue;
1873 }
1874
1875 // Parse ID representing dialect and version.
1876 uint64_t dialectNameIdx;
1877 bool versionAvailable;
1878 if (failed(sectionReader.parseVarIntWithFlag(dialectNameIdx,
1879 versionAvailable)))
1880 return failure();
1881 if (failed(stringReader.parseStringAtIndex(sectionReader, dialectNameIdx,
1882 dialects[i]->name)))
1883 return failure();
1884 if (versionAvailable) {
1885 bytecode::Section::ID sectionID;
1886 if (failed(sectionReader.parseSection(sectionID, checkSectionAlignment,
1887 dialects[i]->versionBuffer)))
1888 return failure();
1889 if (sectionID != bytecode::Section::kDialectVersions) {
1890 emitError(fileLoc, "expected dialect version section");
1891 return failure();
1892 }
1893 }
1894 dialectsMap[dialects[i]->name] = dialects[i].get();
1895 }
1896
1897 // Parse the operation names, which are grouped by dialect.
1898 auto parseOpName = [&](BytecodeDialect *dialect) {
1899 StringRef opName;
1900 std::optional<bool> wasRegistered;
1901 // Prior to version kNativePropertiesEncoding, the information about wheter
1902 // an op was registered or not wasn't encoded.
1904 if (failed(stringReader.parseString(sectionReader, opName)))
1905 return failure();
1906 } else {
1907 bool wasRegisteredFlag;
1908 if (failed(stringReader.parseStringWithFlag(sectionReader, opName,
1909 wasRegisteredFlag)))
1910 return failure();
1911 wasRegistered = wasRegisteredFlag;
1912 }
1913 opNames.emplace_back(dialect, opName, wasRegistered);
1914 return success();
1915 };
1916 // Avoid re-allocation in bytecode version >=kElideUnknownBlockArgLocation
1917 // where the number of ops are known.
1919 uint64_t numOps;
1920 if (failed(sectionReader.parseVarInt(numOps)))
1921 return failure();
1922 opNames.reserve(numOps);
1923 }
1924 while (!sectionReader.empty())
1925 if (failed(parseDialectGrouping(sectionReader, dialects, parseOpName)))
1926 return failure();
1927 return success();
1928}
1929
1930FailureOr<OperationName>
1931BytecodeReader::Impl::parseOpName(EncodingReader &reader,
1932 std::optional<bool> &wasRegistered) {
1933 BytecodeOperationName *opName = nullptr;
1934 if (failed(parseEntry(reader, opNames, opName, "operation name")))
1935 return failure();
1936 wasRegistered = opName->wasRegistered;
1937 // Check to see if this operation name has already been resolved. If we
1938 // haven't, load the dialect and build the operation name.
1939 if (!opName->opName) {
1940 // If the opName is empty, this is because we use to accept names such as
1941 // `foo` without any `.` separator. We shouldn't tolerate this in textual
1942 // format anymore but for now we'll be backward compatible. This can only
1943 // happen with unregistered dialects.
1944 if (opName->name.empty()) {
1945 opName->opName.emplace(opName->dialect->name, getContext());
1946 } else {
1947 // Load the dialect and its version.
1948 DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
1949 dialectsMap, reader, version);
1950 if (failed(opName->dialect->load(dialectReader, getContext())))
1951 return failure();
1952 opName->opName.emplace((opName->dialect->name + "." + opName->name).str(),
1953 getContext());
1954 }
1955 }
1956 return *opName->opName;
1957}
1958
1959//===----------------------------------------------------------------------===//
1960// Resource Section
1961//===----------------------------------------------------------------------===//
1962
1963LogicalResult BytecodeReader::Impl::parseResourceSection(
1964 EncodingReader &reader, std::optional<ArrayRef<uint8_t>> resourceData,
1965 std::optional<ArrayRef<uint8_t>> resourceOffsetData) {
1966 // Ensure both sections are either present or not.
1967 if (resourceData.has_value() != resourceOffsetData.has_value()) {
1968 if (resourceOffsetData)
1969 return emitError(fileLoc, "unexpected resource offset section when "
1970 "resource section is not present");
1971 return emitError(
1972 fileLoc,
1973 "expected resource offset section when resource section is present");
1974 }
1975
1976 // If the resource sections are absent, there is nothing to do.
1977 if (!resourceData)
1978 return success();
1979
1980 // Initialize the resource reader with the resource sections.
1981 DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
1982 dialectsMap, reader, version);
1983 return resourceReader.initialize(fileLoc, config, dialects, stringReader,
1984 *resourceData, *resourceOffsetData,
1985 dialectReader, bufferOwnerRef);
1986}
1987
1988//===----------------------------------------------------------------------===//
1989// UseListOrder Helpers
1990//===----------------------------------------------------------------------===//
1991
1992FailureOr<BytecodeReader::Impl::UseListMapT>
1993BytecodeReader::Impl::parseUseListOrderForRange(EncodingReader &reader,
1994 uint64_t numResults) {
1995 BytecodeReader::Impl::UseListMapT map;
1996 uint64_t numValuesToRead = 1;
1997 if (numResults > 1 && failed(reader.parseVarInt(numValuesToRead)))
1998 return failure();
1999
2000 for (size_t valueIdx = 0; valueIdx < numValuesToRead; valueIdx++) {
2001 uint64_t resultIdx = 0;
2002 if (numResults > 1 && failed(reader.parseVarInt(resultIdx)))
2003 return failure();
2004
2005 uint64_t numValues;
2006 bool indexPairEncoding;
2007 if (failed(reader.parseVarIntWithFlag(numValues, indexPairEncoding)))
2008 return failure();
2009
2010 SmallVector<unsigned, 4> useListOrders;
2011 for (size_t idx = 0; idx < numValues; idx++) {
2012 uint64_t index;
2013 if (failed(reader.parseVarInt(index)))
2014 return failure();
2015 useListOrders.push_back(index);
2016 }
2017
2018 // Store in a map the result index
2019 map.try_emplace(resultIdx, UseListOrderStorage(indexPairEncoding,
2020 std::move(useListOrders)));
2021 }
2022
2023 return map;
2024}
2025
2026/// Sorts each use according to the order specified in the use-list parsed. If
2027/// the custom use-list is not found, this means that the order needs to be
2028/// consistent with the reverse pre-order walk of the IR. If multiple uses lie
2029/// on the same operation, the order will follow the reverse operand number
2030/// ordering.
2031LogicalResult BytecodeReader::Impl::sortUseListOrder(Value value) {
2032 // Early return for trivial use-lists.
2033 if (value.use_empty() || value.hasOneUse())
2034 return success();
2035
2036 bool hasIncomingOrder =
2037 valueToUseListMap.contains(value.getAsOpaquePointer());
2038
2039 // Compute the current order of the use-list with respect to the global
2040 // ordering. Detect if the order is already sorted while doing so.
2041 bool alreadySorted = true;
2042 auto &firstUse = *value.use_begin();
2043 uint64_t prevID =
2044 bytecode::getUseID(firstUse, operationIDs.at(firstUse.getOwner()));
2045 llvm::SmallVector<std::pair<unsigned, uint64_t>> currentOrder = {{0, prevID}};
2046 for (auto item : llvm::drop_begin(llvm::enumerate(value.getUses()))) {
2047 uint64_t currentID = bytecode::getUseID(
2048 item.value(), operationIDs.at(item.value().getOwner()));
2049 alreadySorted &= prevID > currentID;
2050 currentOrder.push_back({item.index(), currentID});
2051 prevID = currentID;
2052 }
2053
2054 // If the order is already sorted, and there wasn't a custom order to apply
2055 // from the bytecode file, we are done.
2056 if (alreadySorted && !hasIncomingOrder)
2057 return success();
2058
2059 // If not already sorted, sort the indices of the current order by descending
2060 // useIDs.
2061 if (!alreadySorted)
2062 std::sort(
2063 currentOrder.begin(), currentOrder.end(),
2064 [](auto elem1, auto elem2) { return elem1.second > elem2.second; });
2065
2066 if (!hasIncomingOrder) {
2067 // If the bytecode file did not contain any custom use-list order, it means
2068 // that the order was descending useID. Hence, shuffle by the first index
2069 // of the `currentOrder` pair.
2070 SmallVector<unsigned> shuffle(llvm::make_first_range(currentOrder));
2071 value.shuffleUseList(shuffle);
2072 return success();
2073 }
2074
2075 // Pull the custom order info from the map.
2076 UseListOrderStorage customOrder =
2077 valueToUseListMap.at(value.getAsOpaquePointer());
2078 SmallVector<unsigned, 4> shuffle = std::move(customOrder.indices);
2079 uint64_t numUses = value.getNumUses();
2080
2081 // If the encoding was a pair of indices `(src, dst)` for every permutation,
2082 // reconstruct the shuffle vector for every use. Initialize the shuffle vector
2083 // as identity, and then apply the mapping encoded in the indices.
2084 if (customOrder.isIndexPairEncoding) {
2085 // Return failure if the number of indices was not representing pairs.
2086 if (shuffle.size() & 1)
2087 return failure();
2088
2089 SmallVector<unsigned, 4> newShuffle(numUses);
2090 size_t idx = 0;
2091 std::iota(newShuffle.begin(), newShuffle.end(), idx);
2092 for (idx = 0; idx < shuffle.size(); idx += 2)
2093 newShuffle[shuffle[idx]] = shuffle[idx + 1];
2094
2095 shuffle = std::move(newShuffle);
2096 }
2097
2098 // Make sure that the indices represent a valid mapping. That is, the sum of
2099 // all the values needs to be equal to (numUses - 1) * numUses / 2, and no
2100 // duplicates are allowed in the list.
2102 uint64_t accumulator = 0;
2103 for (const auto &elem : shuffle) {
2104 if (!set.insert(elem).second)
2105 return failure();
2106 accumulator += elem;
2107 }
2108 if (numUses != shuffle.size() ||
2109 accumulator != (((numUses - 1) * numUses) >> 1))
2110 return failure();
2111
2112 // Apply the current ordering map onto the shuffle vector to get the final
2113 // use-list sorting indices before shuffling.
2114 shuffle = SmallVector<unsigned, 4>(llvm::map_range(
2115 currentOrder, [&](auto item) { return shuffle[item.first]; }));
2116 value.shuffleUseList(shuffle);
2117 return success();
2118}
2119
2120LogicalResult BytecodeReader::Impl::processUseLists(Operation *topLevelOp) {
2121 // Precompute operation IDs according to the pre-order walk of the IR. We
2122 // can't do this while parsing since parseRegions ordering is not strictly
2123 // equal to the pre-order walk.
2124 unsigned operationID = 0;
2125 topLevelOp->walk<mlir::WalkOrder::PreOrder>(
2126 [&](Operation *op) { operationIDs.try_emplace(op, operationID++); });
2127
2128 auto blockWalk = topLevelOp->walk([this](Block *block) {
2129 for (auto arg : block->getArguments())
2130 if (failed(sortUseListOrder(arg)))
2131 return WalkResult::interrupt();
2132 return WalkResult::advance();
2133 });
2134
2135 auto resultWalk = topLevelOp->walk([this](Operation *op) {
2136 for (auto result : op->getResults())
2137 if (failed(sortUseListOrder(result)))
2138 return WalkResult::interrupt();
2139 return WalkResult::advance();
2140 });
2141
2142 return failure(blockWalk.wasInterrupted() || resultWalk.wasInterrupted());
2143}
2144
2145//===----------------------------------------------------------------------===//
2146// IR Section
2147//===----------------------------------------------------------------------===//
2148
2149LogicalResult
2150BytecodeReader::Impl::parseIRSection(ArrayRef<uint8_t> sectionData,
2151 Block *block) {
2152 EncodingReader reader(sectionData, fileLoc);
2153
2154 // A stack of operation regions currently being read from the bytecode.
2155 std::vector<RegionReadState> regionStack;
2156
2157 // Parse the top-level block using a temporary module operation.
2158 OwningOpRef<ModuleOp> moduleOp = ModuleOp::create(fileLoc);
2159 regionStack.emplace_back(*moduleOp, &reader, /*isIsolatedFromAbove=*/true);
2160 regionStack.back().curBlocks.push_back(moduleOp->getBody());
2161 regionStack.back().curBlock = regionStack.back().curRegion->begin();
2162 if (failed(parseBlockHeader(reader, regionStack.back())))
2163 return failure();
2164 valueScopes.emplace_back();
2165 valueScopes.back().push(regionStack.back());
2166
2167 // Iteratively parse regions until everything has been resolved.
2168 while (!regionStack.empty())
2169 if (failed(parseRegions(regionStack, regionStack.back())))
2170 return failure();
2171 if (!forwardRefOps.empty()) {
2172 return reader.emitError(
2173 "not all forward unresolved forward operand references");
2174 }
2175
2176 // Sort use-lists according to what specified in bytecode.
2177 if (failed(processUseLists(*moduleOp)))
2178 return reader.emitError(
2179 "parsed use-list orders were invalid and could not be applied");
2180
2181 // Resolve dialect version.
2182 for (const std::unique_ptr<BytecodeDialect> &byteCodeDialect : dialects) {
2183 // Parsing is complete, give an opportunity to each dialect to visit the
2184 // IR and perform upgrades.
2185 if (!byteCodeDialect->loadedVersion)
2186 continue;
2187 if (byteCodeDialect->interface &&
2188 failed(byteCodeDialect->interface->upgradeFromVersion(
2189 *moduleOp, *byteCodeDialect->loadedVersion)))
2190 return failure();
2191 }
2192
2193 // Verify that the parsed operations are valid.
2194 if (config.shouldVerifyAfterParse() && failed(verify(*moduleOp)))
2195 return failure();
2196
2197 // Splice the parsed operations over to the provided top-level block.
2198 auto &parsedOps = moduleOp->getBody()->getOperations();
2199 auto &destOps = block->getOperations();
2200 destOps.splice(destOps.end(), parsedOps, parsedOps.begin(), parsedOps.end());
2201 return success();
2202}
2203
2204LogicalResult
2205BytecodeReader::Impl::parseRegions(std::vector<RegionReadState> &regionStack,
2206 RegionReadState &readState) {
2207 const auto checkSectionAlignment = [&](unsigned alignment) {
2208 return this->checkSectionAlignment(
2209 alignment, [&](const auto &msg) { return emitError(fileLoc, msg); });
2210 };
2211
2212 // Process regions, blocks, and operations until the end or if a nested
2213 // region is encountered. In this case we push a new state in regionStack and
2214 // return, the processing of the current region will resume afterward.
2215 for (; readState.curRegion != readState.endRegion; ++readState.curRegion) {
2216 // If the current block hasn't been setup yet, parse the header for this
2217 // region. The current block is already setup when this function was
2218 // interrupted to recurse down in a nested region and we resume the current
2219 // block after processing the nested region.
2220 if (readState.curBlock == Region::iterator()) {
2221 if (failed(parseRegion(readState)))
2222 return failure();
2223
2224 // If the region is empty, there is nothing to more to do.
2225 if (readState.curRegion->empty())
2226 continue;
2227 }
2228
2229 // Parse the blocks within the region.
2230 EncodingReader &reader = *readState.reader;
2231 do {
2232 while (readState.numOpsRemaining--) {
2233 // Read in the next operation. We don't read its regions directly, we
2234 // handle those afterwards as necessary.
2235 bool isIsolatedFromAbove = false;
2236 FailureOr<Operation *> op =
2237 parseOpWithoutRegions(reader, readState, isIsolatedFromAbove);
2238 if (failed(op))
2239 return failure();
2240
2241 // If the op has regions, add it to the stack for processing and return:
2242 // we stop the processing of the current region and resume it after the
2243 // inner one is completed. Unless LazyLoading is activated in which case
2244 // nested region parsing is delayed.
2245 if ((*op)->getNumRegions()) {
2246 RegionReadState childState(*op, &reader, isIsolatedFromAbove);
2247
2248 // Isolated regions are encoded as a section in version 2 and above.
2249 if (version >= bytecode::kLazyLoading && isIsolatedFromAbove) {
2250 bytecode::Section::ID sectionID;
2251 ArrayRef<uint8_t> sectionData;
2252 if (failed(reader.parseSection(sectionID, checkSectionAlignment,
2253 sectionData)))
2254 return failure();
2255 if (sectionID != bytecode::Section::kIR)
2256 return emitError(fileLoc, "expected IR section for region");
2257 childState.owningReader =
2258 std::make_unique<EncodingReader>(sectionData, fileLoc);
2259 childState.reader = childState.owningReader.get();
2260
2261 // If the user has a callback set, they have the opportunity to
2262 // control lazyloading as we go.
2263 if (lazyLoading && (!lazyOpsCallback || !lazyOpsCallback(*op))) {
2264 lazyLoadableOps.emplace_back(*op, std::move(childState));
2265 lazyLoadableOpsMap.try_emplace(*op,
2266 std::prev(lazyLoadableOps.end()));
2267 continue;
2268 }
2269 }
2270 regionStack.push_back(std::move(childState));
2271
2272 // If the op is isolated from above, push a new value scope.
2273 if (isIsolatedFromAbove)
2274 valueScopes.emplace_back();
2275 return success();
2276 }
2277 }
2278
2279 // Move to the next block of the region.
2280 if (++readState.curBlock == readState.curRegion->end())
2281 break;
2282 if (failed(parseBlockHeader(reader, readState)))
2283 return failure();
2284 } while (true);
2285
2286 // Reset the current block and any values reserved for this region.
2287 readState.curBlock = {};
2288 valueScopes.back().pop(readState);
2289 }
2290
2291 // When the regions have been fully parsed, pop them off of the read stack. If
2292 // the regions were isolated from above, we also pop the last value scope.
2293 if (readState.isIsolatedFromAbove) {
2294 assert(!valueScopes.empty() && "Expect a valueScope after reading region");
2295 valueScopes.pop_back();
2296 }
2297 assert(!regionStack.empty() && "Expect a regionStack after reading region");
2298 regionStack.pop_back();
2299 return success();
2300}
2301
2302FailureOr<Operation *>
2303BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader,
2304 RegionReadState &readState,
2305 bool &isIsolatedFromAbove) {
2306 // Parse the name of the operation.
2307 std::optional<bool> wasRegistered;
2308 FailureOr<OperationName> opName = parseOpName(reader, wasRegistered);
2309 if (failed(opName))
2310 return failure();
2311
2312 // Parse the operation mask, which indicates which components of the operation
2313 // are present.
2314 uint8_t opMask;
2315 if (failed(reader.parseByte(opMask)))
2316 return failure();
2317
2318 /// Parse the location.
2319 LocationAttr opLoc;
2320 if (failed(parseAttribute(reader, opLoc)))
2321 return failure();
2322
2323 // With the location and name resolved, we can start building the operation
2324 // state.
2325 OperationState opState(opLoc, *opName);
2326
2327 // Parse the attributes of the operation.
2329 DictionaryAttr dictAttr;
2330 if (failed(parseAttribute(reader, dictAttr)))
2331 return failure();
2332 opState.attributes = dictAttr;
2333 }
2334
2336 // kHasProperties wasn't emitted in older bytecode, we should never get
2337 // there without also having the `wasRegistered` flag available.
2338 if (!wasRegistered)
2339 return emitError(fileLoc,
2340 "Unexpected missing `wasRegistered` opname flag at "
2341 "bytecode version ")
2342 << version << " with properties.";
2343 // When an operation is emitted without being registered, the properties are
2344 // stored as an attribute. Otherwise the op must implement the bytecode
2345 // interface and control the serialization.
2346 if (wasRegistered) {
2347 DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
2348 dialectsMap, reader, version);
2349 if (failed(
2350 propertiesReader.read(fileLoc, dialectReader, &*opName, opState)))
2351 return failure();
2352 } else {
2353 // If the operation wasn't registered when it was emitted, the properties
2354 // was serialized as an attribute.
2355 if (failed(parseAttribute(reader, opState.propertiesAttr)))
2356 return failure();
2357 }
2358 }
2359
2360 /// Parse the results of the operation.
2362 uint64_t numResults;
2363 if (failed(reader.parseVarInt(numResults)))
2364 return failure();
2365 opState.types.resize(numResults);
2366 for (int i = 0, e = numResults; i < e; ++i)
2367 if (failed(parseType(reader, opState.types[i])))
2368 return failure();
2369 }
2370
2371 /// Parse the operands of the operation.
2373 uint64_t numOperands;
2374 if (failed(reader.parseVarInt(numOperands)))
2375 return failure();
2376 opState.operands.resize(numOperands);
2377 for (int i = 0, e = numOperands; i < e; ++i)
2378 if (!(opState.operands[i] = parseOperand(reader)))
2379 return failure();
2380 }
2381
2382 /// Parse the successors of the operation.
2384 uint64_t numSuccs;
2385 if (failed(reader.parseVarInt(numSuccs)))
2386 return failure();
2387 opState.successors.resize(numSuccs);
2388 for (int i = 0, e = numSuccs; i < e; ++i) {
2389 if (failed(parseEntry(reader, readState.curBlocks, opState.successors[i],
2390 "successor")))
2391 return failure();
2392 }
2393 }
2394
2395 /// Parse the use-list orders for the results of the operation. Use-list
2396 /// orders are available since version 3 of the bytecode.
2397 std::optional<UseListMapT> resultIdxToUseListMap = std::nullopt;
2398 if (version >= bytecode::kUseListOrdering &&
2400 size_t numResults = opState.types.size();
2401 auto parseResult = parseUseListOrderForRange(reader, numResults);
2402 if (failed(parseResult))
2403 return failure();
2404 resultIdxToUseListMap = std::move(*parseResult);
2405 }
2406
2407 /// Parse the regions of the operation.
2409 uint64_t numRegions;
2410 if (failed(reader.parseVarIntWithFlag(numRegions, isIsolatedFromAbove)))
2411 return failure();
2412
2413 opState.regions.reserve(numRegions);
2414 for (int i = 0, e = numRegions; i < e; ++i)
2415 opState.regions.push_back(std::make_unique<Region>());
2416 }
2417
2418 // Create the operation at the back of the current block.
2419 Operation *op = Operation::create(opState);
2420 readState.curBlock->push_back(op);
2421
2422 // If the operation had results, update the value references. We don't need to
2423 // do this if the current value scope is empty. That is, the op was not
2424 // encoded within a parent region.
2425 if (readState.numValues && op->getNumResults() &&
2426 failed(defineValues(reader, op->getResults())))
2427 return failure();
2428
2429 /// Store a map for every value that received a custom use-list order from the
2430 /// bytecode file.
2431 if (resultIdxToUseListMap.has_value()) {
2432 for (size_t idx = 0; idx < op->getNumResults(); idx++) {
2433 if (resultIdxToUseListMap->contains(idx)) {
2434 valueToUseListMap.try_emplace(op->getResult(idx).getAsOpaquePointer(),
2435 resultIdxToUseListMap->at(idx));
2436 }
2437 }
2438 }
2439 return op;
2440}
2441
2442LogicalResult BytecodeReader::Impl::parseRegion(RegionReadState &readState) {
2443 EncodingReader &reader = *readState.reader;
2444
2445 // Parse the number of blocks in the region.
2446 uint64_t numBlocks;
2447 if (failed(reader.parseVarInt(numBlocks)))
2448 return failure();
2449
2450 // If the region is empty, there is nothing else to do.
2451 if (numBlocks == 0)
2452 return success();
2453
2454 // Parse the number of values defined in this region.
2455 uint64_t numValues;
2456 if (failed(reader.parseVarInt(numValues)))
2457 return failure();
2458 readState.numValues = numValues;
2459
2460 // Create the blocks within this region. We do this before processing so that
2461 // we can rely on the blocks existing when creating operations.
2462 readState.curBlocks.clear();
2463 readState.curBlocks.reserve(numBlocks);
2464 for (uint64_t i = 0; i < numBlocks; ++i) {
2465 readState.curBlocks.push_back(new Block());
2466 readState.curRegion->push_back(readState.curBlocks.back());
2467 }
2468
2469 // Prepare the current value scope for this region.
2470 valueScopes.back().push(readState);
2471
2472 // Parse the entry block of the region.
2473 readState.curBlock = readState.curRegion->begin();
2474 return parseBlockHeader(reader, readState);
2475}
2476
2477LogicalResult
2478BytecodeReader::Impl::parseBlockHeader(EncodingReader &reader,
2479 RegionReadState &readState) {
2480 bool hasArgs;
2481 if (failed(reader.parseVarIntWithFlag(readState.numOpsRemaining, hasArgs)))
2482 return failure();
2483
2484 // Parse the arguments of the block.
2485 if (hasArgs && failed(parseBlockArguments(reader, &*readState.curBlock)))
2486 return failure();
2487
2488 // Uselist orders are available since version 3 of the bytecode.
2489 if (version < bytecode::kUseListOrdering)
2490 return success();
2491
2492 uint8_t hasUseListOrders = 0;
2493 if (hasArgs && failed(reader.parseByte(hasUseListOrders)))
2494 return failure();
2495
2496 if (!hasUseListOrders)
2497 return success();
2498
2499 Block &blk = *readState.curBlock;
2500 auto argIdxToUseListMap =
2501 parseUseListOrderForRange(reader, blk.getNumArguments());
2502 if (failed(argIdxToUseListMap) || argIdxToUseListMap->empty())
2503 return failure();
2504
2505 for (size_t idx = 0; idx < blk.getNumArguments(); idx++)
2506 if (argIdxToUseListMap->contains(idx))
2507 valueToUseListMap.try_emplace(blk.getArgument(idx).getAsOpaquePointer(),
2508 argIdxToUseListMap->at(idx));
2509
2510 // We don't parse the operations of the block here, that's done elsewhere.
2511 return success();
2512}
2513
2514LogicalResult BytecodeReader::Impl::parseBlockArguments(EncodingReader &reader,
2515 Block *block) {
2516 // Parse the value ID for the first argument, and the number of arguments.
2517 uint64_t numArgs;
2518 if (failed(reader.parseVarInt(numArgs)))
2519 return failure();
2520
2521 SmallVector<Type> argTypes;
2522 SmallVector<Location> argLocs;
2523 argTypes.reserve(numArgs);
2524 argLocs.reserve(numArgs);
2525
2526 Location unknownLoc = UnknownLoc::get(config.getContext());
2527 while (numArgs--) {
2528 Type argType;
2529 LocationAttr argLoc = unknownLoc;
2531 // Parse the type with hasLoc flag to determine if it has type.
2532 uint64_t typeIdx;
2533 bool hasLoc;
2534 if (failed(reader.parseVarIntWithFlag(typeIdx, hasLoc)) ||
2535 !(argType = attrTypeReader.resolveType(typeIdx)))
2536 return failure();
2537 if (hasLoc && failed(parseAttribute(reader, argLoc)))
2538 return failure();
2539 } else {
2540 // All args has type and location.
2541 if (failed(parseType(reader, argType)) ||
2542 failed(parseAttribute(reader, argLoc)))
2543 return failure();
2544 }
2545 argTypes.push_back(argType);
2546 argLocs.push_back(argLoc);
2547 }
2548 block->addArguments(argTypes, argLocs);
2549 return defineValues(reader, block->getArguments());
2550}
2551
2552//===----------------------------------------------------------------------===//
2553// Value Processing
2554//===----------------------------------------------------------------------===//
2555
2556Value BytecodeReader::Impl::parseOperand(EncodingReader &reader) {
2557 std::vector<Value> &values = valueScopes.back().values;
2558 Value *value = nullptr;
2559 if (failed(parseEntry(reader, values, value, "value")))
2560 return Value();
2561
2562 // Create a new forward reference if necessary.
2563 if (!*value)
2564 *value = createForwardRef();
2565 return *value;
2566}
2567
2568LogicalResult BytecodeReader::Impl::defineValues(EncodingReader &reader,
2569 ValueRange newValues) {
2570 ValueScope &valueScope = valueScopes.back();
2571 std::vector<Value> &values = valueScope.values;
2572
2573 unsigned &valueID = valueScope.nextValueIDs.back();
2574 unsigned valueIDEnd = valueID + newValues.size();
2575 if (valueIDEnd > values.size()) {
2576 return reader.emitError(
2577 "value index range was outside of the expected range for "
2578 "the parent region, got [",
2579 valueID, ", ", valueIDEnd, "), but the maximum index was ",
2580 values.size() - 1);
2581 }
2582
2583 // Assign the values and update any forward references.
2584 for (unsigned i = 0, e = newValues.size(); i != e; ++i, ++valueID) {
2585 Value newValue = newValues[i];
2586
2587 // Check to see if a definition for this value already exists.
2588 if (Value oldValue = std::exchange(values[valueID], newValue)) {
2589 Operation *forwardRefOp = oldValue.getDefiningOp();
2590
2591 // Assert that this is a forward reference operation. Given how we compute
2592 // definition ids (incrementally as we parse), it shouldn't be possible
2593 // for the value to be defined any other way.
2594 assert(forwardRefOp && forwardRefOp->getBlock() == &forwardRefOps &&
2595 "value index was already defined?");
2596
2597 oldValue.replaceAllUsesWith(newValue);
2598 forwardRefOp->moveBefore(&openForwardRefOps, openForwardRefOps.end());
2599 }
2600 }
2601 return success();
2602}
2603
2604Value BytecodeReader::Impl::createForwardRef() {
2605 // Check for an available existing operation to use. Otherwise, create a new
2606 // fake operation to use for the reference.
2607 if (!openForwardRefOps.empty()) {
2608 Operation *op = &openForwardRefOps.back();
2609 op->moveBefore(&forwardRefOps, forwardRefOps.end());
2610 } else {
2611 forwardRefOps.push_back(Operation::create(forwardRefOpState));
2612 }
2613 return forwardRefOps.back().getResult(0);
2614}
2615
2616//===----------------------------------------------------------------------===//
2617// Entry Points
2618//===----------------------------------------------------------------------===//
2619
2621
2623 llvm::MemoryBufferRef buffer, const ParserConfig &config, bool lazyLoading,
2624 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) {
2625 Location sourceFileLoc =
2626 FileLineColLoc::get(config.getContext(), buffer.getBufferIdentifier(),
2627 /*line=*/0, /*column=*/0);
2628 impl = std::make_unique<Impl>(sourceFileLoc, config, lazyLoading, buffer,
2629 bufferOwnerRef);
2630}
2631
2633 Block *block, llvm::function_ref<bool(Operation *)> lazyOpsCallback) {
2634 return impl->read(block, lazyOpsCallback);
2635}
2636
2638 return impl->getNumOpsToMaterialize();
2639}
2640
2642 return impl->isMaterializable(op);
2643}
2644
2646 Operation *op, llvm::function_ref<bool(Operation *)> lazyOpsCallback) {
2647 return impl->materialize(op, lazyOpsCallback);
2648}
2649
2650LogicalResult
2652 return impl->finalize(shouldMaterialize);
2653}
2654
2655bool mlir::isBytecode(llvm::MemoryBufferRef buffer) {
2656 return buffer.getBuffer().starts_with("ML\xefR");
2657}
2658
2659/// Read the bytecode from the provided memory buffer reference.
2660/// `bufferOwnerRef` if provided is the owning source manager for the buffer,
2661/// and may be used to extend the lifetime of the buffer.
2662static LogicalResult
2663readBytecodeFileImpl(llvm::MemoryBufferRef buffer, Block *block,
2664 const ParserConfig &config,
2665 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) {
2666 Location sourceFileLoc =
2667 FileLineColLoc::get(config.getContext(), buffer.getBufferIdentifier(),
2668 /*line=*/0, /*column=*/0);
2669 if (!isBytecode(buffer)) {
2670 return emitError(sourceFileLoc,
2671 "input buffer is not an MLIR bytecode file");
2672 }
2673
2674 BytecodeReader::Impl reader(sourceFileLoc, config, /*lazyLoading=*/false,
2675 buffer, bufferOwnerRef);
2676 return reader.read(block, /*lazyOpsCallback=*/nullptr);
2677}
2678
2679LogicalResult mlir::readBytecodeFile(llvm::MemoryBufferRef buffer, Block *block,
2680 const ParserConfig &config) {
2681 return readBytecodeFileImpl(buffer, block, config, /*bufferOwnerRef=*/{});
2682}
2683LogicalResult
2684mlir::readBytecodeFile(const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
2685 Block *block, const ParserConfig &config) {
2686 return readBytecodeFileImpl(
2687 *sourceMgr->getMemoryBuffer(sourceMgr->getMainFileID()), block, config,
2688 sourceMgr);
2689}
return success()
static LogicalResult parseDialectGrouping(EncodingReader &reader, MutableArrayRef< std::unique_ptr< BytecodeDialect > > dialects, function_ref< LogicalResult(BytecodeDialect *)> entryCallback)
Parse a single dialect group encoded in the byte stream.
static LogicalResult readBytecodeFileImpl(llvm::MemoryBufferRef buffer, Block *block, const ParserConfig &config, const std::shared_ptr< llvm::SourceMgr > &bufferOwnerRef)
Read the bytecode from the provided memory buffer reference.
static bool isSectionOptional(bytecode::Section::ID sectionID, int version)
Returns true if the given top-level section ID is optional.
static LogicalResult parseResourceGroup(Location fileLoc, bool allowEmpty, EncodingReader &offsetReader, EncodingReader &resourceReader, StringSectionReader &stringReader, T *handler, const std::shared_ptr< llvm::SourceMgr > &bufferOwnerRef, function_ref< StringRef(StringRef)> remapKey={}, function_ref< LogicalResult(StringRef)> processKeyFn={})
static LogicalResult resolveEntry(EncodingReader &reader, RangeT &entries, uint64_t index, T &entry, StringRef entryStr)
Resolve an index into the given entry list.
static LogicalResult parseEntry(EncodingReader &reader, RangeT &entries, T &entry, StringRef entryStr)
Parse and resolve an index into the given entry list.
LogicalResult initialize(unsigned origNumLoops, ArrayRef< ReassociationIndices > foldedIterationDims)
b getContext())
auto load
static std::string diag(const llvm::Value &value)
static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, unsigned nRegions=1)
Definition OpenACC.cpp:984
MutableArrayRef< char > getMutableData()
Return a mutable reference to the raw underlying data of this blob.
Definition AsmState.h:157
ArrayRef< char > getData() const
Return the raw underlying data of this blob.
Definition AsmState.h:145
bool isMutable() const
Return if the data of this blob is mutable.
Definition AsmState.h:164
MLIRContext * getContext() const
Return the context this attribute belongs to.
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:129
unsigned getNumArguments()
Definition Block.h:128
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition Block.cpp:160
OpListType & getOperations()
Definition Block.h:137
BlockArgListType getArguments()
Definition Block.h:87
ArrayRef< std::unique_ptr< AttrTypeBytecodeReader< Type > > > getTypeCallbacks() const
ArrayRef< std::unique_ptr< AttrTypeBytecodeReader< Attribute > > > getAttributeCallbacks() const
Returns the callbacks available to the parser.
This class is used to read a bytecode buffer and translate it into MLIR.
LogicalResult materializeAll()
Materialize all operations.
LogicalResult read(Block *block, llvm::function_ref< bool(Operation *)> lazyOps)
Read the bytecode defined within buffer into the given block.
bool isMaterializable(Operation *op)
Impl(Location fileLoc, const ParserConfig &config, bool lazyLoading, llvm::MemoryBufferRef buffer, const std::shared_ptr< llvm::SourceMgr > &bufferOwnerRef)
LogicalResult finalize(function_ref< bool(Operation *)> shouldMaterialize)
Finalize the lazy-loading by calling back with every op that hasn't been materialized to let the clie...
LogicalResult materialize(Operation *op, llvm::function_ref< bool(Operation *)> lazyOpsCallback)
Materialize the provided operation, invoke the lazyOpsCallback on every newly found lazy operation.
int64_t getNumOpsToMaterialize() const
Return the number of ops that haven't been materialized yet.
LogicalResult materialize(Operation *op, llvm::function_ref< bool(Operation *)> lazyOpsCallback=[](Operation *) { return false;})
Materialize the provide operation.
LogicalResult finalize(function_ref< bool(Operation *)> shouldMaterialize=[](Operation *) { return true;})
Finalize the lazy-loading by calling back with every op that hasn't been materialized to let the clie...
BytecodeReader(llvm::MemoryBufferRef buffer, const ParserConfig &config, bool lazyLoad, const std::shared_ptr< llvm::SourceMgr > &bufferOwnerRef={})
Create a bytecode reader for the given buffer.
int64_t getNumOpsToMaterialize() const
Return the number of ops that haven't been materialized yet.
bool isMaterializable(Operation *op)
Return true if the provided op is materializable.
LogicalResult readTopLevel(Block *block, llvm::function_ref< bool(Operation *)> lazyOps=[](Operation *) { return false;})
Read the operations defined within the given memory buffer, containing MLIR bytecode,...
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
static FileLineColLoc get(StringAttr filename, unsigned line, unsigned column)
Definition Location.cpp:157
This class represents a diagnostic that is inflight and set to be reported.
InFlightDiagnostic & append(Args &&...args) &
Append arguments to the diagnostic.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext * getContext() const
Return the context this location is uniqued in.
Definition Location.h:86
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
T * getOrLoadDialect()
Get (or create) a dialect for the given derived dialect type.
bool allowsUnregisteredDialects()
Return true if we allow to create operation for unregistered dialects.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
bool isRegistered() const
Return if this operation is registered.
T::Concept * getInterface() const
Returns an instance of the concept object for the given interface if it was registered to this operat...
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
void dropAllReferences()
This drops all operand uses from this operation, which is an essential step in breaking cyclic depend...
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition Operation.cpp:67
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:797
result_range getResults()
Definition Operation.h:415
void erase()
Remove this operation from its parent block and delete it.
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
This class represents a configuration for the MLIR assembly parser.
Definition AsmState.h:469
BytecodeReaderConfig & getBytecodeReaderConfig() const
Returns the parsing configurations associated to the bytecode read.
Definition AsmState.h:489
BlockListType::iterator iterator
Definition Region.h:52
This diagnostic handler is a simple RAII class that registers and erases a diagnostic handler on a gi...
static AsmResourceBlob allocateWithAlign(ArrayRef< char > data, size_t align, AsmResourceBlob::DeleterFn deleter={}, bool dataIsMutable=false)
Create a new unmanaged resource directly referencing the provided data.
Definition AsmState.h:228
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
bool use_empty() const
Returns true if this value has no uses.
Definition Value.h:208
void shuffleUseList(ArrayRef< unsigned > indices)
Shuffle the use list order according to the provided indices.
Definition Value.cpp:106
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition Value.h:188
void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
Definition Value.h:233
unsigned getNumUses() const
This method computes the number of uses of this Value.
Definition Value.cpp:52
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition Value.h:197
use_iterator use_begin() const
Definition Value.h:184
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
@ kAttrType
This section contains the attributes and types referenced within an IR module.
Definition Encoding.h:73
@ kAttrTypeOffset
This section contains the offsets for the attribute and types within the AttrType section.
Definition Encoding.h:77
@ kIR
This section contains the list of operations serialized into the bytecode, and their nested regions/o...
Definition Encoding.h:81
@ kResource
This section contains the resources of the bytecode.
Definition Encoding.h:84
@ kResourceOffset
This section contains the offsets of resources within the Resource section.
Definition Encoding.h:88
@ kDialect
This section contains the dialects referenced within an IR module.
Definition Encoding.h:69
@ kString
This section contains strings referenced within the bytecode.
Definition Encoding.h:66
@ kDialectVersions
This section contains the versions of each dialect.
Definition Encoding.h:91
@ kProperties
This section contains the properties for the operations.
Definition Encoding.h:94
@ kNumSections
The total number of section types.
Definition Encoding.h:97
static uint64_t getUseID(OperandT &val, unsigned ownerID)
Get the unique ID of a value use.
Definition Encoding.h:127
@ kUseListOrdering
Use-list ordering started to be encoded in version 3.
Definition Encoding.h:38
@ kAlignmentByte
An arbitrary value used to fill alignment padding.
Definition Encoding.h:56
@ kVersion
The current bytecode version.
Definition Encoding.h:53
@ kLazyLoading
Support for lazy-loading of isolated region was added in version 2.
Definition Encoding.h:35
@ kDialectVersioning
Dialects versioning was added in version 1.
Definition Encoding.h:32
@ kElideUnknownBlockArgLocation
Avoid recording unknown locations on block arguments (compression) started in version 4.
Definition Encoding.h:42
@ kNativePropertiesEncoding
Support for encoding properties natively in bytecode instead of merged with the discardable attribute...
Definition Encoding.h:46
@ kMinSupportedVersion
The minimum supported version of the bytecode.
Definition Encoding.h:29
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
StringRef toString(AsmResourceEntryKind kind)
static LogicalResult readResourceHandle(DialectBytecodeReader &reader, FailureOr< T > &value, Ts &&...params)
Helper for resource handle reading that returns LogicalResult.
bool isBytecode(llvm::MemoryBufferRef buffer)
Returns true if the given buffer starts with the magic bytes that signal MLIR bytecode.
const FrozenRewritePatternSet GreedyRewriteConfig config
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:128
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
AsmResourceEntryKind
This enum represents the different kinds of resource values.
Definition AsmState.h:280
LogicalResult readBytecodeFile(llvm::MemoryBufferRef buffer, Block *block, const ParserConfig &config)
Read the operations defined within the given memory buffer, containing MLIR bytecode,...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition Verifier.cpp:423
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
SmallVector< Block *, 1 > successors
Successors of this operation and their respective operands.
SmallVector< Value, 4 > operands
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Attribute propertiesAttr
This Attribute is used to opaquely construct the properties of the operation.
SmallVector< Type, 4 > types
Types of the results of this operation.