MLIR  17.0.0git
IRNumbering.cpp
Go to the documentation of this file.
1 //===- IRNumbering.cpp - MLIR Bytecode IR numbering -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IRNumbering.h"
12 #include "mlir/IR/AsmState.h"
13 #include "mlir/IR/BuiltinTypes.h"
14 #include "mlir/IR/OpDefinition.h"
15 
16 using namespace mlir;
17 using namespace mlir::bytecode::detail;
18 
19 //===----------------------------------------------------------------------===//
20 // NumberingDialectWriter
21 //===----------------------------------------------------------------------===//
22 
25 
26  void writeAttribute(Attribute attr) override { state.number(attr); }
27  void writeType(Type type) override { state.number(type); }
28  void writeResourceHandle(const AsmDialectResourceHandle &resource) override {
29  state.number(resource.getDialect(), resource);
30  }
31 
32  /// Stubbed out methods that are not used for numbering.
33  void writeVarInt(uint64_t) override {}
34  void writeSignedVarInt(int64_t value) override {}
35  void writeAPIntWithKnownWidth(const APInt &value) override {}
36  void writeAPFloatWithKnownSemantics(const APFloat &value) override {}
37  void writeOwnedString(StringRef) override {
38  // TODO: It might be nice to prenumber strings and sort by the number of
39  // references. This could potentially be useful for optimizing things like
40  // file locations.
41  }
42  void writeOwnedBlob(ArrayRef<char> blob) override {}
43 
44  /// The parent numbering state that is populated by this writer.
46 };
47 
48 //===----------------------------------------------------------------------===//
49 // IR Numbering
50 //===----------------------------------------------------------------------===//
51 
52 /// Group and sort the elements of the given range by their parent dialect. This
53 /// grouping is applied to sub-sections of the ranged defined by how many bytes
54 /// it takes to encode a varint index to that sub-section.
55 template <typename T>
56 static void groupByDialectPerByte(T range) {
57  if (range.empty())
58  return;
59 
60  // A functor used to sort by a given dialect, with a desired dialect to be
61  // ordered first (to better enable sharing of dialects across byte groups).
62  auto sortByDialect = [](unsigned dialectToOrderFirst, const auto &lhs,
63  const auto &rhs) {
64  if (lhs->dialect->number == dialectToOrderFirst)
65  return rhs->dialect->number != dialectToOrderFirst;
66  if (rhs->dialect->number == dialectToOrderFirst)
67  return false;
68  return lhs->dialect->number < rhs->dialect->number;
69  };
70 
71  unsigned dialectToOrderFirst = 0;
72  size_t elementsInByteGroup = 0;
73  auto iterRange = range;
74  for (unsigned i = 1; i < 9; ++i) {
75  // Update the number of elements in the current byte grouping. Reminder
76  // that varint encodes 7-bits per byte, so that's how we compute the
77  // number of elements in each byte grouping.
78  elementsInByteGroup = (1ULL << (7ULL * i)) - elementsInByteGroup;
79 
80  // Slice out the sub-set of elements that are in the current byte grouping
81  // to be sorted.
82  auto byteSubRange = iterRange.take_front(elementsInByteGroup);
83  iterRange = iterRange.drop_front(byteSubRange.size());
84 
85  // Sort the sub range for this byte.
86  llvm::stable_sort(byteSubRange, [&](const auto &lhs, const auto &rhs) {
87  return sortByDialect(dialectToOrderFirst, lhs, rhs);
88  });
89 
90  // Update the dialect to order first to be the dialect at the end of the
91  // current grouping. This seeks to allow larger dialect groupings across
92  // byte boundaries.
93  dialectToOrderFirst = byteSubRange.back()->dialect->number;
94 
95  // If the data range is now empty, we are done.
96  if (iterRange.empty())
97  break;
98  }
99 
100  // Assign the entry numbers based on the sort order.
101  for (auto [idx, value] : llvm::enumerate(range))
102  value->number = idx;
103 }
104 
106  // Number the root operation.
107  number(*op);
108 
109  // Push all of the regions of the root operation onto the worklist.
111  for (Region &region : op->getRegions())
112  numberContext.emplace_back(&region, nextValueID);
113 
114  // Iteratively process each of the nested regions.
115  while (!numberContext.empty()) {
116  Region *region;
117  std::tie(region, nextValueID) = numberContext.pop_back_val();
118  number(*region);
119 
120  // Traverse into nested regions.
121  for (Operation &op : region->getOps()) {
122  // Isolated regions don't share value numbers with their parent, so we can
123  // start numbering these regions at zero.
124  unsigned opFirstValueID =
125  op.hasTrait<OpTrait::IsIsolatedFromAbove>() ? 0 : nextValueID;
126  for (Region &region : op.getRegions())
127  numberContext.emplace_back(&region, opFirstValueID);
128  }
129  }
130 
131  // Number each of the dialects. For now this is just in the order they were
132  // found, given that the number of dialects on average is small enough to fit
133  // within a singly byte (128). If we ever have real world use cases that have
134  // a huge number of dialects, this could be made more intelligent.
135  for (auto [idx, dialect] : llvm::enumerate(dialects))
136  dialect.second->number = idx;
137 
138  // Number each of the recorded components within each dialect.
139 
140  // First sort by ref count so that the most referenced elements are first. We
141  // try to bias more heavily used elements to the front. This allows for more
142  // frequently referenced things to be encoded using smaller varints.
143  auto sortByRefCountFn = [](const auto &lhs, const auto &rhs) {
144  return lhs->refCount > rhs->refCount;
145  };
146  llvm::stable_sort(orderedAttrs, sortByRefCountFn);
147  llvm::stable_sort(orderedOpNames, sortByRefCountFn);
148  llvm::stable_sort(orderedTypes, sortByRefCountFn);
149 
150  // After that, we apply a secondary ordering based on the parent dialect. This
151  // ordering is applied to sub-sections of the element list defined by how many
152  // bytes it takes to encode a varint index to that sub-section. This allows
153  // for more efficiently encoding components of the same dialect (e.g. we only
154  // have to encode the dialect reference once).
158 
159  // Finalize the numbering of the dialect resources.
160  finalizeDialectResourceNumberings(op);
161 }
162 
163 void IRNumberingState::number(Attribute attr) {
164  auto it = attrs.insert({attr, nullptr});
165  if (!it.second) {
166  ++it.first->second->refCount;
167  return;
168  }
169  auto *numbering = new (attrAllocator.Allocate()) AttributeNumbering(attr);
170  it.first->second = numbering;
171  orderedAttrs.push_back(numbering);
172 
173  // Check for OpaqueAttr, which is a dialect-specific attribute that didn't
174  // have a registered dialect when it got created. We don't want to encode this
175  // as the builtin OpaqueAttr, we want to encode it as if the dialect was
176  // actually loaded.
177  if (OpaqueAttr opaqueAttr = attr.dyn_cast<OpaqueAttr>()) {
178  numbering->dialect = &numberDialect(opaqueAttr.getDialectNamespace());
179  return;
180  }
181  numbering->dialect = &numberDialect(&attr.getDialect());
182 
183  // If this attribute will be emitted using the bytecode format, perform a
184  // dummy writing to number any nested components.
185  if (const auto *interface = numbering->dialect->interface) {
186  // TODO: We don't allow custom encodings for mutable attributes right now.
187  if (!attr.hasTrait<AttributeTrait::IsMutable>()) {
188  NumberingDialectWriter writer(*this);
189  if (succeeded(interface->writeAttribute(attr, writer)))
190  return;
191  }
192  }
193  // If this attribute will be emitted using the fallback, number the nested
194  // dialect resources. We don't number everything (e.g. no nested
195  // attributes/types), because we don't want to encode things we won't decode
196  // (the textual format can't really share much).
197  AsmState tempState(attr.getContext());
198  llvm::raw_null_ostream dummyOS;
199  attr.print(dummyOS, tempState);
200 
201  // Number the used dialect resources.
202  for (const auto &it : tempState.getDialectResources())
203  number(it.getFirst(), it.getSecond().getArrayRef());
204 }
205 
206 void IRNumberingState::number(Block &block) {
207  // Number the arguments of the block.
208  for (BlockArgument arg : block.getArguments()) {
209  valueIDs.try_emplace(arg, nextValueID++);
210  number(arg.getLoc());
211  number(arg.getType());
212  }
213 
214  // Number the operations in this block.
215  unsigned &numOps = blockOperationCounts[&block];
216  for (Operation &op : block) {
217  number(op);
218  ++numOps;
219  }
220 }
221 
222 auto IRNumberingState::numberDialect(Dialect *dialect) -> DialectNumbering & {
223  DialectNumbering *&numbering = registeredDialects[dialect];
224  if (!numbering) {
225  numbering = &numberDialect(dialect->getNamespace());
226  numbering->interface = dyn_cast<BytecodeDialectInterface>(dialect);
227  numbering->asmInterface = dyn_cast<OpAsmDialectInterface>(dialect);
228  }
229  return *numbering;
230 }
231 
232 auto IRNumberingState::numberDialect(StringRef dialect) -> DialectNumbering & {
233  DialectNumbering *&numbering = dialects[dialect];
234  if (!numbering) {
235  numbering = new (dialectAllocator.Allocate())
236  DialectNumbering(dialect, dialects.size() - 1);
237  }
238  return *numbering;
239 }
240 
241 void IRNumberingState::number(Region &region) {
242  if (region.empty())
243  return;
244  size_t firstValueID = nextValueID;
245 
246  // Number the blocks within this region.
247  size_t blockCount = 0;
248  for (auto it : llvm::enumerate(region)) {
249  blockIDs.try_emplace(&it.value(), it.index());
250  number(it.value());
251  ++blockCount;
252  }
253 
254  // Remember the number of blocks and values in this region.
255  regionBlockValueCounts.try_emplace(&region, blockCount,
256  nextValueID - firstValueID);
257 }
258 
259 void IRNumberingState::number(Operation &op) {
260  // Number the components of an operation that won't be numbered elsewhere
261  // (e.g. we don't number operands, regions, or successors here).
262  number(op.getName());
263  for (OpResult result : op.getResults()) {
264  valueIDs.try_emplace(result, nextValueID++);
265  number(result.getType());
266  }
267 
268  // Only number the operation's dictionary if it isn't empty.
269  DictionaryAttr dictAttr = op.getAttrDictionary();
270  if (!dictAttr.empty())
271  number(dictAttr);
272 
273  number(op.getLoc());
274 }
275 
276 void IRNumberingState::number(OperationName opName) {
277  OpNameNumbering *&numbering = opNames[opName];
278  if (numbering) {
279  ++numbering->refCount;
280  return;
281  }
282  DialectNumbering *dialectNumber = nullptr;
283  if (Dialect *dialect = opName.getDialect())
284  dialectNumber = &numberDialect(dialect);
285  else
286  dialectNumber = &numberDialect(opName.getDialectNamespace());
287 
288  numbering =
289  new (opNameAllocator.Allocate()) OpNameNumbering(dialectNumber, opName);
290  orderedOpNames.push_back(numbering);
291 }
292 
293 void IRNumberingState::number(Type type) {
294  auto it = types.insert({type, nullptr});
295  if (!it.second) {
296  ++it.first->second->refCount;
297  return;
298  }
299  auto *numbering = new (typeAllocator.Allocate()) TypeNumbering(type);
300  it.first->second = numbering;
301  orderedTypes.push_back(numbering);
302 
303  // Check for OpaqueType, which is a dialect-specific type that didn't have a
304  // registered dialect when it got created. We don't want to encode this as the
305  // builtin OpaqueType, we want to encode it as if the dialect was actually
306  // loaded.
307  if (OpaqueType opaqueType = type.dyn_cast<OpaqueType>()) {
308  numbering->dialect = &numberDialect(opaqueType.getDialectNamespace());
309  return;
310  }
311  numbering->dialect = &numberDialect(&type.getDialect());
312 
313  // If this type will be emitted using the bytecode format, perform a dummy
314  // writing to number any nested components.
315  if (const auto *interface = numbering->dialect->interface) {
316  // TODO: We don't allow custom encodings for mutable types right now.
317  if (!type.hasTrait<TypeTrait::IsMutable>()) {
318  NumberingDialectWriter writer(*this);
319  if (succeeded(interface->writeType(type, writer)))
320  return;
321  }
322  }
323  // If this type will be emitted using the fallback, number the nested dialect
324  // resources. We don't number everything (e.g. no nested attributes/types),
325  // because we don't want to encode things we won't decode (the textual format
326  // can't really share much).
327  AsmState tempState(type.getContext());
328  llvm::raw_null_ostream dummyOS;
329  type.print(dummyOS, tempState);
330 
331  // Number the used dialect resources.
332  for (const auto &it : tempState.getDialectResources())
333  number(it.getFirst(), it.getSecond().getArrayRef());
334 }
335 
336 void IRNumberingState::number(Dialect *dialect,
338  DialectNumbering &dialectNumber = numberDialect(dialect);
339  assert(
340  dialectNumber.asmInterface &&
341  "expected dialect owning a resource to implement OpAsmDialectInterface");
342 
343  for (const auto &resource : resources) {
344  // Check if this is a newly seen resource.
345  if (!dialectNumber.resources.insert(resource))
346  return;
347 
348  auto *numbering =
349  new (resourceAllocator.Allocate()) DialectResourceNumbering(
350  dialectNumber.asmInterface->getResourceKey(resource));
351  dialectNumber.resourceMap.insert({numbering->key, numbering});
352  dialectResources.try_emplace(resource, numbering);
353  }
354 }
355 
356 namespace {
357 /// A dummy resource builder used to number dialect resources.
358 struct NumberingResourceBuilder : public AsmResourceBuilder {
359  NumberingResourceBuilder(DialectNumbering *dialect, unsigned &nextResourceID)
360  : dialect(dialect), nextResourceID(nextResourceID) {}
361  ~NumberingResourceBuilder() override = default;
362 
363  void buildBlob(StringRef key, ArrayRef<char>, uint32_t) final {
364  numberEntry(key);
365  }
366  void buildBool(StringRef key, bool) final { numberEntry(key); }
367  void buildString(StringRef key, StringRef) final {
368  // TODO: We could pre-number the value string here as well.
369  numberEntry(key);
370  }
371 
372  /// Number the dialect entry for the given key.
373  void numberEntry(StringRef key) {
374  // TODO: We could pre-number resource key strings here as well.
375 
376  auto it = dialect->resourceMap.find(key);
377  if (it != dialect->resourceMap.end()) {
378  it->second->number = nextResourceID++;
379  it->second->isDeclaration = false;
380  }
381  }
382 
383  DialectNumbering *dialect;
384  unsigned &nextResourceID;
385 };
386 } // namespace
387 
388 void IRNumberingState::finalizeDialectResourceNumberings(Operation *rootOp) {
389  unsigned nextResourceID = 0;
390  for (DialectNumbering &dialect : getDialects()) {
391  if (!dialect.asmInterface)
392  continue;
393  NumberingResourceBuilder entryBuilder(&dialect, nextResourceID);
394  dialect.asmInterface->buildResources(rootOp, dialect.resources,
395  entryBuilder);
396 
397  // Number any resources that weren't added by the dialect. This can happen
398  // if there was no backing data to the resource, but we still want these
399  // resource references to roundtrip, so we number them and indicate that the
400  // data is missing.
401  for (const auto &it : dialect.resourceMap)
402  if (it.second->isDeclaration)
403  it.second->number = nextResourceID++;
404  }
405 }
static void groupByDialectPerByte(T range)
Group and sort the elements of the given range by their parent dialect.
Definition: IRNumbering.cpp:56
This class represents an opaque handle to a dialect resource entry.
Dialect * getDialect() const
Return the dialect that owns the resource.
This class is used to build resource entries for use by the printer.
Definition: AsmState.h:237
This class provides management for the lifetime of the state used when printing the IR.
Definition: AsmState.h:525
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Dialect & getDialect() const
Get the dialect this attribute is registered to.
Definition: Attributes.h:71
U dyn_cast() const
Definition: Attributes.h:166
void print(raw_ostream &os, bool elideType=false) const
Print the attribute.
MLIRContext * getContext() const
Return the context this attribute belongs to.
Definition: Attributes.cpp:37
bool hasTrait()
Returns true if the type was registered with a particular trait.
Definition: Attributes.h:92
This class represents an argument of a Block.
Definition: Value.h:304
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgListType getArguments()
Definition: Block.h:76
This class defines a virtual interface for writing to a bytecode stream, providing hooks into the byt...
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:41
virtual std::string getResourceKey(const AsmDialectResourceHandle &handle) const
Return a key to use for the given resource.
This is a value defined by a result of an operation.
Definition: Value.h:442
This class provides the API for ops that are known to be isolated from above.
Dialect * getDialect() const
Return the dialect this operation is registered to if the dialect is loaded in the context,...
StringRef getDialectNamespace() const
Return the name of the dialect this operation is registered to.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:75
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:592
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:207
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:540
DictionaryAttr getAttrDictionary()
Return all of the attributes on this operation as a DictionaryAttr.
Definition: Operation.h:421
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:103
result_range getResults()
Definition: Operation.h:394
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator_range< OpIterator > getOps()
Definition: Region.h:172
bool empty()
Definition: Region.h:60
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
void print(raw_ostream &os) const
Print the current type.
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition: Types.h:118
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool hasTrait()
Returns true if the type was registered with a particular trait.
Definition: Types.h:184
U dyn_cast() const
Definition: Types.h:311
This class manages numbering IR entities in preparation of bytecode emission.
Definition: IRNumbering.h:134
auto getDialects()
Return the numbered dialects.
Definition: IRNumbering.h:139
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:223
This header declares functions that assit transformations in the MemRef dialect.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
void writeType(Type type) override
Write a reference to the given type.
Definition: IRNumbering.cpp:27
void writeVarInt(uint64_t) override
Stubbed out methods that are not used for numbering.
Definition: IRNumbering.cpp:33
void writeOwnedString(StringRef) override
Write a string to the bytecode, which is owned by the caller and is guaranteed to not die before the ...
Definition: IRNumbering.cpp:37
IRNumberingState & state
The parent numbering state that is populated by this writer.
Definition: IRNumbering.cpp:45
void writeAttribute(Attribute attr) override
Write a reference to the given attribute.
Definition: IRNumbering.cpp:26
void writeResourceHandle(const AsmDialectResourceHandle &resource) override
Write the given handle to a dialect resource.
Definition: IRNumbering.cpp:28
void writeAPIntWithKnownWidth(const APInt &value) override
Write an APInt to the bytecode stream whose bitwidth will be known externally at read time.
Definition: IRNumbering.cpp:35
void writeSignedVarInt(int64_t value) override
Write a signed variable width integer to the output stream.
Definition: IRNumbering.cpp:34
void writeOwnedBlob(ArrayRef< char > blob) override
Write a blob to the bytecode, which is owned by the caller and is guaranteed to not die before the en...
Definition: IRNumbering.cpp:42
void writeAPFloatWithKnownSemantics(const APFloat &value) override
Write an APFloat to the bytecode stream whose semantics will be known externally at read time.
Definition: IRNumbering.cpp:36
This class represents a numbering entry for an Dialect.
Definition: IRNumbering.h:105
const BytecodeDialectInterface * interface
The bytecode dialect interface of the dialect if defined.
Definition: IRNumbering.h:116
llvm::MapVector< StringRef, DialectResourceNumbering * > resourceMap
A mapping from resource key to the corresponding resource numbering entry.
Definition: IRNumbering.h:125
SetVector< AsmDialectResourceHandle > resources
The referenced resources of this dialect.
Definition: IRNumbering.h:122
const OpAsmDialectInterface * asmInterface
The asm dialect interface of the dialect if defined.
Definition: IRNumbering.h:119
This class represents a numbering entry for a dialect resource.
Definition: IRNumbering.h:86
This class represents the numbering entry of an operation name.
Definition: IRNumbering.h:64
unsigned refCount
The number of references to this name.
Definition: IRNumbering.h:78
DialectNumbering * dialect
The dialect of this value.
Definition: IRNumbering.h:69
This trait is used to determine if a storage user, like Type, is mutable or not.