34 std::sort(target.begin(), target.end(),
36 assert(std::addressof(l) == std::addressof(r) || l != r);
37 return l.first < r.first;
45 unsigned numTensors,
unsigned numLoops,
unsigned maxRank)
46 : linalgOp(linop), sparseOptions(opts),
47 latticeMerger(numTensors, numLoops, maxRank), loopEmitter(),
48 sparseOut(nullptr), outerParNest(-1u), insChain(), expValues(),
49 expFilled(), expAdded(), expCount(), redVal(), redExp(detail::
kInvalidId),
50 redCustom(detail::
kInvalidId), redValidLexInsert() {}
63 assert(insChain ==
nullptr &&
"must only start emitting once");
65 insChain = sparseOut->
get();
78 for (
OpOperand &t : linalgOp->getOpOperands()) {
79 tensors.push_back(t.get());
81 const Level lvlRank = linalgOp.getMatchingIndexingMap(&t).getNumResults();
84 assert(!enc || lvlRank == enc.getLvlRank());
85 for (
Level lvl = 0; lvl < lvlRank; lvl++)
91 linalg::GenericOp::getOperationName()),
108 params.push_back(redVal);
110 params.push_back(redValidLexInsert);
115 params.push_back(expCount);
116 if (insChain !=
nullptr)
117 params.push_back(insChain);
118 auto r = callback(params);
127 if (insChain !=
nullptr)
141 for (utils::IteratorType it : linalgOp.getIteratorTypesArray()) {
142 if (it == utils::IteratorType::reduction) {
149 OpOperand *lhs = linalgOp.getDpsInitOperand(0);
169 const auto iteratorTypes = linalgOp.getIteratorTypesArray();
170 for (
unsigned i = 0, e =
getLoopNum(); i < e; i++) {
178 assert(
static_cast<int64_t
>(outerParNest) >=
179 linalgOp.getRank(linalgOp.getDpsInitOperand(0)) - 1);
196 assert(sparseOut !=
nullptr && insChain !=
nullptr);
201 return sparseOut == o && outerParNest ==
static_cast<LoopId>(rank - 1) &&
207 assert(sparseOut !=
nullptr && expValues ==
nullptr);
215 assert(sparseOut !=
nullptr && expValues !=
nullptr);
220 assert(sparseOut !=
nullptr && expValues !=
nullptr);
221 expValues = expFilled = expAdded = expCount =
Value();
253 redValidLexInsert = val;
257 assert(redValidLexInsert &&
isReduc() && val);
258 redValidLexInsert = val;
263 redValidLexInsert =
Value();
273 return dyn_cast<sparse_tensor::ReduceOp>(
exp(redCustom).
op).getIdentity();
static bool isMaterializing(Value val)
Returns true if tensor materializes uninitialized into the computation.
static void sortDependentLoops(std::vector< LoopCoeffPair > &target)
Sorts the dependent loops such that it is ordered in the same sequence in which loops will be generat...
IRValueT get() const
Return the current value being used by this operand.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
void startReduc(ExprId exp, Value val)
void updateValidLexInsert(Value val)
std::optional< Operation * > genLoopBoundary(function_ref< std::optional< Operation * >(MutableArrayRef< Value > parameters)> callback)
Generates loop boundary statements (entering/exiting loops).
bool isAdmissibleTensorExp(ExprId e)
Whether the tensor expression is admissible for codegen.
bool atExpandLevel(OpOperand *o, unsigned rank, LoopId n) const
bool isCustomReduc() const
CodegenEnv(linalg::GenericOp linop, SparsificationOptions opts, unsigned numTensors, unsigned numLoops, unsigned maxRank)
Constructs a code generation environment which can be passed around during sparsification for bookkee...
constexpr TensorId makeTensorId(unsigned t) const
void startExpand(Value values, Value filled, Value added, Value count)
unsigned getLoopNum() const
void updateInsertionChain(Value chain)
void startCustomReduc(ExprId exp)
linalg::GenericOp op() const
Value getLoopVar(LoopId i) const
Returns the induction-variable for the given loop.
LogicalResult initTensorExp()
void startEmit(SparseEmitStrategy emitStrategy)
const TensorExp & exp(ExprId e) const
void updateExpandCount(Value count)
void updateReduc(Value val)
void startValidLexInsert(Value val)
Value getCustomRedId() const
bool isValidLexInsert() const
void initialize(ValueRange tensors, StringAttr loopTag=nullptr, bool hasOutput=false, bool isSparseOut=false, unsigned numLoops=0, DependentLvlGetter getter=nullptr, SparseEmitStrategy emitStrategy=SparseEmitStrategy::kFunctional)
Takes an array of input tensors, which the generated loops will iterate over.
Value getLoopIV(LoopId n) const
Gets loop induction variable for the given loop.
void setHasSparseOut(bool s)
Sets whether the output tensor is sparse or not.
bool isSingleCondition(TensorId t, ExprId e) const
Returns true if given tensor iterates only in the given tensor expression.
bool hasNegateOnOut(ExprId e) const
Returns true if the expression contains a negation on output tensor.
std::optional< ExprId > buildTensorExpFromLinalg(linalg::GenericOp op)
Builds a tensor expression from the given Linalg operation.
void clearExprValue(ExprId e)
Clears the value associated with the expression.
std::vector< LoopCoeffPair > & getDependentLoops(TensorId t, Level lvl)
Returns the list of loop indices which appear in the non-trivial index expression on t_l,...
void setExprValue(ExprId e, Value v)
Sets the expression to have the associated value.
bool isAllDense() const
Returns true for tensors where every level is dense.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
static constexpr unsigned kInvalidId
A constant serving as the canonically invalid identifier, regardless of the identifier type.
uint64_t Level
The type of level identifiers and level-ranks.
unsigned LoopId
Loop identifiers.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
unsigned ExprId
TensorExp identifiers.
std::pair< LoopId, unsigned > LoopCoeffPair
A pair of loop id and its coefficients.
unsigned TensorId
Tensor identifiers, chosen to be the BlockArgument::getArgNumber of the value passed to Merger::build...
Include the generated interface declarations.
SparseEmitStrategy
Defines a scope for reinterpret map pass.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Options for the Sparsification pass.