27 #define CMPI(p, l, r) \
28 (arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::p, (l), (r)) \
31 #define C_IDX(v) (constantIndex(builder, loc, (v)))
32 #define YIELD(vs) (scf::YieldOp::create(builder, loc, (vs)))
33 #define ADDI(lhs, rhs) (arith::AddIOp::create(builder, loc, (lhs), (rhs)))
34 #define ANDI(lhs, rhs) (arith::AndIOp::create(builder, loc, (lhs), (rhs)))
35 #define SUBI(lhs, rhs) (arith::SubIOp::create(builder, loc, (lhs), (rhs)))
36 #define MULI(lhs, rhs) (arith::MulIOp::create(builder, loc, (lhs), (rhs)))
37 #define REMUI(lhs, rhs) (arith::RemUIOp::create(builder, loc, (lhs), (rhs)))
38 #define DIVUI(lhs, rhs) (arith::DivUIOp::create(builder, loc, (lhs), (rhs)))
39 #define SELECT(c, l, r) (arith::SelectOp::create(builder, loc, (c), (l), (r)))
48 memref = memref::CastOp::create(
77 if (
auto f = llvm::dyn_cast<FloatAttr>(attr); f && f.getValue().isZero())
79 if (
auto i = llvm::dyn_cast<IntegerAttr>(attr); i && i.getValue().isZero())
88 return cast<Value>(ofr);
96 if (padOp && stt.has_value() && stt->hasEncoding() &&
97 padOp.getSourceType().getEncoding() == stt->getEncoding() &&
98 stt->getEncoding().isIdentity()) {
102 m_Op<tensor::YieldOp>(
m_Constant(&padCst))) &&
104 return padOp.getSource();
115 bool isSparseOut,
unsigned numLoops,
118 initialize(tensors, loopTag, hasOutput, isSparseOut, numLoops, dimGetter);
122 bool isSparseOut,
unsigned numLoops,
126 this->loopTag = loopTag;
127 this->hasOutput = hasOutput;
128 this->isSparseOut = isSparseOut;
129 this->emitStrategy = emitStrategy;
131 const unsigned numManifestTensors = ts.size();
132 const unsigned synTensorId = numManifestTensors;
133 const unsigned numTensors = numManifestTensors + 1;
135 this->tensors.assign(ts.begin(), ts.end());
137 this->valBuffer.assign(numTensors,
nullptr);
138 this->lvls.resize(numTensors);
139 this->iters.resize(numTensors);
140 this->spIterVals.resize(numTensors);
144 this->loopStack.reserve(numLoops);
145 this->loopSeqStack.reserve(numLoops);
148 this->dependentLvlMap.assign(
149 numTensors, std::vector<std::vector<std::pair<TensorLevel, unsigned>>>());
150 this->sliceMeta.assign(
151 numTensors, std::vector<std::vector<std::pair<Value, unsigned>>>());
152 this->levelReducedDep.assign(numTensors, std::vector<unsigned>());
155 for (
TensorId tid = 0; tid < numTensors; tid++) {
157 if (tid == synTensorId) {
163 const Value t = tensors[tid];
173 lvls[tid].resize(lvlRank);
174 iters[tid].resize(lvlRank);
175 spIterVals[tid].resize(lvlRank);
176 loopHighs.assign(numLoops,
nullptr);
179 levelReducedDep[tid].assign(lvlRank, 0);
180 dependentLvlMap[tid].assign(
181 lvlRank, std::vector<std::pair<TensorLevel, unsigned>>());
182 sliceMeta[tid].assign(lvlRank, std::vector<std::pair<Value, unsigned>>());
183 if (dimGetter && !isSynTensor(tid)) {
184 for (
Level l = 0; l < lvlRank; l++) {
185 std::vector<std::pair<LoopId, unsigned>> deps = dimGetter(tid, l);
187 llvm::sort(deps, llvm::less_first());
189 dependentLvlMap[tid][l] = std::move(deps);
190 unsigned depends = dependentLvlMap[tid][l].size();
193 sliceMeta[tid][l].reserve(depends);
199 std::unique_ptr<SparseIterator>
202 Value tensor = tensors[t];
207 if (folded != tensor) {
210 if (padOp.getPaddedDims().test(l)) {
218 if (stt.hasEncoding() && stt.getEncoding().isSlice()) {
222 std::move(it), offset, stride, lvls[t][l]->getSize(), emitStrategy);
239 const auto rtp = dyn_cast<RankedTensorType>(tensor.
getType());
246 const auto shape = rtp.getShape();
251 bool isOutput = isOutputTensor(t);
252 Type elementType = stt.getElementType();
253 if (!stt.hasEncoding()) {
260 if (llvm::isa_and_nonnull<tensor::ExtractSliceOp>(tensor.
getDefiningOp()))
264 bufferization::ToBufferOp::create(builder, loc, denseTp, tensor);
266 if (isOutput && updater)
267 denseVal = updater(builder, loc, denseVal, tensor);
269 valBuffer[t] = denseVal;
274 valBuffer[t] = ToValuesOp::create(builder, loc, tensor);
286 for (
unsigned i = 0, e = loopHighs.size(); i < e; i++) {
287 Value sz = loopHighs[i] = synSetter(builder, loc, i);
289 lvls[synId][i] = std::move(stl);
290 iters[synId][i].emplace_back(std::move(it));
304 const auto rtp = dyn_cast<RankedTensorType>(tensor.
getType());
311 const Level lvlRank = stt.getLvlRank();
314 for (
Level l = 0; l < lvlRank; l++) {
317 if (!dependentLvlMap[t][l].empty())
320 auto it = makeLevelIterator(builder, loc, t, l);
321 iters[t][l].emplace_back(std::move(it));
328 initSubSectIterator(builder, loc);
333 for (
TensorId t = 0, e = tensors.size(); t < e; t++) {
334 auto rtp = dyn_cast<RankedTensorType>(tensors[t].
getType());
341 auto remDepStack = dependentLvlMap;
342 std::vector<std::tuple<LoopId, TensorId, Level>> depRedOrder;
343 for (
Level lvl = 0; lvl < lvlRank; lvl++) {
345 std::reverse(remDepStack[t][lvl].begin(), remDepStack[t][lvl].end());
346 for (
auto [loop, coeff] : dependentLvlMap[t][lvl])
347 depRedOrder.emplace_back(std::make_tuple(loop, t, lvl));
350 if (depRedOrder.empty())
353 llvm::sort(depRedOrder, llvm::less_first());
356 for (
auto [loop, t, lvl] : depRedOrder) {
357 std::pair<LoopId, unsigned> curDep = remDepStack[t][lvl].back();
358 assert(curDep.first == loop);
359 remDepStack[t][lvl].pop_back();
361 auto lvlIt = makeLevelIterator(builder, loc, t, lvl);
363 if (!parent && lvl > 0) {
364 if (dependentLvlMap[t][lvl - 1].empty()) {
365 parent = iters[t][lvl - 1].back().get();
369 std::unique_ptr<SparseIterator> it;
370 if (!remDepStack[t][lvl].empty()) {
373 for (
auto [loop, stride] : remDepStack[t][lvl]) {
378 std::move(lvlIt), size, curDep.second,
383 std::move(lvlIt), loopHighs[loop],
384 curDep.second, emitStrategy);
386 lastIter[t] = it.get();
387 iters[t][lvl].emplace_back(std::move(it));
392 void LoopEmitter::categorizeIterators(
400 raIters.push_back(it);
402 spIters.push_back(it);
405 llvm::stable_sort(spIters, [](
auto lhs,
auto rhs) {
407 return static_cast<uint8_t
>(lhs->kind) >
static_cast<uint8_t
>(rhs->kind);
414 assert(loopSeqStack.size() == loopStack.size());
419 levelReducedDep[tid][lvl]++;
420 prepareLoopOverTensorAtLvl(builder, loc, tid, lvl);
425 loopSeqStack.emplace_back(
C_IDX(0), tidLvls.vec());
429 assert(loopSeqStack.size() == loopStack.size() + 1);
434 levelReducedDep[tid][lvl]--;
436 loopSeqStack.pop_back();
446 const auto loopId = cast<AffineDimExpr>(a).getPosition();
447 return loopStack[loopId].iv;
450 auto binOp = cast<AffineBinaryOpExpr>(a);
452 genAffine(builder, loc, binOp.getRHS()));
455 auto binOp = cast<AffineBinaryOpExpr>(a);
457 genAffine(builder, loc, binOp.getRHS()));
460 int64_t c = cast<AffineConstantExpr>(a).getValue();
464 llvm_unreachable(
"unexpected affine subscript");
468 std::pair<Operation *, Value> LoopEmitter::emitForLoopOverTensorAtLvl(
477 auto [lo, hi] = iter.
genForCond(builder, loc);
481 scf::ParallelOp parOp =
482 scf::ParallelOp::create(builder, loc, lo, hi, step, reduc);
484 assert(parOp.getNumReductions() == reduc.size());
485 iv = parOp.getInductionVars()[0];
494 for (
int i = 0, e = reduc.size(); i < e; i++)
495 reduc[i] = parOp.getInitVals()[i];
498 scf::ForOp forOp = scf::ForOp::create(builder, loc, lo, hi, step, reduc);
500 iv = forOp.getInductionVar();
503 assert(forOp.getNumRegionIterArgs() == reduc.size());
504 for (
int i = 0, e = reduc.size(); i < e; i++)
505 reduc[i] = forOp.getRegionIterArg(i);
513 crd = iter.
deref(builder, loc);
515 iter.
locate(builder, loc, iv);
521 std::pair<Operation *, Value> LoopEmitter::emitWhileLoopOverTensorsAtLvls(
525 needsUniv ? loopSeqStack.back().first :
nullptr);
530 if (spIters.size() > 1)
533 if (spIters.size() == 1)
534 return spIters.front()->iteratableByFor();
544 auto coIterOp = cast<CoIterateOp>(loopStack.back().loop);
549 Region &caseRegion = coIterOp.getRegion(caseIdx);
551 "re-initialize the same coiteration case region.");
554 TypeRange iterArgsTps = coIterOp.getInitArgs().getTypes();
559 blockArgTps.append(iterArgsTps.begin(), iterArgsTps.end());
561 for (
auto i : caseBit.
bits()) {
562 blockArgTps.push_back(
563 cast<IterSpaceType>(coIterOp.getIterSpaces()[i].getType())
572 loopStack.back().iv = coIterOp.getCrds(caseIdx).front();
574 ValueRange iterArgs = coIterOp.getRegionIterArgs(caseIdx);
577 ValueRange iters = coIterOp.getRegionIterators(caseIdx);
581 spIterVals[tl.first][tl.second] = iters.front();
582 iters = iters.drop_front();
584 spIterVals[tl.first][tl.second] =
nullptr;
588 assert(iters.empty());
600 if (tidLvls.size() == 1) {
602 Value t = tensors[tid];
605 ExtractIterSpaceOp extractSpaceOp =
606 lvl == 0 ? ExtractIterSpaceOp::create(builder, loc, t)
607 : ExtractIterSpaceOp::create(builder, loc, t,
608 spIterVals[tid][lvl - 1], lvl);
610 IterateOp iterOp = IterateOp::create(
611 builder, loc, extractSpaceOp.getExtractedSpace(), reduc);
612 spIterVals[tid][lvl] = iterOp.getIterator();
615 llvm::copy(iterOp.getRegionIterArgs(), reduc.begin());
619 iterOp.getCrds().
front(), loopTag);
626 Value t = tensors[tid];
627 ExtractIterSpaceOp extractSpaceOp =
628 lvl == 0 ? ExtractIterSpaceOp::create(builder, loc, t)
629 : ExtractIterSpaceOp::create(builder, loc, t,
630 spIterVals[tid][lvl - 1], lvl);
631 spaces.push_back(extractSpaceOp.getExtractedSpace());
633 auto coIterOp = CoIterateOp::create(builder, loc, spaces, reduc, numCases);
636 loopStack.emplace_back(tidLvls, coIterOp,
nullptr,
642 tryParallel = tryParallel && reduc.size() <= 1;
646 categorizeIterators(tidLvls, raIters, spIters);
652 needsUniv = !spIters.empty() && needsUniv;
663 if (shouldIteratedByForLoop(spIters) && !needsUniv) {
664 assert(spIters.size() <= 1);
665 SparseIterator &it = spIters.empty() ? *raIters.front() : *spIters.front();
667 emitForLoopOverTensorAtLvl(builder, loc, it, reduc, tryParallel);
670 for (
auto *it : spIters) {
675 for (
auto *it : raIters)
679 emitWhileLoopOverTensorsAtLvls(builder, loc, spIters, reduc, needsUniv);
684 it->
locate(builder, loc, iv);
698 lvl == 0 ? nullptr : iters[tid][lvl - 1].back().get();
699 auto &it = getCurIterator(tid, lvl);
700 it.
genInit(builder, loc, parent);
704 it.
locate(builder, loc, lvlCrd);
713 bool hasParent = lvl == 0 || !dependentLvlMap[tid][lvl].empty();
716 hasParent ? nullptr : iters[tid][lvl - 1].back().get();
717 auto &it = getCurIterator(tid, lvl);
718 it.
genInit(builder, loc, parent);
727 const LoopInfo &loopInfo = loopStack.back();
729 auto iterateOp = llvm::cast<IterateOp>(loopInfo.loop);
730 assert(reduc.size() == iterateOp.getNumResults());
731 sparse_tensor::YieldOp::create(rewriter, loc, reduc);
735 llvm::copy(iterateOp.getResults(), reduc.begin());
738 if (
auto forOp = llvm::dyn_cast<scf::ForOp>(loopInfo.loop)) {
739 if (!reduc.empty()) {
740 assert(reduc.size() == forOp.getNumResults());
741 scf::YieldOp::create(rewriter, loc, reduc);
746 llvm::copy(forOp.getResults(), reduc.begin());
748 auto parOp = llvm::cast<scf::ParallelOp>(loopInfo.loop);
749 if (!reduc.empty()) {
750 assert(reduc.size() == parOp.getInitVals().size() && reduc.size() == 1);
751 Operation *redExp = reduc.front().getDefiningOp();
753 assert(redExp->
getUses().empty());
759 Value redVal = parOp.getInitVals().front();
771 unsigned numUsers = 0;
773 if (op->getParentOp() == parOp)
776 assert(numUsers == 1);
780 auto redOp = scf::ReduceOp::create(rewriter, loc, curVal);
782 Block *redBlock = &redOp.getReductions().
front().front();
792 scf::ReduceReturnOp::create(rewriter, loc, newRed->
getResult(0));
796 for (
unsigned i = 0, e = parOp.getResults().size(); i < e; i++)
797 reduc[i] = parOp.getResult(i);
803 const LoopInfo &loopInfo = loopStack.back();
804 auto whileOp = llvm::cast<scf::WhileOp>(loopInfo.loop);
805 Value iv = loopInfo.iv;
830 Value uniIdx = whileOp.getResults().back();
831 it.
locate(builder, loc, uniIdx);
836 for (
auto &i : reduc) {
837 operands.push_back(i);
839 i = whileRes.front();
840 whileRes = whileRes.drop_front();
844 if (operands.size() < whileOp.getNumResults()) {
845 assert(operands.size() + 1 == whileOp.getNumResults());
847 operands.push_back(
ADDI(iv, one));
849 loopSeqStack.back().first = whileOp->getResults().back();
852 if (!operands.empty())
862 const LoopInfo &loopInfo = loopStack.back();
865 if (isa<IterateOp>(p))
866 sparse_tensor::YieldOp::create(rewriter, loc, reduc);
872 loopStack.pop_back();
878 if (!loopInfo.userCodeBlock->empty() &&
879 llvm::isa<scf::YieldOp>(&loopInfo.userCodeBlock->back())) {
882 assert(loopInfo.userCodeBlock->back().getNumResults() == 0);
886 if (llvm::isa<scf::WhileOp>(loopInfo.loop)) {
887 exitWhileLoop(rewriter, loc, reduc);
889 exitForLoop(rewriter, loc, reduc);
892 assert(loopStack.size() == loopSeqStack.size());
893 loopStack.pop_back();
914 ivs.append(reduc.begin(), reduc.end());
919 ivs.append(itVals.begin(), itVals.end());
923 ivs.append(reduc.begin(), reduc.end());
927 ivs.push_back(uniIdx);
930 assert(!llvm::is_contained(ivs,
nullptr));
932 auto whileOp = scf::WhileOp::create(builder, loc, types, ivs);
941 Value whileCond =
nullptr;
944 auto [cond, remArgs] = it->
genWhileCond(builder, loc, bArgs);
945 whileCond = !whileCond ? cond :
ANDI(whileCond, cond);
950 assert(bArgs.size() == reduc.size() + (uniIdx ? 1 : 0));
951 scf::ConditionOp::create(builder, loc, whileCond, before->
getArguments());
960 it->
deref(builder, loc);
964 for (
unsigned i = 0, e = reduc.size(); i < e; i++)
980 min = whileOp.getAfterArguments().back();
983 return {whileOp,
min};
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Value genSliceStride(OpBuilder &builder, Location loc, Value tensor, Level lvl)
static Value tryFoldTensors(Value t)
static Value genSliceOffset(OpBuilder &builder, Location loc, Value tensor, Level lvl)
static bool isIntOrFPZero(Attribute attr)
static LLVM_ATTRIBUTE_UNUSED void dumpIndexMemRef(OpBuilder &builder, Location loc, Value memref)
static Value unFoldOpIntResult(OpBuilder &builder, Location loc, OpFoldResult ofr)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
Base type for affine expression.
AffineExprKind getKind() const
Return the classification for this type.
Attributes are known-constant values of operations.
This class provides a shared interface for ranked and unranked memref types.
Block represents an ordered list of Operations.
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
BlockArgListType getArguments()
IntegerAttr getI64IntegerAttr(int64_t value)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumOperands()
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
result_range getResults()
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
unsigned getNumResults()
Return the number of results held by this operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockListType & getBlocks()
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A simple wrapper to encode a bitset of (at most 64) levels, currently used by sparse_tensor....
iterator_range< const_set_bits_iterator > bits() const
void exitCurrentLoop(RewriterBase &rewriter, Location loc, MutableArrayRef< Value > reduc={})
Generates code to exit the current loop (e.g., generates yields, forwards loop induction variables,...
void locateLvlAtAffineAddress(OpBuilder &builder, Location loc, TensorLevel tidLvl, AffineExpr lvlExpr)
Emits the address for a dense level based on the value evaluated by the provided affine expression.
void enterNewLoopSeq(OpBuilder &builder, Location loc, ArrayRef< TensorLevel > tidLvls)
Enters a new loop sequence, the loops within the same sequence starts from the break points of previo...
Value genAffine(OpBuilder &builder, Location loc, AffineExpr a)
Generates code to compute an affine expression whose variables are LoopIds (i.e., cast<AffineDimExpr>...
Region * enterCurrentCoIterationCase(OpBuilder &builder, Location loc, I64BitSet caseBit, unsigned caseIdx, MutableArrayRef< Value > reduc)
Operation * enterCoIterationOverTensorsAtLvls(OpBuilder &builder, Location loc, ArrayRef< TensorLevel > tidLvls, unsigned numCases, MutableArrayRef< Value > reduc={}, bool isParallel=false, bool needsUniv=false)
Emits a co-iteration loop over a set of tensors.
TensorLevel makeTensorLevel(TensorId t, Level l) const
Compresses a TensorId and Level into a TensorLevel.
unsigned getNumManifestTensors() const
Gets the total number of manifest tensors (excluding the synthetic tensor).
void initialize(ValueRange tensors, StringAttr loopTag=nullptr, bool hasOutput=false, bool isSparseOut=false, unsigned numLoops=0, DependentLvlGetter getter=nullptr, SparseEmitStrategy emitStrategy=SparseEmitStrategy::kFunctional)
Takes an array of input tensors, which the generated loops will iterate over.
std::pair< TensorId, Level > unpackTensorLevel(TensorLevel tidLvl) const
De-compresses a TensorLevel back to a pair of TensorId and Level.
auto unpackTensorLevelRange(ContainerTy &&c) const
Converts a range of TensorLevel to a range of std::pair<TensorId, Level>
void initializeLoopEmit(OpBuilder &builder, Location loc, OutputUpdater updater=nullptr, SynTensorBoundSetter synSetter=nullptr)
Starts a loop emitting session by generating all the buffers needed for iterating over the tensors.
void exitCurrentLoopSeq(OpBuilder &builder, Location loc)
Exits the current loop sequence, this will reset universal index to 0.
TensorId getSynTensorId() const
Gets the TensorId for synthetic tensor.
Helper class that generates loop conditions, etc, to traverse a sparse tensor level.
virtual std::pair< Value, Value > genForCond(OpBuilder &b, Location l)
void genInit(OpBuilder &b, Location l, const SparseIterator *p)
void locate(OpBuilder &b, Location l, Value crd)
virtual ValueRange forwardIf(OpBuilder &b, Location l, Value cond)
ValueRange linkNewScope(ValueRange pos)
ValueRange getCursor() const
Value deref(OpBuilder &b, Location l)
virtual bool randomAccessible() const =0
std::pair< Value, ValueRange > genWhileCond(OpBuilder &b, Location l, ValueRange vs)
A wrapper around RankedTensorType, which has three goals:
Level getLvlRank() const
Returns the level-rank.
BaseMemRefType getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with fully dynamic layout.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Dimension toDim(SparseTensorEncodingAttr enc, Level l)
Convenience method to translate the given level to the corresponding dimension.
std::unique_ptr< SparseTensorLevel > makeSparseTensorLevel(OpBuilder &b, Location l, Value t, unsigned tid, Level lvl)
Helper function to create a TensorLevel object from given tensor.
std::unique_ptr< SparseIterator > makeTraverseSubSectIterator(OpBuilder &b, Location l, const SparseIterator &subsectIter, const SparseIterator &parent, std::unique_ptr< SparseIterator > &&wrap, Value loopBound, unsigned stride, SparseEmitStrategy strategy)
Helper function to create a SparseIterator object that iterates over a non-empty subsection created b...
uint64_t Level
The type of level identifiers and level-ranks.
std::optional< SparseTensorType > tryGetSparseTensorType(Value val)
RankedTensorType getRankedTensorType(T &&t)
Convenience method to abbreviate casting getType().
std::pair< std::unique_ptr< SparseTensorLevel >, std::unique_ptr< SparseIterator > > makeSynLevelAndIterator(Value sz, unsigned tid, unsigned lvl, SparseEmitStrategy strategy)
Helper function to create a synthetic SparseIterator object that iterates over a dense space specifie...
Value createOrFoldSliceStrideOp(OpBuilder &builder, Location loc, Value tensor, Dimension dim)
Generates code to retrieve the slice slice for the sparse tensor slice, return a constant if the offs...
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
std::pair< Operation *, Value > genCoIteration(OpBuilder &builder, Location loc, ArrayRef< SparseIterator * > iters, MutableArrayRef< Value > reduc, Value uniIdx, bool userReducFirst=false)
bool isZeroRankedTensorOrScalar(Type type)
std::unique_ptr< SparseIterator > makePaddedIterator(std::unique_ptr< SparseIterator > &&sit, Value padLow, Value padHigh, SparseEmitStrategy strategy)
Helper function to create a SparseIterator object that iterates over a padded sparse level (the padde...
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
std::unique_ptr< SparseIterator > makeSimpleIterator(OpBuilder &b, Location l, const SparseIterationSpace &iterSpace)
Helper function to create a simple SparseIterator object that iterate over the entire iteration space...
func::CallOp createFuncCall(OpBuilder &builder, Location loc, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)
Creates a CallOp to the function reference returned by getFunc() in the builder's module.
std::unique_ptr< SparseIterator > makeSlicedLevelIterator(std::unique_ptr< SparseIterator > &&sit, Value offset, Value stride, Value size, SparseEmitStrategy strategy)
Helper function to create a SparseIterator object that iterates over a sliced space,...
Value createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc, Value tensor, Dimension dim)
Generates code to retrieve the slice offset for the sparse tensor slice, return a constant if the off...
std::unique_ptr< SparseIterator > makeNonEmptySubSectIterator(OpBuilder &b, Location l, const SparseIterator *parent, Value loopBound, std::unique_ptr< SparseIterator > &&delegate, Value size, unsigned stride, SparseEmitStrategy strategy)
Helper function to create a SparseIterator object that iterate over the non-empty subsections set.
unsigned TensorId
Tensor identifiers, chosen to be the BlockArgument::getArgNumber of the value passed to Merger::build...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ DimId
Dimensional identifier.
@ Constant
Constant integer.
SparseEmitStrategy
Defines a scope for reinterpret map pass.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.