18 if (enc.getLvlType(lvl).isWithPosLT())
19 fields.push_back(enc.getPosMemRefType());
20 if (enc.getLvlType(lvl).isWithCrdLT())
21 fields.push_back(enc.getCrdMemRefType());
23 fields.push_back(IndexType::get(enc.getContext()));
26static std::optional<LogicalResult>
29 auto idxTp = IndexType::get(itSp.getContext());
30 for (
Level l = itSp.getLoLvl(); l < itSp.getHiLvl(); l++)
35 fields.append({idxTp, idxTp});
39static std::optional<LogicalResult>
42 auto idxTp = IndexType::get(itTp.getContext());
44 assert(itTp.getEncoding().getBatchLvlRank() == 0);
45 if (!itTp.isUnique()) {
47 fields.push_back(idxTp);
49 fields.push_back(idxTp);
56 ArrayRef<std::unique_ptr<SparseIterator>> iters,
59 if (newBlocks.empty())
63 Block *newBlock = newBlocks.front();
64 Block *oldBlock = oldBlocks.front();
68 for (
unsigned i : caseBits.
bits()) {
70 Value pred = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
72 casePred = arith::AndIOp::create(rewriter, loc, casePred, pred);
74 scf::IfOp ifOp = scf::IfOp::create(
75 rewriter, loc,
ValueRange(userReduc).getTypes(), casePred,
true);
79 rewriter.
eraseBlock(&ifOp.getThenRegion().front());
82 blockArgs.push_back(loopCrd);
83 for (
unsigned idx : caseBits.
bits())
84 llvm::append_range(blockArgs, iters[idx]->getCursor());
90 for (
auto [from, to] : llvm::zip_equal(oldBlock->
getArguments(), blockArgs)) {
91 mapping.
map(from, to);
97 ifOp.getThenRegion().
begin(), mapping);
99 ifOp.getThenRegion().front().eraseArguments(0, blockArgs.size());
102 auto spY = cast<sparse_tensor::YieldOp>(&ifOp.getThenRegion().front().back());
106 scf::YieldOp::create(rewriter, loc, yields);
111 newBlocks.drop_front(),
112 oldBlocks.drop_front(), userReduc);
114 scf::YieldOp::create(rewriter, loc, res);
117 return ifOp.getResults();
128 auto [lo, hi] = it->
genForCond(rewriter, loc);
130 scf::ForOp forOp = scf::ForOp::create(
131 rewriter, loc, lo, hi, step, reduc,
140 it, forOp.getRegionIterArgs());
143 scf::YieldOp::create(rewriter, loc, ret);
145 return forOp.getResults();
149 llvm::append_range(ivs, it->
getCursor());
152 auto whileOp = scf::WhileOp::create(rewriter, loc, types, ivs);
160 auto [whileCond, remArgs] = it->
genWhileCond(rewriter, loc, bArgs);
161 scf::ConditionOp::create(rewriter, loc, whileCond, before->
getArguments());
164 Region &dstRegion = whileOp.getAfter();
166 ValueRange aArgs = whileOp.getAfterArguments();
168 aArgs = aArgs.take_front(reduc.size());
176 llvm::append_range(yields, ret);
177 llvm::append_range(yields, it->
forward(rewriter, loc));
178 scf::YieldOp::create(rewriter, loc, yields);
180 return whileOp.getResults().drop_front(it->
getCursor().size());
186class ExtractIterSpaceConverter
187 :
public OpConversionPattern<ExtractIterSpaceOp> {
189 using OpConversionPattern::OpConversionPattern;
191 matchAndRewrite(ExtractIterSpaceOp op, OneToNOpAdaptor adaptor,
192 ConversionPatternRewriter &rewriter)
const override {
193 Location loc = op.getLoc();
196 SparseIterationSpace space(loc, rewriter,
197 llvm::getSingleElement(adaptor.getTensor()), 0,
198 op.getLvlRange(), adaptor.getParentIter());
200 SmallVector<Value>
result = space.toValues();
201 rewriter.replaceOpWithMultiple(op, {
result});
207class ExtractValOpConverter :
public OpConversionPattern<ExtractValOp> {
209 using OpConversionPattern::OpConversionPattern;
211 matchAndRewrite(ExtractValOp op, OneToNOpAdaptor adaptor,
212 ConversionPatternRewriter &rewriter)
const override {
213 Location loc = op.getLoc();
214 Value pos = adaptor.getIterator().back();
215 Value valBuf = ToValuesOp::create(
216 rewriter, loc, llvm::getSingleElement(adaptor.getTensor()));
217 rewriter.replaceOpWithNewOp<memref::LoadOp>(op, valBuf, pos);
222class SparseIterateOpConverter :
public OpConversionPattern<IterateOp> {
224 using OpConversionPattern::OpConversionPattern;
226 matchAndRewrite(IterateOp op, OneToNOpAdaptor adaptor,
227 ConversionPatternRewriter &rewriter)
const override {
228 if (!op.getCrdUsedLvls().empty())
229 return rewriter.notifyMatchFailure(
230 op,
"non-empty coordinates list not implemented.");
232 Location loc = op.getLoc();
235 op.getIterSpace().getType(), adaptor.getIterSpace(), 0);
237 std::unique_ptr<SparseIterator> it =
238 iterSpace.extractIterator(rewriter, loc);
240 SmallVector<Value> ivs;
241 for (
ValueRange inits : adaptor.getInitArgs())
242 llvm::append_range(ivs, inits);
245 unsigned numOrigArgs = op.getBody()->getArgumentTypes().size();
246 TypeConverter::SignatureConversion signatureConversion(numOrigArgs);
247 if (
failed(typeConverter->convertSignatureArgs(
248 op.getBody()->getArgumentTypes(), signatureConversion)))
249 return rewriter.notifyMatchFailure(
250 op,
"failed to convert iterate region argurment types");
252 Block *block = rewriter.applySignatureConversion(
253 op.getBody(), signatureConversion, getTypeConverter());
255 rewriter, loc, it.get(), ivs,
256 [block](PatternRewriter &rewriter, Location loc, Region &loopBody,
257 SparseIterator *it,
ValueRange reduc) -> SmallVector<Value> {
258 SmallVector<Value> blockArgs(reduc);
261 llvm::append_range(blockArgs, it->getCursor());
263 Block *dstBlock = &loopBody.getBlocks().front();
264 rewriter.inlineBlockBefore(block, dstBlock, dstBlock->end(),
266 auto yield = llvm::cast<sparse_tensor::YieldOp>(dstBlock->back());
269 SmallVector<Value> result(yield.getResults());
270 rewriter.eraseOp(yield);
274 rewriter.replaceOp(op, ret);
279class SparseCoIterateOpConverter :
public OpConversionPattern<CoIterateOp> {
280 using OpConversionPattern::OpConversionPattern;
283 matchAndRewrite(CoIterateOp op, OneToNOpAdaptor adaptor,
284 ConversionPatternRewriter &rewriter)
const override {
285 assert(op.getSpaceDim() == 1 &&
"Not implemented");
286 Location loc = op.getLoc();
288 I64BitSet denseBits(0);
289 for (
auto [idx, spaceTp] : llvm::enumerate(op.getIterSpaces().getTypes()))
290 if (all_of(cast<IterSpaceType>(spaceTp).getLvlTypes(),
isDenseLT))
298 any_of(op.getRegionDefinedSpaces(), [denseBits](I64BitSet caseBits) {
300 if (caseBits.count() == 0)
303 return caseBits.isSubSetOf(denseBits);
305 assert(!needUniv &&
"Not implemented");
308 SmallVector<Block *> newBlocks;
310 for (Region ®ion : op.getCaseRegions()) {
314 TypeConverter::SignatureConversion blockTypeMapping(
317 blockTypeMapping))) {
318 return rewriter.notifyMatchFailure(
319 op,
"failed to convert coiterate region argurment types");
322 newBlocks.push_back(rewriter.applySignatureConversion(
323 block, blockTypeMapping, getTypeConverter()));
324 newToOldBlockMap[newBlocks.back()] = block;
327 SmallVector<SparseIterationSpace> spaces;
328 SmallVector<std::unique_ptr<SparseIterator>> iters;
329 for (
auto [spaceTp, spaceVals] : llvm::zip_equal(
330 op.getIterSpaces().getTypes(), adaptor.getIterSpaces())) {
333 cast<IterSpaceType>(spaceTp), spaceVals, 0));
335 iters.push_back(spaces.back().extractIterator(rewriter, loc));
338 auto getFilteredIters = [&iters](I64BitSet caseBits) {
340 SmallVector<SparseIterator *> validIters;
341 for (
auto idx : caseBits.bits())
342 validIters.push_back(iters[idx].
get());
347 SmallVector<Value> userReduc;
349 llvm::append_range(userReduc, r);
356 for (
auto [r, caseBits] :
357 llvm::zip_equal(newBlocks, op.getRegionDefinedSpaces())) {
358 assert(caseBits.count() > 0 &&
"Complement space not implemented");
361 SmallVector<SparseIterator *> validIters = getFilteredIters(caseBits);
363 if (validIters.size() > 1) {
364 auto [loop, loopCrd] =
373 SmallVector<Region *> subCases =
374 op.getSubCasesOf(r->getParent()->getRegionNumber());
375 SmallVector<Block *> newBlocks, oldBlocks;
376 for (Region *r : subCases) {
377 newBlocks.push_back(&r->front());
378 oldBlocks.push_back(newToOldBlockMap[newBlocks.back()]);
380 assert(!subCases.empty());
383 rewriter, loc, op, loopCrd, iters, newBlocks, oldBlocks, userReduc);
385 SmallVector<Value> nextIterYields(res);
387 for (SparseIterator *it : validIters) {
388 Value cmp = arith::CmpIOp::create(
389 rewriter, loc, arith::CmpIPredicate::eq, it->getCrd(), loopCrd);
390 it->forwardIf(rewriter, loc, cmp);
391 llvm::append_range(nextIterYields, it->getCursor());
393 scf::YieldOp::create(rewriter, loc, nextIterYields);
396 rewriter.setInsertionPointAfter(loop);
397 ValueRange iterVals = loop->getResults().drop_front(userReduc.size());
398 for (SparseIterator *it : validIters)
399 iterVals = it->linkNewScope(iterVals);
400 assert(iterVals.empty());
402 ValueRange curResult = loop->getResults().take_front(userReduc.size());
403 userReduc.assign(curResult.begin(), curResult.end());
406 assert(caseBits.count() == 1);
410 rewriter, loc, validIters.front(), userReduc,
412 [block](PatternRewriter &rewriter, Location loc, Region &dstRegion,
415 SmallVector<Value> blockArgs(reduc);
416 blockArgs.push_back(it->deref(rewriter, loc));
417 llvm::append_range(blockArgs, it->getCursor());
419 Block *dstBlock = &dstRegion.getBlocks().front();
420 rewriter.inlineBlockBefore(
421 block, dstBlock, rewriter.getInsertionPoint(), blockArgs);
422 auto yield = llvm::cast<sparse_tensor::YieldOp>(dstBlock->back());
423 SmallVector<Value> result(yield.getResults());
424 rewriter.eraseOp(yield);
428 userReduc.assign(curResult.begin(), curResult.end());
432 rewriter.replaceOp(op, userReduc);
440 addConversion([](
Type type) {
return type; });
444 addSourceMaterialization([](
OpBuilder &builder, IterSpaceType spTp,
446 return UnrealizedConversionCastOp::create(builder, loc,
TypeRange(spTp),
456 patterns.add<ExtractIterSpaceConverter, ExtractValOpConverter,
457 SparseIterateOpConverter, SparseCoIterateOpConverter>(
static ValueRange genLoopWithIterator(PatternRewriter &rewriter, Location loc, SparseIterator *it, ValueRange reduc, function_ref< SmallVector< Value >(PatternRewriter &rewriter, Location loc, Region &loopBody, SparseIterator *it, ValueRange reduc)> bodyBuilder)
static void convertLevelType(SparseTensorEncodingAttr enc, Level lvl, SmallVectorImpl< Type > &fields)
static std::optional< LogicalResult > convertIteratorType(IteratorType itTp, SmallVectorImpl< Type > &fields)
static ValueRange genCoIterateBranchNest(PatternRewriter &rewriter, Location loc, CoIterateOp op, Value loopCrd, ArrayRef< std::unique_ptr< SparseIterator > > iters, ArrayRef< Block * > newBlocks, ArrayRef< Block * > oldBlocks, ArrayRef< Value > userReduc)
static std::optional< LogicalResult > convertIterSpaceType(IterSpaceType itSp, SmallVectorImpl< Type > &fields)
Block represents an ordered list of Operations.
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
BlockArgListType getArguments()
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
unsigned getRegionNumber()
Return the number of this region in the parent operation.
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
A simple wrapper to encode a bitset of (at most 64) levels, currently used by sparse_tensor....
iterator_range< const_set_bits_iterator > bits() const
static SparseIterationSpace fromValues(IterSpaceType dstTp, ValueRange values, unsigned tid)
Helper class that generates loop conditions, etc, to traverse a sparse tensor level.
ValueRange forward(OpBuilder &b, Location l)
std::pair< Value, ValueRange > genWhileCond(OpBuilder &b, Location l, ValueRange vs)
virtual std::pair< Value, Value > genForCond(OpBuilder &b, Location l)
ValueRange linkNewScope(ValueRange pos)
ValueRange getCursor() const
virtual bool iteratableByFor() const
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Value constantI1(OpBuilder &builder, Location loc, bool b)
Generates a constant of i1 type.
uint64_t Level
The type of level identifiers and level-ranks.
std::pair< Operation *, Value > genCoIteration(OpBuilder &builder, Location loc, ArrayRef< SparseIterator * > iters, MutableArrayRef< Value > reduc, Value uniIdx, bool userReducFirst=false)
bool isDenseLT(LevelType lt)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void populateLowerSparseIterationToSCFPatterns(const TypeConverter &converter, RewritePatternSet &patterns)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
llvm::function_ref< Fn > function_ref
SparseIterationTypeConverter()