14 #ifndef MLIR_BYTECODE_BYTECODEIMPLEMENTATION_H
15 #define MLIR_BYTECODE_BYTECODEIMPLEMENTATION_H
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/Twine.h"
72 template <
typename T,
typename CallbackFn>
79 for (uint64_t i = 0; i < size; ++i) {
82 if constexpr (llvm::function_traits<std::decay_t<CallbackFn>>::num_args) {
84 if (
failed(callback(element)))
86 result.emplace_back(std::move(element));
91 result.emplace_back(std::move(*element));
107 template <
typename T>
111 template <
typename T>
116 if ((result = dyn_cast<T>(baseResult)))
118 return emitError() <<
"expected " << llvm::getTypeName<T>()
119 <<
", but got: " << baseResult;
121 template <
typename T>
128 if ((result = dyn_cast<T>(baseResult)))
130 return emitError() <<
"expected " << llvm::getTypeName<T>()
131 <<
", but got: " << baseResult;
136 template <
typename T>
140 template <
typename T>
145 if ((result = dyn_cast<T>(baseResult)))
147 return emitError() <<
"expected " << llvm::getTypeName<T>()
148 <<
", but got: " << baseResult;
152 template <
typename ResourceT>
157 if (
auto *result = dyn_cast<ResourceT>(&*handle))
158 return std::move(*result);
159 return emitError() <<
"provided resource handle differs from the "
160 "expected resource type";
193 template <
typename T>
195 static_assert(
sizeof(T) <
sizeof(uint64_t),
"expect integer < 64 bits");
196 static_assert(std::is_integral<T>::value,
"expects integer");
197 uint64_t nonZeroesCount;
198 bool useSparseEncoding;
201 if (nonZeroesCount == 0)
203 if (!useSparseEncoding) {
205 if (nonZeroesCount > array.size()) {
207 << nonZeroesCount <<
" but only " << array.size()
208 <<
" storage available.";
211 for (int64_t index : llvm::seq<int64_t>(0, nonZeroesCount)) {
215 array[index] = value;
221 uint64_t indexBitSize;
224 constexpr uint64_t maxIndexBitSize = 8;
225 if (indexBitSize > maxIndexBitSize) {
226 emitError(
"reading sparse array with indexing above 8 bits: ")
230 for (uint32_t count : llvm::seq<uint32_t>(0, nonZeroesCount)) {
232 uint64_t indexValuePair;
235 uint64_t index = indexValuePair & ~(uint64_t(-1) << (indexBitSize));
236 uint64_t value = indexValuePair >> indexBitSize;
237 if (index >= array.size()) {
238 emitError(
"reading a sparse array found index ")
239 << index <<
" but only " << array.size() <<
" storage available.";
242 array[index] = value;
288 template <
typename RangeT,
typename CallbackFn>
291 for (
auto &element : range)
298 template <
typename T>
305 template <
typename T>
340 template <
typename T>
342 static_assert(
sizeof(T) <
sizeof(uint64_t),
"expect integer < 64 bits");
343 static_assert(std::is_integral<T>::value,
"expects integer");
344 uint32_t size = array.size();
345 uint32_t nonZeroesCount = 0, lastIndex = 0;
346 for (uint32_t index : llvm::seq<uint32_t>(0, size)) {
354 if (lastIndex > 256 || nonZeroesCount > size / 2) {
357 for (
const T &elt : array)
364 if (nonZeroesCount == 0)
367 int indexBitSize = llvm::Log2_32_Ceil(lastIndex + 1);
369 for (uint32_t index : llvm::seq<uint32_t>(0, lastIndex + 1)) {
370 T value = array[index];
373 uint64_t indexValuePair = (value << indexBitSize) | (index);
435 reader.
emitError() <<
"dialect " << getDialect()->getNamespace()
436 <<
" does not support reading attributes from bytecode";
444 reader.
emitError() <<
"dialect " << getDialect()->getNamespace()
445 <<
" does not support reading types from bytecode";
476 virtual std::unique_ptr<DialectVersion>
478 reader.
emitError(
"Dialect does not support versioning");
494 template <
typename T,
typename... Ts>
500 if (
auto *result = dyn_cast<T>(&*handle)) {
501 value = std::move(*result);
509 template <
typename T,
typename... Ts>
512 if constexpr (llvm::is_detected<detail::has_get_method, T, Ts...>::value) {
514 return T::get(std::forward<Ts>(params)...);
517 return T::get(context, std::forward<Ts>(params)...);
520 return T::Base::get(context, std::forward<Ts>(params)...);
This class represents an opaque handle to a dialect resource entry.
Attributes are known-constant values of operations.
virtual Type readType(DialectBytecodeReader &reader) const
Read a type belonging to this dialect from the given reader.
virtual LogicalResult upgradeFromVersion(Operation *topLevelOp, const DialectVersion &version) const
Hook invoked after parsing completed, if a version directive was present and included an entry for th...
virtual Attribute readAttribute(DialectBytecodeReader &reader) const
Read an attribute belonging to this dialect from the given reader.
virtual std::unique_ptr< DialectVersion > readVersion(DialectBytecodeReader &reader) const
virtual LogicalResult writeAttribute(Attribute attr, DialectBytecodeWriter &writer) const
Write the given attribute, which belongs to this dialect, to the given writer.
virtual LogicalResult writeType(Type type, DialectBytecodeWriter &writer) const
Write the given type, which belongs to this dialect, to the given writer.
virtual void writeVersion(DialectBytecodeWriter &writer) const
Write the version of this dialect to the given writer.
This class defines a virtual interface for reading a bytecode stream, providing hooks into the byteco...
virtual ~DialectBytecodeReader()=default
virtual LogicalResult readBlob(ArrayRef< char > &result)=0
Read a blob from the bytecode.
LogicalResult readAttributes(SmallVectorImpl< T > &attrs)
FailureOr< ResourceT > readResourceHandle()
Read a handle to a dialect resource.
virtual MLIRContext * getContext() const =0
Retrieve the context associated to the reader.
virtual FailureOr< APInt > readAPIntWithKnownWidth(unsigned bitWidth)=0
Read an APInt that is known to have been encoded with the given width.
LogicalResult readTypes(SmallVectorImpl< T > &types)
virtual LogicalResult readBool(bool &result)=0
Read a bool from the bytecode.
virtual LogicalResult readVarInt(uint64_t &result)=0
Read a variable width integer.
virtual LogicalResult readType(Type &result)=0
Read a reference to the given type.
virtual uint64_t getBytecodeVersion() const =0
Return the bytecode version being read.
LogicalResult readType(T &result)
LogicalResult readVarIntWithFlag(uint64_t &result, bool &flag)
Parse a variable length encoded integer whose low bit is used to encode an unrelated flag,...
LogicalResult readSignedVarInts(SmallVectorImpl< int64_t > &result)
LogicalResult readOptionalAttribute(T &result)
FailureOr< const DialectVersion * > getDialectVersion() const
virtual LogicalResult readOptionalAttribute(Attribute &attr)=0
Read an optional reference to the given attribute.
LogicalResult readAttribute(T &result)
virtual InFlightDiagnostic emitError(const Twine &msg={}) const =0
Emit an error to the reader.
LogicalResult readSparseArray(MutableArrayRef< T > array)
Read a "small" sparse array of integer <= 32 bits elements, where index/value pairs can be compressed...
virtual FailureOr< APFloat > readAPFloatWithKnownSemantics(const llvm::fltSemantics &semantics)=0
Read an APFloat that is known to have been encoded with the given semantics.
virtual FailureOr< const DialectVersion * > getDialectVersion(StringRef dialectName) const =0
Retrieve the dialect version by name if available.
virtual LogicalResult readString(StringRef &result)=0
Read a string from the bytecode.
virtual LogicalResult readSignedVarInt(int64_t &result)=0
Read a signed variable width integer.
LogicalResult readList(SmallVectorImpl< T > &result, CallbackFn &&callback)
Read out a list of elements, invoking the provided callback for each element.
virtual LogicalResult readAttribute(Attribute &result)=0
Read a reference to the given attribute.
This class defines a virtual interface for writing to a bytecode stream, providing hooks into the byt...
virtual void writeOptionalAttribute(Attribute attr)=0
FailureOr< const DialectVersion * > getDialectVersion() const
virtual void writeVarInt(uint64_t value)=0
Write a variable width integer to the output stream.
void writeVarIntWithFlag(uint64_t value, bool flag)
Write a VarInt and a flag packed together.
void writeList(RangeT &&range, CallbackFn &&callback)
Write out a list of elements, invoking the provided callback for each element.
void writeSparseArray(ArrayRef< T > array)
Write out a "small" sparse array of integer <= 32 bits elements, where index/value pairs can be compr...
virtual void writeType(Type type)=0
Write a reference to the given type.
virtual FailureOr< const DialectVersion * > getDialectVersion(StringRef dialectName) const =0
Retrieve the dialect version by name if available.
virtual void writeAPIntWithKnownWidth(const APInt &value)=0
Write an APInt to the bytecode stream whose bitwidth will be known externally at read time.
virtual void writeOwnedBlob(ArrayRef< char > blob)=0
Write a blob to the bytecode, which is owned by the caller and is guaranteed to not die before the en...
virtual void writeAttribute(Attribute attr)=0
Write a reference to the given attribute.
virtual ~DialectBytecodeWriter()=default
void writeAttributes(ArrayRef< T > attrs)
virtual void writeSignedVarInt(int64_t value)=0
Write a signed variable width integer to the output stream.
virtual void writeResourceHandle(const AsmDialectResourceHandle &resource)=0
Write the given handle to a dialect resource.
virtual void writeAPFloatWithKnownSemantics(const APFloat &value)=0
Write an APFloat to the bytecode stream whose semantics will be known externally at read time.
void writeSignedVarInts(ArrayRef< int64_t > value)
virtual void writeOwnedBool(bool value)=0
Write a bool to the output stream.
virtual int64_t getBytecodeVersion() const =0
Return the bytecode version being emitted for.
virtual void writeOwnedString(StringRef str)=0
Write a string to the bytecode, which is owned by the caller and is guaranteed to not die before the ...
void writeTypes(ArrayRef< T > types)
This class is used to represent the version of a dialect, for the purpose of polymorphic destruction.
virtual ~DialectVersion()=default
This class provides support for representing a failure result, or a valid value of type T.
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
Operation is the basic unit of execution within MLIR.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
The base class used for all derived interface types.
decltype(T::get(std::declval< Ts >()...)) has_get_method
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
static LogicalResult readResourceHandle(DialectBytecodeReader &reader, FailureOr< T > &value, Ts &&...params)
Helper for resource handle reading that returns LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.