18 if (enc.getLvlType(lvl).isWithPosLT())
19 fields.push_back(enc.getPosMemRefType());
20 if (enc.getLvlType(lvl).isWithCrdLT())
21 fields.push_back(enc.getCrdMemRefType());
26 static std::optional<LogicalResult>
30 for (
Level l = itSp.getLoLvl(); l < itSp.getHiLvl(); l++)
35 fields.append({idxTp, idxTp});
39 static std::optional<LogicalResult>
44 assert(itTp.getEncoding().getBatchLvlRank() == 0);
45 if (!itTp.isUnique()) {
47 fields.push_back(idxTp);
49 fields.push_back(idxTp);
56 ArrayRef<std::unique_ptr<SparseIterator>> iters,
59 if (newBlocks.empty())
63 Block *newBlock = newBlocks.front();
64 Block *oldBlock = oldBlocks.front();
68 for (
unsigned i : caseBits.
bits()) {
70 Value pred = rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
72 casePred = rewriter.
create<arith::AndIOp>(loc, casePred, pred);
74 scf::IfOp ifOp = rewriter.
create<scf::IfOp>(
79 rewriter.
eraseBlock(&ifOp.getThenRegion().front());
82 blockArgs.push_back(loopCrd);
83 for (
unsigned idx : caseBits.
bits())
84 llvm::append_range(blockArgs, iters[idx]->getCursor());
90 for (
auto [from, to] : llvm::zip_equal(oldBlock->
getArguments(), blockArgs)) {
91 mapping.
map(from, to);
97 ifOp.getThenRegion().
begin(), mapping);
99 ifOp.getThenRegion().front().eraseArguments(0, blockArgs.size());
102 auto spY = cast<sparse_tensor::YieldOp>(&ifOp.getThenRegion().front().back());
106 rewriter.
create<scf::YieldOp>(loc, yields);
111 newBlocks.drop_front(),
112 oldBlocks.drop_front(), userReduc);
114 rewriter.
create<scf::YieldOp>(loc, res);
117 return ifOp.getResults();
128 auto [lo, hi] = it->
genForCond(rewriter, loc);
130 scf::ForOp forOp = rewriter.
create<scf::ForOp>(
131 loc, lo, hi, step, reduc,
140 it, forOp.getRegionIterArgs());
143 rewriter.
create<scf::YieldOp>(loc, ret);
149 llvm::append_range(ivs, it->
getCursor());
152 auto whileOp = rewriter.
create<scf::WhileOp>(loc, types, ivs);
160 auto [whileCond, remArgs] = it->
genWhileCond(rewriter, loc, bArgs);
164 Region &dstRegion = whileOp.getAfter();
166 ValueRange aArgs = whileOp.getAfterArguments();
168 aArgs = aArgs.take_front(reduc.size());
176 llvm::append_range(yields, ret);
177 llvm::append_range(yields, it->
forward(rewriter, loc));
178 rewriter.
create<scf::YieldOp>(loc, yields);
186 class ExtractIterSpaceConverter
191 matchAndRewrite(ExtractIterSpaceOp op, OneToNOpAdaptor adaptor,
197 llvm::getSingleElement(adaptor.getTensor()), 0,
198 op.getLvlRange(), adaptor.getParentIter());
211 matchAndRewrite(ExtractValOp op, OneToNOpAdaptor adaptor,
214 Value pos = adaptor.getIterator().back();
216 loc, llvm::getSingleElement(adaptor.getTensor()));
226 matchAndRewrite(IterateOp op, OneToNOpAdaptor adaptor,
228 if (!op.getCrdUsedLvls().empty())
230 op,
"non-empty coordinates list not implemented.");
235 op.getIterSpace().getType(), adaptor.getIterSpace(), 0);
237 std::unique_ptr<SparseIterator> it =
238 iterSpace.extractIterator(rewriter, loc);
241 for (
ValueRange inits : adaptor.getInitArgs())
242 llvm::append_range(ivs, inits);
245 unsigned numOrigArgs = op.getBody()->getArgumentTypes().size();
247 if (failed(typeConverter->convertSignatureArgs(
248 op.getBody()->getArgumentTypes(), signatureConversion)))
250 op,
"failed to convert iterate region argurment types");
253 op.getBody(), signatureConversion, getTypeConverter());
255 rewriter, loc, it.get(), ivs,
258 SmallVector<Value> blockArgs(reduc);
261 llvm::append_range(blockArgs, it->getCursor());
263 Block *dstBlock = &loopBody.getBlocks().front();
264 rewriter.inlineBlockBefore(block, dstBlock, dstBlock->end(),
266 auto yield = llvm::cast<sparse_tensor::YieldOp>(dstBlock->back());
269 SmallVector<Value> result(yield.getResults());
270 rewriter.eraseOp(yield);
283 matchAndRewrite(CoIterateOp op, OneToNOpAdaptor adaptor,
285 assert(op.getSpaceDim() == 1 &&
"Not implemented");
289 for (
auto [idx, spaceTp] :
llvm::enumerate(op.getIterSpaces().getTypes()))
290 if (all_of(cast<IterSpaceType>(spaceTp).getLvlTypes(),
isDenseLT))
298 any_of(op.getRegionDefinedSpaces(), [denseBits](
I64BitSet caseBits) {
300 if (caseBits.count() == 0)
303 return caseBits.isSubSetOf(denseBits);
305 assert(!needUniv &&
"Not implemented");
310 for (
Region ®ion : op.getCaseRegions()) {
317 blockTypeMapping))) {
319 op,
"failed to convert coiterate region argurment types");
323 block, blockTypeMapping, getTypeConverter()));
324 newToOldBlockMap[newBlocks.back()] = block;
329 for (
auto [spaceTp, spaceVals] : llvm::zip_equal(
330 op.getIterSpaces().getTypes(), adaptor.getIterSpaces())) {
333 cast<IterSpaceType>(spaceTp), spaceVals, 0));
335 iters.push_back(spaces.back().extractIterator(rewriter, loc));
338 auto getFilteredIters = [&iters](
I64BitSet caseBits) {
341 for (
auto idx : caseBits.bits())
342 validIters.push_back(iters[idx].get());
349 llvm::append_range(userReduc, r);
356 for (
auto [r, caseBits] :
357 llvm::zip_equal(newBlocks, op.getRegionDefinedSpaces())) {
358 assert(caseBits.count() > 0 &&
"Complement space not implemented");
363 if (validIters.size() > 1) {
364 auto [loop, loopCrd] =
374 op.getSubCasesOf(r->getParent()->getRegionNumber());
376 for (
Region *r : subCases) {
377 newBlocks.push_back(&r->front());
378 oldBlocks.push_back(newToOldBlockMap[newBlocks.back()]);
380 assert(!subCases.empty());
383 rewriter, loc, op, loopCrd, iters, newBlocks, oldBlocks, userReduc);
389 loc, arith::CmpIPredicate::eq, it->getCrd(), loopCrd);
390 it->forwardIf(rewriter, loc, cmp);
391 llvm::append_range(nextIterYields, it->getCursor());
393 rewriter.
create<scf::YieldOp>(loc, nextIterYields);
397 ValueRange iterVals = loop->getResults().drop_front(userReduc.size());
399 iterVals = it->linkNewScope(iterVals);
400 assert(iterVals.empty());
402 ValueRange curResult = loop->getResults().take_front(userReduc.size());
403 userReduc.assign(curResult.begin(), curResult.end());
406 assert(caseBits.count() == 1);
410 rewriter, loc, validIters.front(), userReduc,
415 SmallVector<Value> blockArgs(reduc);
416 blockArgs.push_back(it->deref(rewriter, loc));
417 llvm::append_range(blockArgs, it->getCursor());
419 Block *dstBlock = &dstRegion.getBlocks().front();
420 rewriter.inlineBlockBefore(
421 block, dstBlock, rewriter.getInsertionPoint(), blockArgs);
422 auto yield = llvm::cast<sparse_tensor::YieldOp>(dstBlock->back());
423 SmallVector<Value> result(yield.getResults());
424 rewriter.eraseOp(yield);
428 userReduc.assign(curResult.begin(), curResult.end());
440 addConversion([](
Type type) {
return type; });
444 addSourceMaterialization([](
OpBuilder &builder, IterSpaceType spTp,
456 patterns.add<ExtractIterSpaceConverter, ExtractValOpConverter,
457 SparseIterateOpConverter, SparseCoIterateOpConverter>(
static ValueRange genLoopWithIterator(PatternRewriter &rewriter, Location loc, SparseIterator *it, ValueRange reduc, function_ref< SmallVector< Value >(PatternRewriter &rewriter, Location loc, Region &loopBody, SparseIterator *it, ValueRange reduc)> bodyBuilder)
static ValueRange genCoIterateBranchNest(PatternRewriter &rewriter, Location loc, CoIterateOp op, Value loopCrd, ArrayRef< std::unique_ptr< SparseIterator >> iters, ArrayRef< Block * > newBlocks, ArrayRef< Block * > oldBlocks, ArrayRef< Value > userReduc)
static void convertLevelType(SparseTensorEncodingAttr enc, Level lvl, SmallVectorImpl< Type > &fields)
static std::optional< LogicalResult > convertIteratorType(IteratorType itTp, SmallVectorImpl< Type > &fields)
static std::optional< LogicalResult > convertIterSpaceType(IterSpaceType itSp, SmallVectorImpl< Type > &fields)
Block represents an ordered list of Operations.
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
BlockArgListType getArguments()
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void replaceOpWithMultiple(Operation *op, SmallVector< SmallVector< Value >> &&newValues)
Replace the given operation with the new value ranges.
Block * applySignatureConversion(Block *block, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter=nullptr)
Apply a signature conversion to given block.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
result_range getResults()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
unsigned getRegionNumber()
Return the number of this region in the parent operation.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides all of the information necessary to convert a type signature.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
A simple wrapper to encode a bitset of (at most 64) levels, currently used by sparse_tensor....
iterator_range< const_set_bits_iterator > bits() const
A SparseIterationSpace represents a sparse set of coordinates defined by (possibly multiple) levels o...
static SparseIterationSpace fromValues(IterSpaceType dstTp, ValueRange values, unsigned tid)
Helper class that generates loop conditions, etc, to traverse a sparse tensor level.
virtual std::pair< Value, Value > genForCond(OpBuilder &b, Location l)
ValueRange forward(OpBuilder &b, Location l)
ValueRange linkNewScope(ValueRange pos)
ValueRange getCursor() const
virtual bool iteratableByFor() const
std::pair< Value, ValueRange > genWhileCond(OpBuilder &b, Location l, ValueRange vs)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
uint64_t Level
The type of level identifiers and level-ranks.
Value constantI1(OpBuilder &builder, Location loc, bool b)
Generates a constant of i1 type.
std::pair< Operation *, Value > genCoIteration(OpBuilder &builder, Location loc, ArrayRef< SparseIterator * > iters, MutableArrayRef< Value > reduc, Value uniIdx, bool userReducFirst=false)
bool isDenseLT(LevelType lt)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void populateLowerSparseIterationToSCFPatterns(const TypeConverter &converter, RewritePatternSet &patterns)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SparseIterationTypeConverter()