29 #define CMPI(p, l, r) \
30 (builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::p, (l), (r)) \
33 #define C_IDX(v) (constantIndex(builder, loc, (v)))
34 #define YIELD(vs) (builder.create<scf::YieldOp>(loc, (vs)))
35 #define ADDI(lhs, rhs) (builder.create<arith::AddIOp>(loc, (lhs), (rhs)))
36 #define ANDI(lhs, rhs) (builder.create<arith::AndIOp>(loc, (lhs), (rhs)))
37 #define SUBI(lhs, rhs) (builder.create<arith::SubIOp>(loc, (lhs), (rhs)))
38 #define MULI(lhs, rhs) (builder.create<arith::MulIOp>(loc, (lhs), (rhs)))
39 #define REMUI(lhs, rhs) (builder.create<arith::RemUIOp>(loc, (lhs), (rhs)))
40 #define DIVUI(lhs, rhs) (builder.create<arith::DivUIOp>(loc, (lhs), (rhs)))
41 #define SELECT(c, l, r) (builder.create<arith::SelectOp>(loc, (c), (l), (r)))
50 memref = builder.
create<memref::CastOp>(
79 if (
auto f = llvm::dyn_cast<FloatAttr>(attr); f && f.getValue().isZero())
81 if (
auto i = llvm::dyn_cast<IntegerAttr>(attr); i && i.getValue().isZero())
90 return cast<Value>(ofr);
98 if (padOp && stt.has_value() && stt->hasEncoding() &&
99 padOp.getSourceType().getEncoding() == stt->getEncoding() &&
100 stt->getEncoding().isIdentity()) {
104 m_Op<tensor::YieldOp>(
m_Constant(&padCst))) &&
106 return padOp.getSource();
117 bool isSparseOut,
unsigned numLoops,
120 initialize(tensors, loopTag, hasOutput, isSparseOut, numLoops, dimGetter);
124 bool isSparseOut,
unsigned numLoops,
128 this->loopTag = loopTag;
129 this->hasOutput = hasOutput;
130 this->isSparseOut = isSparseOut;
131 this->emitStrategy = emitStrategy;
133 const unsigned numManifestTensors = ts.size();
134 const unsigned synTensorId = numManifestTensors;
135 const unsigned numTensors = numManifestTensors + 1;
137 this->tensors.assign(ts.begin(), ts.end());
139 this->valBuffer.assign(numTensors,
nullptr);
140 this->lvls.resize(numTensors);
141 this->iters.resize(numTensors);
142 this->spIterVals.resize(numTensors);
146 this->loopStack.reserve(numLoops);
147 this->loopSeqStack.reserve(numLoops);
150 this->dependentLvlMap.assign(
151 numTensors, std::vector<std::vector<std::pair<TensorLevel, unsigned>>>());
152 this->sliceMeta.assign(
153 numTensors, std::vector<std::vector<std::pair<Value, unsigned>>>());
154 this->levelReducedDep.assign(numTensors, std::vector<unsigned>());
157 for (
TensorId tid = 0; tid < numTensors; tid++) {
159 if (tid == synTensorId) {
165 const Value t = tensors[tid];
175 lvls[tid].resize(lvlRank);
176 iters[tid].resize(lvlRank);
177 spIterVals[tid].resize(lvlRank);
178 loopHighs.assign(numLoops,
nullptr);
181 levelReducedDep[tid].assign(lvlRank, 0);
182 dependentLvlMap[tid].assign(
183 lvlRank, std::vector<std::pair<TensorLevel, unsigned>>());
184 sliceMeta[tid].assign(lvlRank, std::vector<std::pair<Value, unsigned>>());
185 if (dimGetter && !isSynTensor(tid)) {
186 for (
Level l = 0; l < lvlRank; l++) {
187 std::vector<std::pair<LoopId, unsigned>> deps = dimGetter(tid, l);
189 llvm::sort(deps, llvm::less_first());
191 dependentLvlMap[tid][l] = std::move(deps);
192 unsigned depends = dependentLvlMap[tid][l].size();
195 sliceMeta[tid][l].reserve(depends);
201 std::unique_ptr<SparseIterator>
204 Value tensor = tensors[t];
209 if (folded != tensor) {
212 if (padOp.getPaddedDims().test(l)) {
220 if (stt.hasEncoding() && stt.getEncoding().isSlice()) {
224 std::move(it), offset, stride, lvls[t][l]->getSize(), emitStrategy);
241 const auto rtp = dyn_cast<RankedTensorType>(tensor.
getType());
248 const auto shape = rtp.getShape();
253 bool isOutput = isOutputTensor(t);
254 Type elementType = stt.getElementType();
255 if (!stt.hasEncoding()) {
262 if (llvm::isa_and_nonnull<tensor::ExtractSliceOp>(tensor.
getDefiningOp()))
266 builder.
create<bufferization::ToMemrefOp>(loc, denseTp, tensor);
268 if (isOutput && updater)
269 denseVal = updater(builder, loc, denseVal, tensor);
271 valBuffer[t] = denseVal;
276 valBuffer[t] = builder.
create<ToValuesOp>(loc, tensor);
288 for (
unsigned i = 0, e = loopHighs.size(); i < e; i++) {
289 Value sz = loopHighs[i] = synSetter(builder, loc, i);
291 lvls[synId][i] = std::move(stl);
292 iters[synId][i].emplace_back(std::move(it));
306 const auto rtp = dyn_cast<RankedTensorType>(tensor.
getType());
313 const Level lvlRank = stt.getLvlRank();
316 for (
Level l = 0; l < lvlRank; l++) {
319 if (!dependentLvlMap[t][l].empty())
322 auto it = makeLevelIterator(builder, loc, t, l);
323 iters[t][l].emplace_back(std::move(it));
330 initSubSectIterator(builder, loc);
335 for (
TensorId t = 0, e = tensors.size(); t < e; t++) {
336 auto rtp = dyn_cast<RankedTensorType>(tensors[t].
getType());
343 auto remDepStack = dependentLvlMap;
344 std::vector<std::tuple<LoopId, TensorId, Level>> depRedOrder;
345 for (
Level lvl = 0; lvl < lvlRank; lvl++) {
347 std::reverse(remDepStack[t][lvl].begin(), remDepStack[t][lvl].end());
348 for (
auto [loop, coeff] : dependentLvlMap[t][lvl])
349 depRedOrder.emplace_back(std::make_tuple(loop, t, lvl));
352 if (depRedOrder.empty())
355 llvm::sort(depRedOrder, llvm::less_first());
358 for (
auto [loop, t, lvl] : depRedOrder) {
359 std::pair<LoopId, unsigned> curDep = remDepStack[t][lvl].back();
360 assert(curDep.first == loop);
361 remDepStack[t][lvl].pop_back();
363 auto lvlIt = makeLevelIterator(builder, loc, t, lvl);
365 if (!parent && lvl > 0) {
366 if (dependentLvlMap[t][lvl - 1].empty()) {
367 parent = iters[t][lvl - 1].back().get();
371 std::unique_ptr<SparseIterator> it;
372 if (!remDepStack[t][lvl].empty()) {
375 for (
auto [loop, stride] : remDepStack[t][lvl]) {
380 std::move(lvlIt), size, curDep.second,
385 std::move(lvlIt), loopHighs[loop],
386 curDep.second, emitStrategy);
388 lastIter[t] = it.get();
389 iters[t][lvl].emplace_back(std::move(it));
394 void LoopEmitter::categorizeIterators(
402 raIters.push_back(it);
404 spIters.push_back(it);
407 std::stable_sort(spIters.begin(), spIters.end(), [](
auto lhs,
auto rhs) {
409 return static_cast<uint8_t>(lhs->kind) > static_cast<uint8_t>(rhs->kind);
416 assert(loopSeqStack.size() == loopStack.size());
421 levelReducedDep[tid][lvl]++;
422 prepareLoopOverTensorAtLvl(builder, loc, tid, lvl);
427 loopSeqStack.emplace_back(
C_IDX(0), tidLvls.vec());
431 assert(loopSeqStack.size() == loopStack.size() + 1);
436 levelReducedDep[tid][lvl]--;
438 loopSeqStack.pop_back();
448 const auto loopId = cast<AffineDimExpr>(a).getPosition();
449 return loopStack[loopId].iv;
452 auto binOp = cast<AffineBinaryOpExpr>(a);
454 genAffine(builder, loc, binOp.getRHS()));
457 auto binOp = cast<AffineBinaryOpExpr>(a);
459 genAffine(builder, loc, binOp.getRHS()));
462 int64_t c = cast<AffineConstantExpr>(a).getValue();
466 llvm_unreachable(
"unexpected affine subscript");
470 std::pair<Operation *, Value> LoopEmitter::emitForLoopOverTensorAtLvl(
479 auto [lo, hi] = iter.
genForCond(builder, loc);
483 scf::ParallelOp parOp =
484 builder.
create<scf::ParallelOp>(loc, lo, hi, step, reduc);
486 assert(parOp.getNumReductions() == reduc.size());
487 iv = parOp.getInductionVars()[0];
496 for (
int i = 0, e = reduc.size(); i < e; i++)
497 reduc[i] = parOp.getInitVals()[i];
500 scf::ForOp forOp = builder.
create<scf::ForOp>(loc, lo, hi, step, reduc);
502 iv = forOp.getInductionVar();
505 assert(forOp.getNumRegionIterArgs() == reduc.size());
506 for (
int i = 0, e = reduc.size(); i < e; i++)
507 reduc[i] = forOp.getRegionIterArg(i);
515 crd = iter.
deref(builder, loc);
517 iter.
locate(builder, loc, iv);
523 std::pair<Operation *, Value> LoopEmitter::emitWhileLoopOverTensorsAtLvls(
527 needsUniv ? loopSeqStack.back().first :
nullptr);
532 if (spIters.size() > 1)
535 if (spIters.size() == 1)
536 return spIters.front()->iteratableByFor();
546 auto coIterOp = cast<CoIterateOp>(loopStack.back().loop);
551 Region &caseRegion = coIterOp.getRegion(caseIdx);
553 "re-initialize the same coiteration case region.");
556 TypeRange iterArgsTps = coIterOp.getInitArgs().getTypes();
561 blockArgTps.append(iterArgsTps.begin(), iterArgsTps.end());
563 for (
auto i : caseBit.
bits()) {
564 blockArgTps.push_back(
565 cast<IterSpaceType>(coIterOp.getIterSpaces()[i].getType())
574 loopStack.back().iv = coIterOp.getCrds(caseIdx).front();
576 ValueRange iterArgs = coIterOp.getRegionIterArgs(caseIdx);
579 ValueRange iters = coIterOp.getRegionIterators(caseIdx);
583 spIterVals[tl.first][tl.second] = iters.front();
584 iters = iters.drop_front();
586 spIterVals[tl.first][tl.second] =
nullptr;
590 assert(iters.empty());
602 if (tidLvls.size() == 1) {
604 Value t = tensors[tid];
607 ExtractIterSpaceOp extractSpaceOp =
608 lvl == 0 ? builder.
create<ExtractIterSpaceOp>(loc, t)
609 : builder.
create<ExtractIterSpaceOp>(
610 loc, t, spIterVals[tid][lvl - 1], lvl);
612 IterateOp iterOp = builder.
create<IterateOp>(
613 loc, extractSpaceOp.getExtractedSpace(), reduc);
614 spIterVals[tid][lvl] = iterOp.getIterator();
617 llvm::copy(iterOp.getRegionIterArgs(), reduc.begin());
621 iterOp.getCrds().
front(), loopTag);
628 Value t = tensors[tid];
629 ExtractIterSpaceOp extractSpaceOp =
630 lvl == 0 ? builder.
create<ExtractIterSpaceOp>(loc, t)
631 : builder.
create<ExtractIterSpaceOp>(
632 loc, t, spIterVals[tid][lvl - 1], lvl);
633 spaces.push_back(extractSpaceOp.getExtractedSpace());
635 auto coIterOp = builder.
create<CoIterateOp>(loc, spaces, reduc, numCases);
638 loopStack.emplace_back(tidLvls, coIterOp,
nullptr,
644 tryParallel = tryParallel && reduc.size() <= 1;
648 categorizeIterators(tidLvls, raIters, spIters);
654 needsUniv = !spIters.empty() && needsUniv;
665 if (shouldIteratedByForLoop(spIters) && !needsUniv) {
666 assert(spIters.size() <= 1);
667 SparseIterator &it = spIters.empty() ? *raIters.front() : *spIters.front();
669 emitForLoopOverTensorAtLvl(builder, loc, it, reduc, tryParallel);
672 for (
auto *it : spIters) {
677 for (
auto *it : raIters)
681 emitWhileLoopOverTensorsAtLvls(builder, loc, spIters, reduc, needsUniv);
686 it->
locate(builder, loc, iv);
700 lvl == 0 ? nullptr : iters[tid][lvl - 1].back().get();
701 auto &it = getCurIterator(tid, lvl);
702 it.
genInit(builder, loc, parent);
706 it.
locate(builder, loc, lvlCrd);
715 bool hasParent = lvl == 0 || !dependentLvlMap[tid][lvl].empty();
718 hasParent ? nullptr : iters[tid][lvl - 1].back().get();
719 auto &it = getCurIterator(tid, lvl);
720 it.
genInit(builder, loc, parent);
729 const LoopInfo &loopInfo = loopStack.back();
731 auto iterateOp = llvm::cast<IterateOp>(loopInfo.loop);
732 assert(reduc.size() == iterateOp.getNumResults());
733 rewriter.
create<sparse_tensor::YieldOp>(loc, reduc);
737 llvm::copy(iterateOp.getResults(), reduc.begin());
740 if (
auto forOp = llvm::dyn_cast<scf::ForOp>(loopInfo.loop)) {
741 if (!reduc.empty()) {
742 assert(reduc.size() == forOp.getNumResults());
743 rewriter.
create<scf::YieldOp>(loc, reduc);
748 llvm::copy(forOp.getResults(), reduc.begin());
750 auto parOp = llvm::cast<scf::ParallelOp>(loopInfo.loop);
751 if (!reduc.empty()) {
752 assert(reduc.size() == parOp.getInitVals().size() && reduc.size() == 1);
753 Operation *redExp = reduc.front().getDefiningOp();
755 assert(redExp->
getUses().empty());
761 Value redVal = parOp.getInitVals().front();
773 unsigned numUsers = 0;
775 if (op->getParentOp() == parOp)
778 assert(numUsers == 1);
782 auto redOp = rewriter.
create<scf::ReduceOp>(loc, curVal);
784 Block *redBlock = &redOp.getReductions().
front().front();
790 newRed, [&]() { newRed->
setOperands(redBlock->getArguments()); });
798 for (
unsigned i = 0, e = parOp.getResults().size(); i < e; i++)
799 reduc[i] = parOp.getResult(i);
805 const LoopInfo &loopInfo = loopStack.back();
806 auto whileOp = llvm::cast<scf::WhileOp>(loopInfo.loop);
807 Value iv = loopInfo.iv;
832 Value uniIdx = whileOp.getResults().back();
833 it.
locate(builder, loc, uniIdx);
838 for (
auto &i : reduc) {
839 operands.push_back(i);
841 i = whileRes.front();
842 whileRes = whileRes.drop_front();
846 if (operands.size() < whileOp.getNumResults()) {
847 assert(operands.size() + 1 == whileOp.getNumResults());
849 operands.push_back(
ADDI(iv, one));
851 loopSeqStack.back().first = whileOp->getResults().back();
854 if (!operands.empty())
864 const LoopInfo &loopInfo = loopStack.back();
867 if (isa<IterateOp>(p))
868 rewriter.
create<sparse_tensor::YieldOp>(loc, reduc);
874 loopStack.pop_back();
880 if (!loopInfo.userCodeBlock->empty() &&
881 llvm::isa<scf::YieldOp>(&loopInfo.userCodeBlock->back())) {
884 assert(loopInfo.userCodeBlock->back().getNumResults() == 0);
888 if (llvm::isa<scf::WhileOp>(loopInfo.loop)) {
889 exitWhileLoop(rewriter, loc, reduc);
891 exitForLoop(rewriter, loc, reduc);
894 assert(loopStack.size() == loopSeqStack.size());
895 loopStack.pop_back();
916 ivs.append(reduc.begin(), reduc.end());
921 ivs.append(itVals.begin(), itVals.end());
925 ivs.append(reduc.begin(), reduc.end());
929 ivs.push_back(uniIdx);
932 assert(llvm::all_of(ivs, [](
Value v) {
return v !=
nullptr; }));
934 auto whileOp = builder.
create<scf::WhileOp>(loc, types, ivs);
943 Value whileCond =
nullptr;
946 auto [cond, remArgs] = it->
genWhileCond(builder, loc, bArgs);
947 whileCond = !whileCond ? cond :
ANDI(whileCond, cond);
952 assert(bArgs.size() == reduc.size() + (uniIdx ? 1 : 0));
966 it->
deref(builder, loc);
970 for (
unsigned i = 0, e = reduc.size(); i < e; i++)
986 min = whileOp.getAfterArguments().back();
989 return {whileOp,
min};
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Value genSliceStride(OpBuilder &builder, Location loc, Value tensor, Level lvl)
static Value tryFoldTensors(Value t)
static Value genSliceOffset(OpBuilder &builder, Location loc, Value tensor, Level lvl)
static bool isIntOrFPZero(Attribute attr)
static LLVM_ATTRIBUTE_UNUSED void dumpIndexMemRef(OpBuilder &builder, Location loc, Value memref)
static Value unFoldOpIntResult(OpBuilder &builder, Location loc, OpFoldResult ofr)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
Base type for affine expression.
AffineExprKind getKind() const
Return the classification for this type.
Attributes are known-constant values of operations.
This class provides a shared interface for ranked and unranked memref types.
Block represents an ordered list of Operations.
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
BlockArgListType getArguments()
IntegerAttr getI64IntegerAttr(int64_t value)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumOperands()
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
result_range getResults()
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
unsigned getNumResults()
Return the number of results held by this operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockListType & getBlocks()
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A simple wrapper to encode a bitset of (at most 64) levels, currently used by sparse_tensor....
iterator_range< const_set_bits_iterator > bits() const
void exitCurrentLoop(RewriterBase &rewriter, Location loc, MutableArrayRef< Value > reduc={})
Generates code to exit the current loop (e.g., generates yields, forwards loop induction variables,...
void locateLvlAtAffineAddress(OpBuilder &builder, Location loc, TensorLevel tidLvl, AffineExpr lvlExpr)
Emits the address for a dense level based on the value evaluated by the provided affine expression.
void enterNewLoopSeq(OpBuilder &builder, Location loc, ArrayRef< TensorLevel > tidLvls)
Enters a new loop sequence, the loops within the same sequence starts from the break points of previo...
Value genAffine(OpBuilder &builder, Location loc, AffineExpr a)
Generates code to compute an affine expression whose variables are LoopIds (i.e., cast<AffineDimExpr>...
Region * enterCurrentCoIterationCase(OpBuilder &builder, Location loc, I64BitSet caseBit, unsigned caseIdx, MutableArrayRef< Value > reduc)
Operation * enterCoIterationOverTensorsAtLvls(OpBuilder &builder, Location loc, ArrayRef< TensorLevel > tidLvls, unsigned numCases, MutableArrayRef< Value > reduc={}, bool isParallel=false, bool needsUniv=false)
Emits a co-iteration loop over a set of tensors.
TensorLevel makeTensorLevel(TensorId t, Level l) const
Compresses a TensorId and Level into a TensorLevel.
unsigned getNumManifestTensors() const
Gets the total number of manifest tensors (excluding the synthetic tensor).
void initialize(ValueRange tensors, StringAttr loopTag=nullptr, bool hasOutput=false, bool isSparseOut=false, unsigned numLoops=0, DependentLvlGetter getter=nullptr, SparseEmitStrategy emitStrategy=SparseEmitStrategy::kFunctional)
Takes an array of input tensors, which the generated loops will iterate over.
std::pair< TensorId, Level > unpackTensorLevel(TensorLevel tidLvl) const
De-compresses a TensorLevel back to a pair of TensorId and Level.
auto unpackTensorLevelRange(ContainerTy &&c) const
Converts a range of TensorLevel to a range of std::pair<TensorId, Level>
void initializeLoopEmit(OpBuilder &builder, Location loc, OutputUpdater updater=nullptr, SynTensorBoundSetter synSetter=nullptr)
Starts a loop emitting session by generating all the buffers needed for iterating over the tensors.
void exitCurrentLoopSeq(OpBuilder &builder, Location loc)
Exits the current loop sequence, this will reset universal index to 0.
TensorId getSynTensorId() const
Gets the TensorId for synthetic tensor.
Helper class that generates loop conditions, etc, to traverse a sparse tensor level.
virtual std::pair< Value, Value > genForCond(OpBuilder &b, Location l)
void genInit(OpBuilder &b, Location l, const SparseIterator *p)
void locate(OpBuilder &b, Location l, Value crd)
virtual ValueRange forwardIf(OpBuilder &b, Location l, Value cond)
ValueRange linkNewScope(ValueRange pos)
ValueRange getCursor() const
Value deref(OpBuilder &b, Location l)
virtual bool randomAccessible() const =0
std::pair< Value, ValueRange > genWhileCond(OpBuilder &b, Location l, ValueRange vs)
A wrapper around RankedTensorType, which has three goals:
Level getLvlRank() const
Returns the level-rank.
BaseMemRefType getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with fully dynamic layout.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Dimension toDim(SparseTensorEncodingAttr enc, Level l)
Convenience method to translate the given level to the corresponding dimension.
std::unique_ptr< SparseTensorLevel > makeSparseTensorLevel(OpBuilder &b, Location l, Value t, unsigned tid, Level lvl)
Helper function to create a TensorLevel object from given tensor.
std::unique_ptr< SparseIterator > makeTraverseSubSectIterator(OpBuilder &b, Location l, const SparseIterator &subsectIter, const SparseIterator &parent, std::unique_ptr< SparseIterator > &&wrap, Value loopBound, unsigned stride, SparseEmitStrategy strategy)
Helper function to create a SparseIterator object that iterates over a non-empty subsection created b...
uint64_t Level
The type of level identifiers and level-ranks.
std::optional< SparseTensorType > tryGetSparseTensorType(Value val)
RankedTensorType getRankedTensorType(T &&t)
Convenience method to abbreviate casting getType().
std::pair< std::unique_ptr< SparseTensorLevel >, std::unique_ptr< SparseIterator > > makeSynLevelAndIterator(Value sz, unsigned tid, unsigned lvl, SparseEmitStrategy strategy)
Helper function to create a synthetic SparseIterator object that iterates over a dense space specifie...
Value createOrFoldSliceStrideOp(OpBuilder &builder, Location loc, Value tensor, Dimension dim)
Generates code to retrieve the slice slice for the sparse tensor slice, return a constant if the offs...
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
std::pair< Operation *, Value > genCoIteration(OpBuilder &builder, Location loc, ArrayRef< SparseIterator * > iters, MutableArrayRef< Value > reduc, Value uniIdx, bool userReducFirst=false)
bool isZeroRankedTensorOrScalar(Type type)
std::unique_ptr< SparseIterator > makePaddedIterator(std::unique_ptr< SparseIterator > &&sit, Value padLow, Value padHigh, SparseEmitStrategy strategy)
Helper function to create a SparseIterator object that iterates over a padded sparse level (the padde...
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
std::unique_ptr< SparseIterator > makeSimpleIterator(OpBuilder &b, Location l, const SparseIterationSpace &iterSpace)
Helper function to create a simple SparseIterator object that iterate over the entire iteration space...
func::CallOp createFuncCall(OpBuilder &builder, Location loc, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)
Creates a CallOp to the function reference returned by getFunc() in the builder's module.
std::unique_ptr< SparseIterator > makeSlicedLevelIterator(std::unique_ptr< SparseIterator > &&sit, Value offset, Value stride, Value size, SparseEmitStrategy strategy)
Helper function to create a SparseIterator object that iterates over a sliced space,...
Value createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc, Value tensor, Dimension dim)
Generates code to retrieve the slice offset for the sparse tensor slice, return a constant if the off...
std::unique_ptr< SparseIterator > makeNonEmptySubSectIterator(OpBuilder &b, Location l, const SparseIterator *parent, Value loopBound, std::unique_ptr< SparseIterator > &&delegate, Value size, unsigned stride, SparseEmitStrategy strategy)
Helper function to create a SparseIterator object that iterate over the non-empty subsections set.
unsigned TensorId
Tensor identifiers, chosen to be the BlockArgument::getArgNumber of the value passed to Merger::build...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ DimId
Dimensional identifier.
@ Constant
Constant integer.
SparseEmitStrategy
Defines a scope for reinterpret map pass.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.