10 #include "../Encoding.h"
14 #include "llvm/ADT/CachedHashString.h"
15 #include "llvm/ADT/MapVector.h"
16 #include "llvm/ADT/SmallString.h"
18 #define DEBUG_TYPE "mlir-bytecode-writer"
28 Impl(StringRef producer) : producer(producer) {}
38 :
impl(std::make_unique<
Impl>(producer)) {}
47 std::unique_ptr<AsmResourcePrinter> printer) {
48 impl->externalResourcePrinters.emplace_back(std::move(printer));
61 class EncodingEmitter {
63 EncodingEmitter() =
default;
64 EncodingEmitter(
const EncodingEmitter &) =
delete;
65 EncodingEmitter &operator=(
const EncodingEmitter &) =
delete;
68 void writeTo(raw_ostream &os)
const;
71 size_t size()
const {
return prevResultSize + currentResult.size(); }
78 void patchByte(uint64_t offset, uint8_t value) {
79 assert(offset < size() && offset >= prevResultSize &&
80 "cannot patch previously emitted data");
81 currentResult[offset - prevResultSize] = value;
88 appendResult(std::move(currentResult));
89 appendOwnedResult(data);
97 emitVarInt(alignment);
98 emitVarInt(data.size());
103 void emitOwnedBlobAndAlignment(
ArrayRef<char> data, uint32_t alignment) {
106 emitOwnedBlobAndAlignment(castedData, alignment);
110 void alignTo(
unsigned alignment) {
113 assert(llvm::isPowerOf2_32(alignment) &&
"expected valid alignment");
117 size_t curOffset = size();
118 size_t paddingSize = llvm::alignTo(curOffset, alignment) - curOffset;
119 while (paddingSize--)
123 requiredAlignment =
std::max(requiredAlignment, alignment);
130 template <
typename T>
131 void emitByte(T
byte) {
132 currentResult.push_back(
static_cast<uint8_t
>(
byte));
137 llvm::append_range(currentResult, bytes);
147 void emitVarInt(uint64_t value) {
150 if ((value >> 7) == 0)
151 return emitByte((value << 1) | 0x1);
152 emitMultiByteVarInt(value);
159 void emitSignedVarInt(uint64_t value) {
160 emitVarInt((value << 1) ^ (uint64_t)((int64_t)value >> 63));
165 void emitVarIntWithFlag(uint64_t value,
bool flag) {
166 emitVarInt((value << 1) | (flag ? 1 : 0));
173 void emitNulTerminatedString(StringRef str) {
179 void emitString(StringRef str) {
180 emitBytes({
reinterpret_cast<const uint8_t *
>(str.data()), str.size()});
192 uint64_t codeOffset = currentResult.size();
194 emitVarInt(emitter.size());
197 unsigned emitterAlign = emitter.requiredAlignment;
198 if (emitterAlign > 1) {
199 if (size() & (emitterAlign - 1)) {
200 emitVarInt(emitterAlign);
201 alignTo(emitterAlign);
205 currentResult[codeOffset] |= 0b10000000;
209 requiredAlignment =
std::max(requiredAlignment, emitterAlign);
215 appendResult(std::move(currentResult));
216 for (std::vector<uint8_t> &result : emitter.prevResultStorage)
217 prevResultStorage.push_back(std::move(result));
218 llvm::append_range(prevResultList, emitter.prevResultList);
219 prevResultSize += emitter.prevResultSize;
220 appendResult(std::move(emitter.currentResult));
228 LLVM_ATTRIBUTE_NOINLINE
void emitMultiByteVarInt(uint64_t value);
231 void appendResult(std::vector<uint8_t> &&result) {
234 prevResultStorage.emplace_back(std::move(result));
235 appendOwnedResult(prevResultStorage.back());
240 prevResultSize += result.size();
241 prevResultList.emplace_back(result);
249 std::vector<uint8_t> currentResult;
250 std::vector<ArrayRef<uint8_t>> prevResultList;
251 std::vector<std::vector<uint8_t>> prevResultStorage;
255 size_t prevResultSize = 0;
258 unsigned requiredAlignment = 1;
267 class StringSectionBuilder {
271 size_t insert(StringRef str) {
272 auto it = strings.insert({llvm::CachedHashStringRef(str), strings.size()});
273 return it.first->second;
277 void write(EncodingEmitter &emitter) {
278 emitter.emitVarInt(strings.size());
282 for (
const auto &it : llvm::reverse(strings))
283 emitter.emitVarInt(it.first.size() + 1);
285 for (
const auto &it : strings)
286 emitter.emitNulTerminatedString(it.first.val());
292 llvm::MapVector<llvm::CachedHashStringRef, size_t> strings;
299 StringSectionBuilder &stringSection)
300 : emitter(emitter), numberingState(numberingState),
301 stringSection(stringSection) {}
307 void writeAttribute(
Attribute attr)
override {
308 emitter.emitVarInt(numberingState.getNumber(attr));
310 void writeType(
Type type)
override {
311 emitter.emitVarInt(numberingState.getNumber(type));
315 emitter.emitVarInt(numberingState.getNumber(resource));
322 void writeVarInt(uint64_t value)
override { emitter.emitVarInt(value); }
324 void writeSignedVarInt(int64_t value)
override {
325 emitter.emitSignedVarInt(value);
328 void writeAPIntWithKnownWidth(
const APInt &value)
override {
329 size_t bitWidth = value.getBitWidth();
334 return emitter.emitByte(value.getLimitedValue());
338 return emitter.emitSignedVarInt(value.getLimitedValue());
343 unsigned numActiveWords = value.getActiveWords();
344 emitter.emitVarInt(numActiveWords);
346 const uint64_t *rawValueData = value.getRawData();
347 for (
unsigned i = 0; i < numActiveWords; ++i)
348 emitter.emitSignedVarInt(rawValueData[i]);
351 void writeAPFloatWithKnownSemantics(
const APFloat &value)
override {
352 writeAPIntWithKnownWidth(value.bitcastToAPInt());
355 void writeOwnedString(StringRef str)
override {
356 emitter.emitVarInt(stringSection.insert(str));
360 emitter.emitVarInt(blob.size());
362 reinterpret_cast<const uint8_t *
>(blob.data()), blob.size()));
366 EncodingEmitter &emitter;
368 StringSectionBuilder &stringSection;
374 class RawEmitterOstream :
public raw_ostream {
376 explicit RawEmitterOstream(EncodingEmitter &emitter) : emitter(emitter) {
381 void write_impl(
const char *ptr,
size_t size)
override {
382 emitter.emitBytes({
reinterpret_cast<const uint8_t *
>(ptr), size});
384 uint64_t current_pos()
const override {
return emitter.size(); }
387 EncodingEmitter &emitter;
391 void EncodingEmitter::writeTo(raw_ostream &os)
const {
392 for (
auto &prevResult : prevResultList)
393 os.write((
const char *)prevResult.data(), prevResult.size());
394 os.write((
const char *)currentResult.data(), currentResult.size());
397 void EncodingEmitter::emitMultiByteVarInt(uint64_t value) {
401 uint64_t it = value >> 7;
402 for (
size_t numBytes = 2; numBytes < 9; ++numBytes) {
403 if (LLVM_LIKELY(it >>= 7) == 0) {
404 uint64_t encodedValue = (value << 1) | 0x1;
405 encodedValue <<= (numBytes - 1);
406 emitBytes({
reinterpret_cast<uint8_t *
>(&encodedValue), numBytes});
414 emitBytes({
reinterpret_cast<uint8_t *
>(&value),
sizeof(value)});
422 class BytecodeWriter {
424 BytecodeWriter(
Operation *op) : numberingState(op) {}
427 void write(
Operation *rootOp, raw_ostream &os,
434 void writeDialectSection(EncodingEmitter &emitter);
439 void writeAttrTypeSection(EncodingEmitter &emitter);
444 void writeBlock(EncodingEmitter &emitter,
Block *block);
445 void writeOp(EncodingEmitter &emitter,
Operation *op);
446 void writeRegion(EncodingEmitter &emitter,
Region *region);
447 void writeIRSection(EncodingEmitter &emitter,
Operation *op);
452 void writeResourceSection(
Operation *op, EncodingEmitter &emitter,
458 void writeStringSection(EncodingEmitter &emitter);
464 StringSectionBuilder stringSection;
471 void BytecodeWriter::write(
Operation *rootOp, raw_ostream &os,
473 EncodingEmitter emitter;
477 emitter.emitString(
"ML\xefR");
483 emitter.emitNulTerminatedString(config.
producer);
486 writeDialectSection(emitter);
489 writeAttrTypeSection(emitter);
492 writeIRSection(emitter, rootOp);
495 writeResourceSection(rootOp, emitter, config);
498 writeStringSection(emitter);
511 template <
typename EntriesT,
typename EntryCallbackT>
513 EntryCallbackT &&callback) {
514 for (
auto it = entries.begin(), e = entries.end(); it != e;) {
515 auto groupStart = it++;
519 it = std::find_if(it, e, [&](
const auto &entry) {
520 return entry.dialect != currentDialect;
524 emitter.emitVarInt(currentDialect->
number);
525 emitter.emitVarInt(std::distance(groupStart, it));
528 for (
auto &entry : llvm::make_range(groupStart, it))
533 void BytecodeWriter::writeDialectSection(EncodingEmitter &emitter) {
534 EncodingEmitter dialectEmitter;
537 auto dialects = numberingState.getDialects();
538 dialectEmitter.emitVarInt(llvm::size(dialects));
541 size_t nameID = stringSection.insert(dialect.name);
544 EncodingEmitter versionEmitter;
545 if (dialect.interface) {
547 DialectWriter versionWriter(versionEmitter, numberingState,
549 dialect.interface->writeVersion(versionWriter);
555 size_t versionAvailable = versionEmitter.size() > 0;
556 dialectEmitter.emitVarIntWithFlag(nameID, versionAvailable);
557 if (versionAvailable)
559 std::move(versionEmitter));
564 dialectEmitter.emitVarInt(stringSection.insert(name.name.stripDialect()));
574 void BytecodeWriter::writeAttrTypeSection(EncodingEmitter &emitter) {
575 EncodingEmitter attrTypeEmitter;
576 EncodingEmitter offsetEmitter;
577 offsetEmitter.emitVarInt(llvm::size(numberingState.getAttributes()));
578 offsetEmitter.emitVarInt(llvm::size(numberingState.getTypes()));
581 uint64_t prevOffset = 0;
582 auto emitAttrOrType = [&](
auto &entry) {
583 auto entryValue = entry.getValue();
586 bool hasCustomEncoding =
false;
589 DialectWriter dialectWriter(attrTypeEmitter, numberingState,
592 if constexpr (std::is_same_v<std::decay_t<decltype(entryValue)>,
Type>) {
595 !entryValue.template hasTrait<TypeTrait::IsMutable>() &&
596 succeeded(interface->writeType(entryValue, dialectWriter));
600 !entryValue.template hasTrait<AttributeTrait::IsMutable>() &&
601 succeeded(interface->writeAttribute(entryValue, dialectWriter));
607 if (!hasCustomEncoding) {
608 RawEmitterOstream(attrTypeEmitter) << entryValue;
609 attrTypeEmitter.emitByte(0);
613 uint64_t curOffset = attrTypeEmitter.size();
614 offsetEmitter.emitVarIntWithFlag(curOffset - prevOffset, hasCustomEncoding);
615 prevOffset = curOffset;
626 std::move(offsetEmitter));
633 void BytecodeWriter::writeBlock(EncodingEmitter &emitter,
Block *block) {
635 bool hasArgs = !args.empty();
640 unsigned numOps = numberingState.getOperationCount(block);
641 emitter.emitVarIntWithFlag(numOps, hasArgs);
645 emitter.emitVarInt(args.size());
647 emitter.emitVarInt(numberingState.getNumber(arg.getType()));
648 emitter.emitVarInt(numberingState.getNumber(arg.getLoc()));
654 writeOp(emitter, &op);
657 void BytecodeWriter::writeOp(EncodingEmitter &emitter,
Operation *op) {
658 emitter.emitVarInt(numberingState.getNumber(op->
getName()));
663 uint64_t maskOffset = emitter.size();
664 uint8_t opEncodingMask = 0;
668 emitter.emitVarInt(numberingState.getNumber(op->
getLoc()));
672 if (!attrs.empty()) {
680 emitter.emitVarInt(numResults);
682 emitter.emitVarInt(numberingState.getNumber(type));
688 emitter.emitVarInt(numOperands);
690 emitter.emitVarInt(numberingState.getNumber(operand));
696 emitter.emitVarInt(numSuccessors);
698 emitter.emitVarInt(numberingState.getNumber(successor));
707 emitter.patchByte(maskOffset, opEncodingMask);
715 emitter.emitVarIntWithFlag(numRegions, isIsolatedFromAbove);
718 writeRegion(emitter, ®ion);
722 void BytecodeWriter::writeRegion(EncodingEmitter &emitter,
Region *region) {
726 return emitter.emitVarInt( 0);
729 unsigned numBlocks, numValues;
730 std::tie(numBlocks, numValues) = numberingState.getBlockValueCount(region);
731 emitter.emitVarInt(numBlocks);
732 emitter.emitVarInt(numValues);
735 for (
Block &block : *region)
736 writeBlock(emitter, &block);
739 void BytecodeWriter::writeIRSection(EncodingEmitter &emitter,
Operation *op) {
740 EncodingEmitter irEmitter;
745 irEmitter.emitVarIntWithFlag( 1,
false);
748 writeOp(irEmitter, op);
763 ResourceBuilder(EncodingEmitter &emitter, StringSectionBuilder &stringSection,
764 PostProcessFn postProcessFn)
765 : emitter(emitter), stringSection(stringSection),
766 postProcessFn(postProcessFn) {}
767 ~ResourceBuilder()
override =
default;
770 uint32_t dataAlignment)
final {
771 emitter.emitOwnedBlobAndAlignment(data, dataAlignment);
774 void buildBool(StringRef key,
bool data)
final {
775 emitter.emitByte(data);
778 void buildString(StringRef key, StringRef data)
final {
779 emitter.emitVarInt(stringSection.insert(data));
784 EncodingEmitter &emitter;
785 StringSectionBuilder &stringSection;
786 PostProcessFn postProcessFn;
790 void BytecodeWriter::writeResourceSection(
793 EncodingEmitter resourceEmitter;
794 EncodingEmitter resourceOffsetEmitter;
795 uint64_t prevOffset = 0;
802 uint64_t curOffset = resourceEmitter.size();
803 curResourceEntries.emplace_back(key, kind, curOffset - prevOffset);
804 prevOffset = curOffset;
808 auto emitResourceGroup = [&](uint64_t key) {
809 resourceOffsetEmitter.emitVarInt(key);
810 resourceOffsetEmitter.emitVarInt(curResourceEntries.size());
811 for (
auto [key, kind, size] : curResourceEntries) {
812 resourceOffsetEmitter.emitVarInt(stringSection.insert(key));
813 resourceOffsetEmitter.emitVarInt(size);
814 resourceOffsetEmitter.emitByte(kind);
819 ResourceBuilder entryBuilder(resourceEmitter, stringSection,
820 appendResourceOffset);
825 curResourceEntries.clear();
826 printer->buildResources(op, entryBuilder);
827 emitResourceGroup(stringSection.insert(printer->getName()));
832 if (!dialect.asmInterface)
834 curResourceEntries.clear();
835 dialect.asmInterface->buildResources(op, dialect.resources, entryBuilder);
840 for (
const auto &resource : dialect.resourceMap)
841 if (resource.second->isDeclaration)
845 if (!curResourceEntries.empty())
846 emitResourceGroup(dialect.number);
850 if (resourceOffsetEmitter.size() == 0)
854 std::move(resourceOffsetEmitter));
861 void BytecodeWriter::writeStringSection(EncodingEmitter &emitter) {
862 EncodingEmitter stringEmitter;
863 stringSection.write(stringEmitter);
873 BytecodeWriter writer(op);
874 writer.write(op, os, config.
getImpl());
static void writeDialectGrouping(EncodingEmitter &emitter, EntriesT &&entries, EntryCallbackT &&callback)
Write the given entries in contiguous groups with the same parent dialect.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
This class represents an opaque handle to a dialect resource entry.
This class is used to build resource entries for use by the printer.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgListType getArguments()
This class contains the configuration used for the bytecode writer.
BytecodeWriterConfig(StringRef producer="MLIR" LLVM_VERSION_STRING)
producer is an optional string that can be used to identify the producer of the bytecode when reading...
void attachFallbackResourcePrinter(FallbackAsmResourceMap &map)
Attach resource printers to the AsmState for the fallback resources in the given map.
const Impl & getImpl() const
Return an instance of the internal implementation.
void attachResourcePrinter(std::unique_ptr< AsmResourcePrinter > printer)
Attach the given resource printer to the writer configuration.
This class defines a virtual interface for writing to a bytecode stream, providing hooks into the byt...
A fallback map containing external resources not explicitly handled by another parser/printer.
This class provides the API for ops that are known to be isolated from above.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
unsigned getNumSuccessors()
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
DictionaryAttr getAttrDictionary()
Return all of the attributes on this operation as a DictionaryAttr.
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
SuccessorRange getSuccessors()
unsigned getNumResults()
Return the number of results held by this operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
This class manages numbering IR entities in preparation of bytecode emission.
@ kAttrType
This section contains the attributes and types referenced within an IR module.
@ kAttrTypeOffset
This section contains the offsets for the attribute and types within the AttrType section.
@ kIR
This section contains the list of operations serialized into the bytecode, and their nested regions/o...
@ kResource
This section contains the resources of the bytecode.
@ kResourceOffset
This section contains the offsets of resources within the Resource section.
@ kDialect
This section contains the dialects referenced within an IR module.
@ kString
This section contains strings referenced within the bytecode.
@ kDialectVersions
This section contains the versions of each dialect.
@ kAlignmentByte
An arbitrary value used to fill alignment padding.
@ kVersion
The current bytecode version.
This header declares functions that assit transformations in the MemRef dialect.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
AsmResourceEntryKind
This enum represents the different kinds of resource values.
@ Blob
A blob of data with an accompanying alignment.
void writeBytecodeToFile(Operation *op, raw_ostream &os, const BytecodeWriterConfig &config={})
Write the bytecode for the given operation to the provided output stream.
StringRef producer
The producer of the bytecode.
SmallVector< std::unique_ptr< AsmResourcePrinter > > externalResourcePrinters
A collection of non-dialect resource printers.
This class represents a numbering entry for an Dialect.
unsigned number
The number assigned to the dialect.