46 const LoopId i = cast<AffineDimExpr>(a).getPosition();
55 auto binOp = cast<AffineBinaryOpExpr>(a);
60 assert(isa<AffineConstantExpr>(a));
84 if (
auto binOp = dyn_cast<AffineBinaryOpExpr>(a)) {
88 return findAffine(merger, tid, lvl, binOp.getLHS(), lt,
false) &&
89 findAffine(merger, tid, lvl, binOp.getRHS(), lt,
false);
115 int64_t coefficient = 1) {
119 if (coefficient <= 0)
130 assert(coefficient == 1);
162 if (isa<AffineConstantExpr>(a))
163 llvm_unreachable(
"Not yet implemented");
165 auto binOp = cast<AffineBinaryOpExpr>(a);
166 auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
167 if (isa<AffineConstantExpr>(rhs))
170 assert(isa<AffineConstantExpr>(lhs) && isa<AffineDimExpr>(rhs));
171 int64_t coefficient = cast<AffineConstantExpr>(lhs).getValue();
172 return findDepIdxSet(merger, tensor, lvl, rhs, lt, isSubExp, coefficient);
175 auto binOp = cast<AffineBinaryOpExpr>(a);
176 return findDepIdxSet(merger, tensor, lvl, binOp.getLHS(), lt,
true) &&
177 findDepIdxSet(merger, tensor, lvl, binOp.getRHS(), lt,
true);
197 const auto rtp = dyn_cast<RankedTensorType>(tensor.
getType());
204 assert(
static_cast<Dimension>(exprs.size()) == lvlRank &&
205 "AffineMap does not have dimension-rank many results");
207 for (
Level l = 0; l < lvlRank; l++) {
226 OpOperand *out = op.getDpsInitOperand(0);
243 bool annotated =
false;
246 const auto map = env.
op().getMatchingIndexingMap(&t);
250 const Level lvlRank = map.getNumResults();
251 assert(!enc || lvlRank == enc.getLvlRank());
252 assert(
static_cast<Level>(env.
op().getRank(&t)) == lvlRank);
261 for (
Level l = 0; l < lvlRank; l++) {
264 if (idxReducBased && needIdxReduc) {
282 linalg::GenericOp op = env.
op();
284 assert(op.getNumOperands() == op.getNumDpsInputs() + 1);
287 llvm::cast<linalg::LinalgOp>(op.getOperation())
288 .createLoopRanges(builder, loc);
304 OpOperand *lhs = op.getDpsInitOperand(0);
305 assert(lhs->
get() == tensor);
314 bool isInit = op.isInitTensor(lhs);
319 linalg::FillOp::create(builder, loc,
ValueRange{zero},
325 assert(l < loopRange.size());
332 const auto map = env.
op().getMatchingIndexingMap(t);
334 const Level lvlRank = stt.getLvlRank();
335 assert(
static_cast<Level>(map.getNumResults()) == lvlRank);
336 const AffineExpr a = map.getResult(lvlRank - 1);
347 const auto map = env.
op().getMatchingIndexingMap(t);
349 if (stt.hasEncoding()) {
352 assert(!pos.empty());
359 const Level lvlRank = stt.getLvlRank();
360 assert(
static_cast<Level>(map.getNumResults()) == lvlRank);
361 for (
Level l = 0; l < lvlRank; l++) {
362 const auto lvlExpr = map.getResult(l);
364 args.push_back(lvlCrd);
373 linalg::GenericOp op = env.
op();
382 return memref::LoadOp::create(builder, loc, env.
getExpandValues(), index);
388 linalg::GenericOp op = env.
op();
398 Value isFilled = memref::LoadOp::create(builder, loc, filled, index);
399 Value valAtIndex = memref::LoadOp::create(builder, loc, values, index);
400 return arith::SelectOp::create(builder, loc, isFilled, valAtIndex, identity);
405 scf::IfOp condInsert =
406 scf::IfOp::create(builder, loc, sparseOut.
getType(), cond,
true);
409 Value res = tensor::InsertOp::create(builder, loc, v, sparseOut, ivs);
410 scf::YieldOp::create(builder, loc, res);
413 scf::YieldOp::create(builder, loc, sparseOut);
416 return condInsert.getResult(0);
422 linalg::GenericOp op = env.
op();
426 const LoopId numLoops = op.getRank(t);
450 sparseOut = tensor::InsertOp::create(builder, loc, rhs, chain, ivs);
471 Value isFilled = memref::LoadOp::create(builder, loc, filled, index);
472 Value cond = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq,
474 scf::IfOp ifOp = scf::IfOp::create(builder, loc, builder.
getIndexType(), cond,
478 memref::StoreOp::create(builder, loc, tval, filled, index);
479 memref::StoreOp::create(builder, loc, index, added, count);
481 Value add = arith::AddIOp::create(builder, loc, count, one);
482 scf::YieldOp::create(builder, loc, add);
485 scf::YieldOp::create(builder, loc, count);
489 memref::StoreOp::create(builder, loc, rhs, values, index);
499 linalg::GenericOp op = env.
op();
504 if (
auto explVal = stt.getExplicitVal())
516 if (llvm::isa<TensorType>(ptr.
getType())) {
519 return ExtractValOp::create(builder, loc, ptr,
520 llvm::getSingleElement(args));
522 return memref::LoadOp::create(builder, loc, ptr, args);
543 linalg::GenericOp op = env.
op();
549 memref::StoreOp::create(builder, loc, rhs, ptr, args);
560 scf::IfOp::create(builder, loc, chain.
getType(), rhs,
true);
569 scf::YieldOp::create(builder, op.getLoc(), mchain);
572 scf::YieldOp::create(builder, loc, chain);
590 if (
auto arg = dyn_cast<BlockArgument>(e)) {
594 linalg::GenericOp op = env.
op();
595 if (arg.getOwner()->getParentOp() == op) {
601 return memref::LoadOp::create(rewriter, op.getLoc(), ptr, args);
605 if (
auto indexOp = dyn_cast<linalg::IndexOp>(def))
608 if (def->getBlock() == block) {
610 for (
unsigned i = 0, n = def->getNumOperands(); i < n; i++) {
613 i,
relinkBranch(env, rewriter, block, def->getOperand(i)));
626 linalg::GenericOp op = env.
op();
682 LoopId curr,
bool isStart) {
687 linalg::GenericOp op = env.
op();
689 const auto map = op.getMatchingIndexingMap(&t);
691 const Level lvlRank = stt.getLvlRank();
692 assert(
static_cast<Level>(map.getNumResults()) == lvlRank);
693 bool isCurrentLoop = curr == 0;
694 for (
Level l = 0; l < lvlRank; l++) {
706 OpOperand *lhs = op.getDpsInitOperand(0);
753 linalg::GenericOp op = env.
op();
754 OpOperand *lhs = op.getDpsInitOperand(0);
766 auto dynShape = {ShapedType::kDynamic};
767 Type etp = cast<ShapedType>(tensor.
getType()).getElementType();
773 ExpandOp::create(builder, loc,
TypeRange({t1, t2, t3, t4}), tensor);
774 assert(r.getNumResults() == 4);
775 env.
startExpand(r.getResult(0), r.getResult(1), r.getResult(2),
779 for (
LoopId i = 0; i < curr; i++)
786 Value compress = CompressOp::create(builder, loc, values, filled, added,
787 count, chain, indices);
808 return isOuter && !isSparse;
816 llvm_unreachable(
"unexpected parallelization strategy");
823 linalg::GenericOp op = env.
op();
824 auto iteratorTypes = op.getIteratorTypesArray();
825 bool isSparse = llvm::any_of(tidLvls, [curr, &env](
TensorLevel tidLvl) {
840 unsigned numCases,
bool tryParallel,
844 return env.
emitter().enterCoIterationOverTensorsAtLvls(
845 builder, env.
op().getLoc(), tidLvls, numCases, reduc, tryParallel,
855 unsigned numCases,
bool needsUniv,
858 return genCoIteration(env, builder, tidLvls, numCases, tryParallel,
868 while (
auto ifOp = dyn_cast_or_null<scf::IfOp>(
893 assert(y == yields.size());
894 scf::YieldOp::create(builder, loc, yields);
907 assert(allCase == curCase || env.
merger().
latGT(allCase, curCase));
915 if (curCaseBits.test(set))
937 assert(lvl.has_value() &&
isUndefLT(lt));
939 lt = stt.getLvlType(*lvl);
944 assert(lvl.has_value());
947 clause = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq,
954 cond ? arith::AndIOp::create(builder, loc, cond, clause) : clause;
965 scf::IfOp ifOp = scf::IfOp::create(builder, loc, types, cond,
true);
980 operands.push_back(
constantI1(builder, env.
op().getLoc(),
true));
992 if (!operands.empty())
993 scf::YieldOp::create(builder, env.
op().getLoc(), operands);
1004 const BitVector &simple = env.
lat(li).
simple;
1006 const std::optional<Level> outLvl = env.
merger().
getLvl(outTid, curr);
1008 unsigned numloopCond = 0;
1009 bool hasNonUnique =
false;
1038 hasNonUnique = !
isUniqueLT(lt) || hasNonUnique;
1045 linalg::GenericOp op = env.
op();
1046 if (tid >= op.getNumDpsInputs())
1049 OpOperand *operand = &op->getOpOperand(tid);
1052 if (!stt.hasEncoding())
1056 op.getMatchingIndexingMap(operand).getResults();
1057 const Level lvlRank = stt.getLvlRank();
1058 assert(affines.size() ==
static_cast<size_t>(lvlRank));
1059 for (
Level l = 0; l < lvlRank; l++) {
1068 if (!isa<AffineConstantExpr>(exp)) {
1069 bool isCurrentLoop =
false;
1091 if (stt.hasEncoding() && stt.isAllDense())
1095 if (numloopCond == 0) {
1106 return numloopCond == 1 &&
1128 if (llvm::is_contained(tidLvls, tl))
1130 tidLvls.emplace_back(tl);
1149 linalg::GenericOp op = env.
op();
1150 assert(tid < op.getNumDpsInputs());
1151 OpOperand *input = op.getDpsInputOperands()[tid];
1152 const auto lvlExprs = op.getMatchingIndexingMap(input).getResults();
1157 const Level lvlRank = enc.getLvlRank();
1158 assert(lvlExprs.size() ==
static_cast<size_t>(lvlRank));
1159 for (
Level l = startLvl; l < lvlRank; l++) {
1161 if (enc.getLvlType(l).hasDenseSemantic() &&
1162 isa<AffineConstantExpr>(lvlExpr))
1177 for (
TensorId tid = 0, e = env.
op().getNumDpsInputs(); tid < e; tid++)
1189 affineTidLvls.emplace_back(tl, exp);
1191 tidLvls.emplace_back(tl);
1212 Operation *loop =
genLoop(env, builder, curr, numCases, needsUniv, tidLvls);
1214 for (
auto [tidLvl, exp] : affineTidLvls) {
1222 llvm::concat<TensorLevel>(tidLvls, llvm::make_first_range(affineTidLvls));
1229 return std::make_pair(loop, isSingleCond);
1234 LatPointId li,
bool needsUniv,
bool isSingleCond) {
1240 }
else if (
auto whileOp = dyn_cast<scf::WhileOp>(loop)) {
1247 env.
emitter().exitCurrentLoop(rewriter, env.
op().getLoc(), reduc);
1248 return std::nullopt;
1283 bool needsUniv =
startLoopSeq(env, rewriter, exp, curr, lts);
1287 const unsigned lsize = env.
set(lts).size();
1291 auto [loop, isSingleCond] =
1292 startLoop(env, rewriter, curr, li, lsize, needsUniv);
1293 assert(isSingleCond == llvm::isa<IterateOp>(loop));
1297 for (
unsigned j = 0;
j < lsize;
j++) {
1301 if (!isSingleCond) {
1304 genStmt(env, rewriter, ej, curr + 1);
1306 assert(reduc.empty() &&
"Not Implemented");
1307 sparse_tensor::YieldOp::create(rewriter, env.
op().getLoc());
1308 return std::nullopt;
1312 genStmt(env, rewriter, ej, curr + 1);
1316 needsUniv =
endLoop(env, rewriter, loop, curr, needsUniv, isSingleCond);
1319 for (
unsigned i = 0; i < lsize; i++) {
1322 auto [loop, isSingleCond] =
1323 startLoop(env, rewriter, curr, li, lsize, needsUniv);
1334 for (
unsigned j = 0;
j < lsize;
j++) {
1339 if (!isSingleCond) {
1340 scf::IfOp ifOp =
genIf(env, rewriter, curr, lj);
1341 genStmt(env, rewriter, ej, curr + 1);
1342 endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns);
1344 genStmt(env, rewriter, ej, curr + 1);
1350 needsUniv =
endLoop(env, rewriter, loop, curr, needsUniv, isSingleCond);
1361 linalg::GenericOp op = env.
op();
1362 OpOperand *lhs = op.getDpsInitOperand(0);
1369 bool hasInserts =
false;
1395 LogicalResult matchAndRewrite(linalg::GenericOp op,
1398 if (op.getNumDpsInits() != 1 || !op.hasPureTensorSemantics())
1406 if (!op->hasAttr(
"sorted")) {
1408 op,
"Loops not yet scheduled, try run --sparse-reinterpret-map "
1409 "before sparsification.");
1416 const unsigned numTensors = op->getNumOperands();
1417 const unsigned numLoops = op.getNumLoops();
1425 Level maxLvlRank = 0;
1426 for (
auto operand : op.getOperands()) {
1427 if (
auto rtp = dyn_cast<RankedTensorType>(operand.getType())) {
1443 if (op.getNumReductionLoops() > 0) {
1445 assert(isa<linalg::YieldOp>(yield));
1447 if (!isa<arith::AddFOp>(redop) && !isa<complex::AddOp>(redop) &&
1448 !isa<arith::AddIOp>(redop) && !isa<arith::SubFOp>(redop) &&
1449 !isa<complex::SubOp>(redop) && !isa<arith::SubIOp>(redop) &&
1450 !isa<arith::OrIOp>(redop) && !isa<arith::XOrIOp>(redop) &&
1451 !isa<ReduceOp>(redop)) {
union mlir::linalg::@1224::ArityGroupAndKind::Kind kind
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static unsigned getNumNonTrivialIdxExpOnSparseLvls(AffineMap map, Value tensor)
Gets the total number of compound affine expressions in the getMatchingIndexingMap for the given tens...
static Operation * genLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr, unsigned numCases, bool needsUniv, ArrayRef< TensorLevel > tidLvls)
Generates a for-loop or a while-loop, depending on whether it implements singleton iteration or co-it...
static Value genInsertionLoadReduce(CodegenEnv &env, OpBuilder &builder, OpOperand *t)
Generates insertion code to implement dynamic tensor load for reduction.
static bool isInvariantAffine(AffineExpr a, LoopId curr, bool &isCurrentLoop)
Returns true iff affine expression is invariant.
static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl, AffineExpr a, LevelType lt, bool isSubExp=false, int64_t coefficient=1)
Helper method to inspect affine expressions for index variable reduction based codegen.
static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId curr, LatPointId p)
Generates a single if-statement within a while-loop.
static Value genSubscript(CodegenEnv &env, OpBuilder &builder, OpOperand *t, SmallVectorImpl< Value > &args)
Generates subscript for load/store on a dense or sparse tensor.
static Value genConditionalInsert(Location loc, OpBuilder &builder, Value cond, Value sparseOut, ValueRange ivs, Value v)
static void genExpand(CodegenEnv &env, OpBuilder &builder, LoopId curr, bool isStart)
Generates an expanded access pattern in innermost dimension.
static void genConstantDenseAddressFromLevel(CodegenEnv &env, OpBuilder &builder, TensorId tid, Level startLvl)
static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp, LoopId curr, LatSetId lts)
Starts a loop sequence at given level.
static std::pair< Operation *, bool > startLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr, LatPointId li, unsigned numCases, bool needsUniv)
Starts a single loop in current sequence.
static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp, LoopId curr, bool isStart)
Hoists loop invariant tensor loads for which indices have been exhausted.
static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp, unsigned at)
Ends a loop sequence at given level.
static bool isParallelFor(CodegenEnv &env, bool isOuter, bool isSparse)
Returns parallelization strategy.
static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a, LevelType lt, bool setLvlFormat=true)
Helper method to inspect affine expressions.
static bool findSparseAnnotations(CodegenEnv &env, bool idxReducBased)
Helper method to inspect sparse encodings in the tensor types.
static bool getAllTidLvlsInLatPoints(CodegenEnv &env, LatPointId li, LoopId curr, llvm::function_ref< void(TensorLevel, AffineExpr)> callback)
static Value genInsertionLoad(CodegenEnv &env, OpBuilder &builder, OpOperand *t)
Generates insertion code to implement dynamic tensor load.
static bool hasNonTrivialAffineOnSparseOut(linalg::GenericOp op)
static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp, Value redInput, Value cntInput, Value insInput, Value validIns)
Generates end of true branch of if-statement within a while-loop.
static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp, LoopId curr)
Recursively generates code while computing iteration lattices in order to manage the complexity of im...
static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t, Value rhs)
Generates insertion code to implement dynamic tensor store.
static void genTensorStore(CodegenEnv &env, OpBuilder &builder, ExprId exp, Value rhs)
Generates a store on a dense or sparse tensor.
static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block, Value e)
Semi-ring branches are simply inlined by the sparsifier.
static void genBuffers(CodegenEnv &env, OpBuilder &builder)
Local bufferization of all dense and sparse data structures.
static void genResult(CodegenEnv &env, RewriterBase &rewriter)
Converts the result computed by the sparse kernel into the required form.
static bool shouldTryParallize(CodegenEnv &env, LoopId curr, ArrayRef< TensorLevel > tidLvls)
Whether or not the current loop being generated should be parallized (if possible) according to the c...
static bool translateBitsToTidLvlPairs(CodegenEnv &env, LatPointId li, LoopId curr, SmallVectorImpl< TensorLevel > &tidLvls, SmallVectorImpl< std::pair< TensorLevel, AffineExpr >> &affineTidLvls)
Returns true if the lattice bit can be iterated by a for loop.
static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e)
Recursively generates tensor expression.
static void genInitConstantDenseAddress(CodegenEnv &env, RewriterBase &rewriter)
static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp)
Generates a load on a dense or sparse tensor.
static Value genInvariantValue(CodegenEnv &env, ExprId exp)
Generates an invariant value.
static void genCoIterationCase(CodegenEnv &env, OpBuilder &builder, unsigned caseIdx, LatPointId allCase, LatPointId curCase, MutableArrayRef< Value > reduc)
Generates a case region in the coiterate operation.
static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop, LatPointId li, bool needsUniv, bool isSingleCond)
Ends a single loop in current sequence. Returns new values for needsUniv.
static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, bool needsUniv)
Generates the induction structure for a while-loop.
static Value genIndex(CodegenEnv &env, OpOperand *t)
Generates index for load/store on sparse tensor.
Base type for affine expression.
AffineExprKind getKind() const
Return the classification for this type.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
ArrayRef< AffineExpr > getResults() const
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
The code generation environment class aggregates a number of data structures that are needed during t...
void startReduc(ExprId exp, Value val)
void updateValidLexInsert(Value val)
const SparsificationOptions & options() const
Value getInsertionChain() const
std::optional< Operation * > genLoopBoundary(function_ref< std::optional< Operation * >(MutableArrayRef< Value > parameters)> callback)
Generates loop boundary statements (entering/exiting loops).
ArrayRef< LatPointId > set(LatSetId s) const
bool atExpandLevel(OpOperand *o, unsigned rank, LoopId n) const
bool isCustomReduc() const
unsigned getCurrentDepth() const
std::pair< TensorId, Level > unpackTensorLevel(TensorLevel tl) const
Value getExpandValues() const
TensorLevel makeTensorLevel(TensorId t, Level l) const
const LatPoint & lat(LatPointId l) const
constexpr TensorId makeTensorId(unsigned t) const
void startExpand(Value values, Value filled, Value added, Value count)
bool hasSparseOutput() const
unsigned getLoopNum() const
void updateInsertionChain(Value chain)
bool generatingSparseIterator() const
Value getExpandCount() const
void startCustomReduc(ExprId exp)
linalg::GenericOp op() const
Value getLoopVar(LoopId i) const
Returns the induction-variable for the given loop.
Value getExpandFilled() const
LogicalResult initTensorExp()
void startEmit(SparseEmitStrategy emitStrategy)
auto unpackTensorLevelRange(ContainerTy &&c) const
Value getExpandAdded() const
const TensorExp & exp(ExprId e) const
void updateExpandCount(Value count)
void updateReduc(Value val)
Value getValidLexInsert() const
bool isSparseOutput(OpOperand *o) const
void startValidLexInsert(Value val)
constexpr LoopId makeLoopId(unsigned i) const
Value getCustomRedId() const
LevelType lt(TensorId t, LoopId i) const
bool isValidLexInsert() const
A simple wrapper to encode a bitset of (at most 64) levels, currently used by sparse_tensor....
I64BitSet & set(unsigned i)
constexpr static llvm::StringLiteral getLoopEmitterLoopAttrName()
void locateLvlAtAffineAddress(OpBuilder &builder, Location loc, TensorLevel tidLvl, AffineExpr lvlExpr)
Emits the address for a dense level based on the value evaluated by the provided affine expression.
const std::vector< Value > & getValBuffer() const
void enterNewLoopSeq(OpBuilder &builder, Location loc, ArrayRef< TensorLevel > tidLvls)
Enters a new loop sequence, the loops within the same sequence starts from the break points of previo...
Value genAffine(OpBuilder &builder, Location loc, AffineExpr a)
Generates code to compute an affine expression whose variables are LoopIds (i.e., cast<AffineDimExpr>...
Region * enterCurrentCoIterationCase(OpBuilder &builder, Location loc, I64BitSet caseBit, unsigned caseIdx, MutableArrayRef< Value > reduc)
Value getLoopIV(LoopId n) const
Gets loop induction variable for the given loop.
SmallVector< Value > getValPosits(TensorId tid) const
Getters.
auto getLoopIVsRange() const
Get the range of values for all induction variables.
void initializeLoopEmit(OpBuilder &builder, Location loc, OutputUpdater updater=nullptr, SynTensorBoundSetter synSetter=nullptr)
Starts a loop emitting session by generating all the buffers needed for iterating over the tensors.
void exitCurrentLoopSeq(OpBuilder &builder, Location loc)
Exits the current loop sequence, this will reset universal index to 0.
Value getCoord(TensorId tid, Level lvl) const
A class to handle all iteration lattice operations.
std::optional< Level > getLvl(TensorId t, LoopId i) const
Gets the level number of the the tth tensor on ith loop.
LatSetId buildLattices(ExprId e, LoopId i)
Builds the iteration lattices in a bottom-up traversal given the remaining tensor (sub)expression and...
constexpr LoopId makeLoopId(unsigned i) const
Safely converts the argument to a loop identifier.
void setLevelAndType(TensorId t, LoopId i, Level lvl, LevelType lt)
Sets the level number and level-type of the tth tensor on ith loop.
void foreachTensorLoopId(LatPointId p, ForeachTensorLoopIdCallback callback) const
Iterates over a set of TensorLoopIds, invoking the callback for each TensorLoopId and passing it the ...
LatSetId optimizeSet(LatSetId s)
Optimizes the iteration lattice points in the given set.
void setLoopDependentTensorLevel(LoopId i, TensorId t, Level lvl, LevelType lt, unsigned coefficient)
Establishes the two-way map that i <-> <t, lvl, lt>.
bool hasAnySparse(const BitVector &bits) const
Returns true if any TensorLoopId in the bitvector corresponds to sparse level-type.
void clearExprValue(ExprId e)
Clears the value associated with the expression.
constexpr TensorId getSynTensorID() const
Gets the synthetic tensor's identifier (used for all invariant tensor expressions).
bool latGT(LatPointId p0, LatPointId p1) const
Returns true if p0 > p1.
constexpr LoopId loop(TensorLoopId b) const
Gets the loop-identifier of the TensorLoopId.
const LatPoint & lat(LatPointId p) const
constexpr TensorId getOutTensorID() const
Gets the output tensor's identifier.
LevelType getLvlType(TensorId t, LoopId i) const
Gets the level-type of the tth tensor on ith loop.
Value buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0, Value v1) const
Rebuilds SSA format from a tensor expression.
void setExprValue(ExprId e, Value v)
Sets the expression to have the associated value.
bool hasDependentLvl(LoopId i, TensorId t)
Whether the loop has dependent slice.
A wrapper around RankedTensorType, which has three goals:
bool hasEncoding() const
Returns true for tensors which have an encoding, and false for those which do not.
bool isAllDense() const
Returns true for tensors where every level is dense.
Level getLvlRank() const
Returns the level-rank.
LevelType getLvlType(Level l) const
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
static constexpr unsigned kInvalidId
A constant serving as the canonically invalid identifier, regardless of the identifier type.
bool isUniqueLT(LevelType lt)
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
unsigned LatSetId
LatSet identifiers.
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
unsigned TensorLoopId
A compressed representation of std::pair<TensorId, LoopId>.
uint64_t Level
The type of level identifiers and level-ranks.
unsigned LoopId
Loop identifiers.
Value constantI1(OpBuilder &builder, Location loc, bool b)
Generates a constant of i1 type.
Value genValFromAttr(OpBuilder &builder, Location loc, Attribute attr)
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
bool hasAnySparseType(TypeRange types)
Returns true iff the type range has any sparse tensor type.
Value genIsNonzero(OpBuilder &builder, Location loc, Value v)
Generates the comparison v != 0 where v is of numeric type.
bool isUndefLT(LevelType lt)
std::pair< Operation *, Value > genCoIteration(OpBuilder &builder, Location loc, ArrayRef< SparseIterator * > iters, MutableArrayRef< Value > reduc, Value uniIdx, bool userReducFirst=false)
bool isDenseLT(LevelType lt)
bool hasAnyNonIdentityOperandsOrResults(Operation *op)
Returns true iff MLIR operation has any sparse tensor with non-identity dim2lvl maps.
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
unsigned ExprId
TensorExp identifiers.
unsigned LatPointId
LatPoint identifiers.
unsigned TensorId
Tensor identifiers, chosen to be the BlockArgument::getArgNumber of the value passed to Merger::build...
Include the generated interface declarations.
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ DimId
Dimensional identifier.
@ Constant
Constant integer.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void populateSparsificationPatterns(RewritePatternSet &patterns, const SparsificationOptions &options=SparsificationOptions())
Sets up sparsification rewriting rules with the given options.
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Options for the Sparsification pass.
SparseEmitStrategy sparseEmitStrategy
SparseParallelizationStrategy parallelizationStrategy
ExprId exp
Identifier of the tensor expression.
BitVector simple
Simplified conjunction of TensorLoopId as bitvector.
This enum defines all the sparse representations supportable by the SparseTensor dialect.
constexpr bool hasSparseSemantic() const
Check if the LevelType is considered to be sparse.
constexpr bool hasDenseSemantic() const
Check if the LevelType is considered to be dense-like.
Tensor expression. Represents an MLIR expression in tensor index notation.
LoopId loop
kLoopVar expressions simply have a loop identifier.
Value val
Direct link to IR for an invariant or the destination value (to infer destination type) of a cast ope...
Children children
All other expressions hold the ExprIds of their children.
TensorId tensor
kTensor expressions simply have a tensor identifier.
Kind kind
Tensor expression kind.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.