1 //===- IRNumbering.cpp - MLIR Bytecode IR numbering -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "IRNumbering.h"
12 #include "mlir/Bytecode/Encoding.h"
13 #include "mlir/IR/AsmState.h"
14 #include "mlir/IR/BuiltinTypes.h"
15 #include "mlir/IR/OpDefinition.h"
17 using namespace mlir;
18 using namespace mlir::bytecode::detail;
20 //===----------------------------------------------------------------------===//
21 // NumberingDialectWriter
22 //===----------------------------------------------------------------------===//
27  llvm::StringMap<std::unique_ptr<DialectVersion>> &dialectVersionMap)
30  void writeAttribute(Attribute attr) override { state.number(attr); }
31  void writeOptionalAttribute(Attribute attr) override {
32  if (attr)
33  state.number(attr);
34  }
35  void writeType(Type type) override { state.number(type); }
36  void writeResourceHandle(const AsmDialectResourceHandle &resource) override {
37  state.number(resource.getDialect(), resource);
38  }
40  /// Stubbed out methods that are not used for numbering.
41  void writeVarInt(uint64_t) override {}
42  void writeSignedVarInt(int64_t value) override {}
43  void writeAPIntWithKnownWidth(const APInt &value) override {}
44  void writeAPFloatWithKnownSemantics(const APFloat &value) override {}
45  void writeOwnedString(StringRef) override {
46  // TODO: It might be nice to prenumber strings and sort by the number of
47  // references. This could potentially be useful for optimizing things like
48  // file locations.
49  }
50  void writeOwnedBlob(ArrayRef<char> blob) override {}
51  void writeOwnedBool(bool value) override {}
53  int64_t getBytecodeVersion() const override {
55  }
58  getDialectVersion(StringRef dialectName) const override {
59  auto dialectEntry = dialectVersionMap.find(dialectName);
60  if (dialectEntry == dialectVersionMap.end())
61  return failure();
62  return dialectEntry->getValue().get();
63  }
65  /// The parent numbering state that is populated by this writer.
68  /// A map containing dialect version information for each dialect to emit.
69  llvm::StringMap<std::unique_ptr<DialectVersion>> &dialectVersionMap;
70 };
72 //===----------------------------------------------------------------------===//
73 // IR Numbering
74 //===----------------------------------------------------------------------===//
76 /// Group and sort the elements of the given range by their parent dialect. This
77 /// grouping is applied to sub-sections of the ranged defined by how many bytes
78 /// it takes to encode a varint index to that sub-section.
79 template <typename T>
80 static void groupByDialectPerByte(T range) {
81  if (range.empty())
82  return;
84  // A functor used to sort by a given dialect, with a desired dialect to be
85  // ordered first (to better enable sharing of dialects across byte groups).
86  auto sortByDialect = [](unsigned dialectToOrderFirst, const auto &lhs,
87  const auto &rhs) {
88  if (lhs->dialect->number == dialectToOrderFirst)
89  return rhs->dialect->number != dialectToOrderFirst;
90  if (rhs->dialect->number == dialectToOrderFirst)
91  return false;
92  return lhs->dialect->number < rhs->dialect->number;
93  };
95  unsigned dialectToOrderFirst = 0;
96  size_t elementsInByteGroup = 0;
97  auto iterRange = range;
98  for (unsigned i = 1; i < 9; ++i) {
99  // Update the number of elements in the current byte grouping. Reminder
100  // that varint encodes 7-bits per byte, so that's how we compute the
101  // number of elements in each byte grouping.
102  elementsInByteGroup = (1ULL << (7ULL * i)) - elementsInByteGroup;
104  // Slice out the sub-set of elements that are in the current byte grouping
105  // to be sorted.
106  auto byteSubRange = iterRange.take_front(elementsInByteGroup);
107  iterRange = iterRange.drop_front(byteSubRange.size());
109  // Sort the sub range for this byte.
110  llvm::stable_sort(byteSubRange, [&](const auto &lhs, const auto &rhs) {
111  return sortByDialect(dialectToOrderFirst, lhs, rhs);
112  });
114  // Update the dialect to order first to be the dialect at the end of the
115  // current grouping. This seeks to allow larger dialect groupings across
116  // byte boundaries.
117  dialectToOrderFirst = byteSubRange.back()->dialect->number;
119  // If the data range is now empty, we are done.
120  if (iterRange.empty())
121  break;
122  }
124  // Assign the entry numbers based on the sort order.
125  for (auto [idx, value] : llvm::enumerate(range))
126  value->number = idx;
127 }
130  const BytecodeWriterConfig &config)
131  : config(config) {
132  computeGlobalNumberingState(op);
134  // Number the root operation.
135  number(*op);
137  // A worklist of region contexts to number and the next value id before that
138  // region.
141  // Functor to push the regions of the given operation onto the numbering
142  // context.
143  auto addOpRegionsToNumber = [&](Operation *op) {
144  MutableArrayRef<Region> regions = op->getRegions();
145  if (regions.empty())
146  return;
148  // Isolated regions don't share value numbers with their parent, so we can
149  // start numbering these regions at zero.
150  unsigned opFirstValueID = isIsolatedFromAbove(op) ? 0 : nextValueID;
151  for (Region &region : regions)
152  numberContext.emplace_back(&region, opFirstValueID);
153  };
154  addOpRegionsToNumber(op);
156  // Iteratively process each of the nested regions.
157  while (!numberContext.empty()) {
158  Region *region;
159  std::tie(region, nextValueID) = numberContext.pop_back_val();
160  number(*region);
162  // Traverse into nested regions.
163  for (Operation &op : region->getOps())
164  addOpRegionsToNumber(&op);
165  }
167  // Number each of the dialects. For now this is just in the order they were
168  // found, given that the number of dialects on average is small enough to fit
169  // within a singly byte (128). If we ever have real world use cases that have
170  // a huge number of dialects, this could be made more intelligent.
171  for (auto [idx, dialect] : llvm::enumerate(dialects))
172  dialect.second->number = idx;
174  // Number each of the recorded components within each dialect.
176  // First sort by ref count so that the most referenced elements are first. We
177  // try to bias more heavily used elements to the front. This allows for more
178  // frequently referenced things to be encoded using smaller varints.
179  auto sortByRefCountFn = [](const auto &lhs, const auto &rhs) {
180  return lhs->refCount > rhs->refCount;
181  };
182  llvm::stable_sort(orderedAttrs, sortByRefCountFn);
183  llvm::stable_sort(orderedOpNames, sortByRefCountFn);
184  llvm::stable_sort(orderedTypes, sortByRefCountFn);
186  // After that, we apply a secondary ordering based on the parent dialect. This
187  // ordering is applied to sub-sections of the element list defined by how many
188  // bytes it takes to encode a varint index to that sub-section. This allows
189  // for more efficiently encoding components of the same dialect (e.g. we only
190  // have to encode the dialect reference once).
195  // Finalize the numbering of the dialect resources.
196  finalizeDialectResourceNumberings(op);
197 }
199 void IRNumberingState::computeGlobalNumberingState(Operation *rootOp) {
200  // A simple state struct tracking data used when walking operations.
201  struct StackState {
202  /// The operation currently being walked.
203  Operation *op;
205  /// The numbering of the operation.
206  OperationNumbering *numbering;
208  /// A flag indicating if the current state or one of its parents has
209  /// unresolved isolation status. This is tracked separately from the
210  /// isIsolatedFromAbove bit on `numbering` because we need to be able to
211  /// handle the given case:
212  /// top.op {
213  /// %value = ...
214  /// middle.op {
215  /// %value2 = ...
216  /// inner.op {
217  /// // Here we mark `inner.op` as not isolated. Note `middle.op`
218  /// // isn't known not isolated yet.
219  /// use.op %value2
220  ///
221  /// // Here inner.op is already known to be non-isolated, but
222  /// // `middle.op` is now also discovered to be non-isolated.
223  /// use.op %value
224  /// }
225  /// }
226  /// }
227  bool hasUnresolvedIsolation;
228  };
230  // Compute a global operation ID numbering according to the pre-order walk of
231  // the IR. This is used as reference to construct use-list orders.
232  unsigned operationID = 0;
234  // Walk each of the operations within the IR, tracking a stack of operations
235  // as we recurse into nested regions. This walk method hooks in at two stages
236  // during the walk:
237  //
238  // BeforeAllRegions:
239  // Here we generate a numbering for the operation and push it onto the
240  // stack if it has regions. We also compute the isolation status of parent
241  // regions at this stage. This is done by checking the parent regions of
242  // operands used by the operation, and marking each region between the
243  // the operand region and the current as not isolated. See
244  // StackState::hasUnresolvedIsolation above for an example.
245  //
246  // AfterAllRegions:
247  // Here we pop the operation from the stack, and if it hasn't been marked
248  // as non-isolated, we mark it as so. A non-isolated use would have been
249  // found while walking the regions, so it is safe to mark the operation at
250  // this point.
251  //
252  SmallVector<StackState> opStack;
253  rootOp->walk([&](Operation *op, const WalkStage &stage) {
254  // After visiting all nested regions, we pop the operation from the stack.
255  if (op->getNumRegions() && stage.isAfterAllRegions()) {
256  // If no non-isolated uses were found, we can safely mark this operation
257  // as isolated from above.
258  OperationNumbering *numbering = opStack.pop_back_val().numbering;
259  if (!numbering->isIsolatedFromAbove.has_value())
260  numbering->isIsolatedFromAbove = true;
261  return;
262  }
264  // When visiting before nested regions, we process "IsolatedFromAbove"
265  // checks and compute the number for this operation.
266  if (!stage.isBeforeAllRegions())
267  return;
268  // Update the isolation status of parent regions if any have yet to be
269  // resolved.
270  if (!opStack.empty() && opStack.back().hasUnresolvedIsolation) {
271  Region *parentRegion = op->getParentRegion();
272  for (Value operand : op->getOperands()) {
273  Region *operandRegion = operand.getParentRegion();
274  if (operandRegion == parentRegion)
275  continue;
276  // We've found a use of an operand outside of the current region,
277  // walk the operation stack searching for the parent operation,
278  // marking every region on the way as not isolated.
279  Operation *operandContainerOp = operandRegion->getParentOp();
280  auto it = std::find_if(
281  opStack.rbegin(), opStack.rend(), [=](const StackState &it) {
282  // We only need to mark up to the container region, or the first
283  // that has an unresolved status.
284  return !it.hasUnresolvedIsolation || it.op == operandContainerOp;
285  });
286  assert(it != opStack.rend() && "expected to find the container");
287  for (auto &state : llvm::make_range(opStack.rbegin(), it)) {
288  // If we stopped at a region that knows its isolation status, we can
289  // stop updating the isolation status for the parent regions.
290  state.hasUnresolvedIsolation = it->hasUnresolvedIsolation;
291  state.numbering->isIsolatedFromAbove = false;
292  }
293  }
294  }
296  // Compute the number for this op and push it onto the stack.
297  auto *numbering =
298  new (opAllocator.Allocate()) OperationNumbering(operationID++);
300  numbering->isIsolatedFromAbove = true;
301  operations.try_emplace(op, numbering);
302  if (op->getNumRegions()) {
303  opStack.emplace_back(StackState{
304  op, numbering, !numbering->isIsolatedFromAbove.has_value()});
305  }
306  });
307 }
309 void IRNumberingState::number(Attribute attr) {
310  auto it = attrs.insert({attr, nullptr});
311  if (!it.second) {
312  ++it.first->second->refCount;
313  return;
314  }
315  auto *numbering = new (attrAllocator.Allocate()) AttributeNumbering(attr);
316  it.first->second = numbering;
317  orderedAttrs.push_back(numbering);
319  // Check for OpaqueAttr, which is a dialect-specific attribute that didn't
320  // have a registered dialect when it got created. We don't want to encode this
321  // as the builtin OpaqueAttr, we want to encode it as if the dialect was
322  // actually loaded.
323  if (OpaqueAttr opaqueAttr = dyn_cast<OpaqueAttr>(attr)) {
324  numbering->dialect = &numberDialect(opaqueAttr.getDialectNamespace());
325  return;
326  }
327  numbering->dialect = &numberDialect(&attr.getDialect());
329  // If this attribute will be emitted using the bytecode format, perform a
330  // dummy writing to number any nested components.
331  // TODO: We don't allow custom encodings for mutable attributes right now.
332  if (!attr.hasTrait<AttributeTrait::IsMutable>()) {
333  // Try overriding emission with callbacks.
334  for (const auto &callback : config.getAttributeWriterCallbacks()) {
335  NumberingDialectWriter writer(*this, config.getDialectVersionMap());
336  // The client has the ability to override the group name through the
337  // callback.
338  std::optional<StringRef> groupNameOverride;
339  if (succeeded(callback->write(attr, groupNameOverride, writer))) {
340  if (groupNameOverride.has_value())
341  numbering->dialect = &numberDialect(*groupNameOverride);
342  return;
343  }
344  }
346  if (const auto *interface = numbering->dialect->interface) {
347  NumberingDialectWriter writer(*this, config.getDialectVersionMap());
348  if (succeeded(interface->writeAttribute(attr, writer)))
349  return;
350  }
351  }
352  // If this attribute will be emitted using the fallback, number the nested
353  // dialect resources. We don't number everything (e.g. no nested
354  // attributes/types), because we don't want to encode things we won't decode
355  // (the textual format can't really share much).
356  AsmState tempState(attr.getContext());
357  llvm::raw_null_ostream dummyOS;
358  attr.print(dummyOS, tempState);
360  // Number the used dialect resources.
361  for (const auto &it : tempState.getDialectResources())
362  number(it.getFirst(), it.getSecond().getArrayRef());
363 }
365 void IRNumberingState::number(Block &block) {
366  // Number the arguments of the block.
367  for (BlockArgument arg : block.getArguments()) {
368  valueIDs.try_emplace(arg, nextValueID++);
369  number(arg.getLoc());
370  number(arg.getType());
371  }
373  // Number the operations in this block.
374  unsigned &numOps = blockOperationCounts[&block];
375  for (Operation &op : block) {
376  number(op);
377  ++numOps;
378  }
379 }
381 auto IRNumberingState::numberDialect(Dialect *dialect) -> DialectNumbering & {
382  DialectNumbering *&numbering = registeredDialects[dialect];
383  if (!numbering) {
384  numbering = &numberDialect(dialect->getNamespace());
385  numbering->interface = dyn_cast<BytecodeDialectInterface>(dialect);
386  numbering->asmInterface = dyn_cast<OpAsmDialectInterface>(dialect);
387  }
388  return *numbering;
389 }
391 auto IRNumberingState::numberDialect(StringRef dialect) -> DialectNumbering & {
392  DialectNumbering *&numbering = dialects[dialect];
393  if (!numbering) {
394  numbering = new (dialectAllocator.Allocate())
395  DialectNumbering(dialect, dialects.size() - 1);
396  }
397  return *numbering;
398 }
400 void IRNumberingState::number(Region &region) {
401  if (region.empty())
402  return;
403  size_t firstValueID = nextValueID;
405  // Number the blocks within this region.
406  size_t blockCount = 0;
407  for (auto it : llvm::enumerate(region)) {
408  blockIDs.try_emplace(&it.value(), it.index());
409  number(it.value());
410  ++blockCount;
411  }
413  // Remember the number of blocks and values in this region.
414  regionBlockValueCounts.try_emplace(&region, blockCount,
415  nextValueID - firstValueID);
416 }
418 void IRNumberingState::number(Operation &op) {
419  // Number the components of an operation that won't be numbered elsewhere
420  // (e.g. we don't number operands, regions, or successors here).
421  number(op.getName());
422  for (OpResult result : op.getResults()) {
423  valueIDs.try_emplace(result, nextValueID++);
424  number(result.getType());
425  }
427  // Prior to a version with native property encoding, or when properties are
428  // not used, we need to number also the merged dictionary containing both the
429  // inherent and discardable attribute.
430  DictionaryAttr dictAttr;
431  if (config.getDesiredBytecodeVersion() >= bytecode::kNativePropertiesEncoding)
432  dictAttr = op.getRawDictionaryAttrs();
433  else
434  dictAttr = op.getAttrDictionary();
435  // Only number the operation's dictionary if it isn't empty.
436  if (!dictAttr.empty())
437  number(dictAttr);
439  // Visit the operation properties (if any) to make sure referenced attributes
440  // are numbered.
441  if (config.getDesiredBytecodeVersion() >=
444  if (op.isRegistered()) {
445  // Operation that have properties *must* implement this interface.
446  auto iface = cast<BytecodeOpInterface>(op);
447  NumberingDialectWriter writer(*this, config.getDialectVersionMap());
448  iface.writeProperties(writer);
449  } else {
450  // Unregistered op are storing properties as an optional attribute.
451  if (Attribute prop = *op.getPropertiesStorage().as<Attribute *>())
452  number(prop);
453  }
454  }
456  number(op.getLoc());
457 }
459 void IRNumberingState::number(OperationName opName) {
460  OpNameNumbering *&numbering = opNames[opName];
461  if (numbering) {
462  ++numbering->refCount;
463  return;
464  }
465  DialectNumbering *dialectNumber = nullptr;
466  if (Dialect *dialect = opName.getDialect())
467  dialectNumber = &numberDialect(dialect);
468  else
469  dialectNumber = &numberDialect(opName.getDialectNamespace());
471  numbering =
472  new (opNameAllocator.Allocate()) OpNameNumbering(dialectNumber, opName);
473  orderedOpNames.push_back(numbering);
474 }
476 void IRNumberingState::number(Type type) {
477  auto it = types.insert({type, nullptr});
478  if (!it.second) {
479  ++it.first->second->refCount;
480  return;
481  }
482  auto *numbering = new (typeAllocator.Allocate()) TypeNumbering(type);
483  it.first->second = numbering;
484  orderedTypes.push_back(numbering);
486  // Check for OpaqueType, which is a dialect-specific type that didn't have a
487  // registered dialect when it got created. We don't want to encode this as the
488  // builtin OpaqueType, we want to encode it as if the dialect was actually
489  // loaded.
490  if (OpaqueType opaqueType = dyn_cast<OpaqueType>(type)) {
491  numbering->dialect = &numberDialect(opaqueType.getDialectNamespace());
492  return;
493  }
494  numbering->dialect = &numberDialect(&type.getDialect());
496  // If this type will be emitted using the bytecode format, perform a dummy
497  // writing to number any nested components.
498  // TODO: We don't allow custom encodings for mutable types right now.
499  if (!type.hasTrait<TypeTrait::IsMutable>()) {
500  // Try overriding emission with callbacks.
501  for (const auto &callback : config.getTypeWriterCallbacks()) {
502  NumberingDialectWriter writer(*this, config.getDialectVersionMap());
503  // The client has the ability to override the group name through the
504  // callback.
505  std::optional<StringRef> groupNameOverride;
506  if (succeeded(callback->write(type, groupNameOverride, writer))) {
507  if (groupNameOverride.has_value())
508  numbering->dialect = &numberDialect(*groupNameOverride);
509  return;
510  }
511  }
513  // If this attribute will be emitted using the bytecode format, perform a
514  // dummy writing to number any nested components.
515  if (const auto *interface = numbering->dialect->interface) {
516  NumberingDialectWriter writer(*this, config.getDialectVersionMap());
517  if (succeeded(interface->writeType(type, writer)))
518  return;
519  }
520  }
521  // If this type will be emitted using the fallback, number the nested dialect
522  // resources. We don't number everything (e.g. no nested attributes/types),
523  // because we don't want to encode things we won't decode (the textual format
524  // can't really share much).
525  AsmState tempState(type.getContext());
526  llvm::raw_null_ostream dummyOS;
527  type.print(dummyOS, tempState);
529  // Number the used dialect resources.
530  for (const auto &it : tempState.getDialectResources())
531  number(it.getFirst(), it.getSecond().getArrayRef());
532 }
534 void IRNumberingState::number(Dialect *dialect,
536  DialectNumbering &dialectNumber = numberDialect(dialect);
537  assert(
538  dialectNumber.asmInterface &&
539  "expected dialect owning a resource to implement OpAsmDialectInterface");
541  for (const auto &resource : resources) {
542  // Check if this is a newly seen resource.
543  if (!dialectNumber.resources.insert(resource))
544  return;
546  auto *numbering =
547  new (resourceAllocator.Allocate()) DialectResourceNumbering(
548  dialectNumber.asmInterface->getResourceKey(resource));
549  dialectNumber.resourceMap.insert({numbering->key, numbering});
550  dialectResources.try_emplace(resource, numbering);
551  }
552 }
554 int64_t IRNumberingState::getDesiredBytecodeVersion() const {
555  return config.getDesiredBytecodeVersion();
556 }
558 namespace {
559 /// A dummy resource builder used to number dialect resources.
560 struct NumberingResourceBuilder : public AsmResourceBuilder {
561  NumberingResourceBuilder(DialectNumbering *dialect, unsigned &nextResourceID)
562  : dialect(dialect), nextResourceID(nextResourceID) {}
563  ~NumberingResourceBuilder() override = default;
565  void buildBlob(StringRef key, ArrayRef<char>, uint32_t) final {
566  numberEntry(key);
567  }
568  void buildBool(StringRef key, bool) final { numberEntry(key); }
569  void buildString(StringRef key, StringRef) final {
570  // TODO: We could pre-number the value string here as well.
571  numberEntry(key);
572  }
574  /// Number the dialect entry for the given key.
575  void numberEntry(StringRef key) {
576  // TODO: We could pre-number resource key strings here as well.
578  auto *it = dialect->resourceMap.find(key);
579  if (it != dialect->resourceMap.end()) {
580  it->second->number = nextResourceID++;
581  it->second->isDeclaration = false;
582  }
583  }
585  DialectNumbering *dialect;
586  unsigned &nextResourceID;
587 };
588 } // namespace
590 void IRNumberingState::finalizeDialectResourceNumberings(Operation *rootOp) {
591  unsigned nextResourceID = 0;
592  for (DialectNumbering &dialect : getDialects()) {
593  if (!dialect.asmInterface)
594  continue;
595  NumberingResourceBuilder entryBuilder(&dialect, nextResourceID);
596  dialect.asmInterface->buildResources(rootOp, dialect.resources,
597  entryBuilder);
599  // Number any resources that weren't added by the dialect. This can happen
600  // if there was no backing data to the resource, but we still want these
601  // resource references to roundtrip, so we number them and indicate that the
602  // data is missing.
603  for (const auto &it : dialect.resourceMap)
604  if (it.second->isDeclaration)
605  it.second->number = nextResourceID++;
606  }
607 }
