32template <
typename SubClass,
typename SourceOp>
34 using OpRewritePattern<SourceOp>::OpRewritePattern;
35 using OpAdaptor =
typename SourceOp::Adaptor;
37 LogicalResult matchAndRewrite(SourceOp op,
38 PatternRewriter &rewriter)
const override {
39 Location loc = op.getLoc();
43 SmallVector<Value> deMappedIns(op->getOperands());
44 for (Value &in : deMappedIns) {
47 ReinterpretMapOp::create(rewriter, loc, stt->getDemappedType(), in);
53 OpAdaptor adaptor(deMappedIns, op);
54 LogicalResult status =
55 static_cast<const SubClass *
>(
this)->rewriteOp(op, adaptor, rewriter);
62 explicit AffineDimCollector(
unsigned dimNum) : dims(dimNum) {};
63 void visitDimExpr(AffineDimExpr expr) { dims.set(expr.
getPosition()); }
68struct AffineExprAdmissibleVisitor
70 explicit AffineExprAdmissibleVisitor(
bool isOutput) : isOutput(isOutput) {};
73 void visitAddExpr(AffineBinaryOpExpr expr) {
77 void visitMulExpr(AffineBinaryOpExpr expr) {
83 void visitModExpr(AffineBinaryOpExpr expr) { admissible =
false; }
84 void visitFloorDivExpr(AffineBinaryOpExpr expr) { admissible =
false; }
85 void visitCeilDivExpr(AffineBinaryOpExpr expr) { admissible =
false; }
86 operator bool() {
return admissible; }
89 bool admissible =
true;
96using InadmissInfo = std::pair<BitVector, BitVector>;
108 AffineDimCollector collector(map.
getNumDims());
109 for (
unsigned lvl = 0, e = map.
getNumResults(); lvl < e; lvl++) {
110 AffineExprAdmissibleVisitor admissible(isOutput);
111 admissible.walkPostOrder(map.
getResult(lvl));
116 collector.walkPostOrder(map.
getResult(lvl));
119 ret.second = collector.dims;
154 auto [inAdLvls, usedDims] = info;
162 assert(lvl2Idx.getNumResults() <= idxMap.
getNumDims());
163 if (lvl2Idx.getNumResults() != idxMap.
getNumDims()) {
170 AffineDimCollector usedInLvl(idxMap.
getNumDims());
172 usedInLvl.walkPostOrder(e);
174 unsigned curUsedDimID = 0;
175 unsigned curUnusedDimID = lvl2Idx.getNumDims();
177 BitVector unused = usedInLvl.dims.flip();
178 for (
unsigned i = 0; i < idxMap.
getNumDims(); i++) {
182 results.push_back(lvl2Idx.getResult(curUsedDimID++));
185 AffineMap::get(lvl2Idx.getNumDims() + unused.count(), 0, results, ctx);
187 assert(lvl2Idx.getNumResults() == idxMap.
getNumDims());
194 unsigned curRepID = 0;
195 unsigned curOriID = inAdLvls.count();
200 for (
unsigned l : inAdLvls.set_bits()) {
210 AffineDimCollector collector(idxMap.
getNumDims());
211 collector.walkPostOrder(lvlExp);
213 assert(collector.dims.count() == 1);
214 transItTps.push_back(itTps[collector.dims.find_first()]);
217 for (
unsigned d = 0, e = idxMap.
getNumDims(); d < e; d++) {
218 if (usedDims.test(d)) {
222 results.push_back(lvl2Idx.getResult(d).replaceDims(dimRep));
228 transItTps.push_back(itTps[d]);
231 unsigned numDim = idxMap.
getNumDims() - usedDims.count() + inAdLvls.count();
233 itTps.assign(transItTps.begin(), transItTps.end());
241static std::optional<std::pair<ArrayAttr, ArrayAttr>>
247 for (
unsigned i = 0, e = idxMapArray.size(); i < e; i++) {
250 if (stt && !stt->isIdentity()) {
253 idxMapArray[i] = dim2Lvl.
compose(idxMapArray[i]);
261 if (ShapedType::isStatic(lvlSz)) {
269 cstMapping.try_emplace(divExp, c0);
273 cstMapping.try_emplace(modExp, lvlExp);
277 unsigned boundedNum = 0;
282 for (
OpOperand &operand : op->getOpOperands()) {
285 if (!stt || !stt->getEncoding())
288 unsigned tid = operand.getOperandNumber();
289 bool isOutput = &operand == op.getDpsInitOperand(0);
292 auto [inAdLvls, dimExprs] = inAdInfo;
293 for (
unsigned d : dimExprs.set_bits()) {
301 if (inAdLvls.count() != 0) {
306 unsigned position = 0;
307 for (
unsigned lvl : inAdLvls.set_bits()) {
309 populateCstMapping(cstMapping, position, lvlSz);
316 for (
unsigned tid = 0, e = idxMapArray.size(); tid < e; tid++) {
317 AffineMap transMap = idxMapArray[tid].compose(lvl2Idx);
318 idxMapArray[tid] = transMap.
replace(
323 boundedNum += inAdLvls.count();
329 llvm::map_to_vector(itTps, [ctx](
auto itTp) ->
Attribute {
330 return linalg::IteratorTypeAttr::get(ctx, itTp);
340 return ReinterpretMapOp::create(builder, val.
getLoc(), enc.withoutDimToLvl(),
347 return ReinterpretMapOp::create(builder, val.
getLoc(), enc, val);
353 assert(outs.size() == types.size());
354 for (
auto [r, t] : llvm::zip(ret, types))
355 if (r.getType() != t)
356 r = ReinterpretMapOp::create(rewriter, r.getLoc(), t, r);
367struct GenericOpReinterpretMap
368 :
public DemapInsRewriter<GenericOpReinterpretMap, linalg::GenericOp> {
370 using DemapInsRewriter::DemapInsRewriter;
371 LogicalResult rewriteOp(linalg::GenericOp linalgOp, OpAdaptor adaptor,
372 PatternRewriter &rewriter)
const {
375 if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasPureTensorSemantics() ||
384 linalgOp,
"the sparse kernel can not be sparsified.");
387 Value res = linalgOp.getResult(0);
389 auto [idxMap, itTp] = *transMap;
392 linalgOp.setIndexingMapsAttr(idxMap);
393 linalgOp.setIteratorTypesAttr(itTp);
395 linalgOp.getInputsMutable().assign(adaptor.getInputs());
396 linalgOp.getDpsInitsMutable().assign(adaptor.getOutputs());
397 res.
setType(adaptor.getOutputs()[0].getType());
401 if (stt && stt->hasEncoding()) {
402 Value t =
genRemap(rewriter, stt->getEncoding(), res);
410 GenericOpScheduler(MLIRContext *context,
412 : OpRewritePattern<linalg::GenericOp>(context), strategy(strategy) {}
414 LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
415 PatternRewriter &rewriter)
const override {
416 if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasPureTensorSemantics() ||
422 const StringRef sorted =
"sorted";
423 if (linalgOp->hasAttr(sorted))
428 bool isAdmissible =
false;
434 const auto allMasks = {SortMask::kIncludeAll, SortMask::kIncludeDense,
435 SortMask::kIncludeDenseInput,
436 SortMask::kIncludeDenseOutput,
437 SortMask::kSparseOnly};
438 for (
const SortMask mask : allMasks) {
439 order = scheduler.sort(mask);
441 if (isAdmissibleOrder(linalgOp, order)) {
451 if (
failed(resolveCycle(scheduler, linalgOp, rewriter))) {
453 linalgOp,
"the sparse kernel can not be scheduled: loop detected.");
460 linalgOp,
"the sparse kernel can not be scheduled.");
465 linalgOp->setAttr(sorted, rewriter.
getBoolAttr(
true));
474 ArrayAttr preItTypes = linalgOp.getIteratorTypesAttr();
475 SmallVector<Attribute> curItTypes;
476 curItTypes.reserve(preItTypes.size());
478 unsigned loopID = llvm::cast<AffineDimExpr>(expr).getPosition();
479 curItTypes.push_back(preItTypes[loopID]);
484 SmallVector<AffineMap> idxMaps = linalgOp.getIndexingMapsArray();
485 for (AffineMap &idxMap : idxMaps)
486 idxMap = idxMap.compose(order);
490 linalgOp.setIteratorTypesAttr(rewriter.
getArrayAttr(curItTypes));
498 static bool isAdmissibleOrder(linalg::GenericOp linalgOp, AffineMap order) {
502 OpOperand *
lhs = linalgOp.getDpsInitOperand(0);
504 const auto iteratorTypes = linalgOp.getIteratorTypesArray();
505 for (
const AffineExpr l : order.
getResults()) {
506 unsigned loopId = llvm::cast<AffineDimExpr>(l).getPosition();
508 cast<linalg::IteratorTypeAttr>(linalgOp.getIteratorTypes()[loopId]);
516 return static_cast<int64_t
>(nest) >= linalgOp.getRank(
lhs) - 1;
520 static LogicalResult resolveCycle(IterationGraphSorter &scheduler,
521 linalg::LinalgOp linalgOp,
522 PatternRewriter &rewriter) {
525 for (OpOperand *t : linalgOp.getDpsInputOperands()) {
526 Value tval = t->get();
530 AffineMap idxMap = linalgOp.getMatchingIndexingMap(t);
531 bool hasCompExpr = llvm::any_of(idxMap.
getResults(), [](AffineExpr exp) {
532 return !llvm::isa<AffineDimExpr>(exp);
534 if (!srcEnc || hasCompExpr)
538 AffineMap order = scheduler.
sort(SortMask::kSparseOnly, tval);
546 assert(stt.isIdentity());
549 idxMap = idxMap.
compose(order);
558 SmallVector<std::pair<unsigned, unsigned>> lvlSeq;
560 unsigned lvl = llvm::cast<AffineDimExpr>(expr).getPosition();
561 lvlSeq.push_back(std::make_pair(lvl, lvlSeq.size()));
563 llvm::sort(lvlSeq, llvm::less_first());
564 SmallVector<unsigned> perm =
565 llvm::to_vector(llvm::make_second_range(lvlSeq));
568 assert(!dimToLvl.isIdentity());
572 RankedTensorType dstTp = stt.withDimToLvl(dimToLvl).getRankedTensorType();
573 Value dst = ConvertOp::create(rewriter, tval.
getLoc(), dstTp, tval);
575 linalgOp->setOperand(t->getOperandNumber(), dst);
581 bufferization::DeallocTensorOp::create(rewriter, dst.
getLoc(), dst);
598template <
typename AllocOp>
600 using OpRewritePattern<AllocOp>::OpRewritePattern;
601 LogicalResult matchAndRewrite(AllocOp op,
602 PatternRewriter &rewriter)
const override {
606 Location loc = op.getLoc();
609 SmallVector<Value> maxDimCrds;
610 maxDimCrds.reserve(stt.getDimRank());
612 for (int64_t dimSz : stt.getDimShape()) {
613 if (ShapedType::isDynamic(dimSz)) {
614 Value maxCrd = arith::SubIOp::create(rewriter, loc, dynSz.front(),
616 maxDimCrds.push_back(maxCrd);
617 dynSz = dynSz.drop_front();
619 maxDimCrds.push_back(
constantIndex(rewriter, loc, dimSz - 1));
623 ValueRange maxLvlCrds = stt.translateCrds(rewriter, loc, maxDimCrds,
624 CrdTransDirectionKind::dim2lvl);
625 auto lvlShape = stt.getLvlShape();
626 SmallVector<Value> dynLvlSzs;
627 for (
unsigned i = 0, e = lvlShape.size(); i < e; i++) {
628 if (ShapedType::isDynamic(lvlShape[i])) {
629 Value sz = arith::AddIOp::create(rewriter, loc, maxLvlCrds[i],
631 dynLvlSzs.push_back(sz);
635 assert(dynSz.empty());
637 op->setOperands(dynLvlSzs);
638 op.getResult().setType(stt.getDemappedType());
642 Value t =
genRemap(rewriter, stt.getEncoding(), op.getResult());
648struct TensorInsertDemapper
649 :
public DemapInsRewriter<TensorInsertDemapper, tensor::InsertOp> {
650 using DemapInsRewriter::DemapInsRewriter;
651 LogicalResult rewriteOp(tensor::InsertOp op, OpAdaptor adaptor,
652 PatternRewriter &rewriter)
const {
656 Location loc = op.getLoc();
658 ValueRange lvlCrd = stt.translateCrds(rewriter, loc, op.getIndices(),
659 CrdTransDirectionKind::dim2lvl);
660 auto insertOp = tensor::InsertOp::create(rewriter, loc, op.getScalar(),
661 adaptor.getDest(), lvlCrd);
663 Value out =
genRemap(rewriter, stt.getEncoding(), insertOp.getResult());
671 LogicalResult matchAndRewrite(AssembleOp op,
672 PatternRewriter &rewriter)
const override {
679 op, [&op, &stt]() { op.getResult().setType(stt.getDemappedType()); });
681 Value out =
genRemap(rewriter, stt.getEncoding(), op.getResult());
687struct SparseDisassembleDemapper
688 :
public DemapInsRewriter<SparseDisassembleDemapper, DisassembleOp> {
689 using DemapInsRewriter::DemapInsRewriter;
690 LogicalResult rewriteOp(DisassembleOp op, OpAdaptor adaptor,
691 PatternRewriter &rewriter)
const {
697 op.getTensorMutable().assign(adaptor.getTensor());
703struct ForeachOpDemapper
704 :
public DemapInsRewriter<ForeachOpDemapper, ForeachOp> {
705 using DemapInsRewriter::DemapInsRewriter;
706 LogicalResult rewriteOp(ForeachOp op, OpAdaptor adaptor,
707 PatternRewriter &rewriter)
const {
714 if (
auto constOp = op.getTensor().getDefiningOp<arith::ConstantOp>())
715 if (
auto attr = dyn_cast<SparseElementsAttr>(constOp.getValue()))
718 Location loc = op.getLoc();
721 SmallVector<Type> prevRetTps(op.getResultTypes());
724 op.getTensorMutable().assign(adaptor.getTensor());
725 op.getInitArgsMutable().assign(adaptor.getInitArgs());
727 for (
auto r : op.getResults())
729 r.setType(stt->getDemappedType());
733 SmallVector<Type> blockArgTps(lvlRank, rewriter.
getIndexType());
734 blockArgTps.push_back(srcStt.getElementType());
735 blockArgTps.append(adaptor.getInitArgs().getTypes().begin(),
736 adaptor.getInitArgs().getTypes().end());
737 Block *body = op.getBody();
740 for (Type t : blockArgTps)
747 ValueRange dimCrds = srcStt.translateCrds(rewriter, loc, lvlCrds,
748 CrdTransDirectionKind::lvl2dim);
750 body->
getArguments().take_front(srcStt.getDimRank()), dimCrds);
753 unsigned numInitArgs = op.getInitArgs().size();
761 SmallVector<Value> reMappedArgs =
768 if (numInitArgs != 0) {
772 stt && !stt->isIdentity()) {
774 genDemap(rewriter, stt->getEncoding(), yield.getSingleResult());
775 YieldOp::create(rewriter, loc, y);
782 SmallVector<Value> outs =
787 for (
auto [from, to] : llvm::zip(op.getResults(), outs))
806 patterns.add<TensorAllocDemapper<bufferization::AllocTensorOp>,
807 TensorAllocDemapper<tensor::EmptyOp>, SparseAssembleDemapper,
808 SparseDisassembleDemapper, TensorInsertDemapper,
809 ForeachOpDemapper>(
patterns.getContext());
static Value genDemap(OpBuilder &builder, SparseTensorEncodingAttr enc, Value val)
static SmallVector< Value > remapValueRange(OpBuilder &rewriter, TypeRange types, ValueRange outs)
static AffineMap genReplaceDimToLvlMap(const InadmissInfo &info, AffineMap idxMap, SmallVector< utils::IteratorType > &itTps)
static std::optional< std::pair< ArrayAttr, ArrayAttr > > translateMap(linalg::GenericOp op, PatternRewriter &rewriter)
static Value genRemap(OpBuilder &builder, SparseTensorEncodingAttr enc, Value val)
static InadmissInfo collectInadmissInfo(AffineMap map, bool isOutput)
unsigned getPosition() const
See documentation for AffineExprVisitorBase.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
MLIRContext * getContext() const
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const
Sparse replace method.
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isIdentity() const
Returns true if this affine map is an identity affine map.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Attributes are known-constant values of operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
BlockArgListType getArguments()
void eraseArgument(unsigned index)
Erase the argument at 'index' and remove it from the argument list.
BoolAttr getBoolAttr(bool value)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
This class provides an abstraction over the various different ranges of value types.
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static IterationGraphSorter fromGenericOp(linalg::GenericOp genericOp, sparse_tensor::LoopOrderingStrategy strategy)
Factory method that constructs an iteration graph sorter for the given linalg.generic operation with ...
AffineMap sort(SortMask mask, Value ignored=nullptr)
Returns a permutation that represents the scheduled loop order.
Level getLvlRank() const
Returns the level-rank.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
bool hasAnySparseOperandOrResult(Operation *op)
Returns true iff MLIR operand has any sparse operand or result.
uint64_t Level
The type of level identifiers and level-ranks.
LoopOrderingStrategy
Defines a strategy for loop ordering during sparse code generation.
AffineMap inferLvlToDim(AffineMap dimToLvl, MLIRContext *context)
Given the dimToLvl map, infers the lvlToDim map, or returns empty Affine map when inference fails.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
std::optional< SparseTensorType > tryGetSparseTensorType(Value val)
bool hasAnyNonIdentityOperandsOrResults(Operation *op)
Returns true iff MLIR operation has any sparse tensor with non-identity dim2lvl maps.
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
SortMask
Iteration graph sorting mask,.
bool hasAnySparseResult(Operation *op)
Returns true iff MLIR operand has any sparse result.
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
ReinterpretMapScope
Defines a scope for reinterpret map pass.
const FrozenRewritePatternSet & patterns
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
void populateSparseReinterpretMap(RewritePatternSet &patterns, ReinterpretMapScope scope, sparse_tensor::LoopOrderingStrategy strategy=sparse_tensor::LoopOrderingStrategy::kDefault)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...