36 #include "llvm/ADT/SmallBitVector.h"
56 const LoopId i = cast<AffineDimExpr>(a).getPosition();
67 auto binOp = cast<AffineBinaryOpExpr>(a);
72 assert(isa<AffineConstantExpr>(a));
97 if (
auto binOp = dyn_cast<AffineBinaryOpExpr>(a)) {
102 return findAffine(merger, tid, lvl, binOp.getLHS(), lt,
false) &&
103 findAffine(merger, tid, lvl, binOp.getRHS(), lt,
false);
129 int64_t coefficient = 1) {
133 if (coefficient <= 0)
144 assert(coefficient == 1);
176 if (isa<AffineConstantExpr>(a))
177 llvm_unreachable(
"Not yet implemented");
179 auto binOp = cast<AffineBinaryOpExpr>(a);
180 auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
181 if (isa<AffineConstantExpr>(rhs))
184 assert(isa<AffineConstantExpr>(lhs) && isa<AffineDimExpr>(rhs));
185 int64_t coefficient = cast<AffineConstantExpr>(lhs).getValue();
186 return findDepIdxSet(merger, tensor, lvl, rhs, lt, isSubExp, coefficient);
189 auto binOp = cast<AffineBinaryOpExpr>(a);
190 return findDepIdxSet(merger, tensor, lvl, binOp.getLHS(), lt,
true) &&
191 findDepIdxSet(merger, tensor, lvl, binOp.getRHS(), lt,
true);
211 const auto rtp = dyn_cast<RankedTensorType>(tensor.
getType());
218 assert(
static_cast<Dimension>(exprs.size()) == lvlRank &&
219 "AffineMap does not have dimension-rank many results");
221 for (
Level l = 0; l < lvlRank; l++) {
222 if (!isa<AffineDimExpr>(exprs[l]) && !stt.
isDenseLvl(l))
239 OpOperand *out = op.getDpsInitOperand(0);
256 bool annotated =
false;
259 const auto map = env.
op().getMatchingIndexingMap(&t);
264 const Level lvlRank = map.getNumResults();
265 assert(!enc || lvlRank == enc.getLvlRank());
266 assert(
static_cast<Level>(env.
op().getRank(&t)) == lvlRank);
276 for (
Level l = 0; l < lvlRank; l++) {
279 if (idxReducBased && needIdxReduc) {
297 linalg::GenericOp op = env.
op();
302 llvm::cast<linalg::LinalgOp>(op.getOperation())
303 .createLoopRanges(builder, loc);
319 OpOperand *lhs = op.getDpsInitOperand(0);
320 assert(lhs->
get() == tensor);
329 bool isInit = op.isInitTensor(lhs);
340 assert(l < loopRange.size());
350 const auto map = env.
op().getMatchingIndexingMap(t);
352 const Level lvlRank = stt.getLvlRank();
353 assert(
static_cast<Level>(map.getNumResults()) == lvlRank);
354 const AffineExpr a = map.getResult(lvlRank - 1);
365 const auto map = env.
op().getMatchingIndexingMap(t);
367 if (stt.hasEncoding()) {
374 const Level lvlRank = stt.getLvlRank();
375 assert(
static_cast<Level>(map.getNumResults()) == lvlRank);
376 for (
Level l = 0; l < lvlRank; l++) {
377 const auto lvlExpr = map.getResult(l);
379 args.push_back(lvlCrd);
388 linalg::GenericOp op = env.
op();
403 linalg::GenericOp op = env.
op();
413 Value isFilled = builder.
create<memref::LoadOp>(loc, filled, index);
414 Value valAtIndex = builder.
create<memref::LoadOp>(loc, values, index);
415 return builder.
create<arith::SelectOp>(loc, isFilled, valAtIndex, identity);
421 linalg::GenericOp op = env.
op();
425 const LoopOrd numLoops = op.getRank(t);
441 scf::IfOp ifValidLexInsert = builder.
create<scf::IfOp>(
446 Value res = builder.
create<InsertOp>(loc, rhs, chain, ivs);
447 builder.
create<scf::YieldOp>(loc, res);
450 builder.
create<scf::YieldOp>(loc, chain);
471 Value isFilled = builder.
create<memref::LoadOp>(loc, filled, index);
472 Value cond = builder.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
478 builder.
create<memref::StoreOp>(loc, tval, filled, index);
479 builder.
create<memref::StoreOp>(loc, index, added, count);
481 Value add = builder.
create<arith::AddIOp>(loc, count, one);
482 builder.
create<scf::YieldOp>(loc, add);
485 builder.
create<scf::YieldOp>(loc, count);
489 builder.
create<memref::StoreOp>(loc, rhs, values, index);
500 linalg::GenericOp op = env.
op();
510 return builder.
create<memref::LoadOp>(op.
getLoc(), ptr, args);
531 linalg::GenericOp op = env.
op();
537 builder.
create<memref::StoreOp>(loc, rhs, ptr, args);
560 builder.
create<scf::YieldOp>(loc, chain);
578 if (
auto arg = dyn_cast<BlockArgument>(e)) {
582 linalg::GenericOp op = env.
op();
583 if (arg.getOwner()->getParentOp() == op) {
589 return rewriter.
create<memref::LoadOp>(op.
getLoc(), ptr, args);
593 if (
auto indexOp = dyn_cast<linalg::IndexOp>(def))
596 if (def->getBlock() == block) {
598 for (
unsigned i = 0, n = def->getNumOperands(); i < n; i++) {
601 i,
relinkBranch(env, rewriter, block, def->getOperand(i), ldx));
615 linalg::GenericOp op = env.
op();
618 const auto kind = exp.
kind;
671 LoopId ldx,
bool atStart) {
677 linalg::GenericOp op = env.
op();
679 const auto map = op.getMatchingIndexingMap(&t);
681 const Level lvlRank = stt.getLvlRank();
682 assert(
static_cast<Level>(map.getNumResults()) == lvlRank);
683 for (
Level l = 0; l < lvlRank; l++) {
691 OpOperand *lhs = op.getDpsInitOperand(0);
731 linalg::GenericOp op = env.
op();
732 OpOperand *lhs = op.getDpsInitOperand(0);
744 auto dynShape = {ShapedType::kDynamic};
745 Type etp = cast<ShapedType>(tensor.
getType()).getElementType();
750 auto r = builder.
create<ExpandOp>(loc,
TypeRange({t1, t2, t3, t4}), tensor);
751 assert(r.getNumResults() == 4);
752 env.
startExpand(r.getResult(0), r.getResult(1), r.getResult(2),
756 for (
LoopOrd i = 0; i < at; i++)
763 Value compress = builder.
create<CompressOp>(loc, values, filled, added,
764 count, chain, indices);
785 return isOuter && !isSparse;
793 llvm_unreachable(
"unexpected parallelization strategy");
800 linalg::GenericOp op = env.
op();
801 auto iteratorTypes = op.getIteratorTypesArray();
802 bool isSparse = llvm::any_of(tidLvls, [ldx, &env](
TensorLevel tidLvl) {
818 bool tryParallel,
bool needsUniv) {
822 return env.
emitter().enterCoIterationOverTensorsAtLvls(
823 builder, env.
op().getLoc(), tidLvls, reduc, tryParallel,
835 return genCoIteration(env, builder, at, tidLvls, tryParallel, needsUniv);
844 while (
auto ifOp = dyn_cast_or_null<scf::IfOp>(
869 assert(y == yields.size());
870 builder.
create<scf::YieldOp>(loc, yields);
893 assert(lvl.has_value() &&
isUndefLT(lt));
895 lt = stt.getLvlType(*lvl);
901 assert(lvl.has_value());
904 clause = builder.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
910 cond = cond ? builder.
create<arith::AndIOp>(loc, cond, clause) : clause;
921 scf::IfOp ifOp = builder.
create<scf::IfOp>(loc, types, cond,
true);
936 operands.push_back(
constantI1(builder, env.
op().getLoc(),
true));
948 if (!operands.empty())
949 builder.
create<scf::YieldOp>(env.
op().getLoc(), operands);
968 bool needsUniv =
false;
972 std::optional<Level> lvl,
1009 linalg::GenericOp op = env.
op();
1010 assert(tid < op.getNumDpsInputs());
1011 OpOperand *input = op.getDpsInputOperands()[tid];
1012 const auto lvlExprs = op.getMatchingIndexingMap(input).
getResults();
1017 const Level lvlRank = enc.getLvlRank();
1018 assert(lvlExprs.size() ==
static_cast<size_t>(lvlRank));
1019 for (
Level l = startLvl; l < lvlRank; l++) {
1021 if (enc.isDenseLvl(l) && isa<AffineConstantExpr>(lvlExpr))
1036 for (
TensorId tid = 0, e = env.
op().getNumDpsInputs(); tid < e; tid++)
1045 const BitVector &simple = env.
lat(li).
simple;
1047 const std::optional<Level> outLvl = env.
merger().
getLvl(outTid, ldx);
1049 unsigned numloopCond = 0;
1050 bool hasNonUnique =
false;
1052 std::optional<Level> lvl,
1072 lvl = env.emitter().getCurrentDepth();
1078 hasNonUnique = !
isUniqueLT(lt) || hasNonUnique;
1081 }
else if (
isDenseLT(lt) || isIdxReduc) {
1085 linalg::GenericOp op = env.
op();
1086 if (tid >= op.getNumDpsInputs())
1092 if (!stt.hasEncoding())
1096 op.getMatchingIndexingMap(operand).
getResults();
1097 const Level lvlRank = stt.getLvlRank();
1098 assert(affines.size() ==
static_cast<size_t>(lvlRank));
1099 for (
Level l = 0; l < lvlRank; l++) {
1103 if (isa<AffineDimExpr>(exp) || !stt.isDenseLvl(l))
1107 if (!isa<AffineConstantExpr>(exp)) {
1108 bool isAtLoop =
false;
1134 if (numloopCond == 0) {
1144 return numloopCond == 1 && !hasNonUnique;
1163 for (
auto [tidLvl, exp] : affineTidLvls) {
1171 llvm::concat<TensorLevel>(tidLvls, llvm::make_first_range(affineTidLvls));
1178 return std::make_pair(loop, isSingleCond);
1184 bool isSingleCond) {
1191 }
else if (
auto whileOp = dyn_cast<scf::WhileOp>(loop)) {
1199 env.
emitter().exitCurrentLoop(rewriter, env.
op().getLoc(), reduc);
1200 return std::nullopt;
1208 unsigned idx,
unsigned ldx) {
1235 bool needsUniv =
startLoopSeq(env, rewriter, exp, at, ldx, lts);
1242 const unsigned lsize = env.
set(lts).size();
1243 for (
unsigned i = 0; i < lsize; i++) {
1246 auto [loop, isSingleCond] =
startLoop(env, rewriter, at, li, needsUniv);
1257 for (
unsigned j = 0;
j < lsize;
j++) {
1262 if (!isSingleCond) {
1263 scf::IfOp ifOp =
genIf(env, rewriter, at, lj);
1264 genStmt(env, rewriter, ej, at + 1);
1265 endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns);
1267 genStmt(env, rewriter, ej, at + 1);
1273 needsUniv =
endLoop(env, rewriter, loop, at, li, needsUniv, isSingleCond);
1282 linalg::GenericOp op = env.
op();
1283 OpOperand *lhs = op.getDpsInitOperand(0);
1290 bool hasInserts =
false;
1319 if (op.getNumDpsInits() != 1 || !op.hasTensorSemantics())
1328 op,
"Loops not yet scheduled, try run --sparse-reinterpret-map "
1329 "before sparsification.");
1336 const unsigned numLoops = op.getNumLoops();
1344 Level maxLvlRank = 0;
1346 if (
auto rtp = dyn_cast<RankedTensorType>(operand.getType())) {
1362 if (op.getNumReductionLoops() > 0) {
1364 assert(isa<linalg::YieldOp>(yield));
1366 if (!isa<arith::AddFOp>(redop) && !isa<complex::AddOp>(redop) &&
1367 !isa<arith::AddIOp>(redop) && !isa<arith::SubFOp>(redop) &&
1368 !isa<complex::SubOp>(redop) && !isa<arith::SubIOp>(redop) &&
1369 !isa<arith::OrIOp>(redop) && !isa<arith::XOrIOp>(redop) &&
1370 !isa<ReduceOp>(redop)) {
1377 if (
failed(env.initTensorExp()))
1387 genStmt(env, rewriter, env.getExprId(), 0);
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static unsigned getNumNonTrivialIdxExpOnSparseLvls(AffineMap map, Value tensor)
Get the total number of compound affine expressions in the getMatchingIndexingMap for the given tenso...
static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId ldx, LatPointId p)
Generates a single if-statement within a while-loop.
static Value genInsertionLoadReduce(CodegenEnv &env, OpBuilder &builder, OpOperand *t)
Generates insertion code to implement dynamic tensor load for reduction.
static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block, Value e, LoopId ldx)
Semi-ring branches are simply inlined by the sparsifier.
static bool isInvariantAffine(AffineExpr a, unsigned loopDepth, LoopId ldx, bool &isAtLoop)
Determines if affine expression is invariant.
static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl, AffineExpr a, LevelType lt, bool isSubExp=false, int64_t coefficient=1)
Helper method to inspect affine expressions for index variable reduction based codegen.
static Value genSubscript(CodegenEnv &env, OpBuilder &builder, OpOperand *t, SmallVectorImpl< Value > &args)
Generates subscript for load/store on a dense or sparse tensor.
static void genConstantDenseAddressFromLevel(CodegenEnv &env, OpBuilder &builder, TensorId tid, Level startLvl)
static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp, unsigned idx, unsigned ldx)
Ends a loop sequence at given level.
static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop, LoopId idx, LatPointId li, bool needsUniv, bool isSingleCond)
Ends a single loop in current sequence. Returns new values for needsUniv.
static bool translateBitsToTidLvlPairs(CodegenEnv &env, LatPointId li, LoopId ldx, SmallVectorImpl< TensorLevel > &tidLvls, SmallVectorImpl< std::pair< TensorLevel, AffineExpr >> &affineTidLvls)
Return true if the lattices bit can be iterated by a for loop.
static bool isParallelFor(CodegenEnv &env, bool isOuter, bool isSparse)
Returns parallelization strategy.
static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a, LevelType lt, bool setLvlFormat=true)
Helper method to inspect affine expressions.
static bool findSparseAnnotations(CodegenEnv &env, bool idxReducBased)
Helper method to inspect sparse encodings in the tensor types.
static Value genInsertionLoad(CodegenEnv &env, OpBuilder &builder, OpOperand *t)
Generates insertion code to implement dynamic tensor load.
static bool hasNonTrivialAffineOnSparseOut(linalg::GenericOp op)
static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp, Value redInput, Value cntInput, Value insInput, Value validIns)
Generates end of true branch of if-statement within a while-loop.
static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp, LoopOrd idx, LoopId ldx, LatSetId lts)
Starts a loop sequence at given level.
static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t, Value rhs)
Generates insertion code to implement dynamic tensor store.
static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e, LoopId ldx)
Recursively generates tensor expression.
static bool shouldTryParallize(CodegenEnv &env, LoopId ldx, bool isOuter, ArrayRef< TensorLevel > tidLvls)
Whether or not the current loop being generated should be parallized (if possible) according to the c...
static void genExpand(CodegenEnv &env, OpBuilder &builder, LoopOrd at, bool atStart)
Generates an expanded access pattern in innermost dimension.
static void genTensorStore(CodegenEnv &env, OpBuilder &builder, ExprId exp, Value rhs)
Generates a store on a dense or sparse tensor.
static void genBuffers(CodegenEnv &env, OpBuilder &builder)
Local bufferization of all dense and sparse data structures.
static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, LoopId idx, bool needsUniv)
Generates the induction structure for a while-loop.
static Operation * genCoIteration(CodegenEnv &env, OpBuilder &builder, LoopId idx, ArrayRef< TensorLevel > tidLvls, bool tryParallel, bool needsUniv)
Emit a loop to coiterate over the list of tensor levels.
static Operation * genLoop(CodegenEnv &env, OpBuilder &builder, LoopOrd at, bool needsUniv, ArrayRef< TensorLevel > tidLvls)
Generates a for-loop or a while-loop, depending on whether it implements singleton iteration or co-it...
static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp, LoopId ldx, bool atStart)
Hoists loop invariant tensor loads for which indices have been exhausted.
static void genResult(CodegenEnv &env, RewriterBase &rewriter)
Converts the result computed by the sparse kernel into the required form.
static std::pair< Operation *, bool > startLoop(CodegenEnv &env, OpBuilder &builder, LoopOrd at, LatPointId li, bool needsUniv)
Starts a single loop in current sequence.
static void genInitConstantDenseAddress(CodegenEnv &env, RewriterBase &rewriter)
static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp)
Generates a load on a dense or sparse tensor.
static Value genInvariantValue(CodegenEnv &env, ExprId exp)
Generates an invariant value.
static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp, LoopOrd at)
Recursively generates code while computing iteration lattices in order to manage the complexity of im...
static Value genIndex(CodegenEnv &env, OpOperand *t)
Generates index for load/store on sparse tensor.
Base type for affine expression.
AffineExprKind getKind() const
Return the classification for this type.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
ArrayRef< AffineExpr > getResults() const
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
OpOperand & getOpOperand(unsigned idx)
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
MutableArrayRef< OpOperand > getOpOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
The code generation environment class aggregates a number of data structures that are needed during t...
void startReduc(ExprId exp, Value val)
const SparsificationOptions & options() const
Value getInsertionChain() const
std::optional< Operation * > genLoopBoundary(function_ref< std::optional< Operation * >(MutableArrayRef< Value > parameters)> callback)
Generates loop boundary statements (entering/exiting loops).
unsigned getLoopDepth() const
ArrayRef< LatPointId > set(LatSetId s) const
bool isCustomReduc() const
std::pair< TensorId, Level > unpackTensorLevel(TensorLevel tl) const
Value getExpandValues() const
TensorLevel makeTensorLevel(TensorId t, Level l) const
const LatPoint & lat(LatPointId l) const
constexpr TensorId makeTensorId(unsigned t) const
void startExpand(Value values, Value filled, Value added, Value count)
bool hasSparseOutput() const
void clearValidLexInsert()
bool atExpandLevel(OpOperand *o, unsigned rank, LoopOrd n) const
unsigned getLoopNum() const
void updateInsertionChain(Value chain)
Value getExpandCount() const
void startCustomReduc(ExprId exp)
linalg::GenericOp op() const
Value getLoopVar(LoopId i) const
Returns the induction-variable for the loop identified by the given LoopId.
Value getExpandFilled() const
auto unpackTensorLevelRange(ContainerTy &&c) const
Value getExpandAdded() const
void setValidLexInsert(Value val)
const TensorExp & exp(ExprId e) const
void updateExpandCount(Value count)
void updateReduc(Value val)
Value getValidLexInsert() const
bool isSparseOutput(OpOperand *o) const
constexpr LoopId makeLoopId(unsigned i) const
LevelType lt(TensorId t, LoopId i) const
constexpr static llvm::StringLiteral getLoopEmitterLoopAttrName()
const std::vector< Value > & getValBuffer() const
void enterNewLoopSeq(OpBuilder &builder, Location loc, ArrayRef< TensorLevel > tidLvls)
Enters a new loop sequence, the loops within the same sequence starts from the break points of previo...
Value genAffine(OpBuilder &builder, Location loc, AffineExpr a)
Generates code to compute an affine expression whose variables are LoopIds (i.e., a....
const std::vector< std::vector< Value > > & getPosits() const
Getters.
Value getLoopIV(LoopOrd n) const
Gets loop induction variable for the given LoopOrd.
auto getLoopIVsRange() const
Get the range of values for all induction variables.
void initializeLoopEmit(OpBuilder &builder, Location loc, OutputUpdater updater=nullptr, SynTensorBoundSetter synSetter=nullptr)
Starts a loop emitting session by generating all the buffers needed for iterating over the tensors.
void genDenseAffineAddress(OpBuilder &builder, Location loc, TensorLevel tidLvl, AffineExpr lvlExpr)
Emits the address for a dense level based on the value evaluated by the provided affine expression.
void exitCurrentLoopSeq(OpBuilder &builder, Location loc)
Exits the current loop sequence, this will reset universal index to 0.
LoopOrd getCurrentDepth() const
Gets the current depth of the loop-stack.
const std::vector< std::vector< Value > > & getCoords() const
A class to handle all iteration lattice operations.
std::optional< Level > getLvl(TensorId t, LoopId i) const
Gets the level number of the the tth tensor on ith loop.
LatSetId buildLattices(ExprId e, LoopId i)
Builds the iteration lattices in a bottom-up traversal given the remaining tensor (sub)expression and...
constexpr LoopId makeLoopId(unsigned i) const
Safely converts the argument to a loop identifier.
void setLevelAndType(TensorId t, LoopId i, Level lvl, LevelType lt)
Sets the level number and level-type of the tth tensor on ith loop.
void foreachTensorLoopId(LatPointId p, ForeachTensorLoopIdCallback callback) const
Iterates over a set of TensorLoopIds, invoking the callback for each TensorLoopId and passing it the ...
LatSetId optimizeSet(LatSetId s)
Optimizes the iteration lattice points in the given set.
void setLoopDependentTensorLevel(LoopId i, TensorId t, Level lvl, LevelType lt, unsigned coefficient)
Establishes the two-way map that i <-> <t, lvl, lt>.
bool hasAnySparse(const BitVector &bits) const
Returns true if any TensorLoopId in the bitvector corresponds to sparse level-type.
void clearExprValue(ExprId e)
Clears the value associated with the expression.
constexpr TensorId getSynTensorID() const
Gets the synthetic tensor's identifier (used for all invariant tensor expressions).
bool latGT(LatPointId p0, LatPointId p1) const
Returns true if p0 > p1.
constexpr LoopId loop(TensorLoopId b) const
Gets the loop-identifier of the TensorLoopId.
constexpr TensorId getOutTensorID() const
Gets the output tensor's identifier.
LevelType getLvlType(TensorId t, LoopId i) const
Gets the level-type of the tth tensor on ith loop.
Value buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0, Value v1) const
Rebuilds SSA format from a tensor expression.
void setExprValue(ExprId e, Value v)
Sets the expression to have the associated value.
bool hasDependentLvl(LoopId i, TensorId t)
Whether the loop has dependent slice.
A wrapper around RankedTensorType, which has three goals:
bool hasEncoding() const
Returns true for tensors which have an encoding, and false for those which do not.
bool isAllDense() const
Returns true for tensors where every level is dense.
Level getLvlRank() const
Returns the level-rank.
bool isDenseLvl(Level l) const
static constexpr unsigned kInvalidId
A constant serving as the canonically invalid identifier, regardless of the identifier type.
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
unsigned LatSetId
LatSet identifiers.
constexpr bool isLooseCompressedLT(LevelType lt)
Check if the LevelType is loose compressed (regardless of properties).
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
unsigned TensorLoopId
A compressed representation of std::pair<TensorId, LoopId>.
constexpr bool isUniqueLT(LevelType lt)
Check if the LevelType is unique (regardless of storage format).
uint64_t Level
The type of level identifiers and level-ranks.
constexpr bool isUndefLT(LevelType lt)
Check if the LevelType is the special undefined value.
unsigned LoopId
Loop identifiers.
Value constantI1(OpBuilder &builder, Location loc, bool b)
Generates a constant of i1 type.
constexpr bool is2OutOf4LT(LevelType lt)
Check if the LevelType is 2OutOf4 (regardless of properties).
constexpr bool isDenseLT(LevelType lt)
Check if the LevelType is dense (regardless of properties).
constexpr bool isSingletonLT(LevelType lt)
Check if the LevelType is singleton (regardless of properties).
LevelType
This enum defines all the sparse representations supportable by the SparseTensor dialect.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
unsigned LoopOrd
The position of a loop in the loop-stack, or the position of a LoopId in a topologically-sorted list ...
constexpr bool isCompressedLT(LevelType lt)
Check if the LevelType is compressed (regardless of properties).
bool hasAnyNonIdentityOperandsOrResults(Operation *op)
Returns true iff MLIR operation has any sparse tensor with non-identity dim2lvl maps.
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
unsigned ExprId
TensorExp identifiers.
unsigned LatPointId
LatPoint identifiers.
unsigned TensorId
Tensor identifiers, chosen to be the BlockArgument::getArgNumber of the value passed to Merger::build...
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ DimId
Dimensional identifier.
@ Constant
Constant integer.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void populateSparsificationPatterns(RewritePatternSet &patterns, const SparsificationOptions &options=SparsificationOptions())
Sets up sparsification rewriting rules with the given options.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Options for the Sparsification pass.
SparseParallelizationStrategy parallelizationStrategy
ExprId exp
Identifier of the tensor expression.
BitVector simple
Simplified conjunction of TensorLoopId as bitvector.
Tensor expression. Represents an MLIR expression in tensor index notation.
LoopId loop
kLoopVar expressions simply have a loop identifier.
Value val
Direct link to IR for an invariant or the destination value (to infer destination type) of a cast ope...
Children children
All other expressions hold the ExprIds of their children.
TensorId tensor
kTensor expressions simply have a tensor identifier.
Kind kind
Tensor expression kind.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.