MLIR  19.0.0git
Merger.h
Go to the documentation of this file.
1 //===- Merger.h - Utilities for defining lattices ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This header file defines utilities for dealing with iteration lattices.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_
14 #define MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_
15 
19 #include "mlir/IR/Value.h"
20 #include "llvm/ADT/BitVector.h"
21 
22 #include <optional>
23 
24 namespace mlir {
25 namespace sparse_tensor {
26 
27 namespace detail {
28 /// A constant serving as the canonically invalid identifier,
29 /// regardless of the identifier type.
30 static constexpr unsigned kInvalidId = -1u;
31 } // namespace detail
32 
33 /// Tensor identifiers, chosen to be the `BlockArgument::getArgNumber`
34 /// of the value passed to `Merger::buildTensorExp`.
35 using TensorId = unsigned;
36 
37 /// Loop identifiers.
38 using LoopId = unsigned;
39 
40 /// A compressed representation of `std::pair<TensorId, LoopId>`.
41 /// The compression scheme is such that this also serves as an index
42 /// into the bitvector stored in `LatPoint` (since that bitvector is
43 /// just the implementation for a set of `TensorLoopId` values).
44 using TensorLoopId = unsigned;
45 
46 /// `TensorExp` identifiers. These are allocated by `Merger::addExp`,
47 /// and serve as unique identifiers for the corresponding `TensorExp` object.
48 using ExprId = unsigned;
49 
50 /// `LatPoint` identifiers. These are allocated by `Merger::addLat`,
51 /// and serve as unique identifiers for the corresponding `LatPoint` object.
52 using LatPointId = unsigned;
53 
54 /// `LatSet` identifiers. These are allocated by `Merger::addSet` (and
55 /// by other methods calling that one), and serve as unique identifiers
56 /// for the corresponding `SmallVector<LatPointId>` object.
57 using LatSetId = unsigned;
58 
59 /// A pair of level and its corresponding LevelType of a tensor.
60 using LvlLTPair = std::pair<Level, LevelType>;
61 
62 /// A pair of loop id and its coefficients. E.g., for affine expression in the
63 /// affine map `2 * d0`, loop id = 0, coefficient = 2.
64 using LoopCoeffPair = std::pair<LoopId, unsigned>;
65 
66 /// Tensor expression. Represents an MLIR expression in tensor index notation.
67 struct TensorExp final {
68  enum class Kind;
69 
70  /// Child subexpressions for non-leaf expressions.
71  struct Children final {
74  };
75 
76  /// The `x` parameter has different types depending on the value of the
77  /// `k` parameter. The correspondences are:
78  /// * `kTensor` -> `TensorId`
79  /// * `kInvariant` -> `kInvalidId`
80  /// * `kLoopVar` -> `LoopId`
81  /// * else -> `ExprId`
82  ///
83  /// The `y`, `v`, and `op` parameters either must or must not be
84  /// `kInvalidId`/`nullptr`, depending on the value of the `k` parameter;
85  /// however, they have uniform C++ types regardless of the value of `k`.
86  TensorExp(Kind k, unsigned x, ExprId y, Value v, Operation *op, Attribute a);
87 
88  /// Tensor expression kind.
90 
91  union {
92  /// `kTensor` expressions simply have a tensor identifier.
94 
95  /// `kLoopVar` expressions simply have a loop identifier.
97 
98  /// All other expressions hold the `ExprId`s of their children.
100  };
101 
102  /// Direct link to IR for an invariant or the destination value (to
103  /// infer destination type) of a cast operation During code generation,
104  /// this field may be used to cache "hoisted" loop invariant tensor loads.
106 
107  /// Code blocks used by semirings. For the case of kUnary, kBinary, kReduce,
108  /// and kSelect, this holds the original operation with all regions. For
109  /// kBinaryBranch, this holds the YieldOp for the left or right half
110  /// to be merged into a nested scf loop.
111  ///
112  /// Or the actual operation that we can not sparsify but having all dense
113  /// operands for kDenseOp.
115 
116  /// An optional attribute that is required to determine the semantics of the
117  /// operations. E.g., CmpPredicateAttr for CmpI/CmpF operations.
119 };
120 
121 /// Tensor expression kind.
122 ///
123 /// The `kLoopVar` leaf kind is for representing `linalg::IndexOp`.
124 /// That is, its argument is a `LoopId` identifying the loop-variable
125 /// in question, and its value will be the current iteration's value.
126 /// The `kSynZero` leaf kind is for representing a synthetic zero value,
127 /// which can be introduced when sparsifying operations like `arith::cmp`
128 /// to generate `arith::cmp %lhs, %syn_zero` when the rhs operand is absent.
129 enum class TensorExp::Kind {
130  // Leaf.
131  kTensor = 0,
132  kSynZero,
133  kInvariant,
134  kLoopVar,
135  // Unary operations.
136  kAbsF,
137  kAbsC,
138  kAbsI,
139  kCeilF,
140  kFloorF,
141  kSqrtF,
142  kSqrtC,
143  kExpm1F,
144  kExpm1C,
145  kLog1pF,
146  kLog1pC,
147  kSinF,
148  kSinC,
149  kTanhF,
150  kTanhC,
151  kNegF,
152  kNegC,
153  kNegI,
154  kTruncF,
155  kExtF,
156  kCastFS, // signed
157  kCastFU, // unsigned
158  kCastSF, // signed
159  kCastUF, // unsigned
160  kCastS, // signed
161  kCastU, // unsigned
162  kCastIdx,
163  kTruncI,
164  kCIm, // complex.im
165  kCRe, // complex.re
166  kBitCast,
167  kBinaryBranch, // semiring unary branch created from a binary op
168  kUnary, // semiring unary op
169  kSelect, // custom selection criteria
170  // Binary operations.
171  kMulF,
172  kMulC,
173  kMulI,
174  kDivF,
175  kDivC, // complex
176  kDivS, // signed
177  kDivU, // unsigned
178  kAddF,
179  kAddC,
180  kAddI,
181  kSubF,
182  kSubC,
183  kSubI,
184  kAndI,
185  kOrI,
186  kXorI,
187  kCmpI,
188  kCmpF,
189  kShrS, // signed
190  kShrU, // unsigned
191  kShlI,
192  kBinary, // semiring binary op
193  kReduce, // semiring reduction op
194  kDenseOp, // special category of operations requiring all dense operands
195 };
196 
197 /// Lattice point. Each lattice point consists of a formal conjunction
198 /// of `TensorLoopId`s, together with the identifier of the corresponding
199 /// tensor expression. The formal conjunction is represented as a set of
200 /// `TensorLoopId`, where that set is implemented as a `BitVector`.
201 struct LatPoint final {
202  /// Construct a lattice point with the empty set of `TensorLoopId`s.
203  LatPoint(unsigned size, ExprId e) : bits(size, false), exp(e) {}
204 
205  /// Construct a lattice point from the given set of `TensorLoopId`s.
206  LatPoint(const BitVector &bits, ExprId e) : bits(bits), exp(e) {}
207 
208  /// Conjunction of all `TensorLoopId`s involved in the tensor expression.
209  BitVector bits;
210 
211  /// Simplified conjunction of `TensorLoopId` as bitvector. This
212  /// represents a simplified condition under which this tensor expression
213  /// must execute. Pre-computed during codegen to avoid repeated eval.
214  BitVector simple;
215 
216  /// Identifier of the tensor expression.
218 };
219 
220 /// A class to handle all iteration lattice operations. This class abstracts
221 /// away from some implementation details of storing iteration lattices and
222 /// tensor expressions. This allows for fine-tuning performance characteristics
223 /// independently from the basic algorithm if bottlenecks are identified.
224 class Merger {
225 public:
226  /// Constructs a merger for the given number of tensors and loops. The user
227  /// supplies the number of tensors involved in the kernel, with the last
228  /// tensor in this set denoting the output tensor. The merger adds an
229  /// additional synthetic tensor at the end of this set to represent all
230  /// invariant expressions in the kernel.
231  ///
232  /// The maxLvlRank specifies the max level rank of all inputs/output tensors.
233  /// It is used to pre-allocate sufficient memory for internal storage.
234  Merger(unsigned numInputOutputTensors, unsigned numLoops,
235  unsigned maxLvlRank);
236 
237  //
238  // Constructing valid tensor and loop identifiers.
239  //
240 
241  /// Safely converts the argument to a tensor identifier.
242  constexpr TensorId makeTensorId(unsigned t) const {
243  assert(isValidTensorId(t));
244  return t;
245  }
246 
247  /// Safely converts the argument to a loop identifier.
248  constexpr LoopId makeLoopId(unsigned i) const {
249  assert(isValidLoopId(i));
250  return i;
251  }
252 
253  /// Safely converts the arguments to a pair of (tensor,loop) identifiers.
254  constexpr TensorLoopId makeTensorLoopId(unsigned t, unsigned i) const {
255  assert(isValidTensorId(t) && isValidLoopId(i));
256  return numTensors * i + t;
257  }
258 
259  //
260  // Allocating new expressions, points, and sets.
261  //
262 
263  /// Constructs a new tensor expression, and returns its identifier.
265  /// Constructs a new loop-variable expression, and returns its identifier.
267  /// Constructs a new invariant expression, and returns its identifier.
269  /// Constructs a new synthetic zero expression.
271  /// Constructs a new unary or binary expression, and returns its identifier.
273  Operation *op = nullptr, Attribute attr = nullptr);
274  /// Constructs a new sesquinary expression, and returns its identifier.
275  /// Currently no sesquinary `Kind` allows specifying the `op`, but we
276  /// allow it anyways because `mapSet` is designed to allow it.
277  ExprId addExp(TensorExp::Kind k, ExprId e, Value v, Operation *op = nullptr,
278  Attribute attr = nullptr);
279 
280  /// Constructs a new iteration lattice point, and returns its identifier.
282  LatPointId addLat(const BitVector &bits, ExprId e);
283 
284  /// Constructs a new (initially empty) set, and returns its identifier.
285  LatSetId addSet();
286 
287  /// Computes a single conjunction of two lattice points by taking the "union"
288  /// of `LoopId` (effectively constructing a larger "intersection" of those
289  /// loops) with a newly constructed tensor (sub)expression of given kind.
290  /// Returns the identifier of the new lattice point.
292  Operation *op = nullptr);
293 
294  /// Conjunctive merge of two lattice sets: `(s0 /\_op s1)`.
295  /// Returns the identifier of the new set.
296  LatSetId conjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op = nullptr);
297 
298  /// Disjunctive merge of two lattice sets: `(s0 /\_op s1, s0, s1)`.
299  /// Returns the identifier of the new set.
300  LatSetId disjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op = nullptr);
301 
302  /// Disjunctive merge of two lattice sets and also set one of the operand to
303  /// zero: `(s0 /\_op s1 (e0 op e1), s0 (0 op e0), s1 (e1 op 0))`.
304  /// Returns the identifier of the new set.
306 
307  /// Disjunctive merge of two lattice sets with custom handling of the
308  /// overlap, left, and right regions. Any region may be left missing
309  /// in the output. Returns the identifier of the new set.
311  bool includeLeft, TensorExp::Kind ltrans, Operation *opleft,
312  bool includeRight, TensorExp::Kind rtrans,
313  Operation *opright);
314 
315  /// Maps the unary operator over the lattice set of the operand, i.e. each
316  /// lattice point on an expression E is simply copied over, but with OP E
317  /// as new expression. Returns the identifier of the new set.
319  Operation *op = nullptr);
320 
321  /// Maps the binary operator to the same operation but with one of its operand
322  /// set to zero, i.e. each lattice point on an expression E is simply copied
323  /// over, but with `OP 0 E` (if lhsZero == true) or `OP E 0` (if lhsZero ==
324  /// false) as new expression. Returns the identifier of the new set.
325  LatSetId mapBinWithSynZeroSet(ExprId e, LatSetId s, bool lhsZero);
326 
327  /// Optimizes the iteration lattice points in the given set. This
328  /// method should be called right before code generation to avoid
329  /// generating redundant loops and conditions.
331 
332  /// Simplifies the conditions in a conjunction of a given lattice point
333  /// within the given set using just two basic rules:
334  /// (1) multiple dense conditions are reduced to single dense, and
335  /// (2) a *singleton* sparse/dense is reduced to sparse/random access.
336  BitVector simplifyCond(LatSetId s, LatPointId p);
337 
338  /// Returns true if p0 > p1.
339  bool latGT(LatPointId p0, LatPointId p1) const;
340 
341  /// Returns true if p0 and p1 only differ in dense.
342  bool onlyDenseDiff(LatPointId p0, LatPointId p1) const;
343 
344  /// Gets the tensor-identifier of the `TensorLoopId`.
345  constexpr TensorId tensor(TensorLoopId b) const { return b % numTensors; }
346  /// Gets the loop-identifier of the `TensorLoopId`.
347  constexpr LoopId loop(TensorLoopId b) const { return b / numTensors; }
348 
349  /// Gets the total number of tensors (including the output-tensor and
350  /// synthetic-tensor).
351  constexpr unsigned getNumTensors() const { return numTensors; }
352 
353  /// Gets the total number of loops (native loops + filter loops).
354  constexpr unsigned getNumLoops() const { return numLoops; }
355 
356  /// Returns true if `b` is the `i`th loop of the output tensor.
357  constexpr bool isOutTensor(TensorLoopId b, LoopId i) const {
358  return b == makeTensorLoopId(outTensor, i);
359  }
360 
361  /// Gets the output tensor's identifier.
362  constexpr TensorId getOutTensorID() const { return outTensor; }
363 
364  /// Gets the synthetic tensor's identifier (used for all invariant
365  /// tensor expressions).
366  constexpr TensorId getSynTensorID() const { return syntheticTensor; }
367 
368  /// Returns true if the expression is `(kTensor t)`.
369  bool expIsTensor(ExprId e, TensorId t) const {
370  const auto &expr = exp(e);
371  return expr.kind == TensorExp::Kind::kTensor && expr.tensor == t;
372  }
373 
374  /// Returns true if the expression contains the tensor as an operand.
375  bool expContainsTensor(ExprId e, TensorId t) const;
376 
377  /// Returns true if the expression contains a negation on output tensor.
378  /// I.e., `- outTensor` or `exp - outputTensor`
379  /// NOTE: this is an trivial tests in that it does not handle recursive
380  /// negation, i.e., it returns true when the expression is `-(-tensor)`.
381  bool hasNegateOnOut(ExprId e) const;
382 
383  /// Returns true if given tensor iterates *only* in the given tensor
384  /// expression. For the output tensor, this defines a "simply dynamic"
385  /// operation [Bik96]. For instance: a(i) *= 2.0 or a(i) += a(i) for
386  /// sparse vector a.
387  bool isSingleCondition(TensorId t, ExprId e) const;
388 
389  /// Returns true if any `TensorLoopId` in the bitvector corresponds
390  /// to sparse level-type.
391  bool hasAnySparse(const BitVector &bits) const;
392 
393  /// Returns true if bits contains a dependent index reduction condition on
394  /// sparse levels.
395  bool hasSparseIdxReduction(const BitVector &bits) const;
396 
397  /// Gets the level-type of the `t`th tensor on `i`th loop.
399  assert(isValidTensorId(t) && isValidLoopId(i));
400  return lvlTypes[t][i];
401  }
402 
403  /// Gets the level-type of the TensorLoopId.
405  return getLvlType(tensor(b), loop(b));
406  }
407 
408  /// Gets the loop identifier for the `lvl`th level of the `t`th tensor.
409  std::optional<LoopId> getLoopId(TensorId t, Level lvl) const {
410  assert(isValidLevel(t, lvl));
411  return lvlToLoop[t][lvl];
412  }
413 
414  /// Gets the level number of the the `t`th tensor on `i`th loop.
415  std::optional<Level> getLvl(TensorId t, LoopId i) const {
416  assert(isValidTensorId(t) && isValidLoopId(i));
417  return loopToLvl[t][i];
418  }
419  std::optional<Level> getLvl(TensorLoopId b) const {
420  return getLvl(tensor(b), loop(b));
421  }
422 
423  /// Sets the level number and level-type of the `t`th tensor on
424  /// `i`th loop.
426  assert(isValidLevel(t, lvl) && isValidLoopId(i) && isValidLT(lt));
427  lvlTypes[t][i] = lt;
428  loopToLvl[t][i] = lvl;
429  lvlToLoop[t][lvl] = i;
430  // TODO: favor a constant loop bound when there are multiple choices.
431  loopBounds[i] = std::make_pair(t, lvl);
432  }
433 
435  TensorLoopId, TensorId, std::optional<Level>, LevelType, bool)>;
436 
437  /// Iterates over a set of `TensorLoopId`s, invoking the callback
438  /// for each `TensorLoopId` and passing it the corresponding tensor
439  /// identifier, level, and level-type, following with a boolean value
440  /// indicating whether it is a dependent index reduction loop condition.
442  ForeachTensorLoopIdCallback callback) const {
443  // TODO: the default ought to be simple=true; but we'll need to make
444  // sure to update all the tests to make sure they do the right thing.
445  foreachTensorLoopId(p, /*simple=*/false, callback);
446  }
447  void foreachTensorLoopId(LatPointId p, bool simple,
448  ForeachTensorLoopIdCallback callback) const {
449  const auto &point = lat(p);
450  const auto &bits = simple ? point.simple : point.bits;
451  for (const TensorLoopId b : bits.set_bits()) {
452  const TensorId t = tensor(b);
453  const auto optLvl = getLvl(b);
454  const auto lvlTp = getLvlType(b);
455  if (isLvlWithNonTrivialIdxExp(b)) {
456  // This must be an undefined level.
457  assert(!optLvl.has_value());
458  // Slice the tid along the dependent level to iterate current loop.
459  callback(b, t, getLoopDependentLevel(b), lvlTp,
460  /*isIdxReduc=*/true);
461  } else {
462  callback(b, t, optLvl, lvlTp, /*isIdxReduc=*/false);
463  }
464  }
465  }
466 
467  /// Sets whether the output tensor is sparse or not.
468  void setHasSparseOut(bool s) { hasSparseOut = s; }
469 
470  /// Establishes the two-way map that i <-> <t, lvl, lt>.
472  LevelType lt, unsigned coefficient) {
473  assert(isValidLoopId(i) && isValidLevel(t, lvl));
474  assert(!loopToUnresolvedLvls[i][t].has_value()); // must be the first def
475  loopToUnresolvedLvls[i][t] = std::make_pair(lvl, lt);
476  levelToDependentLoop[t][lvl].emplace_back(i, coefficient);
477  }
478 
479  /// Whether the loop has dependent slice.
481  assert(isValidTensorId(t) && isValidLoopId(i));
482  return loopToUnresolvedLvls[i][t].has_value();
483  }
484 
485  /// Returns the list of loop indices which appear in the non-trivial index
486  /// expression on t_l, e.g., A[i+j] => {i, j}
487  std::vector<LoopCoeffPair> &getDependentLoops(TensorId t, Level lvl) {
488  assert(isValidLevel(t, lvl));
489  return levelToDependentLoop[t][lvl];
490  }
491 
492  /// Returns the defining [tid, lvl] for the loop.
493  std::pair<TensorId, Level> getLoopDefiningLvl(LoopId i) const {
494  assert(isValidLoopId(i));
495  return loopBounds[i];
496  }
497 
498  /// Checks whether the TensorLoopId represents a tensor level contains
499  /// non-trivial index expression.
501  const TensorId t = tensor(b);
502  const LoopId i = loop(b);
503  assert(isValidTensorId(t) && isValidLoopId(i));
504  return loopToUnresolvedLvls[i][t].has_value();
505  }
506 
507  /// Checks whether the TensorLoopId represents a sparse tensor level contains
508  /// non-trivial index expression.
510  if (isLvlWithNonTrivialIdxExp(b)) {
511  auto lt = getLoopDependentLevelType(b);
512  return lt.hasSparseSemantic();
513  }
514  return false;
515  }
516 
518  assert(isLvlWithNonTrivialIdxExp(b));
519  return loopToUnresolvedLvls[loop(b)][tensor(b)]->first;
520  }
521 
523  assert(isLvlWithNonTrivialIdxExp(b));
524  return loopToUnresolvedLvls[loop(b)][tensor(b)]->second;
525  }
526 
527  /// Convenience getters to immediately access the stored nodes.
528  /// These methods return `const&` because the underlying objects must
529  /// not be mutated by client code. The only exception is for mutating
530  /// the value associated with an expression, for which there are
531  /// dedicated methods below.
532  ///
533  /// NOTE: It is inadvisable to keep the reference alive for a long
534  /// time (e.g., as in `TensorExpr &te = merger.exp(e)`), since insertions
535  /// into the merger can cause data movement which will invalidate the
536  /// underlying memory address. This isn't just a problem with the `&`
537  /// references, but also applies to the `ArrayRef`. In particular,
538  /// using `for (LatPointId p : merger.set(s))` will run into the same
539  /// dangling-reference problems if the loop body inserts new sets.
540  const TensorExp &exp(ExprId e) const {
541  assert(isValidExprId(e));
542  return tensorExps[e];
543  }
544  const LatPoint &lat(LatPointId p) const {
545  assert(isValidLatPointId(p));
546  return latPoints[p];
547  }
549  assert(isValidLatSetId(s));
550  return latSets[s];
551  }
552 
553  /// Checks whether the given expression has an associated value.
554  bool hasExprValue(ExprId e) const { return static_cast<bool>(exp(e).val); }
555 
556  /// Sets the expression to have the associated value. Asserts that the new
557  /// value is defined, and that the expression does not already have a value.
558  void setExprValue(ExprId e, Value v) {
559  assert(!exp(e).val && "Expression already has an associated value");
560  assert(v && "Trying to assign an undefined value");
561  tensorExps[e].val = v;
562  }
563 
564  /// Clears the value associated with the expression. Asserts that the
565  /// expression does indeed have an associated value before clearing it.
567  assert(exp(e).val && "Expression does not have an associated value");
568  tensorExps[e].val = Value();
569  }
570 
571 #ifndef NDEBUG
572  /// Print methods (for debugging).
573  void dumpExp(ExprId e) const;
574  void dumpLat(LatPointId p) const;
575  void dumpSet(LatSetId s) const;
576  void dumpBits(const BitVector &bits) const;
577 #endif
578 
579  /// Builds the iteration lattices in a bottom-up traversal given the
580  /// remaining tensor (sub)expression and the next loop in the iteration
581  /// graph. Returns the identifier of the root set.
583 
584  /// Builds a tensor expression from the given Linalg operation.
585  /// On success, returns the identifier of the root expression.
586  std::optional<ExprId> buildTensorExpFromLinalg(linalg::GenericOp op);
587 
588  /// Rebuilds SSA format from a tensor expression.
589  Value buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0,
590  Value v1) const;
591 
592 private:
593  /// Private helpers.
594  constexpr bool isValidTensorId(TensorId t) const { return t < numTensors; }
595  constexpr bool isValidLoopId(LoopId i) const {
596  return i != detail::kInvalidId && i < numLoops;
597  }
598  bool isValidLevel(TensorId t, Level lvl) const {
599  assert(levelToDependentLoop[t].size() == lvlToLoop[t].size());
600  return isValidTensorId(t) && lvl < lvlToLoop[t].size();
601  }
602  bool isValidExprId(ExprId e) const {
603  return e != detail::kInvalidId && e < tensorExps.size();
604  }
605  bool isValidLatPointId(LatPointId p) const {
606  return p != detail::kInvalidId && p < latPoints.size();
607  }
608  bool isValidLatSetId(LatSetId s) const {
609  return s != detail::kInvalidId && s < latSets.size();
610  }
611  bool maybeZero(ExprId e) const;
612  bool isInvariant(ExprId e) const {
613  return exp(e).kind == TensorExp::Kind::kInvariant;
614  }
615  Type inferType(ExprId e, Value src) const;
616 
617  /// Traverses the SSA tree (possibly a DAG) to build a tensor expression.
618  /// The boolean value returned indicates whether the result of the current
619  /// operation being built depends on any value that is loaded from a sparse
620  /// tensor.
621  std::pair<std::optional<ExprId>, bool> buildTensorExp(linalg::GenericOp op,
622  Value v);
623 
624  /// Merger data structures.
625  const TensorId outTensor;
626  const TensorId syntheticTensor;
627  const unsigned numTensors;
628  const unsigned numLoops;
629  bool hasSparseOut;
630 
631  // Below we use `std::vector` for things which have a priori fixed
632  // sizes, whereas we use `llvm::SmallVector` for things with variable
633  // size. Do beware that these two classes differ in the semantics of
634  // `operator[]`: `SmallVector` performs OOB checks, whereas `std::vector`
635  // does not.
636 
637  /// Map that converts pair<TensorId, LoopId> to the corresponding lvl-type.
638  std::vector<std::vector<LevelType>> lvlTypes;
639 
640  /// Map that converts pair<TensorId, LoopId> to the corresponding lvl.
641  std::vector<std::vector<std::optional<Level>>> loopToLvl;
642 
643  /// Map that converts pair<TensorId, Level> to the corresponding LoopId.
644  std::vector<std::vector<std::optional<LoopId>>> lvlToLoop;
645 
646  /// Map from a loop to its dependencies if any.
647  /// The dependencies of a loop is a set of (tensor, level) pairs.
648  /// It is currently only set for non-trivial index expressions.
649  /// E.g., A[i+j] => i and j will have dependencies {A0, lt(A0)} to indicate
650  /// that i and j are used in the non-trivial index expression on A0.
651  std::vector<std::vector<std::optional<LvlLTPair>>> loopToUnresolvedLvls;
652 
653  /// The inverse map of ldxToDependencies from tensor level -> dependent loop
654  /// E.g., A[2i+j], we have A0 => {(2, i), (1, j)}, to indicate that A0 uses
655  /// both {i, j} to compute its indices and the coefficients on the loop id are
656  /// 2 and 1 respectively.
657  std::vector<std::vector<std::vector<LoopCoeffPair>>> levelToDependentLoop;
658 
659  /// Map from a loop to the [tid, lvl] pair that defines the loop boundary.
660  std::vector<std::pair<TensorId, Level>> loopBounds;
661 
662  llvm::SmallVector<TensorExp> tensorExps;
663  llvm::SmallVector<LatPoint> latPoints;
665 };
666 
667 } // namespace sparse_tensor
668 } // namespace mlir
669 
670 #endif // MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
A class to handle all iteration lattice operations.
Definition: Merger.h:224
void setHasSparseOut(bool s)
Sets whether the output tensor is sparse or not.
Definition: Merger.h:468
constexpr unsigned getNumLoops() const
Gets the total number of loops (native loops + filter loops).
Definition: Merger.h:354
LatPointId conjLat(ExprId e, LatPointId p0, LatPointId p1, Operation *op=nullptr)
Computes a single conjunction of two lattice points by taking the "union" of LoopId (effectively cons...
Definition: Merger.cpp:314
LatSetId disjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op=nullptr)
Disjunctive merge of two lattice sets: (s0 /\_op s1, s0, s1).
Definition: Merger.cpp:337
Level getLoopDependentLevel(TensorLoopId b) const
Definition: Merger.h:517
std::optional< Level > getLvl(TensorId t, LoopId i) const
Gets the level number of the the tth tensor on ith loop.
Definition: Merger.h:415
constexpr bool isOutTensor(TensorLoopId b, LoopId i) const
Returns true if b is the ith loop of the output tensor.
Definition: Merger.h:357
bool isSingleCondition(TensorId t, ExprId e) const
Returns true if given tensor iterates only in the given tensor expression.
Definition: Merger.cpp:576
bool hasSparseIdxReduction(const BitVector &bits) const
Returns true if bits contains a dependent index reduction condition on sparse levels.
Definition: Merger.cpp:678
bool expContainsTensor(ExprId e, TensorId t) const
Returns true if the expression contains the tensor as an operand.
Definition: Merger.cpp:521
LatSetId mapBinWithSynZeroSet(ExprId e, LatSetId s, bool lhsZero)
Maps the binary operator to the same operation but with one of its operand set to zero,...
Definition: Merger.cpp:413
bool isSparseLvlWithNonTrivialIdxExp(TensorLoopId b) const
Checks whether the TensorLoopId represents a sparse tensor level contains non-trivial index expressio...
Definition: Merger.h:509
void dumpBits(const BitVector &bits) const
Definition: Merger.cpp:915
bool hasExprValue(ExprId e) const
Checks whether the given expression has an associated value.
Definition: Merger.h:554
void foreachTensorLoopId(LatPointId p, bool simple, ForeachTensorLoopIdCallback callback) const
Definition: Merger.h:447
LatSetId addSet()
Constructs a new (initially empty) set, and returns its identifier.
Definition: Merger.cpp:308
std::optional< LoopId > getLoopId(TensorId t, Level lvl) const
Gets the loop identifier for the lvlth level of the tth tensor.
Definition: Merger.h:409
std::pair< TensorId, Level > getLoopDefiningLvl(LoopId i) const
Returns the defining [tid, lvl] for the loop.
Definition: Merger.h:493
BitVector simplifyCond(LatSetId s, LatPointId p)
Simplifies the conditions in a conjunction of a given lattice point within the given set using just t...
Definition: Merger.cpp:460
bool hasNegateOnOut(ExprId e) const
Returns true if the expression contains a negation on output tensor.
Definition: Merger.cpp:543
constexpr unsigned getNumTensors() const
Gets the total number of tensors (including the output-tensor and synthetic-tensor).
Definition: Merger.h:351
bool isLvlWithNonTrivialIdxExp(TensorLoopId b) const
Checks whether the TensorLoopId represents a tensor level contains non-trivial index expression.
Definition: Merger.h:500
LatSetId disjSetWithZero(ExprId e, LatSetId s0, LatSetId s1)
Disjunctive merge of two lattice sets and also set one of the operand to zero: (s0 /\_op s1 (e0 op e1...
Definition: Merger.cpp:356
void dumpSet(LatSetId s) const
Definition: Merger.cpp:905
void dumpLat(LatPointId p) const
Definition: Merger.cpp:894
LatSetId combiSet(ExprId e, LatSetId s0, LatSetId s1, Operation *orig, bool includeLeft, TensorExp::Kind ltrans, Operation *opleft, bool includeRight, TensorExp::Kind rtrans, Operation *opright)
Disjunctive merge of two lattice sets with custom handling of the overlap, left, and right regions.
Definition: Merger.cpp:380
ExprId addTensorExp(TensorId t)
Constructs a new tensor expression, and returns its identifier.
Definition: Merger.cpp:246
LatSetId buildLattices(ExprId e, LoopId i)
Builds the iteration lattices in a bottom-up traversal given the remaining tensor (sub)expression and...
Definition: Merger.cpp:935
LatSetId conjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op=nullptr)
Conjunctive merge of two lattice sets: (s0 /\_op s1).
Definition: Merger.cpp:328
ExprId addExp(TensorExp::Kind k, ExprId e0, ExprId e1=detail::kInvalidId, Operation *op=nullptr, Attribute attr=nullptr)
Constructs a new unary or binary expression, and returns its identifier.
Definition: Merger.cpp:276
ExprId addSynZeroExp()
Constructs a new synthetic zero expression.
Definition: Merger.cpp:269
constexpr LoopId makeLoopId(unsigned i) const
Safely converts the argument to a loop identifier.
Definition: Merger.h:248
std::optional< ExprId > buildTensorExpFromLinalg(linalg::GenericOp op)
Builds a tensor expression from the given Linalg operation.
Definition: Merger.cpp:1186
void setLevelAndType(TensorId t, LoopId i, Level lvl, LevelType lt)
Sets the level number and level-type of the tth tensor on ith loop.
Definition: Merger.h:425
void foreachTensorLoopId(LatPointId p, ForeachTensorLoopIdCallback callback) const
Iterates over a set of TensorLoopIds, invoking the callback for each TensorLoopId and passing it the ...
Definition: Merger.h:441
std::optional< Level > getLvl(TensorLoopId b) const
Definition: Merger.h:419
ArrayRef< LatPointId > set(LatSetId s) const
Definition: Merger.h:548
LatSetId optimizeSet(LatSetId s)
Optimizes the iteration lattice points in the given set.
Definition: Merger.cpp:430
constexpr TensorId tensor(TensorLoopId b) const
Gets the tensor-identifier of the TensorLoopId.
Definition: Merger.h:345
void setLoopDependentTensorLevel(LoopId i, TensorId t, Level lvl, LevelType lt, unsigned coefficient)
Establishes the two-way map that i <-> <t, lvl, lt>.
Definition: Merger.h:471
void dumpExp(ExprId e) const
Print methods (for debugging).
Definition: Merger.cpp:795
LevelType getLvlType(TensorLoopId b) const
Gets the level-type of the TensorLoopId.
Definition: Merger.h:404
Merger(unsigned numInputOutputTensors, unsigned numLoops, unsigned maxLvlRank)
Constructs a merger for the given number of tensors and loops.
Definition: Merger.cpp:223
bool hasAnySparse(const BitVector &bits) const
Returns true if any TensorLoopId in the bitvector corresponds to sparse level-type.
Definition: Merger.cpp:669
void clearExprValue(ExprId e)
Clears the value associated with the expression.
Definition: Merger.h:566
std::vector< LoopCoeffPair > & getDependentLoops(TensorId t, Level lvl)
Returns the list of loop indices which appear in the non-trivial index expression on t_l,...
Definition: Merger.h:487
LatPointId addLat(TensorId t, LoopId i, ExprId e)
Constructs a new iteration lattice point, and returns its identifier.
Definition: Merger.cpp:292
ExprId addLoopVarExp(LoopId i)
Constructs a new loop-variable expression, and returns its identifier.
Definition: Merger.cpp:254
constexpr TensorId getSynTensorID() const
Gets the synthetic tensor's identifier (used for all invariant tensor expressions).
Definition: Merger.h:366
bool latGT(LatPointId p0, LatPointId p1) const
Returns true if p0 > p1.
Definition: Merger.cpp:502
const TensorExp & exp(ExprId e) const
Convenience getters to immediately access the stored nodes.
Definition: Merger.h:540
constexpr LoopId loop(TensorLoopId b) const
Gets the loop-identifier of the TensorLoopId.
Definition: Merger.h:347
const LatPoint & lat(LatPointId p) const
Definition: Merger.h:544
constexpr TensorId getOutTensorID() const
Gets the output tensor's identifier.
Definition: Merger.h:362
bool onlyDenseDiff(LatPointId p0, LatPointId p1) const
Returns true if p0 and p1 only differ in dense.
Definition: Merger.cpp:515
ExprId addInvariantExp(Value v)
Constructs a new invariant expression, and returns its identifier.
Definition: Merger.cpp:262
constexpr TensorId makeTensorId(unsigned t) const
Safely converts the argument to a tensor identifier.
Definition: Merger.h:242
LevelType getLoopDependentLevelType(TensorLoopId b) const
Definition: Merger.h:522
LevelType getLvlType(TensorId t, LoopId i) const
Gets the level-type of the tth tensor on ith loop.
Definition: Merger.h:398
Value buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0, Value v1) const
Rebuilds SSA format from a tensor expression.
Definition: Merger.cpp:1537
constexpr TensorLoopId makeTensorLoopId(unsigned t, unsigned i) const
Safely converts the arguments to a pair of (tensor,loop) identifiers.
Definition: Merger.h:254
bool expIsTensor(ExprId e, TensorId t) const
Returns true if the expression is (kTensor t).
Definition: Merger.h:369
void setExprValue(ExprId e, Value v)
Sets the expression to have the associated value.
Definition: Merger.h:558
LatSetId mapSet(TensorExp::Kind kind, LatSetId s, Value v=Value(), Operation *op=nullptr)
Maps the unary operator over the lattice set of the operand, i.e.
Definition: Merger.cpp:400
bool hasDependentLvl(LoopId i, TensorId t)
Whether the loop has dependent slice.
Definition: Merger.h:480
@ Type
An inlay hint that for a type annotation.
static constexpr unsigned kInvalidId
A constant serving as the canonically invalid identifier, regardless of the identifier type.
Definition: Merger.h:30
unsigned LatSetId
LatSet identifiers.
Definition: Merger.h:57
std::pair< Level, LevelType > LvlLTPair
A pair of level and its corresponding LevelType of a tensor.
Definition: Merger.h:60
unsigned TensorLoopId
A compressed representation of std::pair<TensorId, LoopId>.
Definition: Merger.h:44
uint64_t Level
The type of level identifiers and level-ranks.
Definition: SparseTensor.h:38
unsigned LoopId
Loop identifiers.
Definition: Merger.h:38
bool isValidLT(LevelType lt)
Definition: Enums.h:429
unsigned ExprId
TensorExp identifiers.
Definition: Merger.h:48
unsigned LatPointId
LatPoint identifiers.
Definition: Merger.h:52
std::pair< LoopId, unsigned > LoopCoeffPair
A pair of loop id and its coefficients.
Definition: Merger.h:64
unsigned TensorId
Tensor identifiers, chosen to be the BlockArgument::getArgNumber of the value passed to Merger::build...
Definition: Merger.h:35
Include the generated interface declarations.
LatPoint(const BitVector &bits, ExprId e)
Construct a lattice point from the given set of TensorLoopIds.
Definition: Merger.h:206
ExprId exp
Identifier of the tensor expression.
Definition: Merger.h:217
BitVector bits
Conjunction of all TensorLoopIds involved in the tensor expression.
Definition: Merger.h:209
BitVector simple
Simplified conjunction of TensorLoopId as bitvector.
Definition: Merger.h:214
LatPoint(unsigned size, ExprId e)
Construct a lattice point with the empty set of TensorLoopIds.
Definition: Merger.h:203
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition: Enums.h:238
Child subexpressions for non-leaf expressions.
Definition: Merger.h:71
Tensor expression. Represents an MLIR expression in tensor index notation.
Definition: Merger.h:67
LoopId loop
kLoopVar expressions simply have a loop identifier.
Definition: Merger.h:96
Value val
Direct link to IR for an invariant or the destination value (to infer destination type) of a cast ope...
Definition: Merger.h:105
Kind
Tensor expression kind.
Definition: Merger.h:129
Children children
All other expressions hold the ExprIds of their children.
Definition: Merger.h:99
Attribute attr
An optional attribute that is required to determine the semantics of the operations.
Definition: Merger.h:118
TensorId tensor
kTensor expressions simply have a tensor identifier.
Definition: Merger.h:93
Kind kind
Tensor expression kind.
Definition: Merger.h:89
Operation * op
Code blocks used by semirings.
Definition: Merger.h:114
TensorExp(Kind k, unsigned x, ExprId y, Value v, Operation *op, Attribute a)
The x parameter has different types depending on the value of the k parameter.
Definition: Merger.cpp:105