62 return dialectEntry->getValue().get();
86 auto sortByDialect = [](
unsigned dialectToOrderFirst,
const auto &lhs,
88 if (lhs->dialect->number == dialectToOrderFirst)
89 return rhs->dialect->number != dialectToOrderFirst;
90 if (rhs->dialect->number == dialectToOrderFirst)
92 return lhs->dialect->number < rhs->dialect->number;
95 unsigned dialectToOrderFirst = 0;
96 size_t elementsInByteGroup = 0;
97 auto iterRange = range;
98 for (
unsigned i = 1; i < 9; ++i) {
102 elementsInByteGroup = (1ULL << (7ULL * i)) - elementsInByteGroup;
106 auto byteSubRange = iterRange.take_front(elementsInByteGroup);
107 iterRange = iterRange.drop_front(byteSubRange.size());
110 llvm::stable_sort(byteSubRange, [&](
const auto &lhs,
const auto &rhs) {
111 return sortByDialect(dialectToOrderFirst, lhs, rhs);
117 dialectToOrderFirst = byteSubRange.back()->dialect->number;
120 if (iterRange.empty())
132 computeGlobalNumberingState(op);
143 auto addOpRegionsToNumber = [&](
Operation *op) {
151 for (
Region ®ion : regions)
152 numberContext.emplace_back(®ion, opFirstValueID);
154 addOpRegionsToNumber(op);
157 while (!numberContext.empty()) {
159 std::tie(region, nextValueID) = numberContext.pop_back_val();
164 addOpRegionsToNumber(&op);
172 dialect.second->number = idx;
179 auto sortByRefCountFn = [](
const auto &lhs,
const auto &rhs) {
180 return lhs->refCount > rhs->refCount;
182 llvm::stable_sort(orderedAttrs, sortByRefCountFn);
183 llvm::stable_sort(orderedOpNames, sortByRefCountFn);
184 llvm::stable_sort(orderedTypes, sortByRefCountFn);
196 finalizeDialectResourceNumberings(op);
199 void IRNumberingState::computeGlobalNumberingState(
Operation *rootOp) {
227 bool hasUnresolvedIsolation;
232 unsigned operationID = 0;
258 OperationNumbering *numbering = opStack.pop_back_val().numbering;
259 if (!numbering->isIsolatedFromAbove.has_value())
260 numbering->isIsolatedFromAbove = true;
270 if (!opStack.empty() && opStack.back().hasUnresolvedIsolation) {
271 Region *parentRegion = op->getParentRegion();
272 for (Value operand : op->getOperands()) {
273 Region *operandRegion = operand.getParentRegion();
274 if (operandRegion == parentRegion)
279 Operation *operandContainerOp = operandRegion->getParentOp();
280 auto it = std::find_if(
281 opStack.rbegin(), opStack.rend(), [=](const StackState &it) {
284 return !it.hasUnresolvedIsolation || it.op == operandContainerOp;
286 assert(it != opStack.rend() &&
"expected to find the container");
287 for (auto &state : llvm::make_range(opStack.rbegin(), it)) {
290 state.hasUnresolvedIsolation = it->hasUnresolvedIsolation;
291 state.numbering->isIsolatedFromAbove = false;
300 numbering->isIsolatedFromAbove =
true;
301 operations.try_emplace(op, numbering);
303 opStack.emplace_back(StackState{
304 op, numbering, !numbering->isIsolatedFromAbove.has_value()});
309 void IRNumberingState::number(
Attribute attr) {
310 auto it = attrs.insert({attr,
nullptr});
312 ++it.first->second->refCount;
316 it.first->second = numbering;
317 orderedAttrs.push_back(numbering);
323 if (OpaqueAttr opaqueAttr = dyn_cast<OpaqueAttr>(attr)) {
324 numbering->
dialect = &numberDialect(opaqueAttr.getDialectNamespace());
327 numbering->dialect = &numberDialect(&attr.
getDialect());
334 for (
const auto &callback : config.getAttributeWriterCallbacks()) {
335 NumberingDialectWriter writer(*
this, config.getDialectVersionMap());
338 std::optional<StringRef> groupNameOverride;
339 if (
succeeded(callback->write(attr, groupNameOverride, writer))) {
340 if (groupNameOverride.has_value())
341 numbering->dialect = &numberDialect(*groupNameOverride);
346 if (
const auto *interface = numbering->dialect->interface) {
347 NumberingDialectWriter writer(*
this, config.getDialectVersionMap());
348 if (
succeeded(interface->writeAttribute(attr, writer)))
357 llvm::raw_null_ostream dummyOS;
358 attr.
print(dummyOS, tempState);
361 for (
const auto &it : tempState.getDialectResources())
362 number(it.getFirst(), it.getSecond().getArrayRef());
365 void IRNumberingState::number(
Block &block) {
368 valueIDs.try_emplace(arg, nextValueID++);
369 number(arg.getLoc());
370 number(arg.getType());
374 unsigned &numOps = blockOperationCounts[&block];
384 numbering = &numberDialect(dialect->getNamespace());
385 numbering->
interface = dyn_cast<BytecodeDialectInterface>(dialect);
386 numbering->
asmInterface = dyn_cast<OpAsmDialectInterface>(dialect);
391 auto IRNumberingState::numberDialect(StringRef dialect) ->
DialectNumbering & {
394 numbering =
new (dialectAllocator.Allocate())
400 void IRNumberingState::number(
Region ®ion) {
403 size_t firstValueID = nextValueID;
406 size_t blockCount = 0;
408 blockIDs.try_emplace(&it.value(), it.index());
414 regionBlockValueCounts.try_emplace(®ion, blockCount,
415 nextValueID - firstValueID);
418 void IRNumberingState::number(
Operation &op) {
423 valueIDs.try_emplace(result, nextValueID++);
424 number(result.getType());
430 DictionaryAttr dictAttr;
436 if (!dictAttr.empty())
441 if (config.getDesiredBytecodeVersion() >=
446 auto iface = cast<BytecodeOpInterface>(op);
447 NumberingDialectWriter writer(*
this, config.getDialectVersionMap());
448 iface.writeProperties(writer);
467 dialectNumber = &numberDialect(dialect);
472 new (opNameAllocator.Allocate())
OpNameNumbering(dialectNumber, opName);
473 orderedOpNames.push_back(numbering);
476 void IRNumberingState::number(
Type type) {
477 auto it = types.insert({type,
nullptr});
482 auto *numbering =
new (typeAllocator.Allocate())
TypeNumbering(type);
483 it.first->second = numbering;
484 orderedTypes.push_back(numbering);
490 if (OpaqueType opaqueType = dyn_cast<OpaqueType>(type)) {
491 numbering->
dialect = &numberDialect(opaqueType.getDialectNamespace());
501 for (
const auto &callback : config.getTypeWriterCallbacks()) {
502 NumberingDialectWriter writer(*
this, config.getDialectVersionMap());
505 std::optional<StringRef> groupNameOverride;
506 if (
succeeded(callback->write(type, groupNameOverride, writer))) {
507 if (groupNameOverride.has_value())
508 numbering->
dialect = &numberDialect(*groupNameOverride);
516 NumberingDialectWriter writer(*
this, config.getDialectVersionMap());
517 if (
succeeded(interface->writeType(type, writer)))
526 llvm::raw_null_ostream dummyOS;
527 type.
print(dummyOS, tempState);
530 for (
const auto &it : tempState.getDialectResources())
531 number(it.getFirst(), it.getSecond().getArrayRef());
534 void IRNumberingState::number(
Dialect *dialect,
539 "expected dialect owning a resource to implement OpAsmDialectInterface");
541 for (
const auto &resource : resources) {
543 if (!dialectNumber.
resources.insert(resource))
549 dialectNumber.
resourceMap.insert({numbering->key, numbering});
550 dialectResources.try_emplace(resource, numbering);
554 int64_t IRNumberingState::getDesiredBytecodeVersion()
const {
555 return config.getDesiredBytecodeVersion();
561 NumberingResourceBuilder(
DialectNumbering *dialect,
unsigned &nextResourceID)
562 : dialect(dialect), nextResourceID(nextResourceID) {}
563 ~NumberingResourceBuilder()
override =
default;
568 void buildBool(StringRef key,
bool)
final { numberEntry(key); }
569 void buildString(StringRef key, StringRef)
final {
575 void numberEntry(StringRef key) {
578 auto *it = dialect->resourceMap.find(key);
579 if (it != dialect->resourceMap.end()) {
580 it->second->number = nextResourceID++;
581 it->second->isDeclaration =
false;
586 unsigned &nextResourceID;
590 void IRNumberingState::finalizeDialectResourceNumberings(
Operation *rootOp) {
591 unsigned nextResourceID = 0;
593 if (!dialect.asmInterface)
595 NumberingResourceBuilder entryBuilder(&dialect, nextResourceID);
596 dialect.asmInterface->buildResources(rootOp, dialect.resources,
603 for (
const auto &it : dialect.resourceMap)
604 if (it.second->isDeclaration)
605 it.second->number = nextResourceID++;
static void groupByDialectPerByte(T range)
Group and sort the elements of the given range by their parent dialect.
This class represents an opaque handle to a dialect resource entry.
Dialect * getDialect() const
Return the dialect that owns the resource.
This class is used to build resource entries for use by the printer.
This class provides management for the lifetime of the state used when printing the IR.
Attributes are known-constant values of operations.
Dialect & getDialect() const
Get the dialect this attribute is registered to.
void print(raw_ostream &os, bool elideType=false) const
Print the attribute.
MLIRContext * getContext() const
Return the context this attribute belongs to.
bool hasTrait()
Returns true if the type was registered with a particular trait.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgListType getArguments()
This class contains the configuration used for the bytecode writer.
This class defines a virtual interface for writing to a bytecode stream, providing hooks into the byt...
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
This class provides support for representing a failure result, or a valid value of type T.
virtual std::string getResourceKey(const AsmDialectResourceHandle &handle) const
Return a key to use for the given resource.
This is a value defined by a result of an operation.
This class provides the API for ops that are known to be isolated from above.
Dialect * getDialect() const
Return the dialect this operation is registered to if the dialect is loaded in the context,...
StringRef getDialectNamespace() const
Return the name of the dialect this operation is registered to.
Operation is the basic unit of execution within MLIR.
DictionaryAttr getAttrDictionary()
Return all of the attributes on this operation as a DictionaryAttr.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
bool isRegistered()
Returns true if this operation has a registered operation description, otherwise false.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
DictionaryAttr getRawDictionaryAttrs()
Return all attributes that are not stored as properties.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
OperationName getName()
The name of an operation is the key identifier for it.
result_range getResults()
int getPropertiesStorageSize() const
Returns the properties storage size.
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
iterator_range< OpIterator > getOps()
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
void print(raw_ostream &os) const
Print the current type.
Dialect & getDialect() const
Get the dialect this type is registered to.
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool hasTrait()
Returns true if the type was registered with a particular trait.
A utility class to encode the current walk stage for "generic" walkers.
bool isAfterAllRegions() const
Return true if parent operation is being visited after all regions.
bool isBeforeAllRegions() const
Return true if parent operation is being visited before all regions.
This class manages numbering IR entities in preparation of bytecode emission.
IRNumberingState(Operation *op, const BytecodeWriterConfig &config)
int64_t getDesiredBytecodeVersion() const
Get the set desired bytecode version to emit.
bool isIsolatedFromAbove(Operation *op)
Return if the given operation is isolated from above.
@ kNativePropertiesEncoding
Support for encoding properties natively in bytecode instead of merged with the discardable attribute...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
llvm::StringMap< std::unique_ptr< DialectVersion > > & dialectVersionMap
A map containing dialect version information for each dialect to emit.
void writeType(Type type) override
Write a reference to the given type.
void writeVarInt(uint64_t) override
Stubbed out methods that are not used for numbering.
void writeOptionalAttribute(Attribute attr) override
void writeOwnedString(StringRef) override
Write a string to the bytecode, which is owned by the caller and is guaranteed to not die before the ...
IRNumberingState & state
The parent numbering state that is populated by this writer.
NumberingDialectWriter(IRNumberingState &state, llvm::StringMap< std::unique_ptr< DialectVersion >> &dialectVersionMap)
void writeAttribute(Attribute attr) override
Write a reference to the given attribute.
void writeResourceHandle(const AsmDialectResourceHandle &resource) override
Write the given handle to a dialect resource.
FailureOr< const DialectVersion * > getDialectVersion(StringRef dialectName) const override
Retrieve the dialect version by name if available.
void writeOwnedBool(bool value) override
Write a bool to the output stream.
void writeAPIntWithKnownWidth(const APInt &value) override
Write an APInt to the bytecode stream whose bitwidth will be known externally at read time.
void writeSignedVarInt(int64_t value) override
Write a signed variable width integer to the output stream.
int64_t getBytecodeVersion() const override
Return the bytecode version being emitted for.
void writeOwnedBlob(ArrayRef< char > blob) override
Write a blob to the bytecode, which is owned by the caller and is guaranteed to not die before the en...
void writeAPFloatWithKnownSemantics(const APFloat &value) override
Write an APFloat to the bytecode stream whose semantics will be known externally at read time.
DialectNumbering * dialect
The dialect of this value.
This class represents a numbering entry for an Dialect.
const BytecodeDialectInterface * interface
The bytecode dialect interface of the dialect if defined.
llvm::MapVector< StringRef, DialectResourceNumbering * > resourceMap
A mapping from resource key to the corresponding resource numbering entry.
SetVector< AsmDialectResourceHandle > resources
The referenced resources of this dialect.
const OpAsmDialectInterface * asmInterface
The asm dialect interface of the dialect if defined.
This class represents a numbering entry for a dialect resource.
This class represents the numbering entry of an operation.
This trait is used to determine if a storage user, like Type, is mutable or not.