32 template <
typename SubClass,
typename SourceOp>
35 using OpAdaptor =
typename SourceOp::Adaptor;
37 LogicalResult matchAndRewrite(SourceOp op,
44 for (
Value &in : deMappedIns) {
47 ReinterpretMapOp::create(rewriter, loc, stt->getDemappedType(), in);
53 OpAdaptor adaptor(deMappedIns, op);
54 LogicalResult status =
55 static_cast<const SubClass *
>(
this)->rewriteOp(op, adaptor, rewriter);
56 return changed ? success() : status;
62 explicit AffineDimCollector(
unsigned dimNum) : dims(dimNum){};
68 struct AffineExprAdmissibleVisitor
70 explicit AffineExprAdmissibleVisitor(
bool isOutput) : isOutput(isOutput){};
86 operator bool() {
return admissible; }
89 bool admissible =
true;
96 using InadmissInfo = std::pair<BitVector, BitVector>;
108 AffineDimCollector collector(map.
getNumDims());
109 for (
unsigned lvl = 0, e = map.
getNumResults(); lvl < e; lvl++) {
110 AffineExprAdmissibleVisitor admissible(isOutput);
111 admissible.walkPostOrder(map.
getResult(lvl));
116 collector.walkPostOrder(map.
getResult(lvl));
119 ret.second = collector.dims;
154 auto [inAdLvls, usedDims] = info;
162 assert(lvl2Idx.getNumResults() <= idxMap.
getNumDims());
163 if (lvl2Idx.getNumResults() != idxMap.
getNumDims()) {
170 AffineDimCollector usedInLvl(idxMap.
getNumDims());
172 usedInLvl.walkPostOrder(e);
174 unsigned curUsedDimID = 0;
175 unsigned curUnusedDimID = lvl2Idx.getNumDims();
177 BitVector unused = usedInLvl.dims.flip();
178 for (
unsigned i = 0; i < idxMap.
getNumDims(); i++) {
182 results.push_back(lvl2Idx.getResult(curUsedDimID++));
185 AffineMap::get(lvl2Idx.getNumDims() + unused.count(), 0, results, ctx);
187 assert(lvl2Idx.getNumResults() == idxMap.
getNumDims());
194 unsigned curRepID = 0;
195 unsigned curOriID = inAdLvls.count();
200 for (
unsigned l : inAdLvls.set_bits()) {
210 AffineDimCollector collector(idxMap.
getNumDims());
211 collector.walkPostOrder(lvlExp);
213 assert(collector.dims.count() == 1);
214 transItTps.push_back(itTps[collector.dims.find_first()]);
217 for (
unsigned d = 0, e = idxMap.
getNumDims(); d < e; d++) {
218 if (usedDims.test(d)) {
222 results.push_back(lvl2Idx.getResult(d).replaceDims(dimRep));
228 transItTps.push_back(itTps[d]);
231 unsigned numDim = idxMap.
getNumDims() - usedDims.count() + inAdLvls.count();
233 itTps.assign(transItTps.begin(), transItTps.end());
241 static std::optional<std::pair<ArrayAttr, ArrayAttr>>
247 for (
unsigned i = 0, e = idxMapArray.size(); i < e; i++) {
248 Value tensor = op->getOpOperand(i).get();
250 if (stt && !stt->isIdentity()) {
253 idxMapArray[i] = dim2Lvl.
compose(idxMapArray[i]);
260 unsigned pos, int64_t lvlSz) {
261 if (ShapedType::isStatic(lvlSz)) {
269 cstMapping.try_emplace(divExp, c0);
273 cstMapping.try_emplace(modExp, lvlExp);
277 unsigned boundedNum = 0;
282 for (
OpOperand &operand : op->getOpOperands()) {
285 if (!stt || !stt->getEncoding())
288 unsigned tid = operand.getOperandNumber();
289 bool isOutput = &operand == op.getDpsInitOperand(0);
292 auto [inAdLvls, dimExprs] = inAdInfo;
293 for (
unsigned d : dimExprs.set_bits()) {
301 if (inAdLvls.count() != 0) {
306 unsigned position = 0;
307 for (
unsigned lvl : inAdLvls.set_bits()) {
308 int64_t lvlSz = lvlShape[lvl];
309 populateCstMapping(cstMapping, position, lvlSz);
316 for (
unsigned tid = 0, e = idxMapArray.size(); tid < e; tid++) {
317 AffineMap transMap = idxMapArray[tid].compose(lvl2Idx);
318 idxMapArray[tid] = transMap.
replace(
323 boundedNum += inAdLvls.count();
329 llvm::map_to_vector(itTps, [ctx](
auto itTp) ->
Attribute {
340 return ReinterpretMapOp::create(builder, val.
getLoc(), enc.withoutDimToLvl(),
347 return ReinterpretMapOp::create(builder, val.
getLoc(), enc, val);
353 assert(outs.size() == types.size());
354 for (
auto [r, t] : llvm::zip(ret, types))
355 if (r.getType() != t)
356 r = ReinterpretMapOp::create(rewriter, r.getLoc(), t, r);
367 struct GenericOpReinterpretMap
368 :
public DemapInsRewriter<GenericOpReinterpretMap, linalg::GenericOp> {
370 using DemapInsRewriter::DemapInsRewriter;
371 LogicalResult rewriteOp(linalg::GenericOp linalgOp, OpAdaptor adaptor,
375 if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasPureTensorSemantics() ||
384 linalgOp,
"the sparse kernel can not be sparsified.");
387 Value res = linalgOp.getResult(0);
389 auto [idxMap, itTp] = *transMap;
392 linalgOp.setIndexingMapsAttr(idxMap);
393 linalgOp.setIteratorTypesAttr(itTp);
395 linalgOp.getInputsMutable().assign(adaptor.getInputs());
396 linalgOp.getDpsInitsMutable().assign(adaptor.getOutputs());
397 res.
setType(adaptor.getOutputs()[0].getType());
401 if (stt && stt->hasEncoding()) {
411 LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
413 if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasPureTensorSemantics() ||
419 const StringRef sorted =
"sorted";
420 if (linalgOp->hasAttr(sorted))
424 bool isAdmissible =
false;
434 for (
const SortMask mask : allMasks) {
435 order = scheduler.sort(mask);
437 if (isAdmissibleOrder(linalgOp, order)) {
447 if (
failed(resolveCycle(scheduler, linalgOp, rewriter))) {
449 linalgOp,
"the sparse kernel can not be scheduled: loop detected.");
456 linalgOp,
"the sparse kernel can not be scheduled.");
461 linalgOp->setAttr(sorted, rewriter.
getBoolAttr(
true));
470 ArrayAttr preItTypes = linalgOp.getIteratorTypesAttr();
472 curItTypes.reserve(preItTypes.size());
474 unsigned loopID = llvm::cast<AffineDimExpr>(expr).getPosition();
475 curItTypes.push_back(preItTypes[loopID]);
482 idxMap = idxMap.compose(order);
486 linalgOp.setIteratorTypesAttr(rewriter.
getArrayAttr(curItTypes));
494 static bool isAdmissibleOrder(linalg::GenericOp linalgOp,
AffineMap order) {
498 OpOperand *lhs = linalgOp.getDpsInitOperand(0);
500 const auto iteratorTypes = linalgOp.getIteratorTypesArray();
502 unsigned loopId = llvm::cast<AffineDimExpr>(l).getPosition();
504 cast<linalg::IteratorTypeAttr>(linalgOp.getIteratorTypes()[loopId]);
512 return static_cast<int64_t
>(nest) >= linalgOp.getRank(lhs) - 1;
517 linalg::LinalgOp linalgOp,
521 for (
OpOperand *t : linalgOp.getDpsInputOperands()) {
522 Value tval = t->get();
526 AffineMap idxMap = linalgOp.getMatchingIndexingMap(t);
528 return !llvm::isa<AffineDimExpr>(exp);
530 if (!srcEnc || hasCompExpr)
542 assert(stt.isIdentity());
545 idxMap = idxMap.
compose(order);
556 unsigned lvl = llvm::cast<AffineDimExpr>(expr).getPosition();
557 lvlSeq.push_back(std::make_pair(lvl, lvlSeq.size()));
559 llvm::sort(lvlSeq, llvm::less_first());
561 llvm::to_vector(llvm::make_second_range(lvlSeq));
564 assert(!dimToLvl.isIdentity());
568 RankedTensorType dstTp = stt.withDimToLvl(dimToLvl).getRankedTensorType();
569 Value dst = ConvertOp::create(rewriter, tval.
getLoc(), dstTp, tval);
571 linalgOp->setOperand(t->getOperandNumber(), dst);
577 bufferization::DeallocTensorOp::create(rewriter, dst.
getLoc(), dst);
591 template <
typename AllocOp>
594 LogicalResult matchAndRewrite(AllocOp op,
603 maxDimCrds.reserve(stt.getDimRank());
605 for (int64_t dimSz : stt.getDimShape()) {
606 if (ShapedType::isDynamic(dimSz)) {
607 Value maxCrd = arith::SubIOp::create(rewriter, loc, dynSz.front(),
609 maxDimCrds.push_back(maxCrd);
610 dynSz = dynSz.drop_front();
612 maxDimCrds.push_back(
constantIndex(rewriter, loc, dimSz - 1));
616 ValueRange maxLvlCrds = stt.translateCrds(rewriter, loc, maxDimCrds,
617 CrdTransDirectionKind::dim2lvl);
618 auto lvlShape = stt.getLvlShape();
620 for (
unsigned i = 0, e = lvlShape.size(); i < e; i++) {
621 if (ShapedType::isDynamic(lvlShape[i])) {
622 Value sz = arith::AddIOp::create(rewriter, loc, maxLvlCrds[i],
624 dynLvlSzs.push_back(sz);
628 assert(dynSz.empty());
630 op->setOperands(dynLvlSzs);
631 op.getResult().setType(stt.getDemappedType());
635 Value t =
genRemap(rewriter, stt.getEncoding(), op.getResult());
641 struct TensorInsertDemapper
642 :
public DemapInsRewriter<TensorInsertDemapper, tensor::InsertOp> {
643 using DemapInsRewriter::DemapInsRewriter;
644 LogicalResult rewriteOp(tensor::InsertOp op, OpAdaptor adaptor,
651 ValueRange lvlCrd = stt.translateCrds(rewriter, loc, op.getIndices(),
652 CrdTransDirectionKind::dim2lvl);
653 auto insertOp = tensor::InsertOp::create(rewriter, loc, op.getScalar(),
654 adaptor.getDest(), lvlCrd);
656 Value out =
genRemap(rewriter, stt.getEncoding(), insertOp.getResult());
664 LogicalResult matchAndRewrite(AssembleOp op,
672 op, [&op, &stt]() { op.getResult().setType(stt.getDemappedType()); });
674 Value out =
genRemap(rewriter, stt.getEncoding(), op.getResult());
680 struct SparseDisassembleDemapper
681 :
public DemapInsRewriter<SparseDisassembleDemapper, DisassembleOp> {
682 using DemapInsRewriter::DemapInsRewriter;
683 LogicalResult rewriteOp(DisassembleOp op, OpAdaptor adaptor,
690 op.getTensorMutable().assign(adaptor.getTensor());
696 struct ForeachOpDemapper
697 :
public DemapInsRewriter<ForeachOpDemapper, ForeachOp> {
698 using DemapInsRewriter::DemapInsRewriter;
699 LogicalResult rewriteOp(ForeachOp op, OpAdaptor adaptor,
707 if (
auto constOp = op.getTensor().getDefiningOp<arith::ConstantOp>())
708 if (
auto attr = dyn_cast<SparseElementsAttr>(constOp.getValue()))
717 op.getTensorMutable().assign(adaptor.getTensor());
718 op.getInitArgsMutable().assign(adaptor.getInitArgs());
720 for (
auto r : op.getResults())
722 r.setType(stt->getDemappedType());
727 blockArgTps.push_back(srcStt.getElementType());
728 blockArgTps.append(adaptor.getInitArgs().getTypes().begin(),
729 adaptor.getInitArgs().getTypes().end());
730 Block *body = op.getBody();
733 for (
Type t : blockArgTps)
740 ValueRange dimCrds = srcStt.translateCrds(rewriter, loc, lvlCrds,
741 CrdTransDirectionKind::lvl2dim);
743 body->
getArguments().take_front(srcStt.getDimRank()), dimCrds);
746 unsigned numInitArgs = op.getInitArgs().size();
761 if (numInitArgs != 0) {
765 stt && !stt->isIdentity()) {
767 genDemap(rewriter, stt->getEncoding(), yield.getSingleResult());
768 YieldOp::create(rewriter, loc, y);
780 for (
auto [from, to] : llvm::zip(op.getResults(), outs))
793 patterns.add<GenericOpReinterpretMap, GenericOpScheduler>(
798 patterns.add<TensorAllocDemapper<bufferization::AllocTensorOp>,
799 TensorAllocDemapper<tensor::EmptyOp>, SparseAssembleDemapper,
800 SparseDisassembleDemapper, TensorInsertDemapper,
801 ForeachOpDemapper>(
patterns.getContext());
static Value genDemap(OpBuilder &builder, SparseTensorEncodingAttr enc, Value val)
static SmallVector< Value > remapValueRange(OpBuilder &rewriter, TypeRange types, ValueRange outs)
static AffineMap genReplaceDimToLvlMap(const InadmissInfo &info, AffineMap idxMap, SmallVector< utils::IteratorType > &itTps)
static std::optional< std::pair< ArrayAttr, ArrayAttr > > translateMap(linalg::GenericOp op, PatternRewriter &rewriter)
static Value genRemap(OpBuilder &builder, SparseTensorEncodingAttr enc, Value val)
static InadmissInfo collectInadmissInfo(AffineMap map, bool isOutput)
Affine binary operation expression.
A dimensional identifier appearing in an affine expression.
unsigned getPosition() const
See documentation for AffineExprVisitorBase.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
MLIRContext * getContext() const
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const
Sparse replace method.
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isIdentity() const
Returns true if this affine map is an identity affine map.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
BlockArgListType getArguments()
void eraseArgument(unsigned index)
Erase the argument at 'index' and remove it from the argument list.
BoolAttr getBoolAttr(bool value)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static IterationGraphSorter fromGenericOp(linalg::GenericOp genericOp)
Factory method that construct an iteration graph sorter for the given linalg.generic operation.
AffineMap sort(SortMask mask, Value ignored=nullptr)
Returns a permutation that represents the scheduled loop order.
Level getLvlRank() const
Returns the level-rank.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
bool hasAnySparseOperandOrResult(Operation *op)
Returns true iff MLIR operand has any sparse operand or result.
uint64_t Level
The type of level identifiers and level-ranks.
std::optional< SparseTensorType > tryGetSparseTensorType(Value val)
AffineMap inferLvlToDim(AffineMap dimToLvl, MLIRContext *context)
Given the dimToLvl map, infers the lvlToDim map, or returns empty Affine map when inference fails.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
bool hasAnyNonIdentityOperandsOrResults(Operation *op)
Returns true iff MLIR operation has any sparse tensor with non-identity dim2lvl maps.
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
SortMask
Iteration graph sorting mask,.
bool hasAnySparseResult(Operation *op)
Returns true iff MLIR operand has any sparse result.
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
void populateSparseReinterpretMap(RewritePatternSet &patterns, ReinterpretMapScope scope)
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
ReinterpretMapScope
Defines a scope for reinterpret map pass.
const FrozenRewritePatternSet & patterns
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...