56 return state.getDesiredBytecodeVersion();
59 FailureOr<const DialectVersion *>
64 return dialectEntry->getValue().get();
88 auto sortByDialect = [](
unsigned dialectToOrderFirst,
const auto &
lhs,
90 if (
lhs->dialect->number == dialectToOrderFirst)
91 return rhs->dialect->number != dialectToOrderFirst;
92 if (
rhs->dialect->number == dialectToOrderFirst)
94 return lhs->dialect->number <
rhs->dialect->number;
97 unsigned dialectToOrderFirst = 0;
98 size_t elementsInByteGroup = 0;
99 auto iterRange = range;
100 for (
unsigned i = 1; i < 9; ++i) {
104 elementsInByteGroup = (1ULL << (7ULL * i)) - elementsInByteGroup;
108 auto byteSubRange = iterRange.take_front(elementsInByteGroup);
109 iterRange = iterRange.drop_front(byteSubRange.size());
112 llvm::stable_sort(byteSubRange, [&](
const auto &
lhs,
const auto &
rhs) {
113 return sortByDialect(dialectToOrderFirst,
lhs,
rhs);
119 dialectToOrderFirst = byteSubRange.back()->dialect->number;
122 if (iterRange.empty())
127 for (
auto [idx, value] : llvm::enumerate(range))
134 computeGlobalNumberingState(op);
145 auto addOpRegionsToNumber = [&](
Operation *op) {
153 for (
Region ®ion : regions)
154 numberContext.emplace_back(®ion, opFirstValueID);
156 addOpRegionsToNumber(op);
159 while (!numberContext.empty()) {
161 std::tie(region, nextValueID) = numberContext.pop_back_val();
166 addOpRegionsToNumber(&op);
173 for (
auto [idx, dialect] : llvm::enumerate(dialects))
174 dialect.second->number = idx;
181 auto sortByRefCountFn = [](
const auto &
lhs,
const auto &
rhs) {
182 return lhs->refCount >
rhs->refCount;
184 llvm::stable_sort(orderedAttrs, sortByRefCountFn);
185 llvm::stable_sort(orderedOpNames, sortByRefCountFn);
186 llvm::stable_sort(orderedTypes, sortByRefCountFn);
198 finalizeDialectResourceNumberings(op);
201void IRNumberingState::computeGlobalNumberingState(
Operation *rootOp) {
229 bool hasUnresolvedIsolation;
234 unsigned operationID = 0;
272 if (!opStack.empty() && opStack.back().hasUnresolvedIsolation) {
276 if (operandRegion == parentRegion)
282 auto it = std::find_if(
283 opStack.rbegin(), opStack.rend(), [=](
const StackState &it) {
286 return !it.hasUnresolvedIsolation || it.op == operandContainerOp;
288 assert(it != opStack.rend() &&
"expected to find the container");
289 for (
auto &state : llvm::make_range(opStack.rbegin(), it)) {
292 state.hasUnresolvedIsolation = it->hasUnresolvedIsolation;
293 state.numbering->isIsolatedFromAbove =
false;
300 new (opAllocator.Allocate()) OperationNumbering(operationID++);
301 if (op->
hasTrait<OpTrait::IsIsolatedFromAbove>())
303 operations.try_emplace(op, numbering);
305 opStack.emplace_back(StackState{
311void IRNumberingState::number(Attribute attr) {
312 auto it = attrs.try_emplace(attr);
314 ++it.first->second->refCount;
317 auto *numbering =
new (attrAllocator.Allocate()) AttributeNumbering(attr);
318 it.first->second = numbering;
319 orderedAttrs.push_back(numbering);
325 if (OpaqueAttr opaqueAttr = dyn_cast<OpaqueAttr>(attr)) {
326 numbering->dialect = &numberDialect(opaqueAttr.getDialectNamespace());
329 numbering->dialect = &numberDialect(&attr.
getDialect());
336 for (
const auto &callback : config.getAttributeWriterCallbacks()) {
340 std::optional<StringRef> groupNameOverride;
341 if (succeeded(callback->write(attr, groupNameOverride, writer))) {
342 if (groupNameOverride.has_value())
343 numbering->dialect = &numberDialect(*groupNameOverride);
348 if (
const auto *interface = numbering->dialect->interface) {
350 if (succeeded(interface->writeAttribute(attr, writer)))
359 llvm::raw_null_ostream dummyOS;
360 attr.
print(dummyOS, tempState);
363 for (
const auto &it : tempState.getDialectResources())
364 number(it.getFirst(), it.getSecond().getArrayRef());
367void IRNumberingState::number(
Block &block) {
370 valueIDs.try_emplace(arg, nextValueID++);
371 number(arg.getLoc());
372 number(arg.getType());
376 unsigned &numOps = blockOperationCounts[&block];
377 for (Operation &op : block) {
383auto IRNumberingState::numberDialect(Dialect *dialect) ->
DialectNumbering & {
386 numbering = &numberDialect(dialect->getNamespace());
387 numbering->
interface = dyn_cast<BytecodeDialectInterface>(dialect);
388 numbering->
asmInterface = dyn_cast<OpAsmDialectInterface>(dialect);
393auto IRNumberingState::numberDialect(StringRef dialect) ->
DialectNumbering & {
396 numbering =
new (dialectAllocator.Allocate())
402void IRNumberingState::number(Region ®ion) {
405 size_t firstValueID = nextValueID;
408 size_t blockCount = 0;
409 for (
auto it : llvm::enumerate(region)) {
410 blockIDs.try_emplace(&it.value(), it.index());
416 regionBlockValueCounts.try_emplace(®ion, blockCount,
417 nextValueID - firstValueID);
420void IRNumberingState::number(Operation &op) {
425 valueIDs.try_emplace(
result, nextValueID++);
432 DictionaryAttr dictAttr;
438 if (!dictAttr.empty())
443 if (config.getDesiredBytecodeVersion() >=
448 auto iface = cast<BytecodeOpInterface>(op);
450 iface.writeProperties(writer);
461void IRNumberingState::number(OperationName opName) {
462 OpNameNumbering *&numbering = opNames[opName];
467 DialectNumbering *dialectNumber =
nullptr;
469 dialectNumber = &numberDialect(dialect);
474 new (opNameAllocator.Allocate()) OpNameNumbering(dialectNumber, opName);
475 orderedOpNames.push_back(numbering);
478void IRNumberingState::number(Type type) {
479 auto it = types.try_emplace(type);
481 ++it.first->second->refCount;
484 auto *numbering =
new (typeAllocator.Allocate()) TypeNumbering(type);
485 it.first->second = numbering;
486 orderedTypes.push_back(numbering);
492 if (OpaqueType opaqueType = dyn_cast<OpaqueType>(type)) {
493 numbering->
dialect = &numberDialect(opaqueType.getDialectNamespace());
503 for (
const auto &callback : config.getTypeWriterCallbacks()) {
507 std::optional<StringRef> groupNameOverride;
508 if (succeeded(callback->write(type, groupNameOverride, writer))) {
509 if (groupNameOverride.has_value())
510 numbering->
dialect = &numberDialect(*groupNameOverride);
519 if (succeeded(interface->writeType(type, writer)))
528 llvm::raw_null_ostream dummyOS;
529 type.
print(dummyOS, tempState);
532 for (
const auto &it : tempState.getDialectResources())
533 number(it.getFirst(), it.getSecond().getArrayRef());
536void IRNumberingState::number(Dialect *dialect,
537 ArrayRef<AsmDialectResourceHandle> resources) {
538 DialectNumbering &dialectNumber = numberDialect(dialect);
541 "expected dialect owning a resource to implement OpAsmDialectInterface");
543 for (
const auto &resource : resources) {
545 if (!dialectNumber.
resources.insert(resource))
549 new (resourceAllocator.Allocate()) DialectResourceNumbering(
551 dialectNumber.
resourceMap.insert({numbering->key, numbering});
552 dialectResources.try_emplace(resource, numbering);
557 return config.getDesiredBytecodeVersion();
563 NumberingResourceBuilder(
DialectNumbering *dialect,
unsigned &nextResourceID)
564 : dialect(dialect), nextResourceID(nextResourceID) {}
565 ~NumberingResourceBuilder()
override =
default;
570 void buildBool(StringRef key,
bool)
final { numberEntry(key); }
571 void buildString(StringRef key, StringRef)
final {
577 void numberEntry(StringRef key) {
580 auto *it = dialect->resourceMap.find(key);
581 if (it != dialect->resourceMap.end()) {
582 it->second->number = nextResourceID++;
583 it->second->isDeclaration =
false;
587 DialectNumbering *dialect;
588 unsigned &nextResourceID;
592void IRNumberingState::finalizeDialectResourceNumberings(Operation *rootOp) {
593 unsigned nextResourceID = 0;
595 if (!dialect.asmInterface)
597 NumberingResourceBuilder entryBuilder(&dialect, nextResourceID);
598 dialect.asmInterface->buildResources(rootOp, dialect.resources,
605 for (
const auto &it : dialect.resourceMap)
606 if (it.second->isDeclaration)
607 it.second->number = nextResourceID++;
static void groupByDialectPerByte(T range)
Group and sort the elements of the given range by their parent dialect.
This class represents an opaque handle to a dialect resource entry.
Dialect * getDialect() const
Return the dialect that owns the resource.
This class is used to build resource entries for use by the printer.
Attributes are known-constant values of operations.
Dialect & getDialect() const
Get the dialect this attribute is registered to.
void print(raw_ostream &os, bool elideType=false) const
Print the attribute.
MLIRContext * getContext() const
Return the context this attribute belongs to.
bool hasTrait()
Returns true if the type was registered with a particular trait.
BlockArgListType getArguments()
This class contains the configuration used for the bytecode writer.
This class defines a virtual interface for writing to a bytecode stream, providing hooks into the byt...
Dialect * getDialect() const
Return the dialect this operation is registered to if the dialect is loaded in the context,...
StringRef getDialectNamespace() const
Return the name of the dialect this operation is registered to.
Operation is the basic unit of execution within MLIR.
DictionaryAttr getAttrDictionary()
Return all of the attributes on this operation as a DictionaryAttr.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
bool isRegistered()
Returns true if this operation has a registered operation description, otherwise false.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
DictionaryAttr getRawDictionaryAttrs()
Return all attributes that are not stored as properties.
OperationName getName()
The name of an operation is the key identifier for it.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
operand_range getOperands()
Returns an iterator on the underlying Value's.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
result_range getResults()
int getPropertiesStorageSize() const
Returns the properties storage size.
Region * getParentRegion()
Returns the region to which the instruction belongs.
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
iterator_range< OpIterator > getOps()
Operation * getParentOp()
Return the parent operation this region is attached to.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
void print(raw_ostream &os) const
Print the current type.
Dialect & getDialect() const
Get the dialect this type is registered to.
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool hasTrait()
Returns true if the type was registered with a particular trait.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
A utility class to encode the current walk stage for "generic" walkers.
bool isAfterAllRegions() const
Return true if parent operation is being visited after all regions.
bool isBeforeAllRegions() const
Return true if parent operation is being visited before all regions.
IRNumberingState(Operation *op, const BytecodeWriterConfig &config)
int64_t getDesiredBytecodeVersion() const
Get the set desired bytecode version to emit.
bool isIsolatedFromAbove(Operation *op)
Return if the given operation is isolated from above.
auto getDialects()
Return the numbered dialects.
detail::StorageUserTrait::IsMutable< ConcreteType > IsMutable
This trait is used to determine if an attribute is mutable or not.
detail::StorageUserTrait::IsMutable< ConcreteType > IsMutable
@ kNativePropertiesEncoding
Support for encoding properties natively in bytecode instead of merged with the discardable attribute...
Include the generated interface declarations.
NumberingDialectWriter(IRNumberingState &state, llvm::StringMap< std::unique_ptr< DialectVersion > > &dialectVersionMap)
llvm::StringMap< std::unique_ptr< DialectVersion > > & dialectVersionMap
A map containing dialect version information for each dialect to emit.
void writeType(Type type) override
Write a reference to the given type.
void writeVarInt(uint64_t) override
Stubbed out methods that are not used for numbering.
FailureOr< const DialectVersion * > getDialectVersion(StringRef dialectName) const override
Retrieve the dialect version by name if available.
void writeOptionalAttribute(Attribute attr) override
void writeOwnedString(StringRef) override
Write a string to the bytecode, which is owned by the caller and is guaranteed to not die before the ...
IRNumberingState & state
The parent numbering state that is populated by this writer.
void writeUnownedBlob(ArrayRef< char > blob) override
Write a blob to the bytecode, which is not owned by the caller.
void writeAttribute(Attribute attr) override
Write a reference to the given attribute.
void writeResourceHandle(const AsmDialectResourceHandle &resource) override
Write the given handle to a dialect resource.
void writeOwnedBool(bool value) override
Write a bool to the output stream.
void writeAPIntWithKnownWidth(const APInt &value) override
Write an APInt to the bytecode stream whose bitwidth will be known externally at read time.
void writeSignedVarInt(int64_t value) override
Write a signed variable width integer to the output stream.
int64_t getBytecodeVersion() const override
Return the bytecode version being emitted for.
void writeOwnedBlob(ArrayRef< char > blob) override
Write a blob to the bytecode, which is owned by the caller and is guaranteed to not die before the en...
void writeAPFloatWithKnownSemantics(const APFloat &value) override
Write an APFloat to the bytecode stream whose semantics will be known externally at read time.
This class represents a numbering entry for an Dialect.
const BytecodeDialectInterface * interface
The bytecode dialect interface of the dialect if defined.
llvm::MapVector< StringRef, DialectResourceNumbering * > resourceMap
A mapping from resource key to the corresponding resource numbering entry.
SetVector< AsmDialectResourceHandle > resources
The referenced resources of this dialect.
const OpAsmDialectInterface * asmInterface
The asm dialect interface of the dialect if defined.
This class represents the numbering entry of an operation.
std::optional< bool > isIsolatedFromAbove
A flag indicating if this operation's regions are isolated.