29 : iterTypes(itTypes) {}
33 if (pickedDim ==
nullptr || pickIterType == iterTypes[expr.
getPosition()])
38 void setPickedIterType(utils::IteratorType iterType) {
39 pickIterType = iterType;
44 return llvm::cast<AffineDimExpr>(pickedDim);
57 utils::IteratorType pickIterType;
65 void visitDimExpr(
AffineDimExpr expr) { dims.push_back(expr); }
72 return static_cast<unsigned>(mask1) &
static_cast<unsigned>(mask2);
83 AffineMap IterationGraphSorter::topoSort() {
86 std::vector<unsigned> redIt;
87 std::vector<unsigned> parIt;
89 for (
unsigned i = 0; i < numLoops; i++) {
90 if (inDegree[i] == 0) {
91 if (iterTypes[i] == utils::IteratorType::reduction)
99 while (!redIt.empty() || !parIt.empty()) {
102 auto &it = !parIt.empty() ? parIt : redIt;
103 auto src = it.back();
104 loopOrder.push_back(src);
107 for (
unsigned dst = 0; dst < numLoops; dst++) {
108 if (itGraph[src][dst] && --inDegree[dst] == 0) {
109 if (iterTypes[dst] == utils::IteratorType::reduction)
110 redIt.push_back(dst);
112 parIt.push_back(dst);
118 if (loopOrder.size() == numLoops)
130 genericOp.getNumDpsInits() == 1);
138 Value out = genericOp.getDpsInitOperand(0)->get();
140 genericOp.getIteratorTypesArray();
143 std::move(iterTypes));
146 IterationGraphSorter::IterationGraphSorter(
149 : ins(std::move(ins)), loop2InsLvl(std::move(loop2InsLvl)), out(out),
150 loop2OutLvl(loop2OutLvl), iterTypes(std::move(iterTypes)) {
152 assert(loop2InsLvl.size() == ins.size());
154 assert(llvm::all_equal(llvm::map_range(
157 assert(llvm::all_of(llvm::zip(loop2InsLvl, ins), [](
auto mvPair) {
158 auto [m, v] = mvPair;
159 return m.
getNumResults() == cast<ShapedType>(v.getType()).getRank();
168 for (
auto &row : itGraph)
169 llvm::fill(row,
false);
172 llvm::fill(inDegree, 0);
175 for (
auto [in, map] : llvm::zip(ins, loop2InsLvl)) {
181 addConstraints(in, map);
187 addConstraints(out, loop2OutLvl);
193 void IterationGraphSorter::addConstraints(
Value t,
AffineMap loop2LvlMap) {
194 auto addIterOrdering = [
this](
unsigned f,
unsigned t) {
195 if (!itGraph[f][t] && f != t) {
196 itGraph[f][t] =
true;
202 AffineDimFinder finder(iterTypes);
203 finder.setPickedIterType(utils::IteratorType::reduction);
209 for (
Level lvl = 1; lvl < lvlRank; lvl++) {
213 if (llvm::isa<AffineDimExpr>(fa) || llvm::isa<AffineDimExpr>(ta)) {
216 AffineDimCollector fCollector;
217 fCollector.walkPostOrder(fa);
218 AffineDimCollector tCollector;
219 tCollector.walkPostOrder(ta);
221 for (
auto fd : fCollector.dims) {
222 for (
auto td : tCollector.dims) {
223 const unsigned f = fd.getPosition();
224 const unsigned t = td.getPosition();
225 addIterOrdering(f, t);
233 finder.walkPostOrder(fa);
237 finder.walkPostOrder(ta);
242 addIterOrdering(fldx, tldx);
244 AffineDimCollector fCollector;
245 fCollector.walkPostOrder(fa);
246 AffineDimCollector tCollector;
247 tCollector.walkPostOrder(ta);
250 for (
auto fd : fCollector.dims) {
251 const unsigned f = fd.getPosition();
252 addIterOrdering(f, fldx);
254 for (
auto td : tCollector.dims) {
255 const unsigned t = td.getPosition();
256 addIterOrdering(t, tldx);
261 for (
auto fd : fCollector.dims) {
262 const unsigned f = fd.getPosition();
265 for (
auto td : tCollector.dims) {
266 const unsigned t = td.getPosition();
269 addIterOrdering(f, t);
static bool includesDenseOutput(SortMask mask)
static bool includesAny(SortMask mask1, SortMask mask2)
static bool includesDenseInput(SortMask mask)
A dimensional identifier appearing in an affine expression.
unsigned getPosition() const
See documentation for AffineExprVisitorBase.
RetTy walkPostOrder(AffineExpr expr)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
unsigned getNumDims() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
unsigned getNumLoops() const
Returns the number of loops in the iteration graph.
static IterationGraphSorter fromGenericOp(linalg::GenericOp genericOp)
Factory method that construct an iteration graph sorter for the given linalg.generic operation.
AffineMap sort(SortMask mask, Value ignored=nullptr)
Returns a permutation that represents the scheduled loop order.
bool hasAnySparseOperandOrResult(Operation *op)
Returns true iff MLIR operand has any sparse operand or result.
uint64_t Level
The type of level identifiers and level-ranks.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
bool hasAnyNonIdentityOperandsOrResults(Operation *op)
Returns true iff MLIR operation has any sparse tensor with non-identity dim2lvl maps.
SortMask
Iteration graph sorting mask,.
Include the generated interface declarations.