14 #ifndef MLIR_BYTECODE_BYTECODEIMPLEMENTATION_H
15 #define MLIR_BYTECODE_BYTECODEIMPLEMENTATION_H
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/Twine.h"
54 virtual FailureOr<const DialectVersion *>
71 template <
typename T,
typename CallbackFn>
78 for (uint64_t i = 0; i < size; ++i) {
81 if constexpr (llvm::function_traits<std::decay_t<CallbackFn>>::num_args) {
83 if (failed(callback(element)))
85 result.emplace_back(std::move(element));
87 FailureOr<T> element = callback();
90 result.emplace_back(std::move(*element));
106 template <
typename T>
110 template <
typename T>
115 if ((result = dyn_cast<T>(baseResult)))
117 return emitError() <<
"expected " << llvm::getTypeName<T>()
118 <<
", but got: " << baseResult;
120 template <
typename T>
127 if ((result = dyn_cast<T>(baseResult)))
129 return emitError() <<
"expected " << llvm::getTypeName<T>()
130 <<
", but got: " << baseResult;
135 template <
typename T>
139 template <
typename T>
144 if ((result = dyn_cast<T>(baseResult)))
146 return emitError() <<
"expected " << llvm::getTypeName<T>()
147 <<
", but got: " << baseResult;
151 template <
typename ResourceT>
156 if (
auto *result = dyn_cast<ResourceT>(&*handle))
157 return std::move(*result);
158 return emitError() <<
"provided resource handle differs from the "
159 "expected resource type";
192 template <
typename T>
194 static_assert(
sizeof(T) <
sizeof(uint64_t),
"expect integer < 64 bits");
195 static_assert(std::is_integral<T>::value,
"expects integer");
196 uint64_t nonZeroesCount;
197 bool useSparseEncoding;
200 if (nonZeroesCount == 0)
202 if (!useSparseEncoding) {
204 if (nonZeroesCount > array.size()) {
206 << nonZeroesCount <<
" but only " << array.size()
207 <<
" storage available.";
210 for (int64_t index : llvm::seq<int64_t>(0, nonZeroesCount)) {
214 array[index] = value;
220 uint64_t indexBitSize;
223 constexpr uint64_t maxIndexBitSize = 8;
224 if (indexBitSize > maxIndexBitSize) {
225 emitError(
"reading sparse array with indexing above 8 bits: ")
229 for (uint32_t count : llvm::seq<uint32_t>(0, nonZeroesCount)) {
231 uint64_t indexValuePair;
234 uint64_t index = indexValuePair & ~(uint64_t(-1) << (indexBitSize));
235 uint64_t value = indexValuePair >> indexBitSize;
236 if (index >= array.size()) {
237 emitError(
"reading a sparse array found index ")
238 << index <<
" but only " << array.size() <<
" storage available.";
241 array[index] = value;
251 virtual FailureOr<APFloat>
287 template <
typename RangeT,
typename CallbackFn>
290 for (
auto &element : range)
297 template <
typename T>
304 template <
typename T>
339 template <
typename T>
341 static_assert(
sizeof(T) <
sizeof(uint64_t),
"expect integer < 64 bits");
342 static_assert(std::is_integral<T>::value,
"expects integer");
343 uint32_t size = array.size();
344 uint32_t nonZeroesCount = 0, lastIndex = 0;
345 for (uint32_t index : llvm::seq<uint32_t>(0, size)) {
353 if (lastIndex > 256 || nonZeroesCount > size / 2) {
356 for (
const T &elt : array)
363 if (nonZeroesCount == 0)
366 int indexBitSize = llvm::Log2_32_Ceil(lastIndex + 1);
368 for (uint32_t index : llvm::seq<uint32_t>(0, lastIndex + 1)) {
369 T value = array[index];
372 uint64_t indexValuePair = (value << indexBitSize) | (index);
408 virtual FailureOr<const DialectVersion *>
434 reader.
emitError() <<
"dialect " << getDialect()->getNamespace()
435 <<
" does not support reading attributes from bytecode";
443 reader.
emitError() <<
"dialect " << getDialect()->getNamespace()
444 <<
" does not support reading types from bytecode";
475 virtual std::unique_ptr<DialectVersion>
477 reader.
emitError(
"Dialect does not support versioning");
485 virtual LogicalResult
493 template <
typename T,
typename... Ts>
495 FailureOr<T> &value, Ts &&...params) {
499 if (
auto *result = dyn_cast<T>(&*handle)) {
500 value = std::move(*result);
508 template <
typename T,
typename... Ts>
511 if constexpr (llvm::is_detected<detail::has_get_method, T, Ts...>::value) {
513 return T::get(std::forward<Ts>(params)...);
516 return T::get(context, std::forward<Ts>(params)...);
519 return T::Base::get(context, std::forward<Ts>(params)...);
This class represents an opaque handle to a dialect resource entry.
Attributes are known-constant values of operations.
virtual Type readType(DialectBytecodeReader &reader) const
Read a type belonging to this dialect from the given reader.
virtual LogicalResult upgradeFromVersion(Operation *topLevelOp, const DialectVersion &version) const
Hook invoked after parsing completed, if a version directive was present and included an entry for th...
virtual Attribute readAttribute(DialectBytecodeReader &reader) const
Read an attribute belonging to this dialect from the given reader.
virtual std::unique_ptr< DialectVersion > readVersion(DialectBytecodeReader &reader) const
virtual LogicalResult writeAttribute(Attribute attr, DialectBytecodeWriter &writer) const
Write the given attribute, which belongs to this dialect, to the given writer.
virtual LogicalResult writeType(Type type, DialectBytecodeWriter &writer) const
Write the given type, which belongs to this dialect, to the given writer.
virtual void writeVersion(DialectBytecodeWriter &writer) const
Write the version of this dialect to the given writer.
This class defines a virtual interface for reading a bytecode stream, providing hooks into the byteco...
virtual ~DialectBytecodeReader()=default
virtual LogicalResult readBlob(ArrayRef< char > &result)=0
Read a blob from the bytecode.
LogicalResult readAttributes(SmallVectorImpl< T > &attrs)
FailureOr< ResourceT > readResourceHandle()
Read a handle to a dialect resource.
virtual MLIRContext * getContext() const =0
Retrieve the context associated to the reader.
virtual FailureOr< APInt > readAPIntWithKnownWidth(unsigned bitWidth)=0
Read an APInt that is known to have been encoded with the given width.
LogicalResult readTypes(SmallVectorImpl< T > &types)
virtual LogicalResult readBool(bool &result)=0
Read a bool from the bytecode.
virtual LogicalResult readVarInt(uint64_t &result)=0
Read a variable width integer.
virtual LogicalResult readType(Type &result)=0
Read a reference to the given type.
virtual uint64_t getBytecodeVersion() const =0
Return the bytecode version being read.
LogicalResult readType(T &result)
LogicalResult readVarIntWithFlag(uint64_t &result, bool &flag)
Parse a variable length encoded integer whose low bit is used to encode an unrelated flag,...
LogicalResult readSignedVarInts(SmallVectorImpl< int64_t > &result)
LogicalResult readOptionalAttribute(T &result)
FailureOr< const DialectVersion * > getDialectVersion() const
virtual LogicalResult readOptionalAttribute(Attribute &attr)=0
Read an optional reference to the given attribute.
LogicalResult readAttribute(T &result)
virtual InFlightDiagnostic emitError(const Twine &msg={}) const =0
Emit an error to the reader.
LogicalResult readSparseArray(MutableArrayRef< T > array)
Read a "small" sparse array of integer <= 32 bits elements, where index/value pairs can be compressed...
virtual FailureOr< APFloat > readAPFloatWithKnownSemantics(const llvm::fltSemantics &semantics)=0
Read an APFloat that is known to have been encoded with the given semantics.
virtual FailureOr< const DialectVersion * > getDialectVersion(StringRef dialectName) const =0
Retrieve the dialect version by name if available.
virtual LogicalResult readString(StringRef &result)=0
Read a string from the bytecode.
virtual LogicalResult readSignedVarInt(int64_t &result)=0
Read a signed variable width integer.
LogicalResult readList(SmallVectorImpl< T > &result, CallbackFn &&callback)
Read out a list of elements, invoking the provided callback for each element.
virtual LogicalResult readAttribute(Attribute &result)=0
Read a reference to the given attribute.
This class defines a virtual interface for writing to a bytecode stream, providing hooks into the byt...
virtual void writeOptionalAttribute(Attribute attr)=0
FailureOr< const DialectVersion * > getDialectVersion() const
virtual void writeVarInt(uint64_t value)=0
Write a variable width integer to the output stream.
void writeVarIntWithFlag(uint64_t value, bool flag)
Write a VarInt and a flag packed together.
void writeList(RangeT &&range, CallbackFn &&callback)
Write out a list of elements, invoking the provided callback for each element.
void writeSparseArray(ArrayRef< T > array)
Write out a "small" sparse array of integer <= 32 bits elements, where index/value pairs can be compr...
virtual void writeType(Type type)=0
Write a reference to the given type.
virtual FailureOr< const DialectVersion * > getDialectVersion(StringRef dialectName) const =0
Retrieve the dialect version by name if available.
virtual void writeAPIntWithKnownWidth(const APInt &value)=0
Write an APInt to the bytecode stream whose bitwidth will be known externally at read time.
virtual void writeOwnedBlob(ArrayRef< char > blob)=0
Write a blob to the bytecode, which is owned by the caller and is guaranteed to not die before the en...
virtual void writeAttribute(Attribute attr)=0
Write a reference to the given attribute.
virtual ~DialectBytecodeWriter()=default
void writeAttributes(ArrayRef< T > attrs)
virtual void writeSignedVarInt(int64_t value)=0
Write a signed variable width integer to the output stream.
virtual void writeResourceHandle(const AsmDialectResourceHandle &resource)=0
Write the given handle to a dialect resource.
virtual void writeAPFloatWithKnownSemantics(const APFloat &value)=0
Write an APFloat to the bytecode stream whose semantics will be known externally at read time.
void writeSignedVarInts(ArrayRef< int64_t > value)
virtual void writeOwnedBool(bool value)=0
Write a bool to the output stream.
virtual int64_t getBytecodeVersion() const =0
Return the bytecode version being emitted for.
virtual void writeOwnedString(StringRef str)=0
Write a string to the bytecode, which is owned by the caller and is guaranteed to not die before the ...
void writeTypes(ArrayRef< T > types)
This class is used to represent the version of a dialect, for the purpose of polymorphic destruction.
virtual ~DialectVersion()=default
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
Operation is the basic unit of execution within MLIR.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
The base class used for all derived interface types.
decltype(T::get(std::declval< Ts >()...)) has_get_method
Include the generated interface declarations.
static LogicalResult readResourceHandle(DialectBytecodeReader &reader, FailureOr< T > &value, Ts &&...params)
Helper for resource handle reading that returns LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...