15 #include "llvm/Support/ErrorHandling.h"
72 auto sortByDialect = [](
unsigned dialectToOrderFirst,
const auto &lhs,
74 if (lhs->dialect->number == dialectToOrderFirst)
75 return rhs->dialect->number != dialectToOrderFirst;
76 if (rhs->dialect->number == dialectToOrderFirst)
78 return lhs->dialect->number < rhs->dialect->number;
81 unsigned dialectToOrderFirst = 0;
82 size_t elementsInByteGroup = 0;
83 auto iterRange = range;
84 for (
unsigned i = 1; i < 9; ++i) {
88 elementsInByteGroup = (1ULL << (7ULL * i)) - elementsInByteGroup;
92 auto byteSubRange = iterRange.take_front(elementsInByteGroup);
93 iterRange = iterRange.drop_front(byteSubRange.size());
96 llvm::stable_sort(byteSubRange, [&](
const auto &lhs,
const auto &rhs) {
97 return sortByDialect(dialectToOrderFirst, lhs, rhs);
103 dialectToOrderFirst = byteSubRange.back()->dialect->number;
106 if (iterRange.empty())
118 computeGlobalNumberingState(op);
129 auto addOpRegionsToNumber = [&](
Operation *op) {
137 for (
Region ®ion : regions)
138 numberContext.emplace_back(®ion, opFirstValueID);
140 addOpRegionsToNumber(op);
143 while (!numberContext.empty()) {
145 std::tie(region, nextValueID) = numberContext.pop_back_val();
150 addOpRegionsToNumber(&op);
158 dialect.second->number = idx;
165 auto sortByRefCountFn = [](
const auto &lhs,
const auto &rhs) {
166 return lhs->refCount > rhs->refCount;
168 llvm::stable_sort(orderedAttrs, sortByRefCountFn);
169 llvm::stable_sort(orderedOpNames, sortByRefCountFn);
170 llvm::stable_sort(orderedTypes, sortByRefCountFn);
182 finalizeDialectResourceNumberings(op);
185 void IRNumberingState::computeGlobalNumberingState(
Operation *rootOp) {
213 bool hasUnresolvedIsolation;
218 unsigned operationID = 0;
244 OperationNumbering *numbering = opStack.pop_back_val().numbering;
245 if (!numbering->isIsolatedFromAbove.has_value())
246 numbering->isIsolatedFromAbove = true;
256 if (!opStack.empty() && opStack.back().hasUnresolvedIsolation) {
257 Region *parentRegion = op->getParentRegion();
258 for (Value operand : op->getOperands()) {
259 Region *operandRegion = operand.getParentRegion();
260 if (operandRegion == parentRegion)
265 Operation *operandContainerOp = operandRegion->getParentOp();
266 auto it = std::find_if(
267 opStack.rbegin(), opStack.rend(), [=](const StackState &it) {
270 return !it.hasUnresolvedIsolation || it.op == operandContainerOp;
272 assert(it != opStack.rend() &&
"expected to find the container");
273 for (auto &state : llvm::make_range(opStack.rbegin(), it)) {
276 state.hasUnresolvedIsolation = it->hasUnresolvedIsolation;
277 state.numbering->isIsolatedFromAbove = false;
286 numbering->isIsolatedFromAbove =
true;
287 operations.try_emplace(op, numbering);
289 opStack.emplace_back(StackState{
290 op, numbering, !numbering->isIsolatedFromAbove.has_value()});
295 void IRNumberingState::number(
Attribute attr) {
296 auto it = attrs.insert({attr,
nullptr});
298 ++it.first->second->refCount;
302 it.first->second = numbering;
303 orderedAttrs.push_back(numbering);
309 if (OpaqueAttr opaqueAttr = dyn_cast<OpaqueAttr>(attr)) {
310 numbering->
dialect = &numberDialect(opaqueAttr.getDialectNamespace());
313 numbering->dialect = &numberDialect(&attr.
getDialect());
320 for (
const auto &callback : config.getAttributeWriterCallbacks()) {
321 NumberingDialectWriter writer(*
this);
324 std::optional<StringRef> groupNameOverride;
325 if (
succeeded(callback->write(attr, groupNameOverride, writer))) {
326 if (groupNameOverride.has_value())
327 numbering->dialect = &numberDialect(*groupNameOverride);
332 if (
const auto *interface = numbering->dialect->interface) {
333 NumberingDialectWriter writer(*
this);
334 if (
succeeded(interface->writeAttribute(attr, writer)))
343 llvm::raw_null_ostream dummyOS;
344 attr.
print(dummyOS, tempState);
347 for (
const auto &it : tempState.getDialectResources())
348 number(it.getFirst(), it.getSecond().getArrayRef());
351 void IRNumberingState::number(
Block &block) {
354 valueIDs.try_emplace(arg, nextValueID++);
355 number(arg.getLoc());
356 number(arg.getType());
360 unsigned &numOps = blockOperationCounts[&block];
370 numbering = &numberDialect(dialect->getNamespace());
371 numbering->
interface = dyn_cast<BytecodeDialectInterface>(dialect);
372 numbering->
asmInterface = dyn_cast<OpAsmDialectInterface>(dialect);
377 auto IRNumberingState::numberDialect(StringRef dialect) ->
DialectNumbering & {
380 numbering =
new (dialectAllocator.Allocate())
386 void IRNumberingState::number(
Region ®ion) {
389 size_t firstValueID = nextValueID;
392 size_t blockCount = 0;
394 blockIDs.try_emplace(&it.value(), it.index());
400 regionBlockValueCounts.try_emplace(®ion, blockCount,
401 nextValueID - firstValueID);
404 void IRNumberingState::number(
Operation &op) {
409 valueIDs.try_emplace(result, nextValueID++);
410 number(result.getType());
417 if (config.getDesiredBytecodeVersion() < 5)
419 if (!dictAttr.empty())
424 if (config.getDesiredBytecodeVersion() >= 5 &&
428 auto iface = cast<BytecodeOpInterface>(op);
429 NumberingDialectWriter writer(*
this);
430 iface.writeProperties(writer);
449 dialectNumber = &numberDialect(dialect);
454 new (opNameAllocator.Allocate())
OpNameNumbering(dialectNumber, opName);
455 orderedOpNames.push_back(numbering);
458 void IRNumberingState::number(
Type type) {
459 auto it = types.insert({type,
nullptr});
464 auto *numbering =
new (typeAllocator.Allocate())
TypeNumbering(type);
465 it.first->second = numbering;
466 orderedTypes.push_back(numbering);
472 if (OpaqueType opaqueType = dyn_cast<OpaqueType>(type)) {
473 numbering->
dialect = &numberDialect(opaqueType.getDialectNamespace());
483 for (
const auto &callback : config.getTypeWriterCallbacks()) {
484 NumberingDialectWriter writer(*
this);
487 std::optional<StringRef> groupNameOverride;
488 if (
succeeded(callback->write(type, groupNameOverride, writer))) {
489 if (groupNameOverride.has_value())
490 numbering->
dialect = &numberDialect(*groupNameOverride);
498 NumberingDialectWriter writer(*
this);
499 if (
succeeded(interface->writeType(type, writer)))
508 llvm::raw_null_ostream dummyOS;
509 type.
print(dummyOS, tempState);
512 for (
const auto &it : tempState.getDialectResources())
513 number(it.getFirst(), it.getSecond().getArrayRef());
516 void IRNumberingState::number(
Dialect *dialect,
521 "expected dialect owning a resource to implement OpAsmDialectInterface");
523 for (
const auto &resource : resources) {
525 if (!dialectNumber.
resources.insert(resource))
531 dialectNumber.
resourceMap.insert({numbering->key, numbering});
532 dialectResources.try_emplace(resource, numbering);
536 int64_t IRNumberingState::getDesiredBytecodeVersion()
const {
537 return config.getDesiredBytecodeVersion();
543 NumberingResourceBuilder(
DialectNumbering *dialect,
unsigned &nextResourceID)
544 : dialect(dialect), nextResourceID(nextResourceID) {}
545 ~NumberingResourceBuilder()
override =
default;
550 void buildBool(StringRef key,
bool)
final { numberEntry(key); }
551 void buildString(StringRef key, StringRef)
final {
557 void numberEntry(StringRef key) {
560 auto it = dialect->resourceMap.find(key);
561 if (it != dialect->resourceMap.end()) {
562 it->second->number = nextResourceID++;
563 it->second->isDeclaration =
false;
568 unsigned &nextResourceID;
572 void IRNumberingState::finalizeDialectResourceNumberings(
Operation *rootOp) {
573 unsigned nextResourceID = 0;
575 if (!dialect.asmInterface)
577 NumberingResourceBuilder entryBuilder(&dialect, nextResourceID);
578 dialect.asmInterface->buildResources(rootOp, dialect.resources,
585 for (
const auto &it : dialect.resourceMap)
586 if (it.second->isDeclaration)
587 it.second->number = nextResourceID++;
static void groupByDialectPerByte(T range)
Group and sort the elements of the given range by their parent dialect.
This class represents an opaque handle to a dialect resource entry.
Dialect * getDialect() const
Return the dialect that owns the resource.
This class is used to build resource entries for use by the printer.
This class provides management for the lifetime of the state used when printing the IR.
Attributes are known-constant values of operations.
Dialect & getDialect() const
Get the dialect this attribute is registered to.
void print(raw_ostream &os, bool elideType=false) const
Print the attribute.
MLIRContext * getContext() const
Return the context this attribute belongs to.
bool hasTrait()
Returns true if the type was registered with a particular trait.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgListType getArguments()
This class contains the configuration used for the bytecode writer.
This class defines a virtual interface for writing to a bytecode stream, providing hooks into the byt...
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
virtual std::string getResourceKey(const AsmDialectResourceHandle &handle) const
Return a key to use for the given resource.
This is a value defined by a result of an operation.
This class provides the API for ops that are known to be isolated from above.
Dialect * getDialect() const
Return the dialect this operation is registered to if the dialect is loaded in the context,...
StringRef getDialectNamespace() const
Return the name of the dialect this operation is registered to.
Operation is the basic unit of execution within MLIR.
DictionaryAttr getAttrDictionary()
Return all of the attributes on this operation as a DictionaryAttr.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
bool isRegistered()
Returns true if this operation has a registered operation description, otherwise false.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
OperationName getName()
The name of an operation is the key identifier for it.
DictionaryAttr getDiscardableAttrDictionary()
Return all of the discardable attributes on this operation as a DictionaryAttr.
result_range getResults()
int getPropertiesStorageSize() const
Returns the properties storage size.
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
iterator_range< OpIterator > getOps()
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
void print(raw_ostream &os) const
Print the current type.
Dialect & getDialect() const
Get the dialect this type is registered to.
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool hasTrait()
Returns true if the type was registered with a particular trait.
A utility class to encode the current walk stage for "generic" walkers.
bool isAfterAllRegions() const
Return true if parent operation is being visited after all regions.
bool isBeforeAllRegions() const
Return true if parent operation is being visited before all regions.
This class manages numbering IR entities in preparation of bytecode emission.
IRNumberingState(Operation *op, const BytecodeWriterConfig &config)
int64_t getDesiredBytecodeVersion() const
Get the set desired bytecode version to emit.
bool isIsolatedFromAbove(Operation *op)
Return if the given operation is isolated from above.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
This header declares functions that assist transformations in the MemRef dialect.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
void writeType(Type type) override
Write a reference to the given type.
NumberingDialectWriter(IRNumberingState &state)
void writeVarInt(uint64_t) override
Stubbed out methods that are not used for numbering.
void writeOptionalAttribute(Attribute attr) override
void writeOwnedString(StringRef) override
Write a string to the bytecode, which is owned by the caller and is guaranteed to not die before the ...
IRNumberingState & state
The parent numbering state that is populated by this writer.
void writeAttribute(Attribute attr) override
Write a reference to the given attribute.
void writeResourceHandle(const AsmDialectResourceHandle &resource) override
Write the given handle to a dialect resource.
void writeOwnedBool(bool value) override
Write a bool to the output stream.
void writeAPIntWithKnownWidth(const APInt &value) override
Write an APInt to the bytecode stream whose bitwidth will be known externally at read time.
void writeSignedVarInt(int64_t value) override
Write a signed variable width integer to the output stream.
int64_t getBytecodeVersion() const override
Return the bytecode version being emitted for.
void writeOwnedBlob(ArrayRef< char > blob) override
Write a blob to the bytecode, which is owned by the caller and is guaranteed to not die before the en...
void writeAPFloatWithKnownSemantics(const APFloat &value) override
Write an APFloat to the bytecode stream whose semantics will be known externally at read time.
DialectNumbering * dialect
The dialect of this value.
This class represents a numbering entry for an Dialect.
const BytecodeDialectInterface * interface
The bytecode dialect interface of the dialect if defined.
llvm::MapVector< StringRef, DialectResourceNumbering * > resourceMap
A mapping from resource key to the corresponding resource numbering entry.
SetVector< AsmDialectResourceHandle > resources
The referenced resources of this dialect.
const OpAsmDialectInterface * asmInterface
The asm dialect interface of the dialect if defined.
This class represents a numbering entry for a dialect resource.
This class represents the numbering entry of an operation.
This trait is used to determine if a storage user, like Type, is mutable or not.