32 template <
typename SubClass,
typename SourceOp>
35 using OpAdaptor =
typename SourceOp::Adaptor;
37 LogicalResult matchAndRewrite(SourceOp op,
44 for (
Value &in : deMappedIns) {
47 ReinterpretMapOp::create(rewriter, loc, stt->getDemappedType(), in);
53 OpAdaptor adaptor(deMappedIns, op);
54 LogicalResult status =
55 static_cast<const SubClass *
>(
this)->rewriteOp(op, adaptor, rewriter);
56 return changed ? success() : status;
62 explicit AffineDimCollector(
unsigned dimNum) : dims(dimNum) {};
68 struct AffineExprAdmissibleVisitor
70 explicit AffineExprAdmissibleVisitor(
bool isOutput) : isOutput(isOutput) {};
86 operator bool() {
return admissible; }
89 bool admissible =
true;
96 using InadmissInfo = std::pair<BitVector, BitVector>;
108 AffineDimCollector collector(map.
getNumDims());
109 for (
unsigned lvl = 0, e = map.
getNumResults(); lvl < e; lvl++) {
110 AffineExprAdmissibleVisitor admissible(isOutput);
111 admissible.walkPostOrder(map.
getResult(lvl));
116 collector.walkPostOrder(map.
getResult(lvl));
119 ret.second = collector.dims;
154 auto [inAdLvls, usedDims] = info;
162 assert(lvl2Idx.getNumResults() <= idxMap.
getNumDims());
163 if (lvl2Idx.getNumResults() != idxMap.
getNumDims()) {
170 AffineDimCollector usedInLvl(idxMap.
getNumDims());
172 usedInLvl.walkPostOrder(e);
174 unsigned curUsedDimID = 0;
175 unsigned curUnusedDimID = lvl2Idx.getNumDims();
177 BitVector unused = usedInLvl.dims.flip();
178 for (
unsigned i = 0; i < idxMap.
getNumDims(); i++) {
182 results.push_back(lvl2Idx.getResult(curUsedDimID++));
185 AffineMap::get(lvl2Idx.getNumDims() + unused.count(), 0, results, ctx);
187 assert(lvl2Idx.getNumResults() == idxMap.
getNumDims());
194 unsigned curRepID = 0;
195 unsigned curOriID = inAdLvls.count();
200 for (
unsigned l : inAdLvls.set_bits()) {
210 AffineDimCollector collector(idxMap.
getNumDims());
211 collector.walkPostOrder(lvlExp);
213 assert(collector.dims.count() == 1);
214 transItTps.push_back(itTps[collector.dims.find_first()]);
217 for (
unsigned d = 0, e = idxMap.
getNumDims(); d < e; d++) {
218 if (usedDims.test(d)) {
222 results.push_back(lvl2Idx.getResult(d).replaceDims(dimRep));
228 transItTps.push_back(itTps[d]);
231 unsigned numDim = idxMap.
getNumDims() - usedDims.count() + inAdLvls.count();
233 itTps.assign(transItTps.begin(), transItTps.end());
241 static std::optional<std::pair<ArrayAttr, ArrayAttr>>
247 for (
unsigned i = 0, e = idxMapArray.size(); i < e; i++) {
248 Value tensor = op->getOpOperand(i).get();
250 if (stt && !stt->isIdentity()) {
253 idxMapArray[i] = dim2Lvl.
compose(idxMapArray[i]);
260 unsigned pos, int64_t lvlSz) {
261 if (ShapedType::isStatic(lvlSz)) {
269 cstMapping.try_emplace(divExp, c0);
273 cstMapping.try_emplace(modExp, lvlExp);
277 unsigned boundedNum = 0;
282 for (
OpOperand &operand : op->getOpOperands()) {
285 if (!stt || !stt->getEncoding())
288 unsigned tid = operand.getOperandNumber();
289 bool isOutput = &operand == op.getDpsInitOperand(0);
292 auto [inAdLvls, dimExprs] = inAdInfo;
293 for (
unsigned d : dimExprs.set_bits()) {
301 if (inAdLvls.count() != 0) {
306 unsigned position = 0;
307 for (
unsigned lvl : inAdLvls.set_bits()) {
308 int64_t lvlSz = lvlShape[lvl];
309 populateCstMapping(cstMapping, position, lvlSz);
316 for (
unsigned tid = 0, e = idxMapArray.size(); tid < e; tid++) {
317 AffineMap transMap = idxMapArray[tid].compose(lvl2Idx);
318 idxMapArray[tid] = transMap.
replace(
323 boundedNum += inAdLvls.count();
329 llvm::map_to_vector(itTps, [ctx](
auto itTp) ->
Attribute {
340 return ReinterpretMapOp::create(builder, val.
getLoc(), enc.withoutDimToLvl(),
347 return ReinterpretMapOp::create(builder, val.
getLoc(), enc, val);
353 assert(outs.size() == types.size());
354 for (
auto [r, t] : llvm::zip(ret, types))
355 if (r.getType() != t)
356 r = ReinterpretMapOp::create(rewriter, r.getLoc(), t, r);
367 struct GenericOpReinterpretMap
368 :
public DemapInsRewriter<GenericOpReinterpretMap, linalg::GenericOp> {
370 using DemapInsRewriter::DemapInsRewriter;
371 LogicalResult rewriteOp(linalg::GenericOp linalgOp, OpAdaptor adaptor,
375 if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasPureTensorSemantics() ||
384 linalgOp,
"the sparse kernel can not be sparsified.");
387 Value res = linalgOp.getResult(0);
389 auto [idxMap, itTp] = *transMap;
392 linalgOp.setIndexingMapsAttr(idxMap);
393 linalgOp.setIteratorTypesAttr(itTp);
395 linalgOp.getInputsMutable().assign(adaptor.getInputs());
396 linalgOp.getDpsInitsMutable().assign(adaptor.getOutputs());
397 res.
setType(adaptor.getOutputs()[0].getType());
401 if (stt && stt->hasEncoding()) {
414 LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
416 if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasPureTensorSemantics() ||
422 const StringRef sorted =
"sorted";
423 if (linalgOp->hasAttr(sorted))
428 bool isAdmissible =
false;
438 for (
const SortMask mask : allMasks) {
439 order = scheduler.sort(mask);
441 if (isAdmissibleOrder(linalgOp, order)) {
451 if (
failed(resolveCycle(scheduler, linalgOp, rewriter))) {
453 linalgOp,
"the sparse kernel can not be scheduled: loop detected.");
460 linalgOp,
"the sparse kernel can not be scheduled.");
465 linalgOp->setAttr(sorted, rewriter.
getBoolAttr(
true));
474 ArrayAttr preItTypes = linalgOp.getIteratorTypesAttr();
476 curItTypes.reserve(preItTypes.size());
478 unsigned loopID = llvm::cast<AffineDimExpr>(expr).getPosition();
479 curItTypes.push_back(preItTypes[loopID]);
486 idxMap = idxMap.compose(order);
490 linalgOp.setIteratorTypesAttr(rewriter.
getArrayAttr(curItTypes));
498 static bool isAdmissibleOrder(linalg::GenericOp linalgOp,
AffineMap order) {
502 OpOperand *lhs = linalgOp.getDpsInitOperand(0);
504 const auto iteratorTypes = linalgOp.getIteratorTypesArray();
506 unsigned loopId = llvm::cast<AffineDimExpr>(l).getPosition();
508 cast<linalg::IteratorTypeAttr>(linalgOp.getIteratorTypes()[loopId]);
516 return static_cast<int64_t
>(nest) >= linalgOp.getRank(lhs) - 1;
521 linalg::LinalgOp linalgOp,
525 for (
OpOperand *t : linalgOp.getDpsInputOperands()) {
526 Value tval = t->get();
530 AffineMap idxMap = linalgOp.getMatchingIndexingMap(t);
532 return !llvm::isa<AffineDimExpr>(exp);
534 if (!srcEnc || hasCompExpr)
546 assert(stt.isIdentity());
549 idxMap = idxMap.
compose(order);
560 unsigned lvl = llvm::cast<AffineDimExpr>(expr).getPosition();
561 lvlSeq.push_back(std::make_pair(lvl, lvlSeq.size()));
563 llvm::sort(lvlSeq, llvm::less_first());
565 llvm::to_vector(llvm::make_second_range(lvlSeq));
568 assert(!dimToLvl.isIdentity());
572 RankedTensorType dstTp = stt.withDimToLvl(dimToLvl).getRankedTensorType();
573 Value dst = ConvertOp::create(rewriter, tval.
getLoc(), dstTp, tval);
575 linalgOp->setOperand(t->getOperandNumber(), dst);
581 bufferization::DeallocTensorOp::create(rewriter, dst.
getLoc(), dst);
598 template <
typename AllocOp>
601 LogicalResult matchAndRewrite(AllocOp op,
610 maxDimCrds.reserve(stt.getDimRank());
612 for (int64_t dimSz : stt.getDimShape()) {
613 if (ShapedType::isDynamic(dimSz)) {
614 Value maxCrd = arith::SubIOp::create(rewriter, loc, dynSz.front(),
616 maxDimCrds.push_back(maxCrd);
617 dynSz = dynSz.drop_front();
619 maxDimCrds.push_back(
constantIndex(rewriter, loc, dimSz - 1));
623 ValueRange maxLvlCrds = stt.translateCrds(rewriter, loc, maxDimCrds,
624 CrdTransDirectionKind::dim2lvl);
625 auto lvlShape = stt.getLvlShape();
627 for (
unsigned i = 0, e = lvlShape.size(); i < e; i++) {
628 if (ShapedType::isDynamic(lvlShape[i])) {
629 Value sz = arith::AddIOp::create(rewriter, loc, maxLvlCrds[i],
631 dynLvlSzs.push_back(sz);
635 assert(dynSz.empty());
637 op->setOperands(dynLvlSzs);
638 op.getResult().setType(stt.getDemappedType());
642 Value t =
genRemap(rewriter, stt.getEncoding(), op.getResult());
648 struct TensorInsertDemapper
649 :
public DemapInsRewriter<TensorInsertDemapper, tensor::InsertOp> {
650 using DemapInsRewriter::DemapInsRewriter;
651 LogicalResult rewriteOp(tensor::InsertOp op, OpAdaptor adaptor,
658 ValueRange lvlCrd = stt.translateCrds(rewriter, loc, op.getIndices(),
659 CrdTransDirectionKind::dim2lvl);
660 auto insertOp = tensor::InsertOp::create(rewriter, loc, op.getScalar(),
661 adaptor.getDest(), lvlCrd);
663 Value out =
genRemap(rewriter, stt.getEncoding(), insertOp.getResult());
671 LogicalResult matchAndRewrite(AssembleOp op,
679 op, [&op, &stt]() { op.getResult().setType(stt.getDemappedType()); });
681 Value out =
genRemap(rewriter, stt.getEncoding(), op.getResult());
687 struct SparseDisassembleDemapper
688 :
public DemapInsRewriter<SparseDisassembleDemapper, DisassembleOp> {
689 using DemapInsRewriter::DemapInsRewriter;
690 LogicalResult rewriteOp(DisassembleOp op, OpAdaptor adaptor,
697 op.getTensorMutable().assign(adaptor.getTensor());
703 struct ForeachOpDemapper
704 :
public DemapInsRewriter<ForeachOpDemapper, ForeachOp> {
705 using DemapInsRewriter::DemapInsRewriter;
706 LogicalResult rewriteOp(ForeachOp op, OpAdaptor adaptor,
714 if (
auto constOp = op.getTensor().getDefiningOp<arith::ConstantOp>())
715 if (
auto attr = dyn_cast<SparseElementsAttr>(constOp.getValue()))
724 op.getTensorMutable().assign(adaptor.getTensor());
725 op.getInitArgsMutable().assign(adaptor.getInitArgs());
727 for (
auto r : op.getResults())
729 r.setType(stt->getDemappedType());
734 blockArgTps.push_back(srcStt.getElementType());
735 blockArgTps.append(adaptor.getInitArgs().getTypes().begin(),
736 adaptor.getInitArgs().getTypes().end());
737 Block *body = op.getBody();
740 for (
Type t : blockArgTps)
747 ValueRange dimCrds = srcStt.translateCrds(rewriter, loc, lvlCrds,
748 CrdTransDirectionKind::lvl2dim);
750 body->
getArguments().take_front(srcStt.getDimRank()), dimCrds);
753 unsigned numInitArgs = op.getInitArgs().size();
768 if (numInitArgs != 0) {
772 stt && !stt->isIdentity()) {
774 genDemap(rewriter, stt->getEncoding(), yield.getSingleResult());
775 YieldOp::create(rewriter, loc, y);
787 for (
auto [from, to] : llvm::zip(op.getResults(), outs))
806 patterns.add<TensorAllocDemapper<bufferization::AllocTensorOp>,
807 TensorAllocDemapper<tensor::EmptyOp>, SparseAssembleDemapper,
808 SparseDisassembleDemapper, TensorInsertDemapper,
809 ForeachOpDemapper>(
patterns.getContext());
static Value genDemap(OpBuilder &builder, SparseTensorEncodingAttr enc, Value val)
static SmallVector< Value > remapValueRange(OpBuilder &rewriter, TypeRange types, ValueRange outs)
static AffineMap genReplaceDimToLvlMap(const InadmissInfo &info, AffineMap idxMap, SmallVector< utils::IteratorType > &itTps)
static std::optional< std::pair< ArrayAttr, ArrayAttr > > translateMap(linalg::GenericOp op, PatternRewriter &rewriter)
static Value genRemap(OpBuilder &builder, SparseTensorEncodingAttr enc, Value val)
static InadmissInfo collectInadmissInfo(AffineMap map, bool isOutput)
Affine binary operation expression.
A dimensional identifier appearing in an affine expression.
unsigned getPosition() const
See documentation for AffineExprVisitorBase.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
MLIRContext * getContext() const
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const
Sparse replace method.
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isIdentity() const
Returns true if this affine map is an identity affine map.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
BlockArgListType getArguments()
void eraseArgument(unsigned index)
Erase the argument at 'index' and remove it from the argument list.
BoolAttr getBoolAttr(bool value)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static IterationGraphSorter fromGenericOp(linalg::GenericOp genericOp, sparse_tensor::LoopOrderingStrategy strategy)
Factory method that constructs an iteration graph sorter for the given linalg.generic operation with ...
AffineMap sort(SortMask mask, Value ignored=nullptr)
Returns a permutation that represents the scheduled loop order.
Level getLvlRank() const
Returns the level-rank.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
bool hasAnySparseOperandOrResult(Operation *op)
Returns true iff MLIR operand has any sparse operand or result.
uint64_t Level
The type of level identifiers and level-ranks.
std::optional< SparseTensorType > tryGetSparseTensorType(Value val)
LoopOrderingStrategy
Defines a strategy for loop ordering during sparse code generation.
AffineMap inferLvlToDim(AffineMap dimToLvl, MLIRContext *context)
Given the dimToLvl map, infers the lvlToDim map, or returns empty Affine map when inference fails.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
bool hasAnyNonIdentityOperandsOrResults(Operation *op)
Returns true iff MLIR operation has any sparse tensor with non-identity dim2lvl maps.
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
SortMask
Iteration graph sorting mask,.
bool hasAnySparseResult(Operation *op)
Returns true iff MLIR operand has any sparse result.
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
ReinterpretMapScope
Defines a scope for reinterpret map pass.
const FrozenRewritePatternSet & patterns
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
void populateSparseReinterpretMap(RewritePatternSet &patterns, ReinterpretMapScope scope, sparse_tensor::LoopOrderingStrategy strategy=sparse_tensor::LoopOrderingStrategy::kDefault)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...