9 #ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_SPARSETENSORLOOPEMITTER_H_
10 #define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_SPARSETENSORLOOPEMITTER_H_
20 namespace sparse_tensor {
115 bool hasOutput =
false,
bool isSparseOut =
false,
119 bool hasOutput =
false,
bool isSparseOut =
false,
120 unsigned numLoops = 0,
183 bool genDedup =
false,
bool needsUniv =
false);
192 return llvm::map_range(loopStack, [](
const LoopInfo &li) {
return li.iv; });
245 return std::make_pair(tidLvl % nt, tidLvl / nt);
249 template <
class ContainerTy>
251 using EltTy = decltype(*c.begin());
252 static_assert(std::is_same_v<llvm::remove_cvref_t<EltTy>,
TensorLevel>,
253 "Must be unpacking a TensorLevel range");
254 return llvm::map_range(std::forward<ContainerTy>(c), [
this](EltTy tl) {
259 template <
class ContainerTy>
261 using EltTy = decltype(*c.begin());
262 static_assert(std::is_same_v<llvm::remove_cvref_t<EltTy>, TensorLvlCond>,
263 "Must be unpacking a TensorLvlCond range");
265 llvm::make_first_range(std::forward<ContainerTy>(c)));
271 const std::vector<std::vector<Value>> &
getPosits()
const {
return posits; };
272 const std::vector<std::vector<Value>> &
getCoords()
const {
return coords; };
273 const std::vector<std::vector<Value>> &
getHighs()
const {
return highs; };
275 return positionsBuffers;
278 return coordinatesBuffers;
283 return llvm::StringLiteral(
"Emitted from");
292 struct SliceLoopInfo final {
294 : tid(tid), lvl(lvl), reduced(reduced) {}
301 struct LoopInfo final {
302 LoopInfo(ArrayRef<TensorLevel> trivialTidLvls,
303 ArrayRef<SliceLoopInfo> sliceDrivenInfo, Operation *loop,
304 Block *userBlock, Value iv, StringAttr loopTag)
305 : trivialTidLvls(trivialTidLvls), sliceDrivenInfo(sliceDrivenInfo),
306 loop(loop), userCodeBlock(userBlock), iv(iv) {
319 const Operation *loop;
320 Block *
const userCodeBlock;
327 struct SliceInfo final {
330 SliceInfo(Value minCrd, Value offset, Value isNonEmpty,
331 std::optional<Level> slicedOnLvl,
unsigned depth)
332 : minCrd(minCrd), offset(offset), isNonEmpty(isNonEmpty),
333 slicedOnLvl(slicedOnLvl), depth(depth) {
335 assert(!slicedOnLvl || minCrd);
339 bool isInitialTensor()
const {
return !slicedOnLvl.has_value(); }
344 std::optional<Level> slicedOnLvl;
353 static constexpr uint8_t kSparseCond = 1 << 3;
356 static constexpr uint8_t kSliceCond = 1 << 2;
359 static constexpr uint8_t kAffineIdxCond = 1 << 1;
362 static constexpr uint8_t kAffineIdxCondUnRed = 1 << 0;
364 enum class LoopCondKind : uint8_t {
367 DenseSliceCond = kSliceCond,
368 DenseAffineCond = kAffineIdxCond,
369 DenseAffineUnRedCond = kAffineIdxCond | kAffineIdxCondUnRed,
371 SparseCond = kSparseCond,
372 SparseSliceCond = kSparseCond | kSliceCond,
373 SparseAffineCond = kSparseCond | kAffineIdxCond,
374 SparseAffineUnRedCond = kSparseCond | kAffineIdxCond | kAffineIdxCondUnRed,
376 using TensorLvlCond = std::pair<TensorLevel, LoopCondKind>;
379 static bool isSparseCond(LoopCondKind k) {
380 return static_cast<uint8_t
>(k) & kSparseCond;
382 static bool isDenseCond(LoopCondKind k) {
return !isSparseCond(k); }
385 static bool isSliceCond(LoopCondKind k) {
386 return static_cast<uint8_t
>(k) & kSliceCond;
390 static bool isAffineIdxCond(LoopCondKind k) {
391 return static_cast<uint8_t
>(k) & kAffineIdxCond;
393 static bool isTrivalIdxCond(LoopCondKind k) {
return !isAffineIdxCond(k); }
396 static bool isAffineIdxUnRedCond(LoopCondKind k) {
397 return isAffineIdxCond(k) &&
static_cast<uint8_t
>(k) & kAffineIdxCondUnRed;
399 static bool isAffineIdxRedCond(LoopCondKind k) {
400 return isAffineIdxCond(k) && !isAffineIdxUnRedCond(k);
406 static bool isCondWithExtraCheck(LoopCondKind k) {
407 return isSparseCond(k) && (isSliceCond(k) || isAffineIdxUnRedCond(k));
410 static LoopCondKind makeLoopCondKind(
bool isSparse,
bool isSlice,
411 bool isAffine,
bool isUnRedu) {
412 assert(!isUnRedu || isAffine);
414 bits = isSparse ? bits | kSparseCond : bits;
415 bits = isSlice ? bits | kSliceCond : bits;
416 bits = isAffine ? bits | kAffineIdxCond : bits;
417 bits = isUnRedu ? bits | kAffineIdxCondUnRed : bits;
418 LoopCondKind kind =
static_cast<LoopCondKind
>(bits);
421 assert(isSparse == isSparseCond(kind));
422 assert(isSlice == isSliceCond(kind));
423 assert(isAffine == isAffineIdxCond(kind));
424 assert(isUnRedu == isAffineIdxUnRedCond(kind));
428 void categorizeLoopCondition(ArrayRef<TensorLevel> tidLvls,
429 SmallVectorImpl<TensorLvlCond> &dnConds,
430 SmallVectorImpl<TensorLvlCond> &spConds);
437 MutableArrayRef<Value>)>;
440 bool shouldIteratedByForLoop(ArrayRef<TensorLvlCond> spConds,
bool genDedup);
453 Value genSegmentHigh(OpBuilder &builder, Location loc,
TensorId tid,
454 Level lvl, Value pos, Value pHi);
460 Value genSparseCrd(OpBuilder &builder, Location loc,
TensorId tid,
466 std::pair<Value, Value> genSliceLegitPredicate(OpBuilder &builder,
467 Location loc, Value crd,
472 bool isOutputTensor(
TensorId tid)
const {
476 bool isSparseOutput(
TensorId tid)
const {
477 return isOutputTensor(tid) && isSparseOut;
481 return tid < lvlTypes.size() && lvl < lvlTypes[tid].size();
486 void forwardsReducedSliceLevelTreeIt(OpBuilder &builder, Location loc,
491 void prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
497 void enterTensorsAtDenseLvls(OpBuilder &builder, Location loc,
498 ArrayRef<TensorLvlCond> dnConds, Value iv,
499 SmallVectorImpl<SliceLoopInfo> &sliceInfo);
507 std::pair<Operation *, Value>
508 emitForLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
TensorId tid,
509 Level lvl, Value lo, Value hi,
510 MutableArrayRef<Value> reduc,
bool isParallel);
518 std::pair<Operation *, Value>
519 emitWhileLoopOverTensorsAtLvls(OpBuilder &builder, Location loc,
520 ArrayRef<TensorLvlCond> spConds,
521 MutableArrayRef<Value> reduc,
bool needsUniv);
524 Value genWhileLoopConditions(OpBuilder &builder, Location loc, ValueRange ivs,
528 std::optional<Value> genWhileLoopBody(OpBuilder &builder, Location loc,
529 ValueRange ivs, TensorLvlCond cond);
537 ValueRange genCheckedValue(OpBuilder &builder, Location loc, Value pred,
538 ValueRange curArg, TensorLvlCond cond);
564 void exitForLoop(RewriterBase &rewriter, Location loc,
565 MutableArrayRef<Value> reduc);
568 void exitWhileLoop(OpBuilder &builder, Location loc,
569 MutableArrayRef<Value> reduc);
578 const SliceInfo &getMostRecentSliceOnLvl(
TensorId tid,
Level lvl);
582 const SliceInfo &getFinalSliceOnLvl(
TensorId tid,
Level lvl) {
583 const SliceInfo &info = getMostRecentSliceOnLvl(tid, lvl);
584 assert(info.depth == dependentLvlMap[tid][lvl].size() - 1);
596 return remDepOnLevel(tid, lvl) == 1;
604 return remDepOnLevel(tid, lvl) == 0;
611 std::pair<Operation *, ValueRange>
612 genSliceLvlTraverseLoop(OpBuilder &builder, Location loc, Value pLo,
613 Value pHi, Value offset, Value size,
TensorId tid,
614 Level lvl, ValueRange userReduc,
615 LoopBodyBuilder bodyBuilder);
619 ValueRange genUnResolvedSliceTreeTraverse(
620 OpBuilder &builder, Location loc,
TensorId tid,
621 ArrayRef<const SliceInfo *> unResLvls,
622 std::optional<std::pair<TensorId, Level>> firstResLvl,
623 ValueRange userReduc, LoopBodyBuilder bodyBuilder);
630 void genResolvedSliceBegin(OpBuilder &builder, Location loc,
TensorId tid,
638 void genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
TensorId tid,
644 void invalidateSliceIterIdx(OpBuilder &builder, Location loc,
TensorId tid,
648 bool genSliceBegin(OpBuilder &builder, Location loc,
TensorId tid,
Level lvl);
653 std::tuple<Value, Value, Value> genSliceNextInduction(OpBuilder &builder,
668 Operation *localInsertPos;
677 std::vector<Value> tensors;
679 std::vector<std::vector<LevelType>> lvlTypes;
693 std::vector<std::vector<Value>> posits;
696 std::vector<std::vector<Value>> coords;
698 std::vector<std::vector<Value>> segHi;
699 std::vector<std::vector<Value>> highs;
700 std::vector<std::vector<Value>> lvlSizes;
701 std::vector<std::vector<Value>> positionsBuffers;
702 std::vector<std::vector<Value>> coordinatesBuffers;
703 std::vector<Value> valBuffer;
710 std::vector<bool> isSparseSlices;
712 std::vector<std::vector<Value>> sliceOffsets;
713 std::vector<std::vector<Value>> sliceStrides;
717 std::vector<std::vector<std::vector<std::pair<TensorLevel, unsigned>>>>
724 std::vector<std::vector<std::vector<Value>>> slicePosBuffer;
728 std::vector<std::vector<std::vector<std::pair<Value, unsigned>>>> sliceMeta;
731 std::vector<std::vector<unsigned>> levelReducedDep;
734 std::vector<std::vector<SliceInfo>> sliceStack;
746 std::vector<LoopInfo> loopStack;
751 std::vector<std::pair<Value, std::vector<std::tuple<TensorId, Level, bool>>>>
Base type for affine expression.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void exitCurrentLoop(RewriterBase &rewriter, Location loc, MutableArrayRef< Value > reduc={})
Generates code to exit the current loop (e.g., generates yields, forwards loop induction variables,...
constexpr static llvm::StringLiteral getLoopEmitterLoopAttrName()
const std::vector< Value > & getValBuffer() const
void enterNewLoopSeq(OpBuilder &builder, Location loc, ArrayRef< TensorLevel > tidLvls)
Enters a new loop sequence, the loops within the same sequence starts from the break points of previo...
const std::vector< std::vector< Value > > & getCoordinateBuffers() const
void initialize(ValueRange tensors, StringAttr loopTag=nullptr, bool hasOutput=false, bool isSparseOut=false, unsigned numLoops=0, DependentLvlGetter getter=nullptr)
Takes an array of input tensors, which the generated loops will iterate over.
Value genAffine(OpBuilder &builder, Location loc, AffineExpr a)
Generates code to compute an affine expression whose variables are LoopIds (i.e., a....
TensorId getOutTensorId() const
Gets the TensorId for output tensor.
TensorLevel makeTensorLevel(TensorId t, Level l) const
Compresses a TensorId and Level into a TensorLevel.
unsigned getNumManifestTensors() const
Gets the total number of manifest tensors (excluding the synthetic tensor).
const std::vector< std::vector< Value > > & getPosits() const
Getters.
Operation * enterCoIterationOverTensorsAtLvls(OpBuilder &builder, Location loc, ArrayRef< TensorLevel > tidLvls, MutableArrayRef< Value > reduc={}, bool isParallel=false, bool genDedup=false, bool needsUniv=false)
Emits a co-iteration loop over a set of tensors.
const std::vector< std::vector< Value > > & getPositionBuffers() const
Operation * enterFilterLoopOverTensorAtLvl(OpBuilder &builder, Location loc, TensorId tid, Level lvl, AffineExpr affine, MutableArrayRef< Value > reduc={})
Enters a loop that tries to locate a coordinates in a sparse level based on the value evaluated by th...
std::pair< TensorId, Level > unpackTensorLevel(TensorLevel tidLvl) const
De-compresses a TensorLevel back to a pair of TensorId and Level.
auto unpackTensorLevelRange(ContainerTy &&c) const
Converts a range of TensorLevel to a range of std::pair<TensorId, Level>
unsigned getNumTensors() const
Gets the total number of tensors that loopEmitter is operating on.
SmallVector< Value > getLoopIVs() const
Fills the out-parameter with the loop induction variables for all loops in the current loop-stack.
Value getLoopIV(LoopOrd n) const
Gets loop induction variable for the given LoopOrd.
auto getLoopIVsRange() const
Get the range of values for all induction variables.
void initializeLoopEmit(OpBuilder &builder, Location loc, OutputUpdater updater=nullptr, SynTensorBoundSetter synSetter=nullptr)
Starts a loop emitting session by generating all the buffers needed for iterating over the tensors.
void genDenseAffineAddress(OpBuilder &builder, Location loc, TensorLevel tidLvl, AffineExpr lvlExpr)
Emits the address for a dense level based on the value evaluated by the provided affine expression.
auto unpackTensorLevelFromCondRange(ContainerTy &&c) const
void exitCurrentLoopSeq(OpBuilder &builder, Location loc)
Exits the current loop sequence, this will reset universal index to 0.
LoopOrd getCurrentDepth() const
Gets the current depth of the loop-stack.
TensorId getSynTensorId() const
Gets the TensorId for synthetic tensor.
const std::vector< std::vector< Value > > & getCoords() const
const std::vector< std::vector< Value > > & getHighs() const
uint64_t Level
The type of level identifiers and level-ranks.
unsigned LoopOrd
The position of a loop in the loop-stack, or the position of a LoopId in a topologically-sorted list ...
unsigned TensorId
Tensor identifiers, chosen to be the BlockArgument::getArgNumber of the value passed to Merger::build...
Include the generated interface declarations.