30 : iterTypes(itTypes) {}
34 if (pickedDim ==
nullptr || pickIterType == iterTypes[expr.
getPosition()])
39 void setPickedIterType(utils::IteratorType iterType) {
40 pickIterType = iterType;
45 return llvm::cast<AffineDimExpr>(pickedDim);
58 utils::IteratorType pickIterType;
66 void visitDimExpr(
AffineDimExpr expr) { dims.push_back(expr); }
73 return static_cast<unsigned>(mask1) &
static_cast<unsigned>(mask2);
84 AffineMap IterationGraphSorter::topoSort() {
87 std::vector<unsigned> redIt;
88 std::vector<unsigned> parIt;
90 for (
unsigned i = 0; i < numLoops; i++) {
91 if (inDegree[i] == 0) {
92 if (iterTypes[i] == utils::IteratorType::reduction)
100 while (!redIt.empty() || !parIt.empty()) {
103 auto &it = !parIt.empty() ? parIt : redIt;
104 auto src = it.back();
105 loopOrder.push_back(src);
108 for (
unsigned dst = 0; dst < numLoops; dst++) {
109 if (itGraph[src][dst] && --inDegree[dst] == 0) {
110 if (iterTypes[dst] == utils::IteratorType::reduction)
111 redIt.push_back(dst);
113 parIt.push_back(dst);
119 if (loopOrder.size() == numLoops)
131 genericOp.getNumDpsInits() == 1);
139 Value out = genericOp.getDpsInitOperand(0)->get();
141 genericOp.getIteratorTypesArray();
144 std::move(iterTypes));
147 IterationGraphSorter::IterationGraphSorter(
150 : ins(std::move(ins)), loop2InsLvl(std::move(loop2InsLvl)), out(out),
151 loop2OutLvl(loop2OutLvl), iterTypes(std::move(iterTypes)) {
153 assert(loop2InsLvl.size() == ins.size());
155 assert(llvm::all_equal(llvm::map_range(
156 loop2InsLvl, [](
AffineMap m) {
return m.getNumDims(); })));
158 assert(llvm::all_of(llvm::zip(loop2InsLvl, ins), [](
auto mvPair) {
159 auto [m, v] = mvPair;
160 return m.getNumResults() == cast<ShapedType>(v.getType()).getRank();
169 for (
auto &row : itGraph)
170 std::fill(row.begin(), row.end(),
false);
173 std::fill(inDegree.begin(), inDegree.end(), 0);
176 for (
auto [in, map] : llvm::zip(ins, loop2InsLvl)) {
182 addConstraints(in, map);
188 addConstraints(out, loop2OutLvl);
194 void IterationGraphSorter::addConstraints(
Value t,
AffineMap loop2LvlMap) {
195 auto addIterOrdering = [
this](
unsigned f,
unsigned t) {
196 if (!itGraph[f][t] && f != t) {
197 itGraph[f][t] =
true;
203 AffineDimFinder finder(iterTypes);
204 finder.setPickedIterType(utils::IteratorType::reduction);
210 for (
Level lvl = 1; lvl < lvlRank; lvl++) {
214 if (llvm::isa<AffineDimExpr>(fa) || llvm::isa<AffineDimExpr>(ta)) {
217 AffineDimCollector fCollector;
218 fCollector.walkPostOrder(fa);
219 AffineDimCollector tCollector;
220 tCollector.walkPostOrder(ta);
222 for (
auto fd : fCollector.dims) {
223 for (
auto td : tCollector.dims) {
224 const unsigned f = fd.getPosition();
225 const unsigned t = td.getPosition();
226 addIterOrdering(f, t);
234 finder.walkPostOrder(fa);
238 finder.walkPostOrder(ta);
243 addIterOrdering(fldx, tldx);
245 AffineDimCollector fCollector;
246 fCollector.walkPostOrder(fa);
247 AffineDimCollector tCollector;
248 tCollector.walkPostOrder(ta);
251 for (
auto fd : fCollector.dims) {
252 const unsigned f = fd.getPosition();
253 addIterOrdering(f, fldx);
255 for (
auto td : tCollector.dims) {
256 const unsigned t = td.getPosition();
257 addIterOrdering(t, tldx);
262 for (
auto fd : fCollector.dims) {
263 const unsigned f = fd.getPosition();
266 for (
auto td : tCollector.dims) {
267 const unsigned t = td.getPosition();
270 addIterOrdering(f, t);
static bool includesDenseOutput(SortMask mask)
static bool includesAny(SortMask mask1, SortMask mask2)
static bool includesDenseInput(SortMask mask)
A dimensional identifier appearing in an affine expression.
unsigned getPosition() const
See documentation for AffineExprVisitorBase.
RetTy walkPostOrder(AffineExpr expr)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
unsigned getNumLoops() const
Returns the number of loops in the iteration graph.
static IterationGraphSorter fromGenericOp(linalg::GenericOp genericOp)
Factory method that construct an iteration graph sorter for the given linalg.generic operation.
AffineMap sort(SortMask mask, Value ignored=nullptr)
Returns a permutation that represents the scheduled loop order.
bool hasAnySparseOperandOrResult(Operation *op)
Returns true iff MLIR operand has any sparse operand or result.
uint64_t Level
The type of level identifiers and level-ranks.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
bool hasAnyNonIdentityOperandsOrResults(Operation *op)
Returns true iff MLIR operation has any sparse tensor with non-identity dim2lvl maps.
SortMask
Iteration graph sorting mask,.
Include the generated interface declarations.