33 template <
typename SubClass,
typename SourceOp>
36 using OpAdaptor =
typename SourceOp::Adaptor;
38 LogicalResult matchAndRewrite(SourceOp op,
45 for (
Value &in : deMappedIns) {
47 in = rewriter.
create<ReinterpretMapOp>(loc, stt->getDemappedType(), in);
53 OpAdaptor adaptor(deMappedIns, op);
54 LogicalResult status =
55 static_cast<const SubClass *
>(
this)->rewriteOp(op, adaptor, rewriter);
56 return changed ? success() : status;
62 explicit AffineDimCollector(
unsigned dimNum) : dims(dimNum){};
68 struct AffineExprAdmissibleVisitor
70 explicit AffineExprAdmissibleVisitor(
bool isOutput)
71 : admissible(true), isOutput(isOutput){};
87 operator bool() {
return admissible; }
97 using InadmissInfo = std::pair<BitVector, BitVector>;
109 AffineDimCollector collector(map.
getNumDims());
110 for (
unsigned lvl = 0, e = map.
getNumResults(); lvl < e; lvl++) {
111 AffineExprAdmissibleVisitor admissible(isOutput);
112 admissible.walkPostOrder(map.
getResult(lvl));
117 collector.walkPostOrder(map.
getResult(lvl));
120 ret.second = collector.dims;
155 auto [inAdLvls, usedDims] = info;
163 assert(lvl2Idx.getNumResults() <= idxMap.
getNumDims());
164 if (lvl2Idx.getNumResults() != idxMap.
getNumDims()) {
171 AffineDimCollector usedInLvl(idxMap.
getNumDims());
173 usedInLvl.walkPostOrder(e);
175 unsigned curUsedDimID = 0;
176 unsigned curUnusedDimID = lvl2Idx.getNumDims();
178 BitVector unused = usedInLvl.dims.flip();
179 for (
unsigned i = 0; i < idxMap.
getNumDims(); i++) {
183 results.push_back(lvl2Idx.getResult(curUsedDimID++));
186 AffineMap::get(lvl2Idx.getNumDims() + unused.count(), 0, results, ctx);
188 assert(lvl2Idx.getNumResults() == idxMap.
getNumDims());
195 unsigned curRepID = 0;
196 unsigned curOriID = inAdLvls.count();
201 for (
unsigned l : inAdLvls.set_bits()) {
211 AffineDimCollector collector(idxMap.
getNumDims());
212 collector.walkPostOrder(lvlExp);
214 assert(collector.dims.count() == 1);
215 transItTps.push_back(itTps[collector.dims.find_first()]);
218 for (
unsigned d = 0, e = idxMap.
getNumDims(); d < e; d++) {
219 if (usedDims.test(d)) {
223 results.push_back(lvl2Idx.getResult(d).replaceDims(dimRep));
229 transItTps.push_back(itTps[d]);
232 unsigned numDim = idxMap.
getNumDims() - usedDims.count() + inAdLvls.count();
234 itTps.assign(transItTps.begin(), transItTps.end());
242 static std::optional<std::pair<ArrayAttr, ArrayAttr>>
248 for (
unsigned i = 0, e = idxMapArray.size(); i < e; i++) {
251 if (stt && !stt->isIdentity()) {
254 idxMapArray[i] = dim2Lvl.
compose(idxMapArray[i]);
261 unsigned pos, int64_t lvlSz) {
262 if (!ShapedType::isDynamic(lvlSz)) {
270 cstMapping.try_emplace(divExp, c0);
274 cstMapping.try_emplace(modExp, lvlExp);
278 unsigned boundedNum = 0;
286 if (!stt || !stt->getEncoding())
289 unsigned tid = operand.getOperandNumber();
290 bool isOutput = &operand == op.getDpsInitOperand(0);
293 auto [inAdLvls, dimExprs] = inAdInfo;
294 for (
unsigned d : dimExprs.set_bits()) {
302 if (inAdLvls.count() != 0) {
307 unsigned position = 0;
308 for (
unsigned lvl : inAdLvls.set_bits()) {
309 int64_t lvlSz = lvlShape[lvl];
310 populateCstMapping(cstMapping, position, lvlSz);
317 for (
unsigned tid = 0, e = idxMapArray.size(); tid < e; tid++) {
318 AffineMap transMap = idxMapArray[tid].compose(lvl2Idx);
319 idxMapArray[tid] = transMap.
replace(
324 boundedNum += inAdLvls.count();
330 llvm::map_to_vector(itTps, [ctx](
auto itTp) ->
Attribute {
341 return builder.
create<ReinterpretMapOp>(val.
getLoc(), enc.withoutDimToLvl(),
348 return builder.
create<ReinterpretMapOp>(val.
getLoc(), enc, val);
354 assert(outs.size() == types.size());
355 for (
auto [r, t] : llvm::zip(ret, types))
356 if (r.getType() != t)
357 r = rewriter.
create<ReinterpretMapOp>(r.getLoc(), t, r);
368 struct GenericOpReinterpretMap
369 :
public DemapInsRewriter<GenericOpReinterpretMap, linalg::GenericOp> {
371 using DemapInsRewriter::DemapInsRewriter;
372 LogicalResult rewriteOp(linalg::GenericOp linalgOp, OpAdaptor adaptor,
376 if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasPureTensorSemantics() ||
385 linalgOp,
"the sparse kernel can not be sparsified.");
388 Value res = linalgOp.getResult(0);
390 auto [idxMap, itTp] = *transMap;
393 linalgOp.setIndexingMapsAttr(idxMap);
394 linalgOp.setIteratorTypesAttr(itTp);
396 linalgOp.getInputsMutable().assign(adaptor.getInputs());
397 linalgOp.getDpsInitsMutable().assign(adaptor.getOutputs());
398 res.
setType(adaptor.getOutputs()[0].getType());
402 if (stt && stt->hasEncoding()) {
412 LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
414 if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasPureTensorSemantics() ||
420 const StringRef sorted =
"sorted";
421 if (linalgOp->hasAttr(sorted))
425 bool isAdmissible =
false;
435 for (
const SortMask mask : allMasks) {
436 order = scheduler.sort(mask);
438 if (isAdmissibleOrder(linalgOp, order)) {
448 if (failed(resolveCycle(scheduler, linalgOp, rewriter))) {
450 linalgOp,
"the sparse kernel can not be scheduled: loop detected.");
457 linalgOp,
"the sparse kernel can not be scheduled.");
462 linalgOp->setAttr(sorted, rewriter.
getBoolAttr(
true));
471 ArrayAttr preItTypes = linalgOp.getIteratorTypesAttr();
473 curItTypes.reserve(preItTypes.size());
475 unsigned loopID = llvm::cast<AffineDimExpr>(expr).getPosition();
476 curItTypes.push_back(preItTypes[loopID]);
483 idxMap = idxMap.compose(order);
487 linalgOp.setIteratorTypesAttr(rewriter.
getArrayAttr(curItTypes));
495 static bool isAdmissibleOrder(linalg::GenericOp linalgOp,
AffineMap order) {
499 OpOperand *lhs = linalgOp.getDpsInitOperand(0);
501 const auto iteratorTypes = linalgOp.getIteratorTypesArray();
503 unsigned loopId = llvm::cast<AffineDimExpr>(l).getPosition();
505 cast<linalg::IteratorTypeAttr>(linalgOp.getIteratorTypes()[loopId]);
513 return static_cast<int64_t
>(nest) >= linalgOp.getRank(lhs) - 1;
518 linalg::LinalgOp linalgOp,
522 for (
OpOperand *t : linalgOp.getDpsInputOperands()) {
523 Value tval = t->get();
527 AffineMap idxMap = linalgOp.getMatchingIndexingMap(t);
529 return !llvm::isa<AffineDimExpr>(exp);
531 if (!srcEnc || hasCompExpr)
543 assert(stt.isIdentity());
546 idxMap = idxMap.
compose(order);
557 unsigned lvl = llvm::cast<AffineDimExpr>(expr).getPosition();
558 lvlSeq.push_back(std::make_pair(lvl, lvlSeq.size()));
560 llvm::sort(lvlSeq, llvm::less_first());
562 llvm::to_vector(llvm::make_second_range(lvlSeq));
565 assert(!dimToLvl.isIdentity());
569 RankedTensorType dstTp = stt.withDimToLvl(dimToLvl).getRankedTensorType();
572 linalgOp->setOperand(t->getOperandNumber(), dst);
578 rewriter.
create<bufferization::DeallocTensorOp>(dst.
getLoc(), dst);
592 template <
typename AllocOp>
595 LogicalResult matchAndRewrite(AllocOp op,
604 maxDimCrds.reserve(stt.getDimRank());
606 for (int64_t dimSz : stt.getDimShape()) {
607 if (ShapedType::isDynamic(dimSz)) {
610 maxDimCrds.push_back(maxCrd);
611 dynSz = dynSz.drop_front();
613 maxDimCrds.push_back(
constantIndex(rewriter, loc, dimSz - 1));
617 ValueRange maxLvlCrds = stt.translateCrds(rewriter, loc, maxDimCrds,
618 CrdTransDirectionKind::dim2lvl);
619 auto lvlShape = stt.getLvlShape();
621 for (
unsigned i = 0, e = lvlShape.size(); i < e; i++) {
622 if (ShapedType::isDynamic(lvlShape[i])) {
625 dynLvlSzs.push_back(sz);
629 assert(dynSz.empty());
642 struct TensorInsertDemapper
643 :
public DemapInsRewriter<TensorInsertDemapper, tensor::InsertOp> {
644 using DemapInsRewriter::DemapInsRewriter;
645 LogicalResult rewriteOp(tensor::InsertOp op, OpAdaptor adaptor,
652 ValueRange lvlCrd = stt.translateCrds(rewriter, loc, op.getIndices(),
653 CrdTransDirectionKind::dim2lvl);
654 auto insertOp = rewriter.
create<tensor::InsertOp>(
655 loc, op.getScalar(), adaptor.getDest(), lvlCrd);
657 Value out =
genRemap(rewriter, stt.getEncoding(), insertOp.getResult());
665 LogicalResult matchAndRewrite(AssembleOp op,
673 op, [&op, &stt]() { op.
getResult().setType(stt.getDemappedType()); });
681 struct SparseDisassembleDemapper
682 :
public DemapInsRewriter<SparseDisassembleDemapper, DisassembleOp> {
683 using DemapInsRewriter::DemapInsRewriter;
684 LogicalResult rewriteOp(DisassembleOp op, OpAdaptor adaptor,
691 op.getTensorMutable().assign(adaptor.getTensor());
697 struct ForeachOpDemapper
698 :
public DemapInsRewriter<ForeachOpDemapper, ForeachOp> {
699 using DemapInsRewriter::DemapInsRewriter;
700 LogicalResult rewriteOp(ForeachOp op, OpAdaptor adaptor,
708 if (
auto constOp = op.getTensor().getDefiningOp<arith::ConstantOp>())
709 if (
auto attr = dyn_cast<SparseElementsAttr>(constOp.getValue()))
718 op.getTensorMutable().assign(adaptor.getTensor());
719 op.getInitArgsMutable().assign(adaptor.getInitArgs());
723 r.setType(stt->getDemappedType());
728 blockArgTps.push_back(srcStt.getElementType());
729 blockArgTps.append(adaptor.getInitArgs().getTypes().begin(),
730 adaptor.getInitArgs().getTypes().end());
731 Block *body = op.getBody();
734 for (
Type t : blockArgTps)
741 ValueRange dimCrds = srcStt.translateCrds(rewriter, loc, lvlCrds,
742 CrdTransDirectionKind::lvl2dim);
744 body->
getArguments().take_front(srcStt.getDimRank()), dimCrds);
747 unsigned numInitArgs = op.getInitArgs().size();
762 if (numInitArgs != 0) {
766 stt && !stt->isIdentity()) {
768 genDemap(rewriter, stt->getEncoding(), yield.getSingleResult());
769 rewriter.
create<YieldOp>(loc, y);
781 for (
auto [from, to] : llvm::zip(op.
getResults(), outs))
794 patterns.
add<GenericOpReinterpretMap, GenericOpScheduler>(
799 patterns.
add<TensorAllocDemapper<bufferization::AllocTensorOp>,
800 TensorAllocDemapper<tensor::EmptyOp>, SparseAssembleDemapper,
801 SparseDisassembleDemapper, TensorInsertDemapper,
static Value genDemap(OpBuilder &builder, SparseTensorEncodingAttr enc, Value val)
static SmallVector< Value > remapValueRange(OpBuilder &rewriter, TypeRange types, ValueRange outs)
static AffineMap genReplaceDimToLvlMap(const InadmissInfo &info, AffineMap idxMap, SmallVector< utils::IteratorType > &itTps)
static std::optional< std::pair< ArrayAttr, ArrayAttr > > translateMap(linalg::GenericOp op, PatternRewriter &rewriter)
static Value genRemap(OpBuilder &builder, SparseTensorEncodingAttr enc, Value val)
static InadmissInfo collectInadmissInfo(AffineMap map, bool isOutput)
Affine binary operation expression.
A dimensional identifier appearing in an affine expression.
unsigned getPosition() const
See documentation for AffineExprVisitorBase.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
MLIRContext * getContext() const
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const
Sparse replace method.
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isIdentity() const
Returns true if this affine map is an identity affine map.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
BlockArgListType getArguments()
void eraseArgument(unsigned index)
Erase the argument at 'index' and remove it from the argument list.
BoolAttr getBoolAttr(bool value)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
OpOperand & getOpOperand(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
MutableArrayRef< OpOperand > getOpOperands()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
result_range getResults()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static IterationGraphSorter fromGenericOp(linalg::GenericOp genericOp)
Factory method that construct an iteration graph sorter for the given linalg.generic operation.
AffineMap sort(SortMask mask, Value ignored=nullptr)
Returns a permutation that represents the scheduled loop order.
Level getLvlRank() const
Returns the level-rank.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
bool hasAnySparseOperandOrResult(Operation *op)
Returns true iff MLIR operand has any sparse operand or result.
uint64_t Level
The type of level identifiers and level-ranks.
std::optional< SparseTensorType > tryGetSparseTensorType(Value val)
AffineMap inferLvlToDim(AffineMap dimToLvl, MLIRContext *context)
Given the dimToLvl map, infers the lvlToDim map, or returns empty Affine map when inference fails.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
bool hasAnyNonIdentityOperandsOrResults(Operation *op)
Returns true iff MLIR operation has any sparse tensor with non-identity dim2lvl maps.
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
SortMask
Iteration graph sorting mask,.
bool hasAnySparseResult(Operation *op)
Returns true iff MLIR operand has any sparse result.
Include the generated interface declarations.
void populateSparseReinterpretMap(RewritePatternSet &patterns, ReinterpretMapScope scope)
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
ReinterpretMapScope
Defines a scope for reinterpret map pass.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...