55 return state.getDesiredBytecodeVersion();
58 FailureOr<const DialectVersion *>
63 return dialectEntry->getValue().get();
87 auto sortByDialect = [](
unsigned dialectToOrderFirst,
const auto &
lhs,
89 if (
lhs->dialect->number == dialectToOrderFirst)
90 return rhs->dialect->number != dialectToOrderFirst;
91 if (
rhs->dialect->number == dialectToOrderFirst)
93 return lhs->dialect->number <
rhs->dialect->number;
96 unsigned dialectToOrderFirst = 0;
97 size_t elementsInByteGroup = 0;
98 auto iterRange = range;
99 for (
unsigned i = 1; i < 9; ++i) {
103 elementsInByteGroup = (1ULL << (7ULL * i)) - elementsInByteGroup;
107 auto byteSubRange = iterRange.take_front(elementsInByteGroup);
108 iterRange = iterRange.drop_front(byteSubRange.size());
111 llvm::stable_sort(byteSubRange, [&](
const auto &
lhs,
const auto &
rhs) {
112 return sortByDialect(dialectToOrderFirst,
lhs,
rhs);
118 dialectToOrderFirst = byteSubRange.back()->dialect->number;
121 if (iterRange.empty())
126 for (
auto [idx, value] : llvm::enumerate(range))
133 computeGlobalNumberingState(op);
144 auto addOpRegionsToNumber = [&](
Operation *op) {
152 for (
Region ®ion : regions)
153 numberContext.emplace_back(®ion, opFirstValueID);
155 addOpRegionsToNumber(op);
158 while (!numberContext.empty()) {
160 std::tie(region, nextValueID) = numberContext.pop_back_val();
165 addOpRegionsToNumber(&op);
172 for (
auto [idx, dialect] : llvm::enumerate(dialects))
173 dialect.second->number = idx;
180 auto sortByRefCountFn = [](
const auto &
lhs,
const auto &
rhs) {
181 return lhs->refCount >
rhs->refCount;
183 llvm::stable_sort(orderedAttrs, sortByRefCountFn);
184 llvm::stable_sort(orderedOpNames, sortByRefCountFn);
185 llvm::stable_sort(orderedTypes, sortByRefCountFn);
197 finalizeDialectResourceNumberings(op);
200void IRNumberingState::computeGlobalNumberingState(
Operation *rootOp) {
228 bool hasUnresolvedIsolation;
233 unsigned operationID = 0;
271 if (!opStack.empty() && opStack.back().hasUnresolvedIsolation) {
275 if (operandRegion == parentRegion)
281 auto it = std::find_if(
282 opStack.rbegin(), opStack.rend(), [=](
const StackState &it) {
285 return !it.hasUnresolvedIsolation || it.op == operandContainerOp;
287 assert(it != opStack.rend() &&
"expected to find the container");
288 for (
auto &state : llvm::make_range(opStack.rbegin(), it)) {
291 state.hasUnresolvedIsolation = it->hasUnresolvedIsolation;
292 state.numbering->isIsolatedFromAbove =
false;
299 new (opAllocator.Allocate()) OperationNumbering(operationID++);
300 if (op->
hasTrait<OpTrait::IsIsolatedFromAbove>())
302 operations.try_emplace(op, numbering);
304 opStack.emplace_back(StackState{
310void IRNumberingState::number(Attribute attr) {
311 auto it = attrs.try_emplace(attr);
313 ++it.first->second->refCount;
316 auto *numbering =
new (attrAllocator.Allocate()) AttributeNumbering(attr);
317 it.first->second = numbering;
318 orderedAttrs.push_back(numbering);
324 if (OpaqueAttr opaqueAttr = dyn_cast<OpaqueAttr>(attr)) {
325 numbering->dialect = &numberDialect(opaqueAttr.getDialectNamespace());
328 numbering->dialect = &numberDialect(&attr.
getDialect());
335 for (
const auto &callback : config.getAttributeWriterCallbacks()) {
339 std::optional<StringRef> groupNameOverride;
340 if (succeeded(callback->write(attr, groupNameOverride, writer))) {
341 if (groupNameOverride.has_value())
342 numbering->dialect = &numberDialect(*groupNameOverride);
347 if (
const auto *interface = numbering->dialect->interface) {
349 if (succeeded(interface->writeAttribute(attr, writer)))
358 llvm::raw_null_ostream dummyOS;
359 attr.
print(dummyOS, tempState);
362 for (
const auto &it : tempState.getDialectResources())
363 number(it.getFirst(), it.getSecond().getArrayRef());
366void IRNumberingState::number(
Block &block) {
369 valueIDs.try_emplace(arg, nextValueID++);
370 number(arg.getLoc());
371 number(arg.getType());
375 unsigned &numOps = blockOperationCounts[&block];
376 for (Operation &op : block) {
382auto IRNumberingState::numberDialect(Dialect *dialect) ->
DialectNumbering & {
385 numbering = &numberDialect(dialect->getNamespace());
386 numbering->
interface = dyn_cast<BytecodeDialectInterface>(dialect);
387 numbering->
asmInterface = dyn_cast<OpAsmDialectInterface>(dialect);
392auto IRNumberingState::numberDialect(StringRef dialect) ->
DialectNumbering & {
395 numbering =
new (dialectAllocator.Allocate())
401void IRNumberingState::number(Region ®ion) {
404 size_t firstValueID = nextValueID;
407 size_t blockCount = 0;
408 for (
auto it : llvm::enumerate(region)) {
409 blockIDs.try_emplace(&it.value(), it.index());
415 regionBlockValueCounts.try_emplace(®ion, blockCount,
416 nextValueID - firstValueID);
419void IRNumberingState::number(Operation &op) {
424 valueIDs.try_emplace(
result, nextValueID++);
431 DictionaryAttr dictAttr;
437 if (!dictAttr.empty())
442 if (config.getDesiredBytecodeVersion() >=
447 auto iface = cast<BytecodeOpInterface>(op);
449 iface.writeProperties(writer);
460void IRNumberingState::number(OperationName opName) {
461 OpNameNumbering *&numbering = opNames[opName];
466 DialectNumbering *dialectNumber =
nullptr;
468 dialectNumber = &numberDialect(dialect);
473 new (opNameAllocator.Allocate()) OpNameNumbering(dialectNumber, opName);
474 orderedOpNames.push_back(numbering);
477void IRNumberingState::number(Type type) {
478 auto it = types.try_emplace(type);
480 ++it.first->second->refCount;
483 auto *numbering =
new (typeAllocator.Allocate()) TypeNumbering(type);
484 it.first->second = numbering;
485 orderedTypes.push_back(numbering);
491 if (OpaqueType opaqueType = dyn_cast<OpaqueType>(type)) {
492 numbering->
dialect = &numberDialect(opaqueType.getDialectNamespace());
502 for (
const auto &callback : config.getTypeWriterCallbacks()) {
506 std::optional<StringRef> groupNameOverride;
507 if (succeeded(callback->write(type, groupNameOverride, writer))) {
508 if (groupNameOverride.has_value())
509 numbering->
dialect = &numberDialect(*groupNameOverride);
518 if (succeeded(interface->writeType(type, writer)))
527 llvm::raw_null_ostream dummyOS;
528 type.
print(dummyOS, tempState);
531 for (
const auto &it : tempState.getDialectResources())
532 number(it.getFirst(), it.getSecond().getArrayRef());
535void IRNumberingState::number(Dialect *dialect,
536 ArrayRef<AsmDialectResourceHandle> resources) {
537 DialectNumbering &dialectNumber = numberDialect(dialect);
540 "expected dialect owning a resource to implement OpAsmDialectInterface");
542 for (
const auto &resource : resources) {
544 if (!dialectNumber.
resources.insert(resource))
548 new (resourceAllocator.Allocate()) DialectResourceNumbering(
550 dialectNumber.
resourceMap.insert({numbering->key, numbering});
551 dialectResources.try_emplace(resource, numbering);
556 return config.getDesiredBytecodeVersion();
562 NumberingResourceBuilder(
DialectNumbering *dialect,
unsigned &nextResourceID)
563 : dialect(dialect), nextResourceID(nextResourceID) {}
564 ~NumberingResourceBuilder()
override =
default;
569 void buildBool(StringRef key,
bool)
final { numberEntry(key); }
570 void buildString(StringRef key, StringRef)
final {
576 void numberEntry(StringRef key) {
579 auto *it = dialect->resourceMap.find(key);
580 if (it != dialect->resourceMap.end()) {
581 it->second->number = nextResourceID++;
582 it->second->isDeclaration =
false;
586 DialectNumbering *dialect;
587 unsigned &nextResourceID;
591void IRNumberingState::finalizeDialectResourceNumberings(Operation *rootOp) {
592 unsigned nextResourceID = 0;
594 if (!dialect.asmInterface)
596 NumberingResourceBuilder entryBuilder(&dialect, nextResourceID);
597 dialect.asmInterface->buildResources(rootOp, dialect.resources,
604 for (
const auto &it : dialect.resourceMap)
605 if (it.second->isDeclaration)
606 it.second->number = nextResourceID++;
static void groupByDialectPerByte(T range)
Group and sort the elements of the given range by their parent dialect.
This class represents an opaque handle to a dialect resource entry.
Dialect * getDialect() const
Return the dialect that owns the resource.
This class is used to build resource entries for use by the printer.
Attributes are known-constant values of operations.
Dialect & getDialect() const
Get the dialect this attribute is registered to.
void print(raw_ostream &os, bool elideType=false) const
Print the attribute.
MLIRContext * getContext() const
Return the context this attribute belongs to.
bool hasTrait()
Returns true if the type was registered with a particular trait.
BlockArgListType getArguments()
This class contains the configuration used for the bytecode writer.
This class defines a virtual interface for writing to a bytecode stream, providing hooks into the byt...
virtual std::string getResourceKey(const AsmDialectResourceHandle &handle) const
Return a key to use for the given resource.
Dialect * getDialect() const
Return the dialect this operation is registered to if the dialect is loaded in the context,...
StringRef getDialectNamespace() const
Return the name of the dialect this operation is registered to.
Operation is the basic unit of execution within MLIR.
DictionaryAttr getAttrDictionary()
Return all of the attributes on this operation as a DictionaryAttr.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
bool isRegistered()
Returns true if this operation has a registered operation description, otherwise false.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
DictionaryAttr getRawDictionaryAttrs()
Return all attributes that are not stored as properties.
OperationName getName()
The name of an operation is the key identifier for it.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
operand_range getOperands()
Returns an iterator on the underlying Value's.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
result_range getResults()
int getPropertiesStorageSize() const
Returns the properties storage size.
Region * getParentRegion()
Returns the region to which the instruction belongs.
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
iterator_range< OpIterator > getOps()
Operation * getParentOp()
Return the parent operation this region is attached to.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
void print(raw_ostream &os) const
Print the current type.
Dialect & getDialect() const
Get the dialect this type is registered to.
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool hasTrait()
Returns true if the type was registered with a particular trait.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
A utility class to encode the current walk stage for "generic" walkers.
bool isAfterAllRegions() const
Return true if parent operation is being visited after all regions.
bool isBeforeAllRegions() const
Return true if parent operation is being visited before all regions.
IRNumberingState(Operation *op, const BytecodeWriterConfig &config)
int64_t getDesiredBytecodeVersion() const
Get the set desired bytecode version to emit.
bool isIsolatedFromAbove(Operation *op)
Return if the given operation is isolated from above.
auto getDialects()
Return the numbered dialects.
detail::StorageUserTrait::IsMutable< ConcreteType > IsMutable
This trait is used to determine if an attribute is mutable or not.
detail::StorageUserTrait::IsMutable< ConcreteType > IsMutable
@ kNativePropertiesEncoding
Support for encoding properties natively in bytecode instead of merged with the discardable attribute...
Include the generated interface declarations.
NumberingDialectWriter(IRNumberingState &state, llvm::StringMap< std::unique_ptr< DialectVersion > > &dialectVersionMap)
llvm::StringMap< std::unique_ptr< DialectVersion > > & dialectVersionMap
A map containing dialect version information for each dialect to emit.
void writeType(Type type) override
Write a reference to the given type.
void writeVarInt(uint64_t) override
Stubbed out methods that are not used for numbering.
FailureOr< const DialectVersion * > getDialectVersion(StringRef dialectName) const override
Retrieve the dialect version by name if available.
void writeOptionalAttribute(Attribute attr) override
void writeOwnedString(StringRef) override
Write a string to the bytecode, which is owned by the caller and is guaranteed to not die before the ...
IRNumberingState & state
The parent numbering state that is populated by this writer.
void writeAttribute(Attribute attr) override
Write a reference to the given attribute.
void writeResourceHandle(const AsmDialectResourceHandle &resource) override
Write the given handle to a dialect resource.
void writeOwnedBool(bool value) override
Write a bool to the output stream.
void writeAPIntWithKnownWidth(const APInt &value) override
Write an APInt to the bytecode stream whose bitwidth will be known externally at read time.
void writeSignedVarInt(int64_t value) override
Write a signed variable width integer to the output stream.
int64_t getBytecodeVersion() const override
Return the bytecode version being emitted for.
void writeOwnedBlob(ArrayRef< char > blob) override
Write a blob to the bytecode, which is owned by the caller and is guaranteed to not die before the en...
void writeAPFloatWithKnownSemantics(const APFloat &value) override
Write an APFloat to the bytecode stream whose semantics will be known externally at read time.
This class represents a numbering entry for an Dialect.
const BytecodeDialectInterface * interface
The bytecode dialect interface of the dialect if defined.
llvm::MapVector< StringRef, DialectResourceNumbering * > resourceMap
A mapping from resource key to the corresponding resource numbering entry.
SetVector< AsmDialectResourceHandle > resources
The referenced resources of this dialect.
const OpAsmDialectInterface * asmInterface
The asm dialect interface of the dialect if defined.
This class represents the numbering entry of an operation.
std::optional< bool > isIsolatedFromAbove
A flag indicating if this operation's regions are isolated.