MLIR  16.0.0git
IRNumbering.cpp
Go to the documentation of this file.
1 //===- IRNumbering.cpp - MLIR Bytecode IR numbering -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IRNumbering.h"
12 #include "mlir/IR/AsmState.h"
13 #include "mlir/IR/BuiltinTypes.h"
14 #include "mlir/IR/OpDefinition.h"
15 
16 using namespace mlir;
17 using namespace mlir::bytecode::detail;
18 
19 //===----------------------------------------------------------------------===//
20 // NumberingDialectWriter
21 //===----------------------------------------------------------------------===//
22 
25 
26  void writeAttribute(Attribute attr) override { state.number(attr); }
27  void writeType(Type type) override { state.number(type); }
28  void writeResourceHandle(const AsmDialectResourceHandle &resource) override {
29  state.number(resource.getDialect(), resource);
30  }
31 
32  /// Stubbed out methods that are not used for numbering.
33  void writeVarInt(uint64_t) override {}
34  void writeSignedVarInt(int64_t value) override {}
35  void writeAPIntWithKnownWidth(const APInt &value) override {}
36  void writeAPFloatWithKnownSemantics(const APFloat &value) override {}
37  void writeOwnedString(StringRef) override {
38  // TODO: It might be nice to prenumber strings and sort by the number of
39  // references. This could potentially be useful for optimizing things like
40  // file locations.
41  }
42  void writeOwnedBlob(ArrayRef<char> blob) override {}
43 
44  /// The parent numbering state that is populated by this writer.
46 };
47 
48 //===----------------------------------------------------------------------===//
49 // IR Numbering
50 //===----------------------------------------------------------------------===//
51 
52 /// Group and sort the elements of the given range by their parent dialect. This
53 /// grouping is applied to sub-sections of the ranged defined by how many bytes
54 /// it takes to encode a varint index to that sub-section.
55 template <typename T>
56 static void groupByDialectPerByte(T range) {
57  if (range.empty())
58  return;
59 
60  // A functor used to sort by a given dialect, with a desired dialect to be
61  // ordered first (to better enable sharing of dialects across byte groups).
62  auto sortByDialect = [](unsigned dialectToOrderFirst, const auto &lhs,
63  const auto &rhs) {
64  if (lhs->dialect->number == dialectToOrderFirst)
65  return rhs->dialect->number != dialectToOrderFirst;
66  return lhs->dialect->number < rhs->dialect->number;
67  };
68 
69  unsigned dialectToOrderFirst = 0;
70  size_t elementsInByteGroup = 0;
71  auto iterRange = range;
72  for (unsigned i = 1; i < 9; ++i) {
73  // Update the number of elements in the current byte grouping. Reminder
74  // that varint encodes 7-bits per byte, so that's how we compute the
75  // number of elements in each byte grouping.
76  elementsInByteGroup = (1ULL << (7ULL * i)) - elementsInByteGroup;
77 
78  // Slice out the sub-set of elements that are in the current byte grouping
79  // to be sorted.
80  auto byteSubRange = iterRange.take_front(elementsInByteGroup);
81  iterRange = iterRange.drop_front(byteSubRange.size());
82 
83  // Sort the sub range for this byte.
84  llvm::stable_sort(byteSubRange, [&](const auto &lhs, const auto &rhs) {
85  return sortByDialect(dialectToOrderFirst, lhs, rhs);
86  });
87 
88  // Update the dialect to order first to be the dialect at the end of the
89  // current grouping. This seeks to allow larger dialect groupings across
90  // byte boundaries.
91  dialectToOrderFirst = byteSubRange.back()->dialect->number;
92 
93  // If the data range is now empty, we are done.
94  if (iterRange.empty())
95  break;
96  }
97 
98  // Assign the entry numbers based on the sort order.
99  for (auto &entry : llvm::enumerate(range))
100  entry.value()->number = entry.index();
101 }
102 
104  // Number the root operation.
105  number(*op);
106 
107  // Push all of the regions of the root operation onto the worklist.
109  for (Region &region : op->getRegions())
110  numberContext.emplace_back(&region, nextValueID);
111 
112  // Iteratively process each of the nested regions.
113  while (!numberContext.empty()) {
114  Region *region;
115  std::tie(region, nextValueID) = numberContext.pop_back_val();
116  number(*region);
117 
118  // Traverse into nested regions.
119  for (Operation &op : region->getOps()) {
120  // Isolated regions don't share value numbers with their parent, so we can
121  // start numbering these regions at zero.
122  unsigned opFirstValueID =
123  op.hasTrait<OpTrait::IsIsolatedFromAbove>() ? 0 : nextValueID;
124  for (Region &region : op.getRegions())
125  numberContext.emplace_back(&region, opFirstValueID);
126  }
127  }
128 
129  // Number each of the dialects. For now this is just in the order they were
130  // found, given that the number of dialects on average is small enough to fit
131  // within a singly byte (128). If we ever have real world use cases that have
132  // a huge number of dialects, this could be made more intelligent.
133  for (auto &it : llvm::enumerate(dialects))
134  it.value().second->number = it.index();
135 
136  // Number each of the recorded components within each dialect.
137 
138  // First sort by ref count so that the most referenced elements are first. We
139  // try to bias more heavily used elements to the front. This allows for more
140  // frequently referenced things to be encoded using smaller varints.
141  auto sortByRefCountFn = [](const auto &lhs, const auto &rhs) {
142  return lhs->refCount > rhs->refCount;
143  };
144  llvm::stable_sort(orderedAttrs, sortByRefCountFn);
145  llvm::stable_sort(orderedOpNames, sortByRefCountFn);
146  llvm::stable_sort(orderedTypes, sortByRefCountFn);
147 
148  // After that, we apply a secondary ordering based on the parent dialect. This
149  // ordering is applied to sub-sections of the element list defined by how many
150  // bytes it takes to encode a varint index to that sub-section. This allows
151  // for more efficiently encoding components of the same dialect (e.g. we only
152  // have to encode the dialect reference once).
153  groupByDialectPerByte(llvm::makeMutableArrayRef(orderedAttrs));
154  groupByDialectPerByte(llvm::makeMutableArrayRef(orderedOpNames));
155  groupByDialectPerByte(llvm::makeMutableArrayRef(orderedTypes));
156 
157  // Finalize the numbering of the dialect resources.
158  finalizeDialectResourceNumberings(op);
159 }
160 
161 void IRNumberingState::number(Attribute attr) {
162  auto it = attrs.insert({attr, nullptr});
163  if (!it.second) {
164  ++it.first->second->refCount;
165  return;
166  }
167  auto *numbering = new (attrAllocator.Allocate()) AttributeNumbering(attr);
168  it.first->second = numbering;
169  orderedAttrs.push_back(numbering);
170 
171  // Check for OpaqueAttr, which is a dialect-specific attribute that didn't
172  // have a registered dialect when it got created. We don't want to encode this
173  // as the builtin OpaqueAttr, we want to encode it as if the dialect was
174  // actually loaded.
175  if (OpaqueAttr opaqueAttr = attr.dyn_cast<OpaqueAttr>()) {
176  numbering->dialect = &numberDialect(opaqueAttr.getDialectNamespace());
177  return;
178  }
179  numbering->dialect = &numberDialect(&attr.getDialect());
180 
181  // If this attribute will be emitted using the bytecode format, perform a
182  // dummy writing to number any nested components.
183  if (const auto *interface = numbering->dialect->interface) {
184  // TODO: We don't allow custom encodings for mutable attributes right now.
185  if (!attr.hasTrait<AttributeTrait::IsMutable>()) {
186  NumberingDialectWriter writer(*this);
187  if (succeeded(interface->writeAttribute(attr, writer)))
188  return;
189  }
190  }
191  // If this attribute will be emitted using the fallback, number the nested
192  // dialect resources. We don't number everything (e.g. no nested
193  // attributes/types), because we don't want to encode things we won't decode
194  // (the textual format can't really share much).
195  AsmState tempState(attr.getContext());
196  llvm::raw_null_ostream dummyOS;
197  attr.print(dummyOS, tempState);
198 
199  // Number the used dialect resources.
200  for (const auto &it : tempState.getDialectResources())
201  number(it.getFirst(), it.getSecond().getArrayRef());
202 }
203 
204 void IRNumberingState::number(Block &block) {
205  // Number the arguments of the block.
206  for (BlockArgument arg : block.getArguments()) {
207  valueIDs.try_emplace(arg, nextValueID++);
208  number(arg.getLoc());
209  number(arg.getType());
210  }
211 
212  // Number the operations in this block.
213  unsigned &numOps = blockOperationCounts[&block];
214  for (Operation &op : block) {
215  number(op);
216  ++numOps;
217  }
218 }
219 
220 auto IRNumberingState::numberDialect(Dialect *dialect) -> DialectNumbering & {
221  DialectNumbering *&numbering = registeredDialects[dialect];
222  if (!numbering) {
223  numbering = &numberDialect(dialect->getNamespace());
224  numbering->interface = dyn_cast<BytecodeDialectInterface>(dialect);
225  numbering->asmInterface = dyn_cast<OpAsmDialectInterface>(dialect);
226  }
227  return *numbering;
228 }
229 
230 auto IRNumberingState::numberDialect(StringRef dialect) -> DialectNumbering & {
231  DialectNumbering *&numbering = dialects[dialect];
232  if (!numbering) {
233  numbering = new (dialectAllocator.Allocate())
234  DialectNumbering(dialect, dialects.size() - 1);
235  }
236  return *numbering;
237 }
238 
239 void IRNumberingState::number(Region &region) {
240  if (region.empty())
241  return;
242  size_t firstValueID = nextValueID;
243 
244  // Number the blocks within this region.
245  size_t blockCount = 0;
246  for (auto &it : llvm::enumerate(region)) {
247  blockIDs.try_emplace(&it.value(), it.index());
248  number(it.value());
249  ++blockCount;
250  }
251 
252  // Remember the number of blocks and values in this region.
253  regionBlockValueCounts.try_emplace(&region, blockCount,
254  nextValueID - firstValueID);
255 }
256 
257 void IRNumberingState::number(Operation &op) {
258  // Number the components of an operation that won't be numbered elsewhere
259  // (e.g. we don't number operands, regions, or successors here).
260  number(op.getName());
261  for (OpResult result : op.getResults()) {
262  valueIDs.try_emplace(result, nextValueID++);
263  number(result.getType());
264  }
265 
266  // Only number the operation's dictionary if it isn't empty.
267  DictionaryAttr dictAttr = op.getAttrDictionary();
268  if (!dictAttr.empty())
269  number(dictAttr);
270 
271  number(op.getLoc());
272 }
273 
274 void IRNumberingState::number(OperationName opName) {
275  OpNameNumbering *&numbering = opNames[opName];
276  if (numbering) {
277  ++numbering->refCount;
278  return;
279  }
280  DialectNumbering *dialectNumber = nullptr;
281  if (Dialect *dialect = opName.getDialect())
282  dialectNumber = &numberDialect(dialect);
283  else
284  dialectNumber = &numberDialect(opName.getDialectNamespace());
285 
286  numbering =
287  new (opNameAllocator.Allocate()) OpNameNumbering(dialectNumber, opName);
288  orderedOpNames.push_back(numbering);
289 }
290 
291 void IRNumberingState::number(Type type) {
292  auto it = types.insert({type, nullptr});
293  if (!it.second) {
294  ++it.first->second->refCount;
295  return;
296  }
297  auto *numbering = new (typeAllocator.Allocate()) TypeNumbering(type);
298  it.first->second = numbering;
299  orderedTypes.push_back(numbering);
300 
301  // Check for OpaqueType, which is a dialect-specific type that didn't have a
302  // registered dialect when it got created. We don't want to encode this as the
303  // builtin OpaqueType, we want to encode it as if the dialect was actually
304  // loaded.
305  if (OpaqueType opaqueType = type.dyn_cast<OpaqueType>()) {
306  numbering->dialect = &numberDialect(opaqueType.getDialectNamespace());
307  return;
308  }
309  numbering->dialect = &numberDialect(&type.getDialect());
310 
311  // If this type will be emitted using the bytecode format, perform a dummy
312  // writing to number any nested components.
313  if (const auto *interface = numbering->dialect->interface) {
314  // TODO: We don't allow custom encodings for mutable types right now.
315  if (!type.hasTrait<TypeTrait::IsMutable>()) {
316  NumberingDialectWriter writer(*this);
317  if (succeeded(interface->writeType(type, writer)))
318  return;
319  }
320  }
321  // If this type will be emitted using the fallback, number the nested dialect
322  // resources. We don't number everything (e.g. no nested attributes/types),
323  // because we don't want to encode things we won't decode (the textual format
324  // can't really share much).
325  AsmState tempState(type.getContext());
326  llvm::raw_null_ostream dummyOS;
327  type.print(dummyOS, tempState);
328 
329  // Number the used dialect resources.
330  for (const auto &it : tempState.getDialectResources())
331  number(it.getFirst(), it.getSecond().getArrayRef());
332 }
333 
334 void IRNumberingState::number(Dialect *dialect,
336  DialectNumbering &dialectNumber = numberDialect(dialect);
337  assert(
338  dialectNumber.asmInterface &&
339  "expected dialect owning a resource to implement OpAsmDialectInterface");
340 
341  for (const auto &resource : resources) {
342  // Check if this is a newly seen resource.
343  if (!dialectNumber.resources.insert(resource))
344  return;
345 
346  auto *numbering =
347  new (resourceAllocator.Allocate()) DialectResourceNumbering(
348  dialectNumber.asmInterface->getResourceKey(resource));
349  dialectNumber.resourceMap.insert({numbering->key, numbering});
350  dialectResources.try_emplace(resource, numbering);
351  }
352 }
353 
354 namespace {
355 /// A dummy resource builder used to number dialect resources.
356 struct NumberingResourceBuilder : public AsmResourceBuilder {
357  NumberingResourceBuilder(DialectNumbering *dialect, unsigned &nextResourceID)
358  : dialect(dialect), nextResourceID(nextResourceID) {}
359  ~NumberingResourceBuilder() override = default;
360 
361  void buildBlob(StringRef key, ArrayRef<char>, uint32_t) final {
362  numberEntry(key);
363  }
364  void buildBool(StringRef key, bool) final { numberEntry(key); }
365  void buildString(StringRef key, StringRef) final {
366  // TODO: We could pre-number the value string here as well.
367  numberEntry(key);
368  }
369 
370  /// Number the dialect entry for the given key.
371  void numberEntry(StringRef key) {
372  // TODO: We could pre-number resource key strings here as well.
373 
374  auto it = dialect->resourceMap.find(key);
375  if (it != dialect->resourceMap.end()) {
376  it->second->number = nextResourceID++;
377  it->second->isDeclaration = false;
378  }
379  }
380 
381  DialectNumbering *dialect;
382  unsigned &nextResourceID;
383 };
384 } // namespace
385 
386 void IRNumberingState::finalizeDialectResourceNumberings(Operation *rootOp) {
387  unsigned nextResourceID = 0;
388  for (DialectNumbering &dialect : getDialects()) {
389  if (!dialect.asmInterface)
390  continue;
391  NumberingResourceBuilder entryBuilder(&dialect, nextResourceID);
392  dialect.asmInterface->buildResources(rootOp, dialect.resources,
393  entryBuilder);
394 
395  // Number any resources that weren't added by the dialect. This can happen
396  // if there was no backing data to the resource, but we still want these
397  // resource references to roundtrip, so we number them and indicate that the
398  // data is missing.
399  for (const auto &it : dialect.resourceMap)
400  if (it.second->isDeclaration)
401  it.second->number = nextResourceID++;
402  }
403 }
static void groupByDialectPerByte(T range)
Group and sort the elements of the given range by their parent dialect.
Definition: IRNumbering.cpp:56
static constexpr const bool value
This class represents an opaque handle to a dialect resource entry.
Dialect * getDialect() const
Return the dialect that owns the resource.
This class is used to build resource entries for use by the printer.
Definition: AsmState.h:236
This class provides management for the lifetime of the state used when printing the IR.
Definition: AsmState.h:524
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Dialect & getDialect() const
Get the dialect this attribute is registered to.
Definition: Attributes.h:74
U dyn_cast() const
Definition: Attributes.h:127
void print(raw_ostream &os, bool elideType=false) const
Print the attribute.
MLIRContext * getContext() const
Return the context this attribute belongs to.
Definition: Attributes.cpp:20
bool hasTrait()
Returns true if the type was registered with a particular trait.
Definition: Attributes.h:95
This class represents an argument of a Block.
Definition: Value.h:296
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgListType getArguments()
Definition: Block.h:76
This class defines a virtual interface for writing to a bytecode stream, providing hooks into the byt...
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:41
virtual std::string getResourceKey(const AsmDialectResourceHandle &handle) const
Return a key to use for the given resource.
This is a value defined by a result of an operation.
Definition: Value.h:442
This class provides the API for ops that are known to be isolated from above.
Dialect * getDialect() const
Return the dialect this operation is registered to if the dialect is loaded in the context,...
StringRef getDialectNamespace() const
Return the name of the dialect this operation is registered to.
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:31
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:528
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:154
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:480
DictionaryAttr getAttrDictionary()
Return all of the attributes on this operation as a DictionaryAttr.
Definition: Operation.h:359
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:50
result_range getResults()
Definition: Operation.h:332
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator_range< OpIterator > getOps()
Definition: Region.h:172
bool empty()
Definition: Region.h:60
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
void print(raw_ostream &os) const
Print the current type.
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition: Types.h:121
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:19
bool hasTrait()
Returns true if the type was registered with a particular trait.
Definition: Types.h:184
U dyn_cast() const
Definition: Types.h:270
This class manages numbering IR entities in preparation of bytecode emission.
Definition: IRNumbering.h:134
auto getDialects()
Return the numbered dialects.
Definition: IRNumbering.h:139
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:230
Include the generated interface declarations.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
void writeType(Type type) override
Write a reference to the given type.
Definition: IRNumbering.cpp:27
void writeVarInt(uint64_t) override
Stubbed out methods that are not used for numbering.
Definition: IRNumbering.cpp:33
void writeOwnedString(StringRef) override
Write a string to the bytecode, which is owned by the caller and is guaranteed to not die before the ...
Definition: IRNumbering.cpp:37
IRNumberingState & state
The parent numbering state that is populated by this writer.
Definition: IRNumbering.cpp:45
void writeAttribute(Attribute attr) override
Write a reference to the given attribute.
Definition: IRNumbering.cpp:26
void writeResourceHandle(const AsmDialectResourceHandle &resource) override
Write the given handle to a dialect resource.
Definition: IRNumbering.cpp:28
void writeAPIntWithKnownWidth(const APInt &value) override
Write an APInt to the bytecode stream whose bitwidth will be known externally at read time.
Definition: IRNumbering.cpp:35
void writeSignedVarInt(int64_t value) override
Write a signed variable width integer to the output stream.
Definition: IRNumbering.cpp:34
void writeOwnedBlob(ArrayRef< char > blob) override
Write a blob to the bytecode, which is owned by the caller and is guaranteed to not die before the en...
Definition: IRNumbering.cpp:42
void writeAPFloatWithKnownSemantics(const APFloat &value) override
Write an APFloat to the bytecode stream whose semantics will be known externally at read time.
Definition: IRNumbering.cpp:36
This class represents a numbering entry for an Dialect.
Definition: IRNumbering.h:105
const BytecodeDialectInterface * interface
The bytecode dialect interface of the dialect if defined.
Definition: IRNumbering.h:116
llvm::MapVector< StringRef, DialectResourceNumbering * > resourceMap
A mapping from resource key to the corresponding resource numbering entry.
Definition: IRNumbering.h:125
SetVector< AsmDialectResourceHandle > resources
The referenced resources of this dialect.
Definition: IRNumbering.h:122
const OpAsmDialectInterface * asmInterface
The asm dialect interface of the dialect if defined.
Definition: IRNumbering.h:119
This class represents a numbering entry for a dialect resource.
Definition: IRNumbering.h:86
This class represents the numbering entry of an operation name.
Definition: IRNumbering.h:64
unsigned refCount
The number of references to this name.
Definition: IRNumbering.h:78
DialectNumbering * dialect
The dialect of this value.
Definition: IRNumbering.h:69
This trait is used to determine if a storage user, like Type, is mutable or not.