17 assert(values.size() == 1 &&
"expected single value");
18 return values.front();
24 if (enc.getLvlType(lvl).isWithPosLT())
25 fields.push_back(enc.getPosMemRefType());
26 if (enc.getLvlType(lvl).isWithCrdLT())
27 fields.push_back(enc.getCrdMemRefType());
32 static std::optional<LogicalResult>
36 for (
Level l = itSp.getLoLvl(); l < itSp.getHiLvl(); l++)
41 fields.append({idxTp, idxTp});
45 static std::optional<LogicalResult>
50 assert(itTp.getEncoding().getBatchLvlRank() == 0);
51 if (!itTp.isUnique()) {
53 fields.push_back(idxTp);
55 fields.push_back(idxTp);
62 ArrayRef<std::unique_ptr<SparseIterator>> iters,
65 if (newBlocks.empty())
69 Block *newBlock = newBlocks.front();
70 Block *oldBlock = oldBlocks.front();
74 for (
unsigned i : caseBits.
bits()) {
76 Value pred = rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
78 casePred = rewriter.
create<arith::AndIOp>(loc, casePred, pred);
80 scf::IfOp ifOp = rewriter.
create<scf::IfOp>(
85 rewriter.
eraseBlock(&ifOp.getThenRegion().front());
88 blockArgs.push_back(loopCrd);
89 for (
unsigned idx : caseBits.
bits())
90 llvm::append_range(blockArgs, iters[idx]->getCursor());
96 for (
auto [from, to] : llvm::zip_equal(oldBlock->
getArguments(), blockArgs)) {
97 mapping.
map(from, to);
103 ifOp.getThenRegion().
begin(), mapping);
105 ifOp.getThenRegion().front().eraseArguments(0, blockArgs.size());
108 auto spY = cast<sparse_tensor::YieldOp>(&ifOp.getThenRegion().front().back());
112 rewriter.
create<scf::YieldOp>(loc, yields);
117 newBlocks.drop_front(),
118 oldBlocks.drop_front(), userReduc);
120 rewriter.
create<scf::YieldOp>(loc, res);
123 return ifOp.getResults();
134 auto [lo, hi] = it->
genForCond(rewriter, loc);
136 scf::ForOp forOp = rewriter.
create<scf::ForOp>(
137 loc, lo, hi, step, reduc,
146 it, forOp.getRegionIterArgs());
149 rewriter.
create<scf::YieldOp>(loc, ret);
155 llvm::append_range(ivs, it->
getCursor());
158 auto whileOp = rewriter.
create<scf::WhileOp>(loc, types, ivs);
166 auto [whileCond, remArgs] = it->
genWhileCond(rewriter, loc, bArgs);
170 Region &dstRegion = whileOp.getAfter();
172 ValueRange aArgs = whileOp.getAfterArguments();
174 aArgs = aArgs.take_front(reduc.size());
182 llvm::append_range(yields, ret);
183 llvm::append_range(yields, it->
forward(rewriter, loc));
184 rewriter.
create<scf::YieldOp>(loc, yields);
192 class ExtractIterSpaceConverter
197 matchAndRewrite(ExtractIterSpaceOp op, OneToNOpAdaptor adaptor,
204 op.getLvlRange(), adaptor.getParentIter());
217 matchAndRewrite(ExtractValOp op, OneToNOpAdaptor adaptor,
220 Value pos = adaptor.getIterator().back();
232 matchAndRewrite(IterateOp op, OneToNOpAdaptor adaptor,
234 if (!op.getCrdUsedLvls().empty())
236 op,
"non-empty coordinates list not implemented.");
241 op.getIterSpace().getType(), adaptor.getIterSpace(), 0);
243 std::unique_ptr<SparseIterator> it =
244 iterSpace.extractIterator(rewriter, loc);
247 for (
ValueRange inits : adaptor.getInitArgs())
248 llvm::append_range(ivs, inits);
251 unsigned numOrigArgs = op.getBody()->getArgumentTypes().size();
253 if (failed(typeConverter->convertSignatureArgs(
254 op.getBody()->getArgumentTypes(), signatureConversion)))
256 op,
"failed to convert iterate region argurment types");
259 op.getBody(), signatureConversion, getTypeConverter());
261 rewriter, loc, it.get(), ivs,
264 SmallVector<Value> blockArgs(reduc);
267 llvm::append_range(blockArgs, it->getCursor());
269 Block *dstBlock = &loopBody.getBlocks().front();
270 rewriter.inlineBlockBefore(block, dstBlock, dstBlock->end(),
272 auto yield = llvm::cast<sparse_tensor::YieldOp>(dstBlock->back());
275 SmallVector<Value> result(yield.getResults());
276 rewriter.eraseOp(yield);
289 matchAndRewrite(CoIterateOp op, OneToNOpAdaptor adaptor,
291 assert(op.getSpaceDim() == 1 &&
"Not implemented");
295 for (
auto [idx, spaceTp] :
llvm::enumerate(op.getIterSpaces().getTypes()))
296 if (all_of(cast<IterSpaceType>(spaceTp).getLvlTypes(),
isDenseLT))
304 any_of(op.getRegionDefinedSpaces(), [denseBits](
I64BitSet caseBits) {
306 if (caseBits.count() == 0)
309 return caseBits.isSubSetOf(denseBits);
311 assert(!needUniv &&
"Not implemented");
316 for (
Region ®ion : op.getCaseRegions()) {
323 blockTypeMapping))) {
325 op,
"failed to convert coiterate region argurment types");
329 block, blockTypeMapping, getTypeConverter()));
330 newToOldBlockMap[newBlocks.back()] = block;
335 for (
auto [spaceTp, spaceVals] : llvm::zip_equal(
336 op.getIterSpaces().getTypes(), adaptor.getIterSpaces())) {
339 cast<IterSpaceType>(spaceTp), spaceVals, 0));
341 iters.push_back(spaces.back().extractIterator(rewriter, loc));
344 auto getFilteredIters = [&iters](
I64BitSet caseBits) {
347 for (
auto idx : caseBits.bits())
348 validIters.push_back(iters[idx].get());
355 llvm::append_range(userReduc, r);
362 for (
auto [r, caseBits] :
363 llvm::zip_equal(newBlocks, op.getRegionDefinedSpaces())) {
364 assert(caseBits.count() > 0 &&
"Complement space not implemented");
369 if (validIters.size() > 1) {
370 auto [loop, loopCrd] =
380 op.getSubCasesOf(r->getParent()->getRegionNumber());
382 for (
Region *r : subCases) {
383 newBlocks.push_back(&r->front());
384 oldBlocks.push_back(newToOldBlockMap[newBlocks.back()]);
386 assert(!subCases.empty());
389 rewriter, loc, op, loopCrd, iters, newBlocks, oldBlocks, userReduc);
395 loc, arith::CmpIPredicate::eq, it->getCrd(), loopCrd);
396 it->forwardIf(rewriter, loc, cmp);
397 llvm::append_range(nextIterYields, it->getCursor());
399 rewriter.
create<scf::YieldOp>(loc, nextIterYields);
403 ValueRange iterVals = loop->getResults().drop_front(userReduc.size());
405 iterVals = it->linkNewScope(iterVals);
406 assert(iterVals.empty());
408 ValueRange curResult = loop->getResults().take_front(userReduc.size());
409 userReduc.assign(curResult.begin(), curResult.end());
412 assert(caseBits.count() == 1);
416 rewriter, loc, validIters.front(), userReduc,
421 SmallVector<Value> blockArgs(reduc);
422 blockArgs.push_back(it->deref(rewriter, loc));
423 llvm::append_range(blockArgs, it->getCursor());
425 Block *dstBlock = &dstRegion.getBlocks().front();
426 rewriter.inlineBlockBefore(
427 block, dstBlock, rewriter.getInsertionPoint(), blockArgs);
428 auto yield = llvm::cast<sparse_tensor::YieldOp>(dstBlock->back());
429 SmallVector<Value> result(yield.getResults());
430 rewriter.eraseOp(yield);
434 userReduc.assign(curResult.begin(), curResult.end());
446 addConversion([](
Type type) {
return type; });
450 addSourceMaterialization([](
OpBuilder &builder, IterSpaceType spTp,
462 patterns.add<ExtractIterSpaceConverter, ExtractValOpConverter,
463 SparseIterateOpConverter, SparseCoIterateOpConverter>(
static Value getSingleValue(ValueRange values)
Assert that the given value range contains a single value and return it.
static ValueRange genLoopWithIterator(PatternRewriter &rewriter, Location loc, SparseIterator *it, ValueRange reduc, function_ref< SmallVector< Value >(PatternRewriter &rewriter, Location loc, Region &loopBody, SparseIterator *it, ValueRange reduc)> bodyBuilder)
static ValueRange genCoIterateBranchNest(PatternRewriter &rewriter, Location loc, CoIterateOp op, Value loopCrd, ArrayRef< std::unique_ptr< SparseIterator >> iters, ArrayRef< Block * > newBlocks, ArrayRef< Block * > oldBlocks, ArrayRef< Value > userReduc)
static void convertLevelType(SparseTensorEncodingAttr enc, Level lvl, SmallVectorImpl< Type > &fields)
static std::optional< LogicalResult > convertIteratorType(IteratorType itTp, SmallVectorImpl< Type > &fields)
static std::optional< LogicalResult > convertIterSpaceType(IterSpaceType itSp, SmallVectorImpl< Type > &fields)
Block represents an ordered list of Operations.
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
BlockArgListType getArguments()
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
Block * applySignatureConversion(Block *block, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter=nullptr)
Apply a signature conversion to given block.
void replaceOpWithMultiple(Operation *op, ArrayRef< ValueRange > newValues)
Replace the given operation with the new value ranges.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
result_range getResults()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
unsigned getRegionNumber()
Return the number of this region in the parent operation.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides all of the information necessary to convert a type signature.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
A simple wrapper to encode a bitset of (at most 64) levels, currently used by sparse_tensor....
iterator_range< const_set_bits_iterator > bits() const
A SparseIterationSpace represents a sparse set of coordinates defined by (possibly multiple) levels o...
static SparseIterationSpace fromValues(IterSpaceType dstTp, ValueRange values, unsigned tid)
Helper class that generates loop conditions, etc, to traverse a sparse tensor level.
virtual std::pair< Value, Value > genForCond(OpBuilder &b, Location l)
ValueRange forward(OpBuilder &b, Location l)
ValueRange linkNewScope(ValueRange pos)
ValueRange getCursor() const
virtual bool iteratableByFor() const
std::pair< Value, ValueRange > genWhileCond(OpBuilder &b, Location l, ValueRange vs)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
uint64_t Level
The type of level identifiers and level-ranks.
Value constantI1(OpBuilder &builder, Location loc, bool b)
Generates a constant of i1 type.
std::pair< Operation *, Value > genCoIteration(OpBuilder &builder, Location loc, ArrayRef< SparseIterator * > iters, MutableArrayRef< Value > reduc, Value uniIdx, bool userReducFirst=false)
bool isDenseLT(LevelType lt)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void populateLowerSparseIterationToSCFPatterns(const TypeConverter &converter, RewritePatternSet &patterns)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SparseIterationTypeConverter()