13 #ifndef MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_
14 #define MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_
20 #include "llvm/ADT/BitVector.h"
25 namespace sparse_tensor {
235 Merger(
unsigned numInputOutputTensors,
unsigned numLoops,
236 unsigned maxLvlRank);
244 assert(isValidTensorId(t));
250 assert(isValidLoopId(i));
256 assert(isValidTensorId(t) && isValidLoopId(i));
257 return numTensors * i + t;
371 const auto &expr =
exp(e);
400 assert(isValidTensorId(t) && isValidLoopId(i));
401 return lvlTypes[t][i];
411 assert(isValidLevel(t, lvl));
412 return lvlToLoop[t][lvl];
417 assert(isValidTensorId(t) && isValidLoopId(i));
418 return loopToLvl[t][i];
427 assert(isValidLevel(t, lvl) && isValidLoopId(i) &&
isValidLT(lt));
429 loopToLvl[t][i] = lvl;
430 lvlToLoop[t][lvl] = i;
432 loopBounds[i] = std::make_pair(t, lvl);
450 const auto &point =
lat(p);
451 const auto &bits = simple ? point.simple : point.bits;
454 const auto optLvl =
getLvl(b);
458 assert(!optLvl.has_value());
463 callback(b, t, optLvl, lvlTp,
false);
474 assert(isValidLoopId(i) && isValidLevel(t, lvl));
475 assert(!loopToUnresolvedLvls[i][t].has_value());
476 loopToUnresolvedLvls[i][t] = std::make_pair(lvl, lt);
477 levelToDependentLoop[t][lvl].emplace_back(i, coefficient);
482 assert(isValidTensorId(t) && isValidLoopId(i));
483 return loopToUnresolvedLvls[i][t].has_value();
489 assert(isValidLevel(t, lvl));
490 return levelToDependentLoop[t][lvl];
495 assert(isValidLoopId(i));
496 return loopBounds[i];
504 assert(isValidTensorId(t) && isValidLoopId(i));
505 return loopToUnresolvedLvls[i][t].has_value();
513 return lt.hasSparseSemantic();
520 return loopToUnresolvedLvls[
loop(b)][
tensor(b)]->first;
525 return loopToUnresolvedLvls[
loop(b)][
tensor(b)]->second;
542 assert(isValidExprId(e));
543 return tensorExps[e];
546 assert(isValidLatPointId(p));
550 assert(isValidLatSetId(s));
560 assert(!
exp(e).val &&
"Expression already has an associated value");
561 assert(v &&
"Trying to assign an undefined value");
562 tensorExps[e].val = v;
568 assert(
exp(e).val &&
"Expression does not have an associated value");
569 tensorExps[e].val =
Value();
577 void dumpBits(
const BitVector &bits)
const;
595 constexpr
bool isValidTensorId(
TensorId t)
const {
return t < numTensors; }
596 constexpr
bool isValidLoopId(
LoopId i)
const {
600 assert(levelToDependentLoop[t].size() == lvlToLoop[t].size());
601 return isValidTensorId(t) && lvl < lvlToLoop[t].size();
603 bool isValidExprId(
ExprId e)
const {
609 bool isValidLatSetId(
LatSetId s)
const {
612 bool maybeZero(
ExprId e)
const;
613 bool isInvariant(
ExprId e)
const {
622 std::pair<std::optional<ExprId>,
bool> buildTensorExp(linalg::GenericOp op,
628 const unsigned numTensors;
629 const unsigned numLoops;
639 std::vector<std::vector<LevelType>> lvlTypes;
642 std::vector<std::vector<std::optional<Level>>> loopToLvl;
645 std::vector<std::vector<std::optional<LoopId>>> lvlToLoop;
652 std::vector<std::vector<std::optional<LvlLTPair>>> loopToUnresolvedLvls;
658 std::vector<std::vector<std::vector<LoopCoeffPair>>> levelToDependentLoop;
661 std::vector<std::pair<TensorId, Level>> loopBounds;
Attributes are known-constant values of operations.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation is the basic unit of execution within MLIR.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
A class to handle all iteration lattice operations.
void setHasSparseOut(bool s)
Sets whether the output tensor is sparse or not.
constexpr unsigned getNumLoops() const
Gets the total number of loops (native loops + filter loops).
LatPointId conjLat(ExprId e, LatPointId p0, LatPointId p1, Operation *op=nullptr)
Computes a single conjunction of two lattice points by taking the "union" of LoopId (effectively cons...
LatSetId disjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op=nullptr)
Disjunctive merge of two lattice sets: (s0 /\_op s1, s0, s1).
Level getLoopDependentLevel(TensorLoopId b) const
std::optional< Level > getLvl(TensorId t, LoopId i) const
Gets the level number of the the tth tensor on ith loop.
constexpr bool isOutTensor(TensorLoopId b, LoopId i) const
Returns true if b is the ith loop of the output tensor.
bool isSingleCondition(TensorId t, ExprId e) const
Returns true if given tensor iterates only in the given tensor expression.
bool hasSparseIdxReduction(const BitVector &bits) const
Returns true if bits contains a dependent index reduction condition on sparse levels.
bool expContainsTensor(ExprId e, TensorId t) const
Returns true if the expression contains the tensor as an operand.
LatSetId mapBinWithSynZeroSet(ExprId e, LatSetId s, bool lhsZero)
Maps the binary operator to the same operation but with one of its operand set to zero,...
bool isSparseLvlWithNonTrivialIdxExp(TensorLoopId b) const
Checks whether the TensorLoopId represents a sparse tensor level contains non-trivial index expressio...
void dumpBits(const BitVector &bits) const
bool hasExprValue(ExprId e) const
Checks whether the given expression has an associated value.
void foreachTensorLoopId(LatPointId p, bool simple, ForeachTensorLoopIdCallback callback) const
LatSetId addSet()
Constructs a new (initially empty) set, and returns its identifier.
std::optional< LoopId > getLoopId(TensorId t, Level lvl) const
Gets the loop identifier for the lvlth level of the tth tensor.
std::pair< TensorId, Level > getLoopDefiningLvl(LoopId i) const
Returns the defining [tid, lvl] for the loop.
BitVector simplifyCond(LatSetId s, LatPointId p)
Simplifies the conditions in a conjunction of a given lattice point within the given set using just t...
bool hasNegateOnOut(ExprId e) const
Returns true if the expression contains a negation on output tensor.
constexpr unsigned getNumTensors() const
Gets the total number of tensors (including the output-tensor and synthetic-tensor).
bool isLvlWithNonTrivialIdxExp(TensorLoopId b) const
Checks whether the TensorLoopId represents a tensor level contains non-trivial index expression.
LatSetId disjSetWithZero(ExprId e, LatSetId s0, LatSetId s1)
Disjunctive merge of two lattice sets and also set one of the operand to zero: (s0 /\_op s1 (e0 op e1...
void dumpSet(LatSetId s) const
void dumpLat(LatPointId p) const
LatSetId combiSet(ExprId e, LatSetId s0, LatSetId s1, Operation *orig, bool includeLeft, TensorExp::Kind ltrans, Operation *opleft, bool includeRight, TensorExp::Kind rtrans, Operation *opright)
Disjunctive merge of two lattice sets with custom handling of the overlap, left, and right regions.
ExprId addTensorExp(TensorId t)
Constructs a new tensor expression, and returns its identifier.
LatSetId buildLattices(ExprId e, LoopId i)
Builds the iteration lattices in a bottom-up traversal given the remaining tensor (sub)expression and...
LatSetId conjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op=nullptr)
Conjunctive merge of two lattice sets: (s0 /\_op s1).
ExprId addExp(TensorExp::Kind k, ExprId e0, ExprId e1=detail::kInvalidId, Operation *op=nullptr, Attribute attr=nullptr)
Constructs a new unary or binary expression, and returns its identifier.
ExprId addSynZeroExp()
Constructs a new synthetic zero expression.
constexpr LoopId makeLoopId(unsigned i) const
Safely converts the argument to a loop identifier.
std::optional< ExprId > buildTensorExpFromLinalg(linalg::GenericOp op)
Builds a tensor expression from the given Linalg operation.
void setLevelAndType(TensorId t, LoopId i, Level lvl, LevelType lt)
Sets the level number and level-type of the tth tensor on ith loop.
LatSetId mapSet(TensorExp::Kind kind, LatSetId s, Value v=Value(), Operation *op=nullptr, Attribute attr=nullptr)
Maps the unary operator over the lattice set of the operand, i.e.
void foreachTensorLoopId(LatPointId p, ForeachTensorLoopIdCallback callback) const
Iterates over a set of TensorLoopIds, invoking the callback for each TensorLoopId and passing it the ...
std::optional< Level > getLvl(TensorLoopId b) const
ArrayRef< LatPointId > set(LatSetId s) const
LatSetId optimizeSet(LatSetId s)
Optimizes the iteration lattice points in the given set.
constexpr TensorId tensor(TensorLoopId b) const
Gets the tensor-identifier of the TensorLoopId.
void setLoopDependentTensorLevel(LoopId i, TensorId t, Level lvl, LevelType lt, unsigned coefficient)
Establishes the two-way map that i <-> <t, lvl, lt>.
void dumpExp(ExprId e) const
Print methods (for debugging).
LevelType getLvlType(TensorLoopId b) const
Gets the level-type of the TensorLoopId.
Merger(unsigned numInputOutputTensors, unsigned numLoops, unsigned maxLvlRank)
Constructs a merger for the given number of tensors and loops.
bool hasAnySparse(const BitVector &bits) const
Returns true if any TensorLoopId in the bitvector corresponds to sparse level-type.
void clearExprValue(ExprId e)
Clears the value associated with the expression.
std::vector< LoopCoeffPair > & getDependentLoops(TensorId t, Level lvl)
Returns the list of loop indices which appear in the non-trivial index expression on t_l,...
LatPointId addLat(TensorId t, LoopId i, ExprId e)
Constructs a new iteration lattice point, and returns its identifier.
ExprId addLoopVarExp(LoopId i)
Constructs a new loop-variable expression, and returns its identifier.
constexpr TensorId getSynTensorID() const
Gets the synthetic tensor's identifier (used for all invariant tensor expressions).
bool latGT(LatPointId p0, LatPointId p1) const
Returns true if p0 > p1.
const TensorExp & exp(ExprId e) const
Convenience getters to immediately access the stored nodes.
constexpr LoopId loop(TensorLoopId b) const
Gets the loop-identifier of the TensorLoopId.
const LatPoint & lat(LatPointId p) const
constexpr TensorId getOutTensorID() const
Gets the output tensor's identifier.
bool onlyDenseDiff(LatPointId p0, LatPointId p1) const
Returns true if p0 and p1 only differ in dense.
ExprId addInvariantExp(Value v)
Constructs a new invariant expression, and returns its identifier.
constexpr TensorId makeTensorId(unsigned t) const
Safely converts the argument to a tensor identifier.
LevelType getLoopDependentLevelType(TensorLoopId b) const
LevelType getLvlType(TensorId t, LoopId i) const
Gets the level-type of the tth tensor on ith loop.
Value buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0, Value v1) const
Rebuilds SSA format from a tensor expression.
constexpr TensorLoopId makeTensorLoopId(unsigned t, unsigned i) const
Safely converts the arguments to a pair of (tensor,loop) identifiers.
bool expIsTensor(ExprId e, TensorId t) const
Returns true if the expression is (kTensor t).
void setExprValue(ExprId e, Value v)
Sets the expression to have the associated value.
bool hasDependentLvl(LoopId i, TensorId t)
Whether the loop has dependent slice.
@ Type
An inlay hint that for a type annotation.
static constexpr unsigned kInvalidId
A constant serving as the canonically invalid identifier, regardless of the identifier type.
unsigned LatSetId
LatSet identifiers.
std::pair< Level, LevelType > LvlLTPair
A pair of level and its corresponding LevelType of a tensor.
unsigned TensorLoopId
A compressed representation of std::pair<TensorId, LoopId>.
uint64_t Level
The type of level identifiers and level-ranks.
unsigned LoopId
Loop identifiers.
bool isValidLT(LevelType lt)
unsigned ExprId
TensorExp identifiers.
unsigned LatPointId
LatPoint identifiers.
std::pair< LoopId, unsigned > LoopCoeffPair
A pair of loop id and its coefficients.
unsigned TensorId
Tensor identifiers, chosen to be the BlockArgument::getArgNumber of the value passed to Merger::build...
Include the generated interface declarations.
LatPoint(const BitVector &bits, ExprId e)
Construct a lattice point from the given set of TensorLoopIds.
ExprId exp
Identifier of the tensor expression.
BitVector bits
Conjunction of all TensorLoopIds involved in the tensor expression.
BitVector simple
Simplified conjunction of TensorLoopId as bitvector.
LatPoint(unsigned size, ExprId e)
Construct a lattice point with the empty set of TensorLoopIds.
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Child subexpressions for non-leaf expressions.
Tensor expression. Represents an MLIR expression in tensor index notation.
LoopId loop
kLoopVar expressions simply have a loop identifier.
Value val
Direct link to IR for an invariant or the destination value (to infer destination type) of a cast ope...
Kind
Tensor expression kind.
Children children
All other expressions hold the ExprIds of their children.
Attribute attr
An optional attribute that is required to determine the semantics of the operations.
TensorId tensor
kTensor expressions simply have a tensor identifier.
Kind kind
Tensor expression kind.
Operation * op
Code blocks used by semirings.
TensorExp(Kind k, unsigned x, ExprId y, Value v, Operation *op, Attribute a)
The x parameter has different types depending on the value of the k parameter.