13#ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_CODEGENENV_H_
14#define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_CODEGENENV_H_
41 unsigned numTensors,
unsigned numLoops,
unsigned maxRank);
50 linalg::GenericOp
op()
const {
return linalgOp; }
53 return sparseOptions.sparseEmitStrategy ==
63 std::optional<Operation *>
73 return latticeMerger.makeTensorId(t);
76 return latticeMerger.makeLoopId(i);
79 return latticeMerger.makeTensorLoopId(t, i);
85 return latticeMerger.getLvlType(t, i);
89 unsigned getLoopNum()
const {
return latticeMerger.getNumLoops(); }
98 assert(loopEmitter.getNumManifestTensors() == linalgOp->getNumOperands() &&
99 loopEmitter.getNumTensors() == latticeMerger.getNumTensors() &&
100 loopEmitter.getOutTensorId() == latticeMerger.getOutTensorID() &&
101 loopEmitter.getSynTensorId() == latticeMerger.getSynTensorID());
102 return loopEmitter.makeTensorLevel(t, l);
108 return loopEmitter.unpackTensorLevel(tl);
110 template <
class ContainerTy>
112 return loopEmitter.unpackTensorLevelRange(std::forward<ContainerTy>(c));
140 bool isExpand()
const {
return expValues !=
nullptr; }
171 linalg::GenericOp linalgOp;
204 Value redValidLexInsert;
This class represents an operand of an operation.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void startReduc(ExprId exp, Value val)
void updateValidLexInsert(Value val)
Value getInsertionChain() const
std::optional< Operation * > genLoopBoundary(function_ref< std::optional< Operation * >(MutableArrayRef< Value > parameters)> callback)
Generates loop boundary statements (entering/exiting loops).
bool isAdmissibleTensorExp(ExprId e)
Whether the tensor expression is admissible for codegen.
bool atExpandLevel(OpOperand *o, unsigned rank, LoopId n) const
bool isCustomReduc() const
CodegenEnv(linalg::GenericOp linop, SparsificationOptions opts, unsigned numTensors, unsigned numLoops, unsigned maxRank)
Constructs a code generation environment which can be passed around during sparsification for bookkee...
ArrayRef< LatPointId > set(LatSetId s) const
unsigned getCurrentDepth() const
Value getExpandValues() const
TensorLevel makeTensorLevel(TensorId t, Level l) const
constexpr TensorId makeTensorId(unsigned t) const
LevelType lt(TensorLoopId b) const
void startExpand(Value values, Value filled, Value added, Value count)
bool hasSparseOutput() const
unsigned getLoopNum() const
void updateInsertionChain(Value chain)
bool generatingSparseIterator() const
Value getExpandCount() const
void startCustomReduc(ExprId exp)
TensorLevel makeTensorLevel(std::pair< TensorId, Level > tlPair) const
constexpr TensorLoopId makeTensorLoopId(unsigned t, unsigned i) const
linalg::GenericOp op() const
Value getLoopVar(LoopId i) const
Returns the induction-variable for the given loop.
Value getExpandFilled() const
LogicalResult initTensorExp()
void startEmit(SparseEmitStrategy emitStrategy)
auto unpackTensorLevelRange(ContainerTy &&c) const
Value getExpandAdded() const
void updateExpandCount(Value count)
void updateReduc(Value val)
Value getValidLexInsert() const
bool isSparseOutput(OpOperand *o) const
void startValidLexInsert(Value val)
constexpr LoopId makeLoopId(unsigned i) const
std::pair< TensorId, Level > unpackTensorLevel(TensorLevel tl) const
Value getCustomRedId() const
const TensorExp & exp(ExprId e) const
const SparsificationOptions & options() const
LevelType lt(TensorId t, LoopId i) const
bool isValidLexInsert() const
const LatPoint & lat(LatPointId l) const
A class to handle all iteration lattice operations.
static constexpr unsigned kInvalidId
A constant serving as the canonically invalid identifier, regardless of the identifier type.
unsigned LatPointId
LatPoint identifiers.
unsigned ExprId
TensorExp identifiers.
unsigned TensorId
Tensor identifiers, chosen to be the BlockArgument::getArgNumber of the value passed to Merger::build...
uint64_t Level
The type of level identifiers and level-ranks.
unsigned TensorLoopId
A compressed representation of std::pair<TensorId, LoopId>.
unsigned LoopId
Loop identifiers.
unsigned LatSetId
LatSet identifiers.
Include the generated interface declarations.
SparseEmitStrategy
Defines a scope for reinterpret map pass.
llvm::function_ref< Fn > function_ref
Options for the Sparsification pass.
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Tensor expression. Represents an MLIR expression in tensor index notation.