29 : iterTypes(itTypes) {}
33 if (pickedDim ==
nullptr || pickIterType == iterTypes[expr.
getPosition()])
38 void setPickedIterType(utils::IteratorType iterType) {
39 pickIterType = iterType;
44 return llvm::cast<AffineDimExpr>(pickedDim);
57 utils::IteratorType pickIterType;
65 void visitDimExpr(
AffineDimExpr expr) { dims.push_back(expr); }
72 return static_cast<unsigned>(mask1) &
static_cast<unsigned>(mask2);
83 AffineMap IterationGraphSorter::topoSort() {
86 std::vector<unsigned> redIt;
87 std::vector<unsigned> parIt;
89 for (
unsigned i = 0; i < numLoops; i++) {
90 if (inDegree[i] == 0) {
91 if (iterTypes[i] == utils::IteratorType::reduction)
99 while (!redIt.empty() || !parIt.empty()) {
102 auto &it = !parIt.empty() ? parIt : redIt;
112 loopOrder.push_back(src);
115 for (
unsigned dst = 0; dst < numLoops; dst++) {
116 if (itGraph[src][dst] && --inDegree[dst] == 0) {
117 if (iterTypes[dst] == utils::IteratorType::reduction)
118 redIt.push_back(dst);
120 parIt.push_back(dst);
126 if (loopOrder.size() == numLoops)
138 genericOp.getNumDpsInits() == 1);
146 Value out = genericOp.getDpsInitOperand(0)->get();
148 genericOp.getIteratorTypesArray();
151 std::move(iterTypes), strategy);
154 IterationGraphSorter::IterationGraphSorter(
159 : ins(std::move(insArg)), loop2InsLvl(std::move(loop2InsLvlArg)), out(out),
160 loop2OutLvl(loop2OutLvl), iterTypes(std::move(iterTypesArg)),
163 assert(loop2InsLvl.size() == ins.size());
165 assert(llvm::all_equal(llvm::map_range(
168 assert(llvm::all_of(llvm::zip(loop2InsLvl, ins), [](
auto mvPair) {
169 auto [m, v] = mvPair;
173 if (
auto shapedType = llvm::dyn_cast<ShapedType>(v.getType())) {
174 return !shapedType.hasRank() ||
175 (m.getNumResults() == shapedType.getRank());
187 for (
auto &row : itGraph)
188 llvm::fill(row,
false);
191 llvm::fill(inDegree, 0);
194 for (
auto [in, map] : llvm::zip(ins, loop2InsLvl)) {
200 addConstraints(in, map);
206 addConstraints(out, loop2OutLvl);
212 void IterationGraphSorter::addConstraints(
Value t,
AffineMap loop2LvlMap) {
213 auto addIterOrdering = [
this](
unsigned f,
unsigned t) {
214 if (!itGraph[f][t] && f != t) {
215 itGraph[f][t] =
true;
221 AffineDimFinder finder(iterTypes);
222 finder.setPickedIterType(utils::IteratorType::reduction);
228 for (
Level lvl = 1; lvl < lvlRank; lvl++) {
232 if (llvm::isa<AffineDimExpr>(fa) || llvm::isa<AffineDimExpr>(ta)) {
235 AffineDimCollector fCollector;
236 fCollector.walkPostOrder(fa);
237 AffineDimCollector tCollector;
238 tCollector.walkPostOrder(ta);
240 for (
auto fd : fCollector.dims) {
241 for (
auto td : tCollector.dims) {
242 const unsigned f = fd.getPosition();
243 const unsigned t = td.getPosition();
244 addIterOrdering(f, t);
252 finder.walkPostOrder(fa);
256 finder.walkPostOrder(ta);
261 addIterOrdering(fldx, tldx);
263 AffineDimCollector fCollector;
264 fCollector.walkPostOrder(fa);
265 AffineDimCollector tCollector;
266 tCollector.walkPostOrder(ta);
269 for (
auto fd : fCollector.dims) {
270 const unsigned f = fd.getPosition();
271 addIterOrdering(f, fldx);
273 for (
auto td : tCollector.dims) {
274 const unsigned t = td.getPosition();
275 addIterOrdering(t, tldx);
280 for (
auto fd : fCollector.dims) {
281 const unsigned f = fd.getPosition();
284 for (
auto td : tCollector.dims) {
285 const unsigned t = td.getPosition();
288 addIterOrdering(f, t);
static bool includesDenseOutput(SortMask mask)
static bool includesAny(SortMask mask1, SortMask mask2)
static bool includesDenseInput(SortMask mask)
A dimensional identifier appearing in an affine expression.
unsigned getPosition() const
See documentation for AffineExprVisitorBase.
RetTy walkPostOrder(AffineExpr expr)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
unsigned getNumDims() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
static IterationGraphSorter fromGenericOp(linalg::GenericOp genericOp, sparse_tensor::LoopOrderingStrategy strategy)
Factory method that constructs an iteration graph sorter for the given linalg.generic operation with ...
unsigned getNumLoops() const
Returns the number of loops in the iteration graph.
AffineMap sort(SortMask mask, Value ignored=nullptr)
Returns a permutation that represents the scheduled loop order.
bool hasAnySparseOperandOrResult(Operation *op)
Returns true iff MLIR operand has any sparse operand or result.
uint64_t Level
The type of level identifiers and level-ranks.
LoopOrderingStrategy
Defines a strategy for loop ordering during sparse code generation.
@ kDefault
Default strategy (eagerly selects last loop in topological sort).
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
bool hasAnyNonIdentityOperandsOrResults(Operation *op)
Returns true iff MLIR operation has any sparse tensor with non-identity dim2lvl maps.
SortMask
Iteration graph sorting mask,.
Include the generated interface declarations.