29 : iterTypes(itTypes) {}
33 if (pickedDim ==
nullptr || pickIterType == iterTypes[expr.
getPosition()])
38 void setPickedIterType(utils::IteratorType iterType) {
39 pickIterType = iterType;
44 return llvm::cast<AffineDimExpr>(pickedDim);
57 utils::IteratorType pickIterType;
65 void visitDimExpr(
AffineDimExpr expr) { dims.push_back(expr); }
72 return static_cast<unsigned>(mask1) &
static_cast<unsigned>(mask2);
83 AffineMap IterationGraphSorter::topoSort() {
86 std::vector<unsigned> redIt;
87 std::vector<unsigned> parIt;
89 for (
unsigned i = 0; i < numLoops; i++) {
90 if (inDegree[i] == 0) {
91 if (iterTypes[i] == utils::IteratorType::reduction)
99 while (!redIt.empty() || !parIt.empty()) {
102 auto &it = !parIt.empty() ? parIt : redIt;
112 loopOrder.push_back(src);
115 for (
unsigned dst = 0; dst < numLoops; dst++) {
116 if (itGraph[src][dst] && --inDegree[dst] == 0) {
117 if (iterTypes[dst] == utils::IteratorType::reduction)
118 redIt.push_back(dst);
120 parIt.push_back(dst);
126 if (loopOrder.size() == numLoops)
138 genericOp.getNumDpsInits() == 1);
146 Value out = genericOp.getDpsInitOperand(0)->get();
148 genericOp.getIteratorTypesArray();
151 std::move(iterTypes), strategy);
154 IterationGraphSorter::IterationGraphSorter(
158 : ins(std::move(ins)), loop2InsLvl(std::move(loop2InsLvl)), out(out),
159 loop2OutLvl(loop2OutLvl), iterTypes(std::move(iterTypes)),
162 assert(loop2InsLvl.size() == ins.size());
164 assert(llvm::all_equal(llvm::map_range(
167 assert(llvm::all_of(llvm::zip(loop2InsLvl, ins), [](
auto mvPair) {
168 auto [m, v] = mvPair;
169 return m.
getNumResults() == cast<ShapedType>(v.getType()).getRank();
178 for (
auto &row : itGraph)
179 llvm::fill(row,
false);
182 llvm::fill(inDegree, 0);
185 for (
auto [in, map] : llvm::zip(ins, loop2InsLvl)) {
191 addConstraints(in, map);
197 addConstraints(out, loop2OutLvl);
203 void IterationGraphSorter::addConstraints(
Value t,
AffineMap loop2LvlMap) {
204 auto addIterOrdering = [
this](
unsigned f,
unsigned t) {
205 if (!itGraph[f][t] && f != t) {
206 itGraph[f][t] =
true;
212 AffineDimFinder finder(iterTypes);
213 finder.setPickedIterType(utils::IteratorType::reduction);
219 for (
Level lvl = 1; lvl < lvlRank; lvl++) {
223 if (llvm::isa<AffineDimExpr>(fa) || llvm::isa<AffineDimExpr>(ta)) {
226 AffineDimCollector fCollector;
227 fCollector.walkPostOrder(fa);
228 AffineDimCollector tCollector;
229 tCollector.walkPostOrder(ta);
231 for (
auto fd : fCollector.dims) {
232 for (
auto td : tCollector.dims) {
233 const unsigned f = fd.getPosition();
234 const unsigned t = td.getPosition();
235 addIterOrdering(f, t);
243 finder.walkPostOrder(fa);
247 finder.walkPostOrder(ta);
252 addIterOrdering(fldx, tldx);
254 AffineDimCollector fCollector;
255 fCollector.walkPostOrder(fa);
256 AffineDimCollector tCollector;
257 tCollector.walkPostOrder(ta);
260 for (
auto fd : fCollector.dims) {
261 const unsigned f = fd.getPosition();
262 addIterOrdering(f, fldx);
264 for (
auto td : tCollector.dims) {
265 const unsigned t = td.getPosition();
266 addIterOrdering(t, tldx);
271 for (
auto fd : fCollector.dims) {
272 const unsigned f = fd.getPosition();
275 for (
auto td : tCollector.dims) {
276 const unsigned t = td.getPosition();
279 addIterOrdering(f, t);
static bool includesDenseOutput(SortMask mask)
static bool includesAny(SortMask mask1, SortMask mask2)
static bool includesDenseInput(SortMask mask)
A dimensional identifier appearing in an affine expression.
unsigned getPosition() const
See documentation for AffineExprVisitorBase.
RetTy walkPostOrder(AffineExpr expr)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
unsigned getNumDims() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
static IterationGraphSorter fromGenericOp(linalg::GenericOp genericOp, sparse_tensor::LoopOrderingStrategy strategy)
Factory method that constructs an iteration graph sorter for the given linalg.generic operation with ...
unsigned getNumLoops() const
Returns the number of loops in the iteration graph.
AffineMap sort(SortMask mask, Value ignored=nullptr)
Returns a permutation that represents the scheduled loop order.
bool hasAnySparseOperandOrResult(Operation *op)
Returns true iff MLIR operand has any sparse operand or result.
uint64_t Level
The type of level identifiers and level-ranks.
LoopOrderingStrategy
Defines a strategy for loop ordering during sparse code generation.
@ kDefault
Default strategy (eagerly selects last loop in topological sort).
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
bool hasAnyNonIdentityOperandsOrResults(Operation *op)
Returns true iff MLIR operation has any sparse tensor with non-identity dim2lvl maps.
SortMask
Iteration graph sorting mask,.
Include the generated interface declarations.