MLIR 22.0.0git
Merger.h
Go to the documentation of this file.
1//===- Merger.h - Utilities for defining lattices ---------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This header file defines utilities for dealing with iteration lattices.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_
14#define MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_
15
19#include "mlir/IR/Value.h"
20#include "llvm/ADT/BitVector.h"
21
22#include <optional>
23
24namespace mlir {
25namespace sparse_tensor {
26
27namespace detail {
28/// A constant serving as the canonically invalid identifier,
29/// regardless of the identifier type.
30static constexpr unsigned kInvalidId = -1u;
31} // namespace detail
32
33/// Tensor identifiers, chosen to be the `BlockArgument::getArgNumber`
34/// of the value passed to `Merger::buildTensorExp`.
36
37/// Loop identifiers.
39
40/// A compressed representation of `std::pair<TensorId, LoopId>`.
41/// The compression scheme is such that this also serves as an index
42/// into the bitvector stored in `LatPoint` (since that bitvector is
43/// just the implementation for a set of `TensorLoopId` values).
45
46/// `TensorExp` identifiers. These are allocated by `Merger::addExp`,
47/// and serve as unique identifiers for the corresponding `TensorExp` object.
49
50/// `LatPoint` identifiers. These are allocated by `Merger::addLat`,
51/// and serve as unique identifiers for the corresponding `LatPoint` object.
53
54/// `LatSet` identifiers. These are allocated by `Merger::addSet` (and
55/// by other methods calling that one), and serve as unique identifiers
56/// for the corresponding `SmallVector<LatPointId>` object.
58
59/// A pair of level and its corresponding LevelType of a tensor.
60using LvlLTPair = std::pair<Level, LevelType>;
61
62/// A pair of loop id and its coefficients. E.g., for affine expression in the
63/// affine map `2 * d0`, loop id = 0, coefficient = 2.
64using LoopCoeffPair = std::pair<LoopId, unsigned>;
65
66/// Tensor expression. Represents an MLIR expression in tensor index notation.
67struct TensorExp final {
68 enum class Kind;
69
70 /// Child subexpressions for non-leaf expressions.
71 struct Children final {
74 };
75
76 /// The `x` parameter has different types depending on the value of the
77 /// `k` parameter. The correspondences are:
78 /// * `kTensor` -> `TensorId`
79 /// * `kInvariant` -> `kInvalidId`
80 /// * `kLoopVar` -> `LoopId`
81 /// * else -> `ExprId`
82 ///
83 /// The `y`, `v`, and `op` parameters either must or must not be
84 /// `kInvalidId`/`nullptr`, depending on the value of the `k` parameter;
85 /// however, they have uniform C++ types regardless of the value of `k`.
86 TensorExp(Kind k, unsigned x, ExprId y, Value v, Operation *op, Attribute a);
87
88 /// Tensor expression kind.
90
91 union {
92 /// `kTensor` expressions simply have a tensor identifier.
94
95 /// `kLoopVar` expressions simply have a loop identifier.
97
98 /// All other expressions hold the `ExprId`s of their children.
100 };
101
102 /// Direct link to IR for an invariant or the destination value (to
103 /// infer destination type) of a cast operation During code generation,
104 /// this field may be used to cache "hoisted" loop invariant tensor loads.
106
107 /// Code blocks used by semirings. For the case of kUnary, kBinary, kReduce,
108 /// and kSelect, this holds the original operation with all regions. For
109 /// kBinaryBranch, this holds the YieldOp for the left or right half
110 /// to be merged into a nested scf loop.
111 ///
112 /// Or the actual operation that we can not sparsify but having all dense
113 /// operands for kDenseOp.
115
116 /// An optional attribute that is required to determine the semantics of the
117 /// operations. E.g., CmpPredicateAttr for CmpI/CmpF operations.
119};
120
121/// Tensor expression kind.
122///
123/// The `kLoopVar` leaf kind is for representing `linalg::IndexOp`.
124/// That is, its argument is a `LoopId` identifying the loop-variable
125/// in question, and its value will be the current iteration's value.
126/// The `kSynZero` leaf kind is for representing a synthetic zero value,
127/// which can be introduced when sparsifying operations like `arith::cmp`
128/// to generate `arith::cmp %lhs, %syn_zero` when the rhs operand is absent.
129enum class TensorExp::Kind {
130 // Leaf.
135 // Unary operations.
157 kCastFS, // signed
158 kCastFU, // unsigned
159 kCastSF, // signed
160 kCastUF, // unsigned
161 kCastS, // signed
162 kCastU, // unsigned
165 kCIm, // complex.im
166 kCRe, // complex.re
168 kBinaryBranch, // semiring unary branch created from a binary op
169 kUnary, // semiring unary op
170 kSelect, // custom selection criteria
171 // Binary operations.
176 kDivC, // complex
177 kDivS, // signed
178 kDivU, // unsigned
190 kShrS, // signed
191 kShrU, // unsigned
193 kBinary, // semiring binary op
194 kReduce, // semiring reduction op
195 kDenseOp, // special category of operations requiring all dense operands
196};
197
198/// Lattice point. Each lattice point consists of a formal conjunction
199/// of `TensorLoopId`s, together with the identifier of the corresponding
200/// tensor expression. The formal conjunction is represented as a set of
201/// `TensorLoopId`, where that set is implemented as a `BitVector`.
202struct LatPoint final {
203 /// Construct a lattice point with the empty set of `TensorLoopId`s.
204 LatPoint(unsigned size, ExprId e) : bits(size, false), exp(e) {}
205
206 /// Construct a lattice point from the given set of `TensorLoopId`s.
207 LatPoint(const BitVector &bits, ExprId e) : bits(bits), exp(e) {}
208
209 /// Conjunction of all `TensorLoopId`s involved in the tensor expression.
210 BitVector bits;
211
212 /// Simplified conjunction of `TensorLoopId` as bitvector. This
213 /// represents a simplified condition under which this tensor expression
214 /// must execute. Pre-computed during codegen to avoid repeated eval.
215 BitVector simple;
216
217 /// Identifier of the tensor expression.
219};
220
221/// A class to handle all iteration lattice operations. This class abstracts
222/// away from some implementation details of storing iteration lattices and
223/// tensor expressions. This allows for fine-tuning performance characteristics
224/// independently from the basic algorithm if bottlenecks are identified.
225class Merger {
226public:
227 /// Constructs a merger for the given number of tensors and loops. The user
228 /// supplies the number of tensors involved in the kernel, with the last
229 /// tensor in this set denoting the output tensor. The merger adds an
230 /// additional synthetic tensor at the end of this set to represent all
231 /// invariant expressions in the kernel.
232 ///
233 /// The maxLvlRank specifies the max level rank of all inputs/output tensors.
234 /// It is used to pre-allocate sufficient memory for internal storage.
235 Merger(unsigned numInputOutputTensors, unsigned numLoops,
236 unsigned maxLvlRank);
237
238 //
239 // Constructing valid tensor and loop identifiers.
240 //
241
242 /// Safely converts the argument to a tensor identifier.
243 constexpr TensorId makeTensorId(unsigned t) const {
244 assert(isValidTensorId(t));
245 return t;
246 }
247
248 /// Safely converts the argument to a loop identifier.
249 constexpr LoopId makeLoopId(unsigned i) const {
250 assert(isValidLoopId(i));
251 return i;
252 }
253
254 /// Safely converts the arguments to a pair of (tensor,loop) identifiers.
255 constexpr TensorLoopId makeTensorLoopId(unsigned t, unsigned i) const {
256 assert(isValidTensorId(t) && isValidLoopId(i));
257 return numTensors * i + t;
258 }
259
260 //
261 // Allocating new expressions, points, and sets.
262 //
263
264 /// Constructs a new tensor expression, and returns its identifier.
266 /// Constructs a new loop-variable expression, and returns its identifier.
268 /// Constructs a new invariant expression, and returns its identifier.
270 /// Constructs a new synthetic zero expression.
272 /// Constructs a new unary or binary expression, and returns its identifier.
274 Operation *op = nullptr, Attribute attr = nullptr);
275 /// Constructs a new sesquinary expression, and returns its identifier.
276 /// Currently no sesquinary `Kind` allows specifying the `op`, but we
277 /// allow it anyways because `mapSet` is designed to allow it.
278 ExprId addExp(TensorExp::Kind k, ExprId e, Value v, Operation *op = nullptr,
279 Attribute attr = nullptr);
280
281 /// Constructs a new iteration lattice point, and returns its identifier.
283 LatPointId addLat(const BitVector &bits, ExprId e);
284
285 /// Constructs a new (initially empty) set, and returns its identifier.
287
288 /// Computes a single conjunction of two lattice points by taking the "union"
289 /// of `LoopId` (effectively constructing a larger "intersection" of those
290 /// loops) with a newly constructed tensor (sub)expression of given kind.
291 /// Returns the identifier of the new lattice point.
293 Operation *op = nullptr);
294
295 /// Conjunctive merge of two lattice sets: `(s0 /\_op s1)`.
296 /// Returns the identifier of the new set.
297 LatSetId conjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op = nullptr);
298
299 /// Disjunctive merge of two lattice sets: `(s0 /\_op s1, s0, s1)`.
300 /// Returns the identifier of the new set.
301 LatSetId disjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op = nullptr);
302
303 /// Disjunctive merge of two lattice sets and also set one of the operand to
304 /// zero: `(s0 /\_op s1 (e0 op e1), s0 (0 op e0), s1 (e1 op 0))`.
305 /// Returns the identifier of the new set.
307
308 /// Disjunctive merge of two lattice sets with custom handling of the
309 /// overlap, left, and right regions. Any region may be left missing
310 /// in the output. Returns the identifier of the new set.
312 bool includeLeft, TensorExp::Kind ltrans, Operation *opleft,
313 bool includeRight, TensorExp::Kind rtrans,
314 Operation *opright);
315
316 /// Maps the unary operator over the lattice set of the operand, i.e. each
317 /// lattice point on an expression E is simply copied over, but with OP E
318 /// as new expression. Returns the identifier of the new set.
320 Operation *op = nullptr, Attribute attr = nullptr);
321
322 /// Maps the binary operator to the same operation but with one of its operand
323 /// set to zero, i.e. each lattice point on an expression E is simply copied
324 /// over, but with `OP 0 E` (if lhsZero == true) or `OP E 0` (if lhsZero ==
325 /// false) as new expression. Returns the identifier of the new set.
326 LatSetId mapBinWithSynZeroSet(ExprId e, LatSetId s, bool lhsZero);
327
328 /// Optimizes the iteration lattice points in the given set. This
329 /// method should be called right before code generation to avoid
330 /// generating redundant loops and conditions.
332
333 /// Simplifies the conditions in a conjunction of a given lattice point
334 /// within the given set using just two basic rules:
335 /// (1) multiple dense conditions are reduced to single dense, and
336 /// (2) a *singleton* sparse/dense is reduced to sparse/random access.
337 BitVector simplifyCond(LatSetId s, LatPointId p);
338
339 /// Returns true if p0 > p1.
340 bool latGT(LatPointId p0, LatPointId p1) const;
341
342 /// Returns true if p0 and p1 only differ in dense.
343 bool onlyDenseDiff(LatPointId p0, LatPointId p1) const;
344
345 /// Gets the tensor-identifier of the `TensorLoopId`.
346 constexpr TensorId tensor(TensorLoopId b) const { return b % numTensors; }
347 /// Gets the loop-identifier of the `TensorLoopId`.
348 constexpr LoopId loop(TensorLoopId b) const { return b / numTensors; }
349
350 /// Gets the total number of tensors (including the output-tensor and
351 /// synthetic-tensor).
352 constexpr unsigned getNumTensors() const { return numTensors; }
353
354 /// Gets the total number of loops (native loops + filter loops).
355 constexpr unsigned getNumLoops() const { return numLoops; }
356
357 /// Returns true if `b` is the `i`th loop of the output tensor.
358 constexpr bool isOutTensor(TensorLoopId b, LoopId i) const {
359 return b == makeTensorLoopId(outTensor, i);
360 }
361
362 /// Gets the output tensor's identifier.
363 constexpr TensorId getOutTensorID() const { return outTensor; }
364
365 /// Gets the synthetic tensor's identifier (used for all invariant
366 /// tensor expressions).
367 constexpr TensorId getSynTensorID() const { return syntheticTensor; }
368
369 /// Returns true if the expression is `(kTensor t)`.
370 bool expIsTensor(ExprId e, TensorId t) const {
371 const auto &expr = exp(e);
372 return expr.kind == TensorExp::Kind::kTensor && expr.tensor == t;
373 }
374
375 /// Returns true if the expression contains the tensor as an operand.
376 bool expContainsTensor(ExprId e, TensorId t) const;
377
378 /// Returns true if the expression contains a negation on output tensor.
379 /// I.e., `- outTensor` or `exp - outputTensor`
380 /// NOTE: this is an trivial tests in that it does not handle recursive
381 /// negation, i.e., it returns true when the expression is `-(-tensor)`.
382 bool hasNegateOnOut(ExprId e) const;
383
384 /// Returns true if given tensor iterates *only* in the given tensor
385 /// expression. For the output tensor, this defines a "simply dynamic"
386 /// operation [Bik96]. For instance: a(i) *= 2.0 or a(i) += a(i) for
387 /// sparse vector a.
388 bool isSingleCondition(TensorId t, ExprId e) const;
389
390 /// Returns true if any `TensorLoopId` in the bitvector corresponds
391 /// to sparse level-type.
392 bool hasAnySparse(const BitVector &bits) const;
393
394 /// Returns true if bits contains a dependent index reduction condition on
395 /// sparse levels.
396 bool hasSparseIdxReduction(const BitVector &bits) const;
397
398 /// Gets the level-type of the `t`th tensor on `i`th loop.
400 assert(isValidTensorId(t) && isValidLoopId(i));
401 return lvlTypes[t][i];
402 }
403
404 /// Gets the level-type of the TensorLoopId.
406 return getLvlType(tensor(b), loop(b));
407 }
408
409 /// Gets the loop identifier for the `lvl`th level of the `t`th tensor.
410 std::optional<LoopId> getLoopId(TensorId t, Level lvl) const {
411 assert(isValidLevel(t, lvl));
412 return lvlToLoop[t][lvl];
413 }
414
415 /// Gets the level number of the the `t`th tensor on `i`th loop.
416 std::optional<Level> getLvl(TensorId t, LoopId i) const {
417 assert(isValidTensorId(t) && isValidLoopId(i));
418 return loopToLvl[t][i];
419 }
420 std::optional<Level> getLvl(TensorLoopId b) const {
421 return getLvl(tensor(b), loop(b));
422 }
423
424 /// Sets the level number and level-type of the `t`th tensor on
425 /// `i`th loop.
427 assert(isValidLevel(t, lvl) && isValidLoopId(i) && isValidLT(lt));
428 lvlTypes[t][i] = lt;
429 loopToLvl[t][i] = lvl;
430 lvlToLoop[t][lvl] = i;
431 // TODO: favor a constant loop bound when there are multiple choices.
432 loopBounds[i] = std::make_pair(t, lvl);
433 }
434
436 TensorLoopId, TensorId, std::optional<Level>, LevelType, bool)>;
437
438 /// Iterates over a set of `TensorLoopId`s, invoking the callback
439 /// for each `TensorLoopId` and passing it the corresponding tensor
440 /// identifier, level, and level-type, following with a boolean value
441 /// indicating whether it is a dependent index reduction loop condition.
443 ForeachTensorLoopIdCallback callback) const {
444 // TODO: the default ought to be simple=true; but we'll need to make
445 // sure to update all the tests to make sure they do the right thing.
446 foreachTensorLoopId(p, /*simple=*/false, callback);
447 }
448 void foreachTensorLoopId(LatPointId p, bool simple,
449 ForeachTensorLoopIdCallback callback) const {
450 const auto &point = lat(p);
451 const auto &bits = simple ? point.simple : point.bits;
452 for (const TensorLoopId b : bits.set_bits()) {
453 const TensorId t = tensor(b);
454 const auto optLvl = getLvl(b);
455 const auto lvlTp = getLvlType(b);
457 // This must be an undefined level.
458 assert(!optLvl.has_value());
459 // Slice the tid along the dependent level to iterate current loop.
460 callback(b, t, getLoopDependentLevel(b), lvlTp,
461 /*isIdxReduc=*/true);
462 } else {
463 callback(b, t, optLvl, lvlTp, /*isIdxReduc=*/false);
464 }
465 }
466 }
467
468 /// Sets whether the output tensor is sparse or not.
469 void setHasSparseOut(bool s) { hasSparseOut = s; }
470
471 /// Establishes the two-way map that i <-> <t, lvl, lt>.
473 LevelType lt, unsigned coefficient) {
474 assert(isValidLoopId(i) && isValidLevel(t, lvl));
475 assert(!loopToUnresolvedLvls[i][t].has_value()); // must be the first def
476 loopToUnresolvedLvls[i][t] = std::make_pair(lvl, lt);
477 levelToDependentLoop[t][lvl].emplace_back(i, coefficient);
478 }
479
480 /// Whether the loop has dependent slice.
482 assert(isValidTensorId(t) && isValidLoopId(i));
483 return loopToUnresolvedLvls[i][t].has_value();
484 }
485
486 /// Returns the list of loop indices which appear in the non-trivial index
487 /// expression on t_l, e.g., A[i+j] => {i, j}
488 std::vector<LoopCoeffPair> &getDependentLoops(TensorId t, Level lvl) {
489 assert(isValidLevel(t, lvl));
490 return levelToDependentLoop[t][lvl];
491 }
492
493 /// Returns the defining [tid, lvl] for the loop.
494 std::pair<TensorId, Level> getLoopDefiningLvl(LoopId i) const {
495 assert(isValidLoopId(i));
496 return loopBounds[i];
497 }
498
499 /// Checks whether the TensorLoopId represents a tensor level contains
500 /// non-trivial index expression.
502 const TensorId t = tensor(b);
503 const LoopId i = loop(b);
504 assert(isValidTensorId(t) && isValidLoopId(i));
505 return loopToUnresolvedLvls[i][t].has_value();
506 }
507
508 /// Checks whether the TensorLoopId represents a sparse tensor level contains
509 /// non-trivial index expression.
512 auto lt = getLoopDependentLevelType(b);
513 return lt.hasSparseSemantic();
514 }
515 return false;
516 }
517
520 return loopToUnresolvedLvls[loop(b)][tensor(b)]->first;
521 }
522
525 return loopToUnresolvedLvls[loop(b)][tensor(b)]->second;
526 }
527
528 /// Convenience getters to immediately access the stored nodes.
529 /// These methods return `const&` because the underlying objects must
530 /// not be mutated by client code. The only exception is for mutating
531 /// the value associated with an expression, for which there are
532 /// dedicated methods below.
533 ///
534 /// NOTE: It is inadvisable to keep the reference alive for a long
535 /// time (e.g., as in `TensorExpr &te = merger.exp(e)`), since insertions
536 /// into the merger can cause data movement which will invalidate the
537 /// underlying memory address. This isn't just a problem with the `&`
538 /// references, but also applies to the `ArrayRef`. In particular,
539 /// using `for (LatPointId p : merger.set(s))` will run into the same
540 /// dangling-reference problems if the loop body inserts new sets.
541 const TensorExp &exp(ExprId e) const {
542 assert(isValidExprId(e));
543 return tensorExps[e];
544 }
545 const LatPoint &lat(LatPointId p) const {
546 assert(isValidLatPointId(p));
547 return latPoints[p];
548 }
550 assert(isValidLatSetId(s));
551 return latSets[s];
552 }
553
554 /// Checks whether the given expression has an associated value.
555 bool hasExprValue(ExprId e) const { return static_cast<bool>(exp(e).val); }
556
557 /// Sets the expression to have the associated value. Asserts that the new
558 /// value is defined, and that the expression does not already have a value.
560 assert(!exp(e).val && "Expression already has an associated value");
561 assert(v && "Trying to assign an undefined value");
562 tensorExps[e].val = v;
563 }
564
565 /// Clears the value associated with the expression. Asserts that the
566 /// expression does indeed have an associated value before clearing it.
568 assert(exp(e).val && "Expression does not have an associated value");
569 tensorExps[e].val = Value();
570 }
571
572#ifndef NDEBUG
573 /// Print methods (for debugging).
574 void dumpExp(ExprId e) const;
575 void dumpLat(LatPointId p) const;
576 void dumpSet(LatSetId s) const;
577 void dumpBits(const BitVector &bits) const;
578#endif
579
580 /// Builds the iteration lattices in a bottom-up traversal given the
581 /// remaining tensor (sub)expression and the next loop in the iteration
582 /// graph. Returns the identifier of the root set.
584
585 /// Builds a tensor expression from the given Linalg operation.
586 /// On success, returns the identifier of the root expression.
587 std::optional<ExprId> buildTensorExpFromLinalg(linalg::GenericOp op);
588
589 /// Rebuilds SSA format from a tensor expression.
590 Value buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0,
591 Value v1) const;
592
593private:
594 /// Private helpers.
595 constexpr bool isValidTensorId(TensorId t) const { return t < numTensors; }
596 constexpr bool isValidLoopId(LoopId i) const {
597 return i != detail::kInvalidId && i < numLoops;
598 }
599 bool isValidLevel(TensorId t, Level lvl) const {
600 assert(levelToDependentLoop[t].size() == lvlToLoop[t].size());
601 return isValidTensorId(t) && lvl < lvlToLoop[t].size();
602 }
603 bool isValidExprId(ExprId e) const {
604 return e != detail::kInvalidId && e < tensorExps.size();
605 }
606 bool isValidLatPointId(LatPointId p) const {
607 return p != detail::kInvalidId && p < latPoints.size();
608 }
609 bool isValidLatSetId(LatSetId s) const {
610 return s != detail::kInvalidId && s < latSets.size();
611 }
612 bool maybeZero(ExprId e) const;
613 bool isInvariant(ExprId e) const {
615 }
616 Type inferType(ExprId e, Value src) const;
617
618 /// Traverses the SSA tree (possibly a DAG) to build a tensor expression.
619 /// The boolean value returned indicates whether the result of the current
620 /// operation being built depends on any value that is loaded from a sparse
621 /// tensor.
622 std::pair<std::optional<ExprId>, bool> buildTensorExp(linalg::GenericOp op,
623 Value v);
624
625 /// Merger data structures.
626 const TensorId outTensor;
627 const TensorId syntheticTensor;
628 const unsigned numTensors;
629 const unsigned numLoops;
630 bool hasSparseOut;
631
632 // Below we use `std::vector` for things which have a priori fixed
633 // sizes, whereas we use `llvm::SmallVector` for things with variable
634 // size. Do beware that these two classes differ in the semantics of
635 // `operator[]`: `SmallVector` performs OOB checks, whereas `std::vector`
636 // does not.
637
638 /// Map that converts pair<TensorId, LoopId> to the corresponding lvl-type.
639 std::vector<std::vector<LevelType>> lvlTypes;
640
641 /// Map that converts pair<TensorId, LoopId> to the corresponding lvl.
642 std::vector<std::vector<std::optional<Level>>> loopToLvl;
643
644 /// Map that converts pair<TensorId, Level> to the corresponding LoopId.
645 std::vector<std::vector<std::optional<LoopId>>> lvlToLoop;
646
647 /// Map from a loop to its dependencies if any.
648 /// The dependencies of a loop is a set of (tensor, level) pairs.
649 /// It is currently only set for non-trivial index expressions.
650 /// E.g., A[i+j] => i and j will have dependencies {A0, lt(A0)} to indicate
651 /// that i and j are used in the non-trivial index expression on A0.
652 std::vector<std::vector<std::optional<LvlLTPair>>> loopToUnresolvedLvls;
653
654 /// The inverse map of ldxToDependencies from tensor level -> dependent loop
655 /// E.g., A[2i+j], we have A0 => {(2, i), (1, j)}, to indicate that A0 uses
656 /// both {i, j} to compute its indices and the coefficients on the loop id are
657 /// 2 and 1 respectively.
658 std::vector<std::vector<std::vector<LoopCoeffPair>>> levelToDependentLoop;
659
660 /// Map from a loop to the [tid, lvl] pair that defines the loop boundary.
661 std::vector<std::pair<TensorId, Level>> loopBounds;
662
663 llvm::SmallVector<TensorExp> tensorExps;
664 llvm::SmallVector<LatPoint> latPoints;
665 llvm::SmallVector<SmallVector<LatPointId>> latSets;
666};
667
668} // namespace sparse_tensor
669} // namespace mlir
670
671#endif // MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
false
Parses a map_entries map type from a string format back into its numeric value.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
void setHasSparseOut(bool s)
Sets whether the output tensor is sparse or not.
Definition Merger.h:469
constexpr unsigned getNumLoops() const
Gets the total number of loops (native loops + filter loops).
Definition Merger.h:355
LatPointId conjLat(ExprId e, LatPointId p0, LatPointId p1, Operation *op=nullptr)
Computes a single conjunction of two lattice points by taking the "union" of LoopId (effectively cons...
Definition Merger.cpp:315
LatSetId disjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op=nullptr)
Disjunctive merge of two lattice sets: (s0 /\_op s1, s0, s1).
Definition Merger.cpp:338
Level getLoopDependentLevel(TensorLoopId b) const
Definition Merger.h:518
const LatPoint & lat(LatPointId p) const
Definition Merger.h:545
constexpr bool isOutTensor(TensorLoopId b, LoopId i) const
Returns true if b is the ith loop of the output tensor.
Definition Merger.h:358
bool isSingleCondition(TensorId t, ExprId e) const
Returns true if given tensor iterates only in the given tensor expression.
Definition Merger.cpp:577
bool hasSparseIdxReduction(const BitVector &bits) const
Returns true if bits contains a dependent index reduction condition on sparse levels.
Definition Merger.cpp:680
bool expContainsTensor(ExprId e, TensorId t) const
Returns true if the expression contains the tensor as an operand.
Definition Merger.cpp:522
ArrayRef< LatPointId > set(LatSetId s) const
Definition Merger.h:549
const TensorExp & exp(ExprId e) const
Convenience getters to immediately access the stored nodes.
Definition Merger.h:541
LatSetId mapBinWithSynZeroSet(ExprId e, LatSetId s, bool lhsZero)
Maps the binary operator to the same operation but with one of its operand set to zero,...
Definition Merger.cpp:414
bool isSparseLvlWithNonTrivialIdxExp(TensorLoopId b) const
Checks whether the TensorLoopId represents a sparse tensor level contains non-trivial index expressio...
Definition Merger.h:510
std::optional< LoopId > getLoopId(TensorId t, Level lvl) const
Gets the loop identifier for the lvlth level of the tth tensor.
Definition Merger.h:410
void dumpBits(const BitVector &bits) const
Definition Merger.cpp:920
bool hasExprValue(ExprId e) const
Checks whether the given expression has an associated value.
Definition Merger.h:555
void foreachTensorLoopId(LatPointId p, bool simple, ForeachTensorLoopIdCallback callback) const
Definition Merger.h:448
LatSetId addSet()
Constructs a new (initially empty) set, and returns its identifier.
Definition Merger.cpp:309
BitVector simplifyCond(LatSetId s, LatPointId p)
Simplifies the conditions in a conjunction of a given lattice point within the given set using just t...
Definition Merger.cpp:461
bool hasNegateOnOut(ExprId e) const
Returns true if the expression contains a negation on output tensor.
Definition Merger.cpp:544
constexpr unsigned getNumTensors() const
Gets the total number of tensors (including the output-tensor and synthetic-tensor).
Definition Merger.h:352
bool isLvlWithNonTrivialIdxExp(TensorLoopId b) const
Checks whether the TensorLoopId represents a tensor level contains non-trivial index expression.
Definition Merger.h:501
LatSetId disjSetWithZero(ExprId e, LatSetId s0, LatSetId s1)
Disjunctive merge of two lattice sets and also set one of the operand to zero: (s0 /\_op s1 (e0 op e1...
Definition Merger.cpp:356
void dumpSet(LatSetId s) const
Definition Merger.cpp:910
void dumpLat(LatPointId p) const
Definition Merger.cpp:899
LatSetId combiSet(ExprId e, LatSetId s0, LatSetId s1, Operation *orig, bool includeLeft, TensorExp::Kind ltrans, Operation *opleft, bool includeRight, TensorExp::Kind rtrans, Operation *opright)
Disjunctive merge of two lattice sets with custom handling of the overlap, left, and right regions.
Definition Merger.cpp:380
ExprId addTensorExp(TensorId t)
Constructs a new tensor expression, and returns its identifier.
Definition Merger.cpp:247
LatSetId buildLattices(ExprId e, LoopId i)
Builds the iteration lattices in a bottom-up traversal given the remaining tensor (sub)expression and...
Definition Merger.cpp:940
LatSetId conjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op=nullptr)
Conjunctive merge of two lattice sets: (s0 /\_op s1).
Definition Merger.cpp:329
ExprId addExp(TensorExp::Kind k, ExprId e0, ExprId e1=detail::kInvalidId, Operation *op=nullptr, Attribute attr=nullptr)
Constructs a new unary or binary expression, and returns its identifier.
Definition Merger.cpp:277
ExprId addSynZeroExp()
Constructs a new synthetic zero expression.
Definition Merger.cpp:270
constexpr LoopId makeLoopId(unsigned i) const
Safely converts the argument to a loop identifier.
Definition Merger.h:249
std::optional< ExprId > buildTensorExpFromLinalg(linalg::GenericOp op)
Builds a tensor expression from the given Linalg operation.
Definition Merger.cpp:1193
void setLevelAndType(TensorId t, LoopId i, Level lvl, LevelType lt)
Sets the level number and level-type of the tth tensor on ith loop.
Definition Merger.h:426
LatSetId mapSet(TensorExp::Kind kind, LatSetId s, Value v=Value(), Operation *op=nullptr, Attribute attr=nullptr)
Maps the unary operator over the lattice set of the operand, i.e.
Definition Merger.cpp:401
function_ref< void( TensorLoopId, TensorId, std::optional< Level >, LevelType, bool)> ForeachTensorLoopIdCallback
Definition Merger.h:435
void foreachTensorLoopId(LatPointId p, ForeachTensorLoopIdCallback callback) const
Iterates over a set of TensorLoopIds, invoking the callback for each TensorLoopId and passing it the ...
Definition Merger.h:442
std::optional< Level > getLvl(TensorId t, LoopId i) const
Gets the level number of the the tth tensor on ith loop.
Definition Merger.h:416
LatSetId optimizeSet(LatSetId s)
Optimizes the iteration lattice points in the given set.
Definition Merger.cpp:431
constexpr TensorId tensor(TensorLoopId b) const
Gets the tensor-identifier of the TensorLoopId.
Definition Merger.h:346
void setLoopDependentTensorLevel(LoopId i, TensorId t, Level lvl, LevelType lt, unsigned coefficient)
Establishes the two-way map that i <-> <t, lvl, lt>.
Definition Merger.h:472
void dumpExp(ExprId e) const
Print methods (for debugging).
Definition Merger.cpp:799
LevelType getLvlType(TensorLoopId b) const
Gets the level-type of the TensorLoopId.
Definition Merger.h:405
Merger(unsigned numInputOutputTensors, unsigned numLoops, unsigned maxLvlRank)
Constructs a merger for the given number of tensors and loops.
Definition Merger.cpp:224
std::optional< Level > getLvl(TensorLoopId b) const
Definition Merger.h:420
bool hasAnySparse(const BitVector &bits) const
Returns true if any TensorLoopId in the bitvector corresponds to sparse level-type.
Definition Merger.cpp:671
void clearExprValue(ExprId e)
Clears the value associated with the expression.
Definition Merger.h:567
LatPointId addLat(TensorId t, LoopId i, ExprId e)
Constructs a new iteration lattice point, and returns its identifier.
Definition Merger.cpp:293
ExprId addLoopVarExp(LoopId i)
Constructs a new loop-variable expression, and returns its identifier.
Definition Merger.cpp:255
constexpr TensorId getSynTensorID() const
Gets the synthetic tensor's identifier (used for all invariant tensor expressions).
Definition Merger.h:367
bool latGT(LatPointId p0, LatPointId p1) const
Returns true if p0 > p1.
Definition Merger.cpp:503
constexpr LoopId loop(TensorLoopId b) const
Gets the loop-identifier of the TensorLoopId.
Definition Merger.h:348
std::vector< LoopCoeffPair > & getDependentLoops(TensorId t, Level lvl)
Returns the list of loop indices which appear in the non-trivial index expression on t_l,...
Definition Merger.h:488
constexpr TensorId getOutTensorID() const
Gets the output tensor's identifier.
Definition Merger.h:363
bool onlyDenseDiff(LatPointId p0, LatPointId p1) const
Returns true if p0 and p1 only differ in dense.
Definition Merger.cpp:516
ExprId addInvariantExp(Value v)
Constructs a new invariant expression, and returns its identifier.
Definition Merger.cpp:263
constexpr TensorId makeTensorId(unsigned t) const
Safely converts the argument to a tensor identifier.
Definition Merger.h:243
LevelType getLoopDependentLevelType(TensorLoopId b) const
Definition Merger.h:523
std::pair< TensorId, Level > getLoopDefiningLvl(LoopId i) const
Returns the defining [tid, lvl] for the loop.
Definition Merger.h:494
LevelType getLvlType(TensorId t, LoopId i) const
Gets the level-type of the tth tensor on ith loop.
Definition Merger.h:399
Value buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0, Value v1) const
Rebuilds SSA format from a tensor expression.
Definition Merger.cpp:1618
constexpr TensorLoopId makeTensorLoopId(unsigned t, unsigned i) const
Safely converts the arguments to a pair of (tensor,loop) identifiers.
Definition Merger.h:255
bool expIsTensor(ExprId e, TensorId t) const
Returns true if the expression is (kTensor t).
Definition Merger.h:370
void setExprValue(ExprId e, Value v)
Sets the expression to have the associated value.
Definition Merger.h:559
bool hasDependentLvl(LoopId i, TensorId t)
Whether the loop has dependent slice.
Definition Merger.h:481
static constexpr unsigned kInvalidId
A constant serving as the canonically invalid identifier, regardless of the identifier type.
Definition Merger.h:30
std::pair< LoopId, unsigned > LoopCoeffPair
A pair of loop id and its coefficients.
Definition Merger.h:64
unsigned LatPointId
LatPoint identifiers.
Definition Merger.h:52
unsigned ExprId
TensorExp identifiers.
Definition Merger.h:48
std::pair< Level, LevelType > LvlLTPair
A pair of level and its corresponding LevelType of a tensor.
Definition Merger.h:60
unsigned TensorId
Tensor identifiers, chosen to be the BlockArgument::getArgNumber of the value passed to Merger::build...
Definition Merger.h:35
uint64_t Level
The type of level identifiers and level-ranks.
unsigned TensorLoopId
A compressed representation of std::pair<TensorId, LoopId>.
Definition Merger.h:44
bool isValidLT(LevelType lt)
Definition Enums.h:433
unsigned LoopId
Loop identifiers.
Definition Merger.h:38
unsigned LatSetId
LatSet identifiers.
Definition Merger.h:57
Include the generated interface declarations.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
LatPoint(const BitVector &bits, ExprId e)
Construct a lattice point from the given set of TensorLoopIds.
Definition Merger.h:207
ExprId exp
Identifier of the tensor expression.
Definition Merger.h:218
BitVector bits
Conjunction of all TensorLoopIds involved in the tensor expression.
Definition Merger.h:210
BitVector simple
Simplified conjunction of TensorLoopId as bitvector.
Definition Merger.h:215
LatPoint(unsigned size, ExprId e)
Construct a lattice point with the empty set of TensorLoopIds.
Definition Merger.h:204
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition Enums.h:238
Child subexpressions for non-leaf expressions.
Definition Merger.h:71
Tensor expression. Represents an MLIR expression in tensor index notation.
Definition Merger.h:67
LoopId loop
kLoopVar expressions simply have a loop identifier.
Definition Merger.h:96
Value val
Direct link to IR for an invariant or the destination value (to infer destination type) of a cast ope...
Definition Merger.h:105
Kind
Tensor expression kind.
Definition Merger.h:129
Children children
All other expressions hold the ExprIds of their children.
Definition Merger.h:99
Attribute attr
An optional attribute that is required to determine the semantics of the operations.
Definition Merger.h:118
TensorId tensor
kTensor expressions simply have a tensor identifier.
Definition Merger.h:93
Kind kind
Tensor expression kind.
Definition Merger.h:89
Operation * op
Code blocks used by semirings.
Definition Merger.h:114
TensorExp(Kind k, unsigned x, ExprId y, Value v, Operation *op, Attribute a)
The x parameter has different types depending on the value of the k parameter.
Definition Merger.cpp:106