MLIR  19.0.0git
Merger.h
Go to the documentation of this file.
1 //===- Merger.h - Utilities for defining lattices ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This header file defines utilities for dealing with iteration lattices.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_
14 #define MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_
15 
19 #include "mlir/IR/Value.h"
20 #include "llvm/ADT/BitVector.h"
21 
22 #include <optional>
23 
24 namespace mlir {
25 namespace sparse_tensor {
26 
27 namespace detail {
28 /// A constant serving as the canonically invalid identifier,
29 /// regardless of the identifier type.
30 static constexpr unsigned kInvalidId = -1u;
31 } // namespace detail
32 
33 /// Tensor identifiers, chosen to be the `BlockArgument::getArgNumber`
34 /// of the value passed to `Merger::buildTensorExp`.
35 using TensorId = unsigned;
36 
37 /// Loop identifiers.
38 using LoopId = unsigned;
39 
40 /// A compressed representation of `std::pair<TensorId, LoopId>`.
41 /// The compression scheme is such that this also serves as an index
42 /// into the bitvector stored in `LatPoint` (since that bitvector is
43 /// just the implementation for a set of `TensorLoopId` values).
44 using TensorLoopId = unsigned;
45 
46 /// `TensorExp` identifiers. These are allocated by `Merger::addExp`,
47 /// and serve as unique identifiers for the corresponding `TensorExp` object.
48 using ExprId = unsigned;
49 
50 /// `LatPoint` identifiers. These are allocated by `Merger::addLat`,
51 /// and serve as unique identifiers for the corresponding `LatPoint` object.
52 using LatPointId = unsigned;
53 
54 /// `LatSet` identifiers. These are allocated by `Merger::addSet` (and
55 /// by other methods calling that one), and serve as unique identifiers
56 /// for the corresponding `SmallVector<LatPointId>` object.
57 using LatSetId = unsigned;
58 
59 /// A pair of level and its corresponding LevelType of a tensor.
60 using LvlLTPair = std::pair<Level, LevelType>;
61 
62 /// A pair of loop id and its coefficients. E.g., for affine expression in the
63 /// affine map `2 * d0`, loop id = 0, coefficient = 2.
64 using LoopCoeffPair = std::pair<LoopId, unsigned>;
65 
66 /// Tensor expression. Represents an MLIR expression in tensor index notation.
67 struct TensorExp final {
68  enum class Kind;
69 
70  /// Child subexpressions for non-leaf expressions.
71  struct Children final {
74  };
75 
76  /// The `x` parameter has different types depending on the value of the
77  /// `k` parameter. The correspondences are:
78  /// * `kTensor` -> `TensorId`
79  /// * `kInvariant` -> `kInvalidId`
80  /// * `kLoopVar` -> `LoopId`
81  /// * else -> `ExprId`
82  ///
83  /// The `y`, `v`, and `op` parameters either must or must not be
84  /// `kInvalidId`/`nullptr`, depending on the value of the `k` parameter;
85  /// however, they have uniform C++ types regardless of the value of `k`.
86  TensorExp(Kind k, unsigned x, ExprId y, Value v, Operation *op, Attribute a);
87 
88  /// Tensor expression kind.
90 
91  union {
92  /// `kTensor` expressions simply have a tensor identifier.
94 
95  /// `kLoopVar` expressions simply have a loop identifier.
97 
98  /// All other expressions hold the `ExprId`s of their children.
100  };
101 
102  /// Direct link to IR for an invariant or the destination value (to
103  /// infer destination type) of a cast operation During code generation,
104  /// this field may be used to cache "hoisted" loop invariant tensor loads.
106 
107  /// Code blocks used by semirings. For the case of kUnary, kBinary, kReduce,
108  /// and kSelect, this holds the original operation with all regions. For
109  /// kBinaryBranch, this holds the YieldOp for the left or right half
110  /// to be merged into a nested scf loop.
111  ///
112  /// Or the actual operation that we can not sparsify but having all dense
113  /// operands for kDenseOp.
115 
116  /// An optional attribute that is required to determine the semantics of the
117  /// operations. E.g., CmpPredicateAttr for CmpI/CmpF operations.
119 };
120 
121 /// Tensor expression kind.
122 ///
123 /// The `kLoopVar` leaf kind is for representing `linalg::IndexOp`.
124 /// That is, its argument is a `LoopId` identifying the loop-variable
125 /// in question, and its value will be the current iteration's value.
126 /// The `kSynZero` leaf kind is for representing a synthetic zero value,
127 /// which can be introduced when sparsifying operations like `arith::cmp`
128 /// to generate `arith::cmp %lhs, %syn_zero` when the rhs operand is absent.
129 enum class TensorExp::Kind {
130  // Leaf.
131  kTensor = 0,
132  kSynZero,
133  kInvariant,
134  kLoopVar,
135  // Unary operations.
136  kAbsF,
137  kAbsC,
138  kAbsI,
139  kCeilF,
140  kFloorF,
141  kSqrtF,
142  kSqrtC,
143  kExpm1F,
144  kExpm1C,
145  kLog1pF,
146  kLog1pC,
147  kRelu,
148  kSinF,
149  kSinC,
150  kTanhF,
151  kTanhC,
152  kNegF,
153  kNegC,
154  kNegI,
155  kTruncF,
156  kExtF,
157  kCastFS, // signed
158  kCastFU, // unsigned
159  kCastSF, // signed
160  kCastUF, // unsigned
161  kCastS, // signed
162  kCastU, // unsigned
163  kCastIdx,
164  kTruncI,
165  kCIm, // complex.im
166  kCRe, // complex.re
167  kBitCast,
168  kBinaryBranch, // semiring unary branch created from a binary op
169  kUnary, // semiring unary op
170  kSelect, // custom selection criteria
171  // Binary operations.
172  kMulF,
173  kMulC,
174  kMulI,
175  kDivF,
176  kDivC, // complex
177  kDivS, // signed
178  kDivU, // unsigned
179  kAddF,
180  kAddC,
181  kAddI,
182  kSubF,
183  kSubC,
184  kSubI,
185  kAndI,
186  kOrI,
187  kXorI,
188  kCmpI,
189  kCmpF,
190  kShrS, // signed
191  kShrU, // unsigned
192  kShlI,
193  kBinary, // semiring binary op
194  kReduce, // semiring reduction op
195  kDenseOp, // special category of operations requiring all dense operands
196 };
197 
198 /// Lattice point. Each lattice point consists of a formal conjunction
199 /// of `TensorLoopId`s, together with the identifier of the corresponding
200 /// tensor expression. The formal conjunction is represented as a set of
201 /// `TensorLoopId`, where that set is implemented as a `BitVector`.
202 struct LatPoint final {
203  /// Construct a lattice point with the empty set of `TensorLoopId`s.
204  LatPoint(unsigned size, ExprId e) : bits(size, false), exp(e) {}
205 
206  /// Construct a lattice point from the given set of `TensorLoopId`s.
207  LatPoint(const BitVector &bits, ExprId e) : bits(bits), exp(e) {}
208 
209  /// Conjunction of all `TensorLoopId`s involved in the tensor expression.
210  BitVector bits;
211 
212  /// Simplified conjunction of `TensorLoopId` as bitvector. This
213  /// represents a simplified condition under which this tensor expression
214  /// must execute. Pre-computed during codegen to avoid repeated eval.
215  BitVector simple;
216 
217  /// Identifier of the tensor expression.
219 };
220 
221 /// A class to handle all iteration lattice operations. This class abstracts
222 /// away from some implementation details of storing iteration lattices and
223 /// tensor expressions. This allows for fine-tuning performance characteristics
224 /// independently from the basic algorithm if bottlenecks are identified.
225 class Merger {
226 public:
227  /// Constructs a merger for the given number of tensors and loops. The user
228  /// supplies the number of tensors involved in the kernel, with the last
229  /// tensor in this set denoting the output tensor. The merger adds an
230  /// additional synthetic tensor at the end of this set to represent all
231  /// invariant expressions in the kernel.
232  ///
233  /// The maxLvlRank specifies the max level rank of all inputs/output tensors.
234  /// It is used to pre-allocate sufficient memory for internal storage.
235  Merger(unsigned numInputOutputTensors, unsigned numLoops,
236  unsigned maxLvlRank);
237 
238  //
239  // Constructing valid tensor and loop identifiers.
240  //
241 
242  /// Safely converts the argument to a tensor identifier.
243  constexpr TensorId makeTensorId(unsigned t) const {
244  assert(isValidTensorId(t));
245  return t;
246  }
247 
248  /// Safely converts the argument to a loop identifier.
249  constexpr LoopId makeLoopId(unsigned i) const {
250  assert(isValidLoopId(i));
251  return i;
252  }
253 
254  /// Safely converts the arguments to a pair of (tensor,loop) identifiers.
255  constexpr TensorLoopId makeTensorLoopId(unsigned t, unsigned i) const {
256  assert(isValidTensorId(t) && isValidLoopId(i));
257  return numTensors * i + t;
258  }
259 
260  //
261  // Allocating new expressions, points, and sets.
262  //
263 
264  /// Constructs a new tensor expression, and returns its identifier.
266  /// Constructs a new loop-variable expression, and returns its identifier.
268  /// Constructs a new invariant expression, and returns its identifier.
270  /// Constructs a new synthetic zero expression.
272  /// Constructs a new unary or binary expression, and returns its identifier.
274  Operation *op = nullptr, Attribute attr = nullptr);
275  /// Constructs a new sesquinary expression, and returns its identifier.
276  /// Currently no sesquinary `Kind` allows specifying the `op`, but we
277  /// allow it anyways because `mapSet` is designed to allow it.
278  ExprId addExp(TensorExp::Kind k, ExprId e, Value v, Operation *op = nullptr,
279  Attribute attr = nullptr);
280 
281  /// Constructs a new iteration lattice point, and returns its identifier.
283  LatPointId addLat(const BitVector &bits, ExprId e);
284 
285  /// Constructs a new (initially empty) set, and returns its identifier.
286  LatSetId addSet();
287 
288  /// Computes a single conjunction of two lattice points by taking the "union"
289  /// of `LoopId` (effectively constructing a larger "intersection" of those
290  /// loops) with a newly constructed tensor (sub)expression of given kind.
291  /// Returns the identifier of the new lattice point.
293  Operation *op = nullptr);
294 
295  /// Conjunctive merge of two lattice sets: `(s0 /\_op s1)`.
296  /// Returns the identifier of the new set.
297  LatSetId conjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op = nullptr);
298 
299  /// Disjunctive merge of two lattice sets: `(s0 /\_op s1, s0, s1)`.
300  /// Returns the identifier of the new set.
301  LatSetId disjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op = nullptr);
302 
303  /// Disjunctive merge of two lattice sets and also set one of the operand to
304  /// zero: `(s0 /\_op s1 (e0 op e1), s0 (0 op e0), s1 (e1 op 0))`.
305  /// Returns the identifier of the new set.
307 
308  /// Disjunctive merge of two lattice sets with custom handling of the
309  /// overlap, left, and right regions. Any region may be left missing
310  /// in the output. Returns the identifier of the new set.
312  bool includeLeft, TensorExp::Kind ltrans, Operation *opleft,
313  bool includeRight, TensorExp::Kind rtrans,
314  Operation *opright);
315 
316  /// Maps the unary operator over the lattice set of the operand, i.e. each
317  /// lattice point on an expression E is simply copied over, but with OP E
318  /// as new expression. Returns the identifier of the new set.
320  Operation *op = nullptr, Attribute attr = nullptr);
321 
322  /// Maps the binary operator to the same operation but with one of its operand
323  /// set to zero, i.e. each lattice point on an expression E is simply copied
324  /// over, but with `OP 0 E` (if lhsZero == true) or `OP E 0` (if lhsZero ==
325  /// false) as new expression. Returns the identifier of the new set.
326  LatSetId mapBinWithSynZeroSet(ExprId e, LatSetId s, bool lhsZero);
327 
328  /// Optimizes the iteration lattice points in the given set. This
329  /// method should be called right before code generation to avoid
330  /// generating redundant loops and conditions.
332 
333  /// Simplifies the conditions in a conjunction of a given lattice point
334  /// within the given set using just two basic rules:
335  /// (1) multiple dense conditions are reduced to single dense, and
336  /// (2) a *singleton* sparse/dense is reduced to sparse/random access.
337  BitVector simplifyCond(LatSetId s, LatPointId p);
338 
339  /// Returns true if p0 > p1.
340  bool latGT(LatPointId p0, LatPointId p1) const;
341 
342  /// Returns true if p0 and p1 only differ in dense.
343  bool onlyDenseDiff(LatPointId p0, LatPointId p1) const;
344 
345  /// Gets the tensor-identifier of the `TensorLoopId`.
346  constexpr TensorId tensor(TensorLoopId b) const { return b % numTensors; }
347  /// Gets the loop-identifier of the `TensorLoopId`.
348  constexpr LoopId loop(TensorLoopId b) const { return b / numTensors; }
349 
350  /// Gets the total number of tensors (including the output-tensor and
351  /// synthetic-tensor).
352  constexpr unsigned getNumTensors() const { return numTensors; }
353 
354  /// Gets the total number of loops (native loops + filter loops).
355  constexpr unsigned getNumLoops() const { return numLoops; }
356 
357  /// Returns true if `b` is the `i`th loop of the output tensor.
358  constexpr bool isOutTensor(TensorLoopId b, LoopId i) const {
359  return b == makeTensorLoopId(outTensor, i);
360  }
361 
362  /// Gets the output tensor's identifier.
363  constexpr TensorId getOutTensorID() const { return outTensor; }
364 
365  /// Gets the synthetic tensor's identifier (used for all invariant
366  /// tensor expressions).
367  constexpr TensorId getSynTensorID() const { return syntheticTensor; }
368 
369  /// Returns true if the expression is `(kTensor t)`.
370  bool expIsTensor(ExprId e, TensorId t) const {
371  const auto &expr = exp(e);
372  return expr.kind == TensorExp::Kind::kTensor && expr.tensor == t;
373  }
374 
375  /// Returns true if the expression contains the tensor as an operand.
376  bool expContainsTensor(ExprId e, TensorId t) const;
377 
378  /// Returns true if the expression contains a negation on output tensor.
379  /// I.e., `- outTensor` or `exp - outputTensor`
380  /// NOTE: this is an trivial tests in that it does not handle recursive
381  /// negation, i.e., it returns true when the expression is `-(-tensor)`.
382  bool hasNegateOnOut(ExprId e) const;
383 
384  /// Returns true if given tensor iterates *only* in the given tensor
385  /// expression. For the output tensor, this defines a "simply dynamic"
386  /// operation [Bik96]. For instance: a(i) *= 2.0 or a(i) += a(i) for
387  /// sparse vector a.
388  bool isSingleCondition(TensorId t, ExprId e) const;
389 
390  /// Returns true if any `TensorLoopId` in the bitvector corresponds
391  /// to sparse level-type.
392  bool hasAnySparse(const BitVector &bits) const;
393 
394  /// Returns true if bits contains a dependent index reduction condition on
395  /// sparse levels.
396  bool hasSparseIdxReduction(const BitVector &bits) const;
397 
398  /// Gets the level-type of the `t`th tensor on `i`th loop.
400  assert(isValidTensorId(t) && isValidLoopId(i));
401  return lvlTypes[t][i];
402  }
403 
404  /// Gets the level-type of the TensorLoopId.
406  return getLvlType(tensor(b), loop(b));
407  }
408 
409  /// Gets the loop identifier for the `lvl`th level of the `t`th tensor.
410  std::optional<LoopId> getLoopId(TensorId t, Level lvl) const {
411  assert(isValidLevel(t, lvl));
412  return lvlToLoop[t][lvl];
413  }
414 
415  /// Gets the level number of the the `t`th tensor on `i`th loop.
416  std::optional<Level> getLvl(TensorId t, LoopId i) const {
417  assert(isValidTensorId(t) && isValidLoopId(i));
418  return loopToLvl[t][i];
419  }
420  std::optional<Level> getLvl(TensorLoopId b) const {
421  return getLvl(tensor(b), loop(b));
422  }
423 
424  /// Sets the level number and level-type of the `t`th tensor on
425  /// `i`th loop.
427  assert(isValidLevel(t, lvl) && isValidLoopId(i) && isValidLT(lt));
428  lvlTypes[t][i] = lt;
429  loopToLvl[t][i] = lvl;
430  lvlToLoop[t][lvl] = i;
431  // TODO: favor a constant loop bound when there are multiple choices.
432  loopBounds[i] = std::make_pair(t, lvl);
433  }
434 
436  TensorLoopId, TensorId, std::optional<Level>, LevelType, bool)>;
437 
438  /// Iterates over a set of `TensorLoopId`s, invoking the callback
439  /// for each `TensorLoopId` and passing it the corresponding tensor
440  /// identifier, level, and level-type, following with a boolean value
441  /// indicating whether it is a dependent index reduction loop condition.
443  ForeachTensorLoopIdCallback callback) const {
444  // TODO: the default ought to be simple=true; but we'll need to make
445  // sure to update all the tests to make sure they do the right thing.
446  foreachTensorLoopId(p, /*simple=*/false, callback);
447  }
448  void foreachTensorLoopId(LatPointId p, bool simple,
449  ForeachTensorLoopIdCallback callback) const {
450  const auto &point = lat(p);
451  const auto &bits = simple ? point.simple : point.bits;
452  for (const TensorLoopId b : bits.set_bits()) {
453  const TensorId t = tensor(b);
454  const auto optLvl = getLvl(b);
455  const auto lvlTp = getLvlType(b);
456  if (isLvlWithNonTrivialIdxExp(b)) {
457  // This must be an undefined level.
458  assert(!optLvl.has_value());
459  // Slice the tid along the dependent level to iterate current loop.
460  callback(b, t, getLoopDependentLevel(b), lvlTp,
461  /*isIdxReduc=*/true);
462  } else {
463  callback(b, t, optLvl, lvlTp, /*isIdxReduc=*/false);
464  }
465  }
466  }
467 
468  /// Sets whether the output tensor is sparse or not.
469  void setHasSparseOut(bool s) { hasSparseOut = s; }
470 
471  /// Establishes the two-way map that i <-> <t, lvl, lt>.
473  LevelType lt, unsigned coefficient) {
474  assert(isValidLoopId(i) && isValidLevel(t, lvl));
475  assert(!loopToUnresolvedLvls[i][t].has_value()); // must be the first def
476  loopToUnresolvedLvls[i][t] = std::make_pair(lvl, lt);
477  levelToDependentLoop[t][lvl].emplace_back(i, coefficient);
478  }
479 
480  /// Whether the loop has dependent slice.
482  assert(isValidTensorId(t) && isValidLoopId(i));
483  return loopToUnresolvedLvls[i][t].has_value();
484  }
485 
486  /// Returns the list of loop indices which appear in the non-trivial index
487  /// expression on t_l, e.g., A[i+j] => {i, j}
488  std::vector<LoopCoeffPair> &getDependentLoops(TensorId t, Level lvl) {
489  assert(isValidLevel(t, lvl));
490  return levelToDependentLoop[t][lvl];
491  }
492 
493  /// Returns the defining [tid, lvl] for the loop.
494  std::pair<TensorId, Level> getLoopDefiningLvl(LoopId i) const {
495  assert(isValidLoopId(i));
496  return loopBounds[i];
497  }
498 
499  /// Checks whether the TensorLoopId represents a tensor level contains
500  /// non-trivial index expression.
502  const TensorId t = tensor(b);
503  const LoopId i = loop(b);
504  assert(isValidTensorId(t) && isValidLoopId(i));
505  return loopToUnresolvedLvls[i][t].has_value();
506  }
507 
508  /// Checks whether the TensorLoopId represents a sparse tensor level contains
509  /// non-trivial index expression.
511  if (isLvlWithNonTrivialIdxExp(b)) {
512  auto lt = getLoopDependentLevelType(b);
513  return lt.hasSparseSemantic();
514  }
515  return false;
516  }
517 
519  assert(isLvlWithNonTrivialIdxExp(b));
520  return loopToUnresolvedLvls[loop(b)][tensor(b)]->first;
521  }
522 
524  assert(isLvlWithNonTrivialIdxExp(b));
525  return loopToUnresolvedLvls[loop(b)][tensor(b)]->second;
526  }
527 
528  /// Convenience getters to immediately access the stored nodes.
529  /// These methods return `const&` because the underlying objects must
530  /// not be mutated by client code. The only exception is for mutating
531  /// the value associated with an expression, for which there are
532  /// dedicated methods below.
533  ///
534  /// NOTE: It is inadvisable to keep the reference alive for a long
535  /// time (e.g., as in `TensorExpr &te = merger.exp(e)`), since insertions
536  /// into the merger can cause data movement which will invalidate the
537  /// underlying memory address. This isn't just a problem with the `&`
538  /// references, but also applies to the `ArrayRef`. In particular,
539  /// using `for (LatPointId p : merger.set(s))` will run into the same
540  /// dangling-reference problems if the loop body inserts new sets.
541  const TensorExp &exp(ExprId e) const {
542  assert(isValidExprId(e));
543  return tensorExps[e];
544  }
545  const LatPoint &lat(LatPointId p) const {
546  assert(isValidLatPointId(p));
547  return latPoints[p];
548  }
550  assert(isValidLatSetId(s));
551  return latSets[s];
552  }
553 
554  /// Checks whether the given expression has an associated value.
555  bool hasExprValue(ExprId e) const { return static_cast<bool>(exp(e).val); }
556 
557  /// Sets the expression to have the associated value. Asserts that the new
558  /// value is defined, and that the expression does not already have a value.
559  void setExprValue(ExprId e, Value v) {
560  assert(!exp(e).val && "Expression already has an associated value");
561  assert(v && "Trying to assign an undefined value");
562  tensorExps[e].val = v;
563  }
564 
565  /// Clears the value associated with the expression. Asserts that the
566  /// expression does indeed have an associated value before clearing it.
568  assert(exp(e).val && "Expression does not have an associated value");
569  tensorExps[e].val = Value();
570  }
571 
572 #ifndef NDEBUG
573  /// Print methods (for debugging).
574  void dumpExp(ExprId e) const;
575  void dumpLat(LatPointId p) const;
576  void dumpSet(LatSetId s) const;
577  void dumpBits(const BitVector &bits) const;
578 #endif
579 
580  /// Builds the iteration lattices in a bottom-up traversal given the
581  /// remaining tensor (sub)expression and the next loop in the iteration
582  /// graph. Returns the identifier of the root set.
584 
585  /// Builds a tensor expression from the given Linalg operation.
586  /// On success, returns the identifier of the root expression.
587  std::optional<ExprId> buildTensorExpFromLinalg(linalg::GenericOp op);
588 
589  /// Rebuilds SSA format from a tensor expression.
590  Value buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0,
591  Value v1) const;
592 
593 private:
594  /// Private helpers.
595  constexpr bool isValidTensorId(TensorId t) const { return t < numTensors; }
596  constexpr bool isValidLoopId(LoopId i) const {
597  return i != detail::kInvalidId && i < numLoops;
598  }
599  bool isValidLevel(TensorId t, Level lvl) const {
600  assert(levelToDependentLoop[t].size() == lvlToLoop[t].size());
601  return isValidTensorId(t) && lvl < lvlToLoop[t].size();
602  }
603  bool isValidExprId(ExprId e) const {
604  return e != detail::kInvalidId && e < tensorExps.size();
605  }
606  bool isValidLatPointId(LatPointId p) const {
607  return p != detail::kInvalidId && p < latPoints.size();
608  }
609  bool isValidLatSetId(LatSetId s) const {
610  return s != detail::kInvalidId && s < latSets.size();
611  }
612  bool maybeZero(ExprId e) const;
613  bool isInvariant(ExprId e) const {
614  return exp(e).kind == TensorExp::Kind::kInvariant;
615  }
616  Type inferType(ExprId e, Value src) const;
617 
618  /// Traverses the SSA tree (possibly a DAG) to build a tensor expression.
619  /// The boolean value returned indicates whether the result of the current
620  /// operation being built depends on any value that is loaded from a sparse
621  /// tensor.
622  std::pair<std::optional<ExprId>, bool> buildTensorExp(linalg::GenericOp op,
623  Value v);
624 
625  /// Merger data structures.
626  const TensorId outTensor;
627  const TensorId syntheticTensor;
628  const unsigned numTensors;
629  const unsigned numLoops;
630  bool hasSparseOut;
631 
632  // Below we use `std::vector` for things which have a priori fixed
633  // sizes, whereas we use `llvm::SmallVector` for things with variable
634  // size. Do beware that these two classes differ in the semantics of
635  // `operator[]`: `SmallVector` performs OOB checks, whereas `std::vector`
636  // does not.
637 
638  /// Map that converts pair<TensorId, LoopId> to the corresponding lvl-type.
639  std::vector<std::vector<LevelType>> lvlTypes;
640 
641  /// Map that converts pair<TensorId, LoopId> to the corresponding lvl.
642  std::vector<std::vector<std::optional<Level>>> loopToLvl;
643 
644  /// Map that converts pair<TensorId, Level> to the corresponding LoopId.
645  std::vector<std::vector<std::optional<LoopId>>> lvlToLoop;
646 
647  /// Map from a loop to its dependencies if any.
648  /// The dependencies of a loop is a set of (tensor, level) pairs.
649  /// It is currently only set for non-trivial index expressions.
650  /// E.g., A[i+j] => i and j will have dependencies {A0, lt(A0)} to indicate
651  /// that i and j are used in the non-trivial index expression on A0.
652  std::vector<std::vector<std::optional<LvlLTPair>>> loopToUnresolvedLvls;
653 
654  /// The inverse map of ldxToDependencies from tensor level -> dependent loop
655  /// E.g., A[2i+j], we have A0 => {(2, i), (1, j)}, to indicate that A0 uses
656  /// both {i, j} to compute its indices and the coefficients on the loop id are
657  /// 2 and 1 respectively.
658  std::vector<std::vector<std::vector<LoopCoeffPair>>> levelToDependentLoop;
659 
660  /// Map from a loop to the [tid, lvl] pair that defines the loop boundary.
661  std::vector<std::pair<TensorId, Level>> loopBounds;
662 
663  llvm::SmallVector<TensorExp> tensorExps;
664  llvm::SmallVector<LatPoint> latPoints;
666 };
667 
668 } // namespace sparse_tensor
669 } // namespace mlir
670 
671 #endif // MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
A class to handle all iteration lattice operations.
Definition: Merger.h:225
void setHasSparseOut(bool s)
Sets whether the output tensor is sparse or not.
Definition: Merger.h:469
constexpr unsigned getNumLoops() const
Gets the total number of loops (native loops + filter loops).
Definition: Merger.h:355
LatPointId conjLat(ExprId e, LatPointId p0, LatPointId p1, Operation *op=nullptr)
Computes a single conjunction of two lattice points by taking the "union" of LoopId (effectively cons...
Definition: Merger.cpp:315
LatSetId disjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op=nullptr)
Disjunctive merge of two lattice sets: (s0 /\_op s1, s0, s1).
Definition: Merger.cpp:338
Level getLoopDependentLevel(TensorLoopId b) const
Definition: Merger.h:518
std::optional< Level > getLvl(TensorId t, LoopId i) const
Gets the level number of the the tth tensor on ith loop.
Definition: Merger.h:416
constexpr bool isOutTensor(TensorLoopId b, LoopId i) const
Returns true if b is the ith loop of the output tensor.
Definition: Merger.h:358
bool isSingleCondition(TensorId t, ExprId e) const
Returns true if given tensor iterates only in the given tensor expression.
Definition: Merger.cpp:577
bool hasSparseIdxReduction(const BitVector &bits) const
Returns true if bits contains a dependent index reduction condition on sparse levels.
Definition: Merger.cpp:680
bool expContainsTensor(ExprId e, TensorId t) const
Returns true if the expression contains the tensor as an operand.
Definition: Merger.cpp:522
LatSetId mapBinWithSynZeroSet(ExprId e, LatSetId s, bool lhsZero)
Maps the binary operator to the same operation but with one of its operand set to zero,...
Definition: Merger.cpp:414
bool isSparseLvlWithNonTrivialIdxExp(TensorLoopId b) const
Checks whether the TensorLoopId represents a sparse tensor level contains non-trivial index expressio...
Definition: Merger.h:510
void dumpBits(const BitVector &bits) const
Definition: Merger.cpp:920
bool hasExprValue(ExprId e) const
Checks whether the given expression has an associated value.
Definition: Merger.h:555
void foreachTensorLoopId(LatPointId p, bool simple, ForeachTensorLoopIdCallback callback) const
Definition: Merger.h:448
LatSetId addSet()
Constructs a new (initially empty) set, and returns its identifier.
Definition: Merger.cpp:309
std::optional< LoopId > getLoopId(TensorId t, Level lvl) const
Gets the loop identifier for the lvlth level of the tth tensor.
Definition: Merger.h:410
std::pair< TensorId, Level > getLoopDefiningLvl(LoopId i) const
Returns the defining [tid, lvl] for the loop.
Definition: Merger.h:494
BitVector simplifyCond(LatSetId s, LatPointId p)
Simplifies the conditions in a conjunction of a given lattice point within the given set using just t...
Definition: Merger.cpp:461
bool hasNegateOnOut(ExprId e) const
Returns true if the expression contains a negation on output tensor.
Definition: Merger.cpp:544
constexpr unsigned getNumTensors() const
Gets the total number of tensors (including the output-tensor and synthetic-tensor).
Definition: Merger.h:352
bool isLvlWithNonTrivialIdxExp(TensorLoopId b) const
Checks whether the TensorLoopId represents a tensor level contains non-trivial index expression.
Definition: Merger.h:501
LatSetId disjSetWithZero(ExprId e, LatSetId s0, LatSetId s1)
Disjunctive merge of two lattice sets and also set one of the operand to zero: (s0 /\_op s1 (e0 op e1...
Definition: Merger.cpp:356
void dumpSet(LatSetId s) const
Definition: Merger.cpp:910
void dumpLat(LatPointId p) const
Definition: Merger.cpp:899
LatSetId combiSet(ExprId e, LatSetId s0, LatSetId s1, Operation *orig, bool includeLeft, TensorExp::Kind ltrans, Operation *opleft, bool includeRight, TensorExp::Kind rtrans, Operation *opright)
Disjunctive merge of two lattice sets with custom handling of the overlap, left, and right regions.
Definition: Merger.cpp:380
ExprId addTensorExp(TensorId t)
Constructs a new tensor expression, and returns its identifier.
Definition: Merger.cpp:247
LatSetId buildLattices(ExprId e, LoopId i)
Builds the iteration lattices in a bottom-up traversal given the remaining tensor (sub)expression and...
Definition: Merger.cpp:940
LatSetId conjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op=nullptr)
Conjunctive merge of two lattice sets: (s0 /\_op s1).
Definition: Merger.cpp:329
ExprId addExp(TensorExp::Kind k, ExprId e0, ExprId e1=detail::kInvalidId, Operation *op=nullptr, Attribute attr=nullptr)
Constructs a new unary or binary expression, and returns its identifier.
Definition: Merger.cpp:277
ExprId addSynZeroExp()
Constructs a new synthetic zero expression.
Definition: Merger.cpp:270
constexpr LoopId makeLoopId(unsigned i) const
Safely converts the argument to a loop identifier.
Definition: Merger.h:249
std::optional< ExprId > buildTensorExpFromLinalg(linalg::GenericOp op)
Builds a tensor expression from the given Linalg operation.
Definition: Merger.cpp:1193
void setLevelAndType(TensorId t, LoopId i, Level lvl, LevelType lt)
Sets the level number and level-type of the tth tensor on ith loop.
Definition: Merger.h:426
LatSetId mapSet(TensorExp::Kind kind, LatSetId s, Value v=Value(), Operation *op=nullptr, Attribute attr=nullptr)
Maps the unary operator over the lattice set of the operand, i.e.
Definition: Merger.cpp:401
void foreachTensorLoopId(LatPointId p, ForeachTensorLoopIdCallback callback) const
Iterates over a set of TensorLoopIds, invoking the callback for each TensorLoopId and passing it the ...
Definition: Merger.h:442
std::optional< Level > getLvl(TensorLoopId b) const
Definition: Merger.h:420
ArrayRef< LatPointId > set(LatSetId s) const
Definition: Merger.h:549
LatSetId optimizeSet(LatSetId s)
Optimizes the iteration lattice points in the given set.
Definition: Merger.cpp:431
constexpr TensorId tensor(TensorLoopId b) const
Gets the tensor-identifier of the TensorLoopId.
Definition: Merger.h:346
void setLoopDependentTensorLevel(LoopId i, TensorId t, Level lvl, LevelType lt, unsigned coefficient)
Establishes the two-way map that i <-> <t, lvl, lt>.
Definition: Merger.h:472
void dumpExp(ExprId e) const
Print methods (for debugging).
Definition: Merger.cpp:799
LevelType getLvlType(TensorLoopId b) const
Gets the level-type of the TensorLoopId.
Definition: Merger.h:405
Merger(unsigned numInputOutputTensors, unsigned numLoops, unsigned maxLvlRank)
Constructs a merger for the given number of tensors and loops.
Definition: Merger.cpp:224
bool hasAnySparse(const BitVector &bits) const
Returns true if any TensorLoopId in the bitvector corresponds to sparse level-type.
Definition: Merger.cpp:671
void clearExprValue(ExprId e)
Clears the value associated with the expression.
Definition: Merger.h:567
std::vector< LoopCoeffPair > & getDependentLoops(TensorId t, Level lvl)
Returns the list of loop indices which appear in the non-trivial index expression on t_l,...
Definition: Merger.h:488
LatPointId addLat(TensorId t, LoopId i, ExprId e)
Constructs a new iteration lattice point, and returns its identifier.
Definition: Merger.cpp:293
ExprId addLoopVarExp(LoopId i)
Constructs a new loop-variable expression, and returns its identifier.
Definition: Merger.cpp:255
constexpr TensorId getSynTensorID() const
Gets the synthetic tensor's identifier (used for all invariant tensor expressions).
Definition: Merger.h:367
bool latGT(LatPointId p0, LatPointId p1) const
Returns true if p0 > p1.
Definition: Merger.cpp:503
const TensorExp & exp(ExprId e) const
Convenience getters to immediately access the stored nodes.
Definition: Merger.h:541
constexpr LoopId loop(TensorLoopId b) const
Gets the loop-identifier of the TensorLoopId.
Definition: Merger.h:348
const LatPoint & lat(LatPointId p) const
Definition: Merger.h:545
constexpr TensorId getOutTensorID() const
Gets the output tensor's identifier.
Definition: Merger.h:363
bool onlyDenseDiff(LatPointId p0, LatPointId p1) const
Returns true if p0 and p1 only differ in dense.
Definition: Merger.cpp:516
ExprId addInvariantExp(Value v)
Constructs a new invariant expression, and returns its identifier.
Definition: Merger.cpp:263
constexpr TensorId makeTensorId(unsigned t) const
Safely converts the argument to a tensor identifier.
Definition: Merger.h:243
LevelType getLoopDependentLevelType(TensorLoopId b) const
Definition: Merger.h:523
LevelType getLvlType(TensorId t, LoopId i) const
Gets the level-type of the tth tensor on ith loop.
Definition: Merger.h:399
Value buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0, Value v1) const
Rebuilds SSA format from a tensor expression.
Definition: Merger.cpp:1618
constexpr TensorLoopId makeTensorLoopId(unsigned t, unsigned i) const
Safely converts the arguments to a pair of (tensor,loop) identifiers.
Definition: Merger.h:255
bool expIsTensor(ExprId e, TensorId t) const
Returns true if the expression is (kTensor t).
Definition: Merger.h:370
void setExprValue(ExprId e, Value v)
Sets the expression to have the associated value.
Definition: Merger.h:559
bool hasDependentLvl(LoopId i, TensorId t)
Whether the loop has dependent slice.
Definition: Merger.h:481
@ Type
An inlay hint that for a type annotation.
static constexpr unsigned kInvalidId
A constant serving as the canonically invalid identifier, regardless of the identifier type.
Definition: Merger.h:30
unsigned LatSetId
LatSet identifiers.
Definition: Merger.h:57
std::pair< Level, LevelType > LvlLTPair
A pair of level and its corresponding LevelType of a tensor.
Definition: Merger.h:60
unsigned TensorLoopId
A compressed representation of std::pair<TensorId, LoopId>.
Definition: Merger.h:44
uint64_t Level
The type of level identifiers and level-ranks.
Definition: SparseTensor.h:42
unsigned LoopId
Loop identifiers.
Definition: Merger.h:38
bool isValidLT(LevelType lt)
Definition: Enums.h:433
unsigned ExprId
TensorExp identifiers.
Definition: Merger.h:48
unsigned LatPointId
LatPoint identifiers.
Definition: Merger.h:52
std::pair< LoopId, unsigned > LoopCoeffPair
A pair of loop id and its coefficients.
Definition: Merger.h:64
unsigned TensorId
Tensor identifiers, chosen to be the BlockArgument::getArgNumber of the value passed to Merger::build...
Definition: Merger.h:35
Include the generated interface declarations.
LatPoint(const BitVector &bits, ExprId e)
Construct a lattice point from the given set of TensorLoopIds.
Definition: Merger.h:207
ExprId exp
Identifier of the tensor expression.
Definition: Merger.h:218
BitVector bits
Conjunction of all TensorLoopIds involved in the tensor expression.
Definition: Merger.h:210
BitVector simple
Simplified conjunction of TensorLoopId as bitvector.
Definition: Merger.h:215
LatPoint(unsigned size, ExprId e)
Construct a lattice point with the empty set of TensorLoopIds.
Definition: Merger.h:204
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition: Enums.h:238
Child subexpressions for non-leaf expressions.
Definition: Merger.h:71
Tensor expression. Represents an MLIR expression in tensor index notation.
Definition: Merger.h:67
LoopId loop
kLoopVar expressions simply have a loop identifier.
Definition: Merger.h:96
Value val
Direct link to IR for an invariant or the destination value (to infer destination type) of a cast ope...
Definition: Merger.h:105
Kind
Tensor expression kind.
Definition: Merger.h:129
Children children
All other expressions hold the ExprIds of their children.
Definition: Merger.h:99
Attribute attr
An optional attribute that is required to determine the semantics of the operations.
Definition: Merger.h:118
TensorId tensor
kTensor expressions simply have a tensor identifier.
Definition: Merger.h:93
Kind kind
Tensor expression kind.
Definition: Merger.h:89
Operation * op
Code blocks used by semirings.
Definition: Merger.h:114
TensorExp(Kind k, unsigned x, ExprId y, Value v, Operation *op, Attribute a)
The x parameter has different types depending on the value of the k parameter.
Definition: Merger.cpp:106