18 if (enc.getLvlType(lvl).isWithPosLT())
19 fields.push_back(enc.getPosMemRefType());
20 if (enc.getLvlType(lvl).isWithCrdLT())
21 fields.push_back(enc.getCrdMemRefType());
26 static std::optional<LogicalResult>
30 for (
Level l = itSp.getLoLvl(); l < itSp.getHiLvl(); l++)
35 fields.append({idxTp, idxTp});
39 static std::optional<LogicalResult>
44 assert(itTp.getEncoding().getBatchLvlRank() == 0);
45 if (!itTp.isUnique()) {
47 fields.push_back(idxTp);
49 fields.push_back(idxTp);
56 ArrayRef<std::unique_ptr<SparseIterator>> iters,
62 Region *b = subCases.front();
65 for (
unsigned i : caseBits.
bits()) {
67 Value pred = rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
69 casePred = rewriter.
create<arith::AndIOp>(loc, casePred, pred);
71 scf::IfOp ifOp = rewriter.
create<scf::IfOp>(
76 rewriter.
eraseBlock(&ifOp.getThenRegion().front());
79 blockArgs.push_back(loopCrd);
80 for (
unsigned idx : caseBits.
bits())
81 llvm::append_range(blockArgs, iters[idx]->getCursor());
84 for (
auto [from, to] :
86 mapping.
map(from, to);
92 ifOp.getThenRegion().
begin(), mapping);
95 auto spY = cast<sparse_tensor::YieldOp>(&ifOp.getThenRegion().front().back());
99 rewriter.
create<scf::YieldOp>(loc, yields);
104 subCases.drop_front(), userReduc);
106 rewriter.
create<scf::YieldOp>(loc, res);
109 return ifOp.getResults();
120 auto [lo, hi] = it->
genForCond(rewriter, loc);
122 scf::ForOp forOp = rewriter.
create<scf::ForOp>(loc, lo, hi, step, reduc);
127 if (!forOp.getBody()->empty())
128 rewriter.
eraseOp(&forOp.getBody()->front());
129 assert(forOp.getBody()->empty());
134 it, forOp.getRegionIterArgs());
137 rewriter.
create<scf::YieldOp>(loc, ret);
143 llvm::append_range(ivs, it->
getCursor());
146 auto whileOp = rewriter.
create<scf::WhileOp>(loc, types, ivs);
154 auto [whileCond, remArgs] = it->
genWhileCond(rewriter, loc, bArgs);
158 Region &dstRegion = whileOp.getAfter();
160 ValueRange aArgs = whileOp.getAfterArguments();
162 aArgs = aArgs.take_front(reduc.size());
170 llvm::append_range(yields, ret);
171 llvm::append_range(yields, it->
forward(rewriter, loc));
172 rewriter.
create<scf::YieldOp>(loc, yields);
180 class ExtractIterSpaceConverter
185 matchAndRewrite(ExtractIterSpaceOp op, OpAdaptor adaptor,
192 op.getLvlRange(), adaptor.getParentIter());
195 rewriter.
replaceOp(op, result, resultMapping);
205 matchAndRewrite(ExtractValOp op, OpAdaptor adaptor,
208 Value pos = adaptor.getIterator().back();
209 Value valBuf = rewriter.
create<ToValuesOp>(loc, op.getTensor());
219 matchAndRewrite(IterateOp op, OpAdaptor adaptor,
221 if (!op.getCrdUsedLvls().empty())
223 op,
"non-empty coordinates list not implemented.");
228 op.getIterSpace().getType(), adaptor.getIterSpace(), 0);
230 std::unique_ptr<SparseIterator> it =
231 iterSpace.extractIterator(rewriter, loc);
234 for (
ValueRange inits : adaptor.getInitArgs())
235 llvm::append_range(ivs, inits);
239 if (failed(typeConverter->convertSignatureArgs(
240 op.getBody()->getArgumentTypes(), blockTypeMapping)))
242 op,
"failed to convert iterate region argurment types");
245 Block *block = op.getBody();
247 rewriter, loc, it.get(), ivs,
250 SmallVector<Value> blockArgs(reduc);
253 llvm::append_range(blockArgs, it->getCursor());
255 Block *dstBlock = &loopBody.getBlocks().front();
256 rewriter.inlineBlockBefore(block, dstBlock, dstBlock->end(),
258 auto yield = llvm::cast<sparse_tensor::YieldOp>(dstBlock->back());
261 SmallVector<Value> result(yield.getResults());
262 rewriter.eraseOp(yield);
267 rewriter.
replaceOp(op, ret, resultMapping);
272 class SparseCoIterateOpConverter
277 matchAndRewrite(CoIterateOp op, OpAdaptor adaptor,
279 assert(op.getSpaceDim() == 1 &&
"Not implemented");
283 for (
auto [idx, spaceTp] :
llvm::enumerate(op.getIterSpaces().getTypes()))
284 if (all_of(cast<IterSpaceType>(spaceTp).getLvlTypes(),
isDenseLT))
292 any_of(op.getRegionDefinedSpaces(), [denseBits](
I64BitSet caseBits) {
294 if (caseBits.count() == 0)
297 return caseBits.isSubSetOf(denseBits);
299 assert(!needUniv &&
"Not implemented");
302 for (
Region ®ion : op.getCaseRegions()) {
308 blockTypeMapping))) {
310 op,
"failed to convert coiterate region argurment types");
318 for (
auto [spaceTp, spaceVals] : llvm::zip_equal(
319 op.getIterSpaces().getTypes(), adaptor.getIterSpaces())) {
322 cast<IterSpaceType>(spaceTp), spaceVals, 0));
324 iters.push_back(spaces.back().extractIterator(rewriter, loc));
327 auto getFilteredIters = [&iters](
I64BitSet caseBits) {
330 for (
auto idx : caseBits.bits())
331 validIters.push_back(iters[idx].get());
338 llvm::append_range(userReduc, r);
345 for (
auto [r, caseBits] :
346 llvm::zip_equal(op.getCaseRegions(), op.getRegionDefinedSpaces())) {
347 assert(caseBits.count() > 0 &&
"Complement space not implemented");
352 if (validIters.size() > 1) {
353 auto [loop, loopCrd] =
363 assert(!subCases.empty());
366 iters, subCases, userReduc);
372 loc, arith::CmpIPredicate::eq, it->getCrd(), loopCrd);
373 it->forwardIf(rewriter, loc, cmp);
374 llvm::append_range(nextIterYields, it->getCursor());
376 rewriter.
create<scf::YieldOp>(loc, nextIterYields);
380 ValueRange iterVals = loop->getResults().drop_front(userReduc.size());
382 iterVals = it->linkNewScope(iterVals);
383 assert(iterVals.empty());
385 ValueRange curResult = loop->getResults().take_front(userReduc.size());
386 userReduc.assign(curResult.begin(), curResult.end());
389 assert(caseBits.count() == 1);
391 Block *block = &r.getBlocks().front();
393 rewriter, loc, validIters.front(), userReduc,
398 SmallVector<Value> blockArgs(reduc);
399 blockArgs.push_back(it->deref(rewriter, loc));
400 llvm::append_range(blockArgs, it->getCursor());
402 Block *dstBlock = &dstRegion.getBlocks().front();
403 rewriter.inlineBlockBefore(
404 block, dstBlock, rewriter.getInsertionPoint(), blockArgs);
405 auto yield = llvm::cast<sparse_tensor::YieldOp>(dstBlock->back());
406 SmallVector<Value> result(yield.getResults());
407 rewriter.eraseOp(yield);
411 userReduc.assign(curResult.begin(), curResult.end());
423 addConversion([](
Type type) {
return type; });
427 addSourceMaterialization([](
OpBuilder &builder, IterSpaceType spTp,
438 IterateOp::getCanonicalizationPatterns(patterns, patterns.
getContext());
439 patterns.
add<ExtractIterSpaceConverter, ExtractValOpConverter,
440 SparseIterateOpConverter, SparseCoIterateOpConverter>(
static ValueRange genLoopWithIterator(PatternRewriter &rewriter, Location loc, SparseIterator *it, ValueRange reduc, function_ref< SmallVector< Value >(PatternRewriter &rewriter, Location loc, Region &loopBody, SparseIterator *it, ValueRange reduc)> bodyBuilder)
static ValueRange genCoIterateBranchNest(PatternRewriter &rewriter, Location loc, CoIterateOp op, Value loopCrd, ArrayRef< std::unique_ptr< SparseIterator >> iters, ArrayRef< Region * > subCases, ArrayRef< Value > userReduc)
static void convertLevelType(SparseTensorEncodingAttr enc, Level lvl, SmallVectorImpl< Type > &fields)
static std::optional< LogicalResult > convertIteratorType(IteratorType itTp, SmallVectorImpl< Type > &fields)
static std::optional< LogicalResult > convertIterSpaceType(IterSpaceType itSp, SmallVectorImpl< Type > &fields)
Block represents an ordered list of Operations.
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
BlockArgListType getArguments()
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class is a wrapper around OneToNConversionPattern for matching against instances of a particular...
OneToNOpConversionPattern(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Specialization of PatternRewriter that OneToNConversionPatterns use.
Block * applySignatureConversion(Block *block, OneToNTypeMapping &argumentConversion)
Applies the given argument conversion to the given block.
void replaceOp(Operation *op, ValueRange newValues, const OneToNTypeMapping &resultMapping)
Replaces the results of the operation with the specified list of values mapped back to the original t...
Stores a 1:N mapping of types and provides several useful accessors.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
result_range getResults()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
unsigned getRegionNumber()
Return the number of this region in the parent operation.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
A simple wrapper to encode a bitset of (at most 64) levels, currently used by sparse_tensor....
iterator_range< const_set_bits_iterator > bits() const
A SparseIterationSpace represents a sparse set of coordinates defined by (possibly multiple) levels o...
static SparseIterationSpace fromValues(IterSpaceType dstTp, ValueRange values, unsigned tid)
Helper class that generates loop conditions, etc, to traverse a sparse tensor level.
virtual std::pair< Value, Value > genForCond(OpBuilder &b, Location l)
ValueRange forward(OpBuilder &b, Location l)
ValueRange linkNewScope(ValueRange pos)
ValueRange getCursor() const
virtual bool iteratableByFor() const
std::pair< Value, ValueRange > genWhileCond(OpBuilder &b, Location l, ValueRange vs)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
uint64_t Level
The type of level identifiers and level-ranks.
Value constantI1(OpBuilder &builder, Location loc, bool b)
Generates a constant of i1 type.
std::pair< Operation *, Value > genCoIteration(OpBuilder &builder, Location loc, ArrayRef< SparseIterator * > iters, MutableArrayRef< Value > reduc, Value uniIdx, bool userReducFirst=false)
bool isDenseLT(LevelType lt)
Include the generated interface declarations.
void populateLowerSparseIterationToSCFPatterns(const TypeConverter &converter, RewritePatternSet &patterns)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SparseIterationTypeConverter()