29 #define CMPI(p, l, r) \
30 (builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::p, (l), (r)) \
33 #define C_IDX(v) (constantIndex(builder, loc, (v)))
34 #define YIELD(vs) (builder.create<scf::YieldOp>(loc, (vs)))
35 #define ADDI(lhs, rhs) (builder.create<arith::AddIOp>(loc, (lhs), (rhs)))
36 #define ANDI(lhs, rhs) (builder.create<arith::AndIOp>(loc, (lhs), (rhs)))
37 #define SUBI(lhs, rhs) (builder.create<arith::SubIOp>(loc, (lhs), (rhs)))
38 #define MULI(lhs, rhs) (builder.create<arith::MulIOp>(loc, (lhs), (rhs)))
39 #define REMUI(lhs, rhs) (builder.create<arith::RemUIOp>(loc, (lhs), (rhs)))
40 #define DIVUI(lhs, rhs) (builder.create<arith::DivUIOp>(loc, (lhs), (rhs)))
41 #define SELECT(c, l, r) (builder.create<arith::SelectOp>(loc, (c), (l), (r)))
50 memref = builder.
create<memref::CastOp>(
79 if (
auto f = llvm::dyn_cast<FloatAttr>(attr); f && f.getValue().isZero())
81 if (
auto i = llvm::dyn_cast<IntegerAttr>(attr); i && i.getValue().isZero())
90 return ofr.get<
Value>();
98 if (padOp && stt.has_value() && stt->hasEncoding() &&
99 padOp.getSourceType().getEncoding() == stt->getEncoding() &&
100 stt->getEncoding().isIdentity()) {
104 m_Op<tensor::YieldOp>(
m_Constant(&padCst))) &&
106 return padOp.getSource();
117 bool isSparseOut,
unsigned numLoops,
120 initialize(tensors, loopTag, hasOutput, isSparseOut, numLoops, dimGetter);
124 bool isSparseOut,
unsigned numLoops,
128 this->loopTag = loopTag;
129 this->hasOutput = hasOutput;
130 this->isSparseOut = isSparseOut;
131 this->emitStrategy = emitStrategy;
133 const unsigned numManifestTensors = ts.size();
134 const unsigned synTensorId = numManifestTensors;
135 const unsigned numTensors = numManifestTensors + 1;
137 this->tensors.assign(ts.begin(), ts.end());
139 this->valBuffer.assign(numTensors,
nullptr);
140 this->lvls.resize(numTensors);
141 this->iters.resize(numTensors);
142 this->spIterVals.resize(numTensors);
146 this->loopStack.reserve(numLoops);
147 this->loopSeqStack.reserve(numLoops);
150 this->dependentLvlMap.assign(
151 numTensors, std::vector<std::vector<std::pair<TensorLevel, unsigned>>>());
152 this->sliceMeta.assign(
153 numTensors, std::vector<std::vector<std::pair<Value, unsigned>>>());
154 this->levelReducedDep.assign(numTensors, std::vector<unsigned>());
157 for (
TensorId tid = 0; tid < numTensors; tid++) {
159 if (tid == synTensorId) {
165 const Value t = tensors[tid];
175 lvls[tid].resize(lvlRank);
176 iters[tid].resize(lvlRank);
177 spIterVals[tid].resize(lvlRank);
178 loopHighs.assign(numLoops,
nullptr);
181 levelReducedDep[tid].assign(lvlRank, 0);
182 dependentLvlMap[tid].assign(
183 lvlRank, std::vector<std::pair<TensorLevel, unsigned>>());
184 sliceMeta[tid].assign(lvlRank, std::vector<std::pair<Value, unsigned>>());
185 if (dimGetter && !isSynTensor(tid)) {
186 for (
Level l = 0; l < lvlRank; l++) {
187 std::vector<std::pair<LoopId, unsigned>> deps = dimGetter(tid, l);
189 llvm::sort(deps, llvm::less_first());
191 dependentLvlMap[tid][l] = std::move(deps);
192 unsigned depends = dependentLvlMap[tid][l].size();
195 sliceMeta[tid][l].reserve(depends);
201 std::unique_ptr<SparseIterator>
204 Value tensor = tensors[t];
209 if (folded != tensor) {
212 if (padOp.getPaddedDims().test(l)) {
220 if (stt.hasEncoding() && stt.getEncoding().isSlice()) {
224 std::move(it), offset, stride, lvls[t][l]->getSize(), emitStrategy);
241 const auto rtp = dyn_cast<RankedTensorType>(tensor.
getType());
248 const auto shape = rtp.getShape();
253 bool isOutput = isOutputTensor(t);
254 Type elementType = stt.getElementType();
255 if (!stt.hasEncoding()) {
262 if (llvm::isa_and_nonnull<tensor::ExtractSliceOp>(tensor.
getDefiningOp()))
266 builder.
create<bufferization::ToMemrefOp>(loc, denseTp, tensor);
268 if (isOutput && updater)
269 denseVal = updater(builder, loc, denseVal, tensor);
271 valBuffer[t] = denseVal;
276 valBuffer[t] = builder.
create<ToValuesOp>(loc, tensor);
288 for (
unsigned i = 0, e = loopHighs.size(); i < e; i++) {
289 Value sz = loopHighs[i] = synSetter(builder, loc, i);
291 lvls[synId][i] = std::move(stl);
292 iters[synId][i].emplace_back(std::move(it));
306 const auto rtp = dyn_cast<RankedTensorType>(tensor.
getType());
313 const Level lvlRank = stt.getLvlRank();
316 for (
Level l = 0; l < lvlRank; l++) {
319 if (!dependentLvlMap[t][l].empty())
322 auto it = makeLevelIterator(builder, loc, t, l);
323 iters[t][l].emplace_back(std::move(it));
330 initSubSectIterator(builder, loc);
335 for (
TensorId t = 0, e = tensors.size(); t < e; t++) {
336 auto rtp = dyn_cast<RankedTensorType>(tensors[t].
getType());
343 auto remDepStack = dependentLvlMap;
344 std::vector<std::tuple<LoopId, TensorId, Level>> depRedOrder;
345 for (
Level lvl = 0; lvl < lvlRank; lvl++) {
347 std::reverse(remDepStack[t][lvl].begin(), remDepStack[t][lvl].end());
348 for (
auto [loop, coeff] : dependentLvlMap[t][lvl])
349 depRedOrder.emplace_back(std::make_tuple(loop, t, lvl));
352 if (depRedOrder.empty())
355 std::sort(depRedOrder.begin(), depRedOrder.end(),
356 [](
auto &l,
auto &r) { return std::get<0>(l) < std::get<0>(r); });
359 for (
auto [loop, t, lvl] : depRedOrder) {
360 std::pair<LoopId, unsigned> curDep = remDepStack[t][lvl].back();
361 assert(curDep.first == loop);
362 remDepStack[t][lvl].pop_back();
364 auto lvlIt = makeLevelIterator(builder, loc, t, lvl);
366 if (!parent && lvl > 0) {
367 if (dependentLvlMap[t][lvl - 1].empty()) {
368 parent = iters[t][lvl - 1].back().get();
372 std::unique_ptr<SparseIterator> it;
373 if (!remDepStack[t][lvl].empty()) {
376 for (
auto [loop, stride] : remDepStack[t][lvl]) {
381 std::move(lvlIt), size, curDep.second,
386 std::move(lvlIt), loopHighs[loop],
387 curDep.second, emitStrategy);
389 lastIter[t] = it.get();
390 iters[t][lvl].emplace_back(std::move(it));
395 void LoopEmitter::categorizeIterators(
403 raIters.push_back(it);
405 spIters.push_back(it);
408 std::stable_sort(spIters.begin(), spIters.end(), [](
auto lhs,
auto rhs) {
410 return static_cast<uint8_t>(lhs->kind) > static_cast<uint8_t>(rhs->kind);
417 assert(loopSeqStack.size() == loopStack.size());
422 levelReducedDep[tid][lvl]++;
423 prepareLoopOverTensorAtLvl(builder, loc, tid, lvl);
428 loopSeqStack.emplace_back(
C_IDX(0), tidLvls.vec());
432 assert(loopSeqStack.size() == loopStack.size() + 1);
437 levelReducedDep[tid][lvl]--;
439 loopSeqStack.pop_back();
449 const auto loopId = cast<AffineDimExpr>(a).getPosition();
450 return loopStack[loopId].iv;
453 auto binOp = cast<AffineBinaryOpExpr>(a);
455 genAffine(builder, loc, binOp.getRHS()));
458 auto binOp = cast<AffineBinaryOpExpr>(a);
460 genAffine(builder, loc, binOp.getRHS()));
463 int64_t c = cast<AffineConstantExpr>(a).getValue();
467 llvm_unreachable(
"unexpected affine subscript");
471 std::pair<Operation *, Value> LoopEmitter::emitForLoopOverTensorAtLvl(
480 auto [lo, hi] = iter.
genForCond(builder, loc);
484 scf::ParallelOp parOp =
485 builder.
create<scf::ParallelOp>(loc, lo, hi, step, reduc);
487 assert(parOp.getNumReductions() == reduc.size());
488 iv = parOp.getInductionVars()[0];
497 for (
int i = 0, e = reduc.size(); i < e; i++)
498 reduc[i] = parOp.getInitVals()[i];
501 scf::ForOp forOp = builder.
create<scf::ForOp>(loc, lo, hi, step, reduc);
503 iv = forOp.getInductionVar();
506 assert(forOp.getNumRegionIterArgs() == reduc.size());
507 for (
int i = 0, e = reduc.size(); i < e; i++)
508 reduc[i] = forOp.getRegionIterArg(i);
516 crd = iter.
deref(builder, loc);
518 iter.
locate(builder, loc, iv);
524 std::pair<Operation *, Value> LoopEmitter::emitWhileLoopOverTensorsAtLvls(
528 needsUniv ? loopSeqStack.back().first :
nullptr);
533 if (spIters.size() > 1)
536 if (spIters.size() == 1)
537 return spIters.front()->iteratableByFor();
547 auto coIterOp = cast<CoIterateOp>(loopStack.back().loop);
552 Region &caseRegion = coIterOp.getRegion(caseIdx);
554 "re-initialize the same coiteration case region.");
557 TypeRange iterArgsTps = coIterOp.getInitArgs().getTypes();
562 blockArgTps.append(iterArgsTps.begin(), iterArgsTps.end());
564 for (
auto i : caseBit.
bits()) {
565 blockArgTps.push_back(
566 cast<IterSpaceType>(coIterOp.getIterSpaces()[i].getType())
575 loopStack.back().iv = coIterOp.getCrds(caseIdx).front();
577 ValueRange iterArgs = coIterOp.getRegionIterArgs(caseIdx);
580 ValueRange iters = coIterOp.getRegionIterators(caseIdx);
584 spIterVals[tl.first][tl.second] = iters.front();
585 iters = iters.drop_front();
587 spIterVals[tl.first][tl.second] =
nullptr;
591 assert(iters.empty());
603 if (tidLvls.size() == 1) {
605 Value t = tensors[tid];
608 ExtractIterSpaceOp extractSpaceOp =
609 lvl == 0 ? builder.
create<ExtractIterSpaceOp>(loc, t)
610 : builder.
create<ExtractIterSpaceOp>(
611 loc, t, spIterVals[tid][lvl - 1], lvl);
613 IterateOp iterOp = builder.
create<IterateOp>(
614 loc, extractSpaceOp.getExtractedSpace(), reduc);
615 spIterVals[tid][lvl] = iterOp.getIterator();
618 llvm::copy(iterOp.getRegionIterArgs(), reduc.begin());
622 iterOp.getCrds().
front(), loopTag);
629 Value t = tensors[tid];
630 ExtractIterSpaceOp extractSpaceOp =
631 lvl == 0 ? builder.
create<ExtractIterSpaceOp>(loc, t)
632 : builder.
create<ExtractIterSpaceOp>(
633 loc, t, spIterVals[tid][lvl - 1], lvl);
634 spaces.push_back(extractSpaceOp.getExtractedSpace());
636 auto coIterOp = builder.
create<CoIterateOp>(loc, spaces, reduc, numCases);
639 loopStack.emplace_back(tidLvls, coIterOp,
nullptr,
645 tryParallel = tryParallel && reduc.size() <= 1;
649 categorizeIterators(tidLvls, raIters, spIters);
655 needsUniv = !spIters.empty() && needsUniv;
666 if (shouldIteratedByForLoop(spIters) && !needsUniv) {
667 assert(spIters.size() <= 1);
668 SparseIterator &it = spIters.empty() ? *raIters.front() : *spIters.front();
670 emitForLoopOverTensorAtLvl(builder, loc, it, reduc, tryParallel);
673 for (
auto *it : spIters) {
678 for (
auto *it : raIters)
682 emitWhileLoopOverTensorsAtLvls(builder, loc, spIters, reduc, needsUniv);
687 it->
locate(builder, loc, iv);
701 lvl == 0 ? nullptr : iters[tid][lvl - 1].back().get();
702 auto &it = getCurIterator(tid, lvl);
703 it.
genInit(builder, loc, parent);
707 it.
locate(builder, loc, lvlCrd);
716 bool hasParent = lvl == 0 || !dependentLvlMap[tid][lvl].empty();
719 hasParent ? nullptr : iters[tid][lvl - 1].back().get();
720 auto &it = getCurIterator(tid, lvl);
721 it.
genInit(builder, loc, parent);
730 const LoopInfo &loopInfo = loopStack.back();
732 auto iterateOp = llvm::cast<IterateOp>(loopInfo.loop);
733 assert(reduc.size() == iterateOp.getNumResults());
734 rewriter.
create<sparse_tensor::YieldOp>(loc, reduc);
738 llvm::copy(iterateOp.getResults(), reduc.begin());
741 if (
auto forOp = llvm::dyn_cast<scf::ForOp>(loopInfo.loop)) {
742 if (!reduc.empty()) {
743 assert(reduc.size() == forOp.getNumResults());
744 rewriter.
create<scf::YieldOp>(loc, reduc);
749 llvm::copy(forOp.getResults(), reduc.begin());
751 auto parOp = llvm::cast<scf::ParallelOp>(loopInfo.loop);
752 if (!reduc.empty()) {
753 assert(reduc.size() == parOp.getInitVals().size() && reduc.size() == 1);
754 Operation *redExp = reduc.front().getDefiningOp();
756 assert(redExp->
getUses().empty());
762 Value redVal = parOp.getInitVals().front();
774 unsigned numUsers = 0;
776 if (op->getParentOp() == parOp)
779 assert(numUsers == 1);
783 auto redOp = rewriter.
create<scf::ReduceOp>(loc, curVal);
785 Block *redBlock = &redOp.getReductions().
front().front();
791 newRed, [&]() { newRed->
setOperands(redBlock->getArguments()); });
799 for (
unsigned i = 0, e = parOp.getResults().size(); i < e; i++)
800 reduc[i] = parOp.getResult(i);
806 const LoopInfo &loopInfo = loopStack.back();
807 auto whileOp = llvm::cast<scf::WhileOp>(loopInfo.loop);
808 Value iv = loopInfo.iv;
833 Value uniIdx = whileOp.getResults().back();
834 it.
locate(builder, loc, uniIdx);
839 for (
auto &i : reduc) {
840 operands.push_back(i);
842 i = whileRes.front();
843 whileRes = whileRes.drop_front();
847 if (operands.size() < whileOp.getNumResults()) {
848 assert(operands.size() + 1 == whileOp.getNumResults());
850 operands.push_back(
ADDI(iv, one));
852 loopSeqStack.back().first = whileOp->getResults().back();
855 if (!operands.empty())
865 const LoopInfo &loopInfo = loopStack.back();
868 if (isa<IterateOp>(p))
869 rewriter.
create<sparse_tensor::YieldOp>(loc, reduc);
875 loopStack.pop_back();
881 if (!loopInfo.userCodeBlock->empty() &&
882 llvm::isa<scf::YieldOp>(&loopInfo.userCodeBlock->back())) {
885 assert(loopInfo.userCodeBlock->back().getNumResults() == 0);
889 if (llvm::isa<scf::WhileOp>(loopInfo.loop)) {
890 exitWhileLoop(rewriter, loc, reduc);
892 exitForLoop(rewriter, loc, reduc);
895 assert(loopStack.size() == loopSeqStack.size());
896 loopStack.pop_back();
917 ivs.append(reduc.begin(), reduc.end());
922 ivs.append(itVals.begin(), itVals.end());
926 ivs.append(reduc.begin(), reduc.end());
930 ivs.push_back(uniIdx);
933 assert(llvm::all_of(ivs, [](
Value v) {
return v !=
nullptr; }));
935 auto whileOp = builder.
create<scf::WhileOp>(loc, types, ivs);
944 Value whileCond =
nullptr;
947 auto [cond, remArgs] = it->
genWhileCond(builder, loc, bArgs);
948 whileCond = !whileCond ? cond :
ANDI(whileCond, cond);
953 assert(bArgs.size() == reduc.size() + (uniIdx ? 1 : 0));
967 it->
deref(builder, loc);
971 for (
unsigned i = 0, e = reduc.size(); i < e; i++)
987 min = whileOp.getAfterArguments().back();
990 return {whileOp,
min};
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Value genSliceStride(OpBuilder &builder, Location loc, Value tensor, Level lvl)
static Value tryFoldTensors(Value t)
static Value genSliceOffset(OpBuilder &builder, Location loc, Value tensor, Level lvl)
static bool isIntOrFPZero(Attribute attr)
static LLVM_ATTRIBUTE_UNUSED void dumpIndexMemRef(OpBuilder &builder, Location loc, Value memref)
static Value unFoldOpIntResult(OpBuilder &builder, Location loc, OpFoldResult ofr)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
Base type for affine expression.
AffineExprKind getKind() const
Return the classification for this type.
Attributes are known-constant values of operations.
This class provides a shared interface for ranked and unranked memref types.
Block represents an ordered list of Operations.
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
BlockArgListType getArguments()
IntegerAttr getI64IntegerAttr(int64_t value)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumOperands()
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
result_range getResults()
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
unsigned getNumResults()
Return the number of results held by this operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockListType & getBlocks()
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A simple wrapper to encode a bitset of (at most 64) levels, currently used by sparse_tensor....
iterator_range< const_set_bits_iterator > bits() const
void exitCurrentLoop(RewriterBase &rewriter, Location loc, MutableArrayRef< Value > reduc={})
Generates code to exit the current loop (e.g., generates yields, forwards loop induction variables,...
void locateLvlAtAffineAddress(OpBuilder &builder, Location loc, TensorLevel tidLvl, AffineExpr lvlExpr)
Emits the address for a dense level based on the value evaluated by the provided affine expression.
void enterNewLoopSeq(OpBuilder &builder, Location loc, ArrayRef< TensorLevel > tidLvls)
Enters a new loop sequence, the loops within the same sequence starts from the break points of previo...
Value genAffine(OpBuilder &builder, Location loc, AffineExpr a)
Generates code to compute an affine expression whose variables are LoopIds (i.e., a....
Region * enterCurrentCoIterationCase(OpBuilder &builder, Location loc, I64BitSet caseBit, unsigned caseIdx, MutableArrayRef< Value > reduc)
Operation * enterCoIterationOverTensorsAtLvls(OpBuilder &builder, Location loc, ArrayRef< TensorLevel > tidLvls, unsigned numCases, MutableArrayRef< Value > reduc={}, bool isParallel=false, bool needsUniv=false)
Emits a co-iteration loop over a set of tensors.
TensorLevel makeTensorLevel(TensorId t, Level l) const
Compresses a TensorId and Level into a TensorLevel.
unsigned getNumManifestTensors() const
Gets the total number of manifest tensors (excluding the synthetic tensor).
void initialize(ValueRange tensors, StringAttr loopTag=nullptr, bool hasOutput=false, bool isSparseOut=false, unsigned numLoops=0, DependentLvlGetter getter=nullptr, SparseEmitStrategy emitStrategy=SparseEmitStrategy::kFunctional)
Takes an array of input tensors, which the generated loops will iterate over.
std::pair< TensorId, Level > unpackTensorLevel(TensorLevel tidLvl) const
De-compresses a TensorLevel back to a pair of TensorId and Level.
auto unpackTensorLevelRange(ContainerTy &&c) const
Converts a range of TensorLevel to a range of std::pair<TensorId, Level>
void initializeLoopEmit(OpBuilder &builder, Location loc, OutputUpdater updater=nullptr, SynTensorBoundSetter synSetter=nullptr)
Starts a loop emitting session by generating all the buffers needed for iterating over the tensors.
void exitCurrentLoopSeq(OpBuilder &builder, Location loc)
Exits the current loop sequence, this will reset universal index to 0.
TensorId getSynTensorId() const
Gets the TensorId for synthetic tensor.
Helper class that generates loop conditions, etc, to traverse a sparse tensor level.
virtual std::pair< Value, Value > genForCond(OpBuilder &b, Location l)
void genInit(OpBuilder &b, Location l, const SparseIterator *p)
void locate(OpBuilder &b, Location l, Value crd)
virtual ValueRange forwardIf(OpBuilder &b, Location l, Value cond)
ValueRange linkNewScope(ValueRange pos)
ValueRange getCursor() const
Value deref(OpBuilder &b, Location l)
virtual bool randomAccessible() const =0
std::pair< Value, ValueRange > genWhileCond(OpBuilder &b, Location l, ValueRange vs)
A wrapper around RankedTensorType, which has three goals:
Level getLvlRank() const
Returns the level-rank.
BaseMemRefType getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with fully dynamic layout.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Dimension toDim(SparseTensorEncodingAttr enc, Level l)
Convenience method to translate the given level to the corresponding dimension.
std::unique_ptr< SparseTensorLevel > makeSparseTensorLevel(OpBuilder &b, Location l, Value t, unsigned tid, Level lvl)
Helper function to create a TensorLevel object from given tensor.
std::unique_ptr< SparseIterator > makeTraverseSubSectIterator(OpBuilder &b, Location l, const SparseIterator &subsectIter, const SparseIterator &parent, std::unique_ptr< SparseIterator > &&wrap, Value loopBound, unsigned stride, SparseEmitStrategy strategy)
Helper function to create a SparseIterator object that iterates over a non-empty subsection created b...
uint64_t Level
The type of level identifiers and level-ranks.
std::optional< SparseTensorType > tryGetSparseTensorType(Value val)
RankedTensorType getRankedTensorType(T &&t)
Convenience method to abbreviate casting getType().
std::pair< std::unique_ptr< SparseTensorLevel >, std::unique_ptr< SparseIterator > > makeSynLevelAndIterator(Value sz, unsigned tid, unsigned lvl, SparseEmitStrategy strategy)
Helper function to create a synthetic SparseIterator object that iterates over a dense space specifie...
Value createOrFoldSliceStrideOp(OpBuilder &builder, Location loc, Value tensor, Dimension dim)
Generates code to retrieve the slice slice for the sparse tensor slice, return a constant if the offs...
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
std::pair< Operation *, Value > genCoIteration(OpBuilder &builder, Location loc, ArrayRef< SparseIterator * > iters, MutableArrayRef< Value > reduc, Value uniIdx, bool userReducFirst=false)
bool isZeroRankedTensorOrScalar(Type type)
std::unique_ptr< SparseIterator > makePaddedIterator(std::unique_ptr< SparseIterator > &&sit, Value padLow, Value padHigh, SparseEmitStrategy strategy)
Helper function to create a SparseIterator object that iterates over a padded sparse level (the padde...
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
std::unique_ptr< SparseIterator > makeSimpleIterator(OpBuilder &b, Location l, const SparseIterationSpace &iterSpace)
Helper function to create a simple SparseIterator object that iterate over the entire iteration space...
func::CallOp createFuncCall(OpBuilder &builder, Location loc, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)
Creates a CallOp to the function reference returned by getFunc() in the builder's module.
std::unique_ptr< SparseIterator > makeSlicedLevelIterator(std::unique_ptr< SparseIterator > &&sit, Value offset, Value stride, Value size, SparseEmitStrategy strategy)
Helper function to create a SparseIterator object that iterates over a sliced space,...
Value createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc, Value tensor, Dimension dim)
Generates code to retrieve the slice offset for the sparse tensor slice, return a constant if the off...
std::unique_ptr< SparseIterator > makeNonEmptySubSectIterator(OpBuilder &b, Location l, const SparseIterator *parent, Value loopBound, std::unique_ptr< SparseIterator > &&delegate, Value size, unsigned stride, SparseEmitStrategy strategy)
Helper function to create a SparseIterator object that iterate over the non-empty subsections set.
unsigned TensorId
Tensor identifiers, chosen to be the BlockArgument::getArgNumber of the value passed to Merger::build...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ DimId
Dimensional identifier.
@ Constant
Constant integer.
SparseEmitStrategy
Defines a scope for reinterpret map pass.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.