MLIR  16.0.0git
Sparsification.cpp
Go to the documentation of this file.
1 //===- Sparsification.cpp - Implementation of sparsification --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements converting sparse tensor types to actual sparse code.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "CodegenUtils.h"
14 
31 #include "mlir/IR/Matchers.h"
32 #include "mlir/IR/TensorEncoding.h"
33 #include "llvm/ADT/SmallBitVector.h"
34 
35 using namespace mlir;
36 using namespace mlir::sparse_tensor;
37 
38 //===----------------------------------------------------------------------===//
39 // Declarations of data structures.
40 //===----------------------------------------------------------------------===//
41 
42 namespace {
43 
44 // Iteration graph sorting.
45 enum SortMask {
46  kSparseOnly = 0x0,
47  kIncludeDense = 0x1,
48  kIncludeUndef = 0x2,
49  kIncludeAll = 0x3
50 };
51 
52 // Reduction kinds.
53 enum Reduction { kNoReduc, kSum, kProduct, kAnd, kOr, kXor, kCustom };
54 
55 // Code generation.
56 struct CodeGen {
57  CodeGen(SparsificationOptions o, MLIRContext *context, ValueRange tensors,
58  unsigned numTensors, unsigned numLoops, OpOperand *op, unsigned nest,
59  std::vector<unsigned> &ts)
60  : options(o),
61  loopEmitter(
62  tensors,
63  StringAttr::get(context, linalg::GenericOp::getOperationName()),
64  /*hasOutput=*/true,
65  /*isSparseOut=*/op != nullptr, ts),
66  sparseOut(op), outerParNest(nest), topSort(ts) {
67  if (op)
68  insChain = op->get();
69  }
70  /// Sparsification options.
72  /// Loop emitter helper class.
73  SparseTensorLoopEmitter loopEmitter;
74  /// Current reduction, updated during code generation. When indices of a
75  /// reduction are exhausted, all inner loops can use a scalarized reduction.
76  unsigned redExp = -1u;
77  Value redVal;
78  Reduction redKind = kNoReduc;
79  unsigned redCustom = -1u;
80  // Sparse tensor as output. Implemented either through direct injective
81  // insertion in lexicographic index order or through access pattern expansion
82  // in the innermost loop nest (`expValues` through `expCount`).
83  OpOperand *sparseOut;
84  unsigned outerParNest;
85  Value insChain; // bookkeeping for insertion chain
86  Value expValues;
87  Value expFilled;
88  Value expAdded;
89  Value expCount;
90  // Topsort (reference should remain in scope).
91  std::vector<unsigned> &topSort;
92 
93  ArrayRef<unsigned> getLoopCurStack() const {
94  ArrayRef<unsigned> topSortRef = topSort;
95  return topSortRef.slice(0, loopEmitter.getCurrentDepth());
96  }
97 
98  Value getLoopIdxValue(size_t loopIdx) const {
99  for (unsigned lv = 0; lv < topSort.size(); lv++)
100  if (topSort[lv] == loopIdx)
101  return loopEmitter.getLoopIV(lv);
102 
103  llvm_unreachable("invalid loop index");
104  }
105 };
106 
107 /// A helper class that visits an affine expression and tries to find an
108 /// AffineDimExpr to which the corresponding iterator from a GenericOp matches
109 /// the desired iterator type.
110 class AffineDimFinder : public AffineExprVisitor<AffineDimFinder> {
111 public:
112  explicit AffineDimFinder(linalg::GenericOp op)
113  : iterTypes(op.getIteratorTypesArray()) {}
114  void visitDimExpr(AffineDimExpr expr) {
115  if (pickedDim == nullptr || pickIterType == iterTypes[expr.getPosition()]) {
116  pickedDim = expr;
117  }
118  }
119 
120  /// Set the desired iterator type that we want to pick.
121  void setPickedIterType(utils::IteratorType iterType) {
122  pickIterType = iterType;
123  }
124 
125  /// Get the desired AffineDimExpr.
126  AffineDimExpr getDimExpr() const { return pickedDim.cast<AffineDimExpr>(); }
127 
128 private:
129  /// The picked AffineDimExpr after visit.
130  AffineExpr pickedDim;
131  /// The iterator type that we want.
132  utils::IteratorType pickIterType;
133  /// The mapping between dim=>iterator type.
135 };
136 } // namespace
137 
138 //===----------------------------------------------------------------------===//
139 // Sparse compiler analysis methods.
140 //===----------------------------------------------------------------------===//
141 
142 /// Determines if affine expression is invariant.
144  unsigned ldx, bool &atLevel) {
145  switch (a.getKind()) {
146  case AffineExprKind::DimId: {
147  unsigned idx = a.cast<AffineDimExpr>().getPosition();
148  if (idx == ldx) {
149  atLevel = true;
150  // Must be invariant if we are at the level.
151  return true;
152  }
153  bool isInvariant = false;
154  for (unsigned loop : loopStack) {
155  isInvariant = (loop == idx);
156  if (isInvariant)
157  break;
158  }
159  return isInvariant;
160  }
161  case AffineExprKind::Add:
162  case AffineExprKind::Mul: {
163  auto binOp = a.cast<AffineBinaryOpExpr>();
164  return isInvariantAffine(binOp.getLHS(), loopStack, ldx, atLevel) &&
165  isInvariantAffine(binOp.getRHS(), loopStack, ldx, atLevel);
166  }
167  default: {
168  assert(a.isa<AffineConstantExpr>());
169  return true;
170  }
171  }
172 }
173 
174 /// Determines if affine expression is invariant.
175 static bool isInvariantAffine(const CodeGen &codegen, AffineExpr a,
176  unsigned ldx, bool &atLevel) {
177  return isInvariantAffine(a, codegen.getLoopCurStack(), ldx, atLevel);
178 }
179 
180 /// Helper method to construct a permuted dimension ordering
181 /// that adheres to the given topological sort.
182 static AffineMap permute(const Merger &merger, MLIRContext *context,
183  AffineMap m, ArrayRef<unsigned> topSort) {
184  assert(m.getNumDims() + merger.getNumFilterLoops() == topSort.size() &&
185  "TopoSort/AffineMap size mismatch");
186  // Construct the inverse of `m`; to avoid the asymptotic complexity
187  // of calling `m.getPermutedPosition` repeatedly.
189  unsigned numResults = m.getNumResults();
190  BitVector worklist(numResults, true);
191  unsigned loopDepth = 1;
192 
193  // Construct the permutation.
194  while (worklist.any() && loopDepth <= topSort.size()) {
195  unsigned preSize = perm.size();
196  for (auto dim : worklist.set_bits()) {
197  bool atLevel = false;
198  if (m.getResult(dim).isa<AffineConstantExpr>() ||
199  (isInvariantAffine(m.getResult(dim), topSort.slice(0, loopDepth),
200  topSort[loopDepth - 1], atLevel) &&
201  atLevel)) {
202  // If the matching affine is constant expression or just become
203  // invariant. We can visit the dimension now without breaking the
204  // topSort constraint.
205  perm.push_back(dim);
206  }
207  }
208 
209  // Removes resolved dimension.
210  for (unsigned i = preSize, e = perm.size(); i < e; i++)
211  worklist.reset(perm[i]);
212 
213  // Tries to entering the next loop level.
214  loopDepth += 1;
215  }
216 
217  assert(perm.size() == numResults);
218  return AffineMap::getPermutationMap(perm, context);
219 }
220 
221 /// Helper method to inspect affine expressions. Rejects cases where the
222 /// same index is used more than once. Also rejects compound affine
223 /// expressions in sparse dimensions.
224 /// filterIdx stores the current filter loop idx should be used for the next
225 /// compound affine sparse level, and it will be incremented by one when
226 /// used.
227 static bool findAffine(Merger &merger, unsigned tensor, unsigned dim,
228  AffineExpr a, DimLevelType dlt, unsigned &filterLdx,
229  bool setLvlFormat = true) {
230  switch (a.getKind()) {
231  case AffineExprKind::DimId: {
232  unsigned idx = a.cast<AffineDimExpr>().getPosition();
233  if (!isUndefDLT(merger.getDimLevelType(tensor, idx)))
234  return false; // used more than once
235 
236  if (setLvlFormat)
237  merger.setDimAndDimLevelType(tensor, idx, dim, dlt);
238  return true;
239  }
240  case AffineExprKind::Add:
241  case AffineExprKind::Mul:
243  if (!isDenseDLT(dlt) && setLvlFormat) {
244  assert(isUndefDLT(merger.getDimLevelType(tensor, filterLdx)));
245  // Use a filter loop for sparse affine expression.
246  merger.setDimAndDimLevelType(tensor, filterLdx++, dim, dlt);
247  }
248 
249  if (auto binOp = a.dyn_cast<AffineBinaryOpExpr>()) {
250  // We do not set dim level format for affine expresssion like d0 + d1 on
251  // either loop index at d0 or d1.
252  // We continue the recursion merely to check whether current affine is
253  // admissible or not.
254  return findAffine(merger, tensor, dim, binOp.getLHS(), dlt, filterLdx,
255  false) &&
256  findAffine(merger, tensor, dim, binOp.getRHS(), dlt, filterLdx,
257  false);
258  }
259  // Falls through when it is a constant Affine
260  return true;
261  }
262  default:
263  return false;
264  }
265 }
266 
267 /// Get the total number of compound affine expressions in affineMap that are
268 /// attached to the given tensor. For the following inputs:
269 ///
270 /// affineMap = (d0, d1, d2) => (d0 + d1, d2)
271 /// tensor = ["compressed", "compressed"]
272 ///
273 /// Returns 1 (because the first level is compressed and its corresponding
274 /// affineMap is d0 + d1)
275 static unsigned getNumCompoundAffineOnSparseDims(AffineMap affineMap,
276  Value tensor) {
277  unsigned num = 0;
278  auto enc = getSparseTensorEncoding(tensor.getType());
279  if (enc) {
280  ArrayRef<AffineExpr> exps = affineMap.getResults();
281  for (unsigned rank = 0; rank < exps.size(); rank++) {
282  auto aidx = toOrigDim(enc, rank);
283  auto affine = exps[aidx];
284  if (!affine.isa<AffineDimExpr>())
285  if (!isDenseDLT(getDimLevelType(enc, rank)))
286  num++;
287  }
288  }
289 
290  return num;
291 }
292 
293 /// Get the total number of compound affine expressions attached on a sparse
294 /// level in the given GenericOp.
295 static unsigned getNumCompoundAffineOnSparseDims(linalg::GenericOp op) {
296  unsigned num = 0;
297  for (OpOperand &t : op->getOpOperands())
298  num += getNumCompoundAffineOnSparseDims(op.getMatchingIndexingMap(&t),
299  t.get());
300  return num;
301 }
302 
303 /// Helper method to inspect sparse encodings in the tensor types.
304 /// Fills the per-dimension sparsity information for all tensors.
305 /// Returns true if the sparse annotations and affine subscript
306 /// expressions of all tensors are admissible. Returns false if
307 /// no annotations are found or inadmissible constructs occur.
308 static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) {
309  bool annotated = false;
310  unsigned filterLdx = merger.getFilterLoopStartingIdx();
311  for (OpOperand &t : op->getOpOperands()) {
312  auto map = op.getMatchingIndexingMap(&t);
313  auto enc = getSparseTensorEncoding(t.get().getType());
314  if (enc)
315  annotated = true;
316  assert(map.getNumResults() == op.getRank(&t));
317 
318  for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
319  unsigned tensor = t.getOperandNumber();
320  AffineExpr a = map.getResult(toOrigDim(enc, d));
321  if (!findAffine(merger, tensor, d, a, getDimLevelType(enc, d), filterLdx))
322  return false; // inadmissible affine expression
323  }
324  }
325  assert(filterLdx == merger.getNumLoops());
326  return annotated;
327 }
328 
329 /// A helper to compute a topological sort. O(n^2) time complexity
330 /// as we use adj matrix for the graph.
331 /// The sorted result will put the first Reduction iterator to the
332 /// latest possible index.
333 static bool topSortOptimal(unsigned n,
334  ArrayRef<utils::IteratorType> iteratorTypes,
335  const Merger &merger, std::vector<unsigned> &topSort,
336  std::vector<unsigned> &inDegree,
337  std::vector<std::vector<bool>> &adjM) {
338  std::vector<unsigned> redIt; // reduce iterator with 0 degree
339  std::vector<unsigned> parIt; // parallel iterator with 0 degree
340  std::vector<unsigned> filterIt; // filter loop with 0 degree
341  for (unsigned i = 0; i < n; i++) {
342  if (inDegree[i] == 0) {
343  if (merger.isFilterLoop(i))
344  filterIt.push_back(i);
345  else if (linalg::isReductionIterator(iteratorTypes[i]))
346  redIt.push_back(i);
347  else
348  parIt.push_back(i);
349  }
350  }
351 
352  while (!redIt.empty() || !parIt.empty() || !filterIt.empty()) {
353  // We always choose in order of filter loop -> parallel loop -> reduction
354  // loop because
355  // 1. Putting reduction loop early might make the loop sequence
356  // inadmissible.
357  // 2. Filter loops should be put as early as possible for better
358  // performance, since only one (if any) iteration will carry the
359  // computation. E.g., for (1 to N)
360  // for (1 to M)
361  // for (1 to K)
362  // if (xxx)
363  // O(X) computation => O(NMK+NMX) time complexity
364  //
365  // By putting the filter loop one level up, we got
366  //
367  // for (1 to N)
368  // for (1 to K)
369  // if (xxx)
370  // for (1 to M)
371  // O(X) computation => O(NK+NMX) time complexity
372  auto &it = !filterIt.empty() ? filterIt : (!parIt.empty() ? parIt : redIt);
373  auto src = it.back();
374  topSort.push_back(src);
375  it.pop_back();
376  // Update in-degree, and push 0-degree node into worklist.
377  for (unsigned dst = 0; dst < n; dst++) {
378  if (adjM[src][dst] && --inDegree[dst] == 0) {
379  if (merger.isFilterLoop(dst))
380  filterIt.push_back(dst);
381  else if (linalg::isReductionIterator(iteratorTypes[dst]))
382  redIt.push_back(dst);
383  else
384  parIt.push_back(dst);
385  }
386  }
387  }
388  return topSort.size() == n;
389 }
390 
391 /// Helper method to add all constraints from the indices in one affine
392 /// expression before all indices in the other affine expression. For
393 /// example i0+i1 < i2+i3+1 yields i0<i2, i0<i3, i1<i2, and i1<i3.
394 /// The affine expression `a` is empty iff `fidx` have a value, leading to
395 /// b = (i0 + i1) < fidx => i0 < fidx, i1 < fidx.
396 /// The affine expression `b` is empty iff `tidx` have a value, leading to
397 /// tidx < a = (i0 + i1) => tidx < i0, tidx < i1.
398 static void addAffineOrderings(std::vector<std::vector<bool>> &adjM,
399  std::vector<unsigned> &inDegree, AffineExpr a,
401  Optional<unsigned> tidx) {
402  if (!a && !b) {
403  // Recursion leaf.
404  assert(fidx && tidx);
405  unsigned f = *fidx, t = *tidx;
406  if (!adjM[f][t]) {
407  adjM[f][t] = true;
408  inDegree[t]++;
409  }
410  return;
411  }
412  // Picks an affine expression and expand (recurse into) it.
413  auto toExpand = a ? a : b;
414  switch (toExpand.getKind()) {
415  case AffineExprKind::DimId: {
416  auto idx = toExpand.cast<AffineDimExpr>().getPosition();
417  if (toExpand == a)
418  addAffineOrderings(adjM, inDegree, AffineExpr(), b, idx, tidx);
419  else // toExpand == b
420  addAffineOrderings(adjM, inDegree, a, AffineExpr(), fidx, idx);
421  break;
422  }
423  case AffineExprKind::Add:
424  case AffineExprKind::Mul: {
425  auto binOp = toExpand.cast<AffineBinaryOpExpr>();
426  if (toExpand == a) {
427  addAffineOrderings(adjM, inDegree, binOp.getLHS(), b, fidx, tidx);
428  addAffineOrderings(adjM, inDegree, binOp.getRHS(), b, fidx, tidx);
429  } else {
430  addAffineOrderings(adjM, inDegree, a, binOp.getLHS(), fidx, tidx);
431  addAffineOrderings(adjM, inDegree, a, binOp.getRHS(), fidx, tidx);
432  }
433  break;
434  }
435  default:
436  break;
437  }
438 }
439 
440 static void tryLoosenAffineDenseConstraints(linalg::GenericOp op,
441  Optional<unsigned> &fldx,
442  AffineExpr &fa,
443  Optional<unsigned> &tldx,
444  AffineExpr &ta) {
445  // We use a heuristic here to only pick one dim expression from each
446  // compound affine expression to establish the order between two dense
447  // dimensions.
448  if (!tldx) {
449  AffineDimFinder finder(op);
450  // NOTE: The ordering can only be loosen when the destination level is
451  // dense (when !tldx), for [dense, sparse] -> (d0 + d1, d2), we still
452  // require both d0 < d2 and d1 < d2 to ensure correct ordering (i.e.,
453  // no ordering like d0->d2->d1).
454  // TODO: this is obviously a sub optimal solution.
455  if (!fldx && !fa.isa<AffineConstantExpr>()) {
456  // Heuristic: we prefer parallel loop for lhs to reduce the chance
457  // we add reduce < parallel ordering.
458  finder.setPickedIterType(utils::IteratorType::parallel);
459  finder.walkPostOrder(fa);
460  fa = finder.getDimExpr();
461  fldx = finder.getDimExpr().getPosition();
462  }
463  if (!ta.isa<AffineConstantExpr>()) {
464  // Heuristic: we prefer reduction loop for rhs to reduce the chance
465  // addint reduce < parallel ordering.
466  finder.setPickedIterType(utils::IteratorType::reduction);
467  finder.walkPostOrder(ta);
468  ta = finder.getDimExpr();
469  tldx = finder.getDimExpr().getPosition();
470  }
471  }
472 }
473 
474 /// Computes a topologically sorted iteration graph for the linalg
475 /// operation. Ensures all tensors are visited in natural index order. This
476 /// is essential for sparse storage formats since these only support access
477 /// along fixed dimensions. Even for dense storage formats, however, the
478 /// natural index order yields innermost unit-stride access with better
479 /// spatial locality.
480 static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
481  std::vector<unsigned> &topSort, unsigned mask,
482  OpOperand *skip = nullptr) {
483  // Set up an n x n from/to adjacency matrix of the iteration graph
484  // for the implicit loop indices i_0 .. i_n-1.
485  unsigned n = merger.getNumLoops();
486  std::vector<std::vector<bool>> adjM(n, std::vector<bool>(n, false));
487  std::vector<unsigned> inDegree(n, 0); // in-degree of each node.
488  auto iteratorTypes = op.getIteratorTypesArray();
489  // Iterate over the indexing maps of every tensor in the tensor expression.
490  for (OpOperand &t : op->getOpOperands()) {
491  // Get map and encoding.
492  auto map = op.getMatchingIndexingMap(&t);
493  auto enc = getSparseTensorEncoding(t.get().getType());
494  assert(map.getNumDims() + getNumCompoundAffineOnSparseDims(op) == n);
495  // Skip dense tensor constraints when not requested.
496  if (!(mask & SortMask::kIncludeDense) && !enc)
497  continue;
498  // Each tensor expression and optional dimension ordering (row-major
499  // by default) puts an ordering constraint on the loop indices. For
500  // example, the tensor expresion A_ijk forces the ordering i < j < k
501  // on the loop indices if no explicit dimension ordering is given.
502  for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
503  AffineExpr ta = map.getResult(toOrigDim(enc, d));
504  Optional<unsigned> tldx = merger.getLoopIdx(t.getOperandNumber(), d);
505 
506  // Filter loops should be constructed after all the dependent loops,
507  // i.e., d0 + d1 < filter_loop(d0 + d1)
508  if (tldx && merger.isFilterLoop(tldx.value())) {
509  assert(!ta.isa<AffineDimExpr>() &&
510  !isDenseDLT(getDimLevelType(enc, d)));
511  addAffineOrderings(adjM, inDegree, ta, AffineExpr(), llvm::None, tldx);
512  // Now that the ordering of affine expression is captured by filter
513  // loop idx, we only need to ensure the affine ordering against filter
514  // loop. Thus, we reset the affine express to nil here to mark it as
515  // resolved.
516  ta = AffineExpr();
517  }
518 
519  // Skip tensor during cycle resolution, though order between filter loop
520  // and dependent loops need to be guaranteed unconditionally.
521  if (&t == skip)
522  continue;
523 
524  if (d > 0) {
525  AffineExpr fa = map.getResult(toOrigDim(enc, d - 1));
526  Optional<unsigned> fldx =
527  merger.getLoopIdx(t.getOperandNumber(), d - 1);
528 
529  // Applying order constraints on every pair of dimExpr between two
530  // compound affine expressions can sometime too strict:
531  // E.g, for [dense, dense] -> (d0 + d1, d2 + d3).
532  // It is totally fine to have loop sequence d0->d2->d1->d3 instead of
533  // requiring d0 < d2, d1 < d2, d0 < d3, d1 < d3.
534  if (!(mask & SortMask::kIncludeDense))
535  tryLoosenAffineDenseConstraints(op, fldx, fa, tldx, ta);
536 
537  // (d0 + d1) < (d2 + d3), or
538  // filter_loop_d-1 < (d2 + d3), or
539  // (d0 + d1) < filter_loop_d, or
540  // filter_loop_d-1 < filter_loop_d depending on whether fa/ta is reset
541  // above.
542  addAffineOrderings(adjM, inDegree, fa, ta, fldx, tldx);
543  }
544  }
545  // Push unrelated loops into sparse iteration space, so these
546  // will be skipped more often.
547  if (mask & SortMask::kIncludeUndef) {
548  unsigned tensor = t.getOperandNumber();
549  for (unsigned i = 0; i < n; i++)
550  if (isCompressedDLT(merger.getDimLevelType(tensor, i)) ||
551  isSingletonDLT(merger.getDimLevelType(tensor, i))) {
552  for (unsigned j = 0; j < n; j++)
553  if (isUndefDLT(merger.getDimLevelType(tensor, j))) {
554  adjM[i][j] = true;
555  inDegree[j]++;
556  }
557  } else {
558  assert(isDenseDLT(merger.getDimLevelType(tensor, i)) ||
559  isUndefDLT(merger.getDimLevelType(tensor, i)));
560  }
561  }
562  }
563  // Topologically sort the iteration graph to determine loop order.
564  // Report failure for a cyclic iteration graph.
565  topSort.clear();
566  topSort.reserve(n);
567  return topSortOptimal(n, iteratorTypes, merger, topSort, inDegree, adjM);
568 }
569 
570 /// Returns true if tensor materializes uninitialized into the computation.
571 static bool isMaterializing(Value val) {
572  return val.getDefiningOp<tensor::EmptyOp>() ||
573  val.getDefiningOp<bufferization::AllocTensorOp>();
574 }
575 
576 /// Returns true when the tensor expression is admissible for codegen.
577 /// Since all sparse input tensors are admissible, we just need to check
578 /// whether the out tensor in the tensor expression codegen is admissible.
579 /// Sets `sparseOut` to the tensor and `outerParNest` to the outer injective
580 /// nesting depth when a "truly dynamic" sparse tensor output occurs.
581 static bool isAdmissibleTensorExp(Merger &merger, linalg::GenericOp op,
582  std::vector<unsigned> &topSort, unsigned exp,
583  OpOperand **sparseOut,
584  unsigned &outerParNest) {
585  OpOperand *lhs = op.getDpsInitOperand(0);
586  unsigned tensor = lhs->getOperandNumber();
587  auto enc = getSparseTensorEncoding(lhs->get().getType());
588  // An non-annotated output tensor is assumed dense, and becomes a random
589  // access n-dim memref. Admissible since insertions cannot occur.
590  if (!enc)
591  return true;
592  // An all-dense annotated "sparse" output tensor becomes a linearized random
593  // access 1-dim memref. Also admissible since insertions cannot occur.
594  bool allDense = true;
595  unsigned numLoops = merger.getNumLoops(); // numNativeLoops + numFilterLoops
596  for (unsigned i = 0; i < merger.getNumLoops(); i++)
597  if (isCompressedDLT(merger.getDimLevelType(tensor, i)) ||
598  isSingletonDLT(merger.getDimLevelType(tensor, i))) {
599  allDense = false;
600  break;
601  } else {
602  assert(isDenseDLT(merger.getDimLevelType(tensor, i)) ||
603  isUndefDLT(merger.getDimLevelType(tensor, i)));
604  }
605  if (allDense)
606  return true;
607 
608  // TODO: support compound affine expression on sparse output.
609  if (getNumCompoundAffineOnSparseDims(op.getMatchingIndexingMap(lhs),
610  lhs->get()) != 0)
611  return false;
612 
613  // A tensor expression with a sparse output tensor that changes its values
614  // but not its nonzero structure, an operation called "simply dynamic" in
615  // [Bik96,Ch9], is also admissible without special codegen.
616  if (merger.isSingleCondition(tensor, exp))
617  return true;
618 
619  // Accept "truly dynamic" if the output tensor materializes uninitialized
620  // into the computation and insertions occur in lexicographic index order.
621  if (isMaterializing(lhs->get())) {
622  auto iteratorTypes = op.getIteratorTypesArray();
623  unsigned nest = 0;
624  for (unsigned i = 0; i < numLoops; i++) {
625  if (!merger.isFilterLoop(topSort[i])) {
626  // We only count non-filter loops as filter loops should be considered
627  // as a special type of parallel loops.
628  if (linalg::isReductionIterator(iteratorTypes[topSort[i]]))
629  break; // terminate at first reduction
630  nest++;
631  }
632  }
633  // Determine admissible dynamic insertion situations:
634  // (1) fully injective, since there are no reductions,
635  // (2) admissible 1-d expansion in innermost dimension.
636  if (nest >= op.getRank(lhs) - 1) {
637  *sparseOut = lhs;
638  outerParNest = nest;
639  return true;
640  }
641  }
642  return false;
643 }
644 
645 //===----------------------------------------------------------------------===//
646 // Sparse compiler synthesis methods (reductions).
647 //===----------------------------------------------------------------------===//
648 
649 /// Maps operation to reduction.
650 static Reduction getReduction(Kind kind) {
651  switch (kind) {
652  case Kind::kAddF:
653  case Kind::kAddC:
654  case Kind::kAddI:
655  case Kind::kSubF:
656  case Kind::kSubC:
657  case Kind::kSubI:
658  return kSum;
659  case Kind::kMulF:
660  case Kind::kMulC:
661  case Kind::kMulI:
662  return kProduct;
663  case Kind::kAndI:
664  return kAnd;
665  case Kind::kOrI:
666  return kOr;
667  case Kind::kXorI:
668  return kXor;
669  case Kind::kReduce:
670  return kCustom;
671  default:
672  llvm_unreachable("unexpected reduction operator");
673  }
674 }
675 
676 /// Updates scalarized reduction value.
677 static void updateReduc(Merger &merger, CodeGen &codegen, Value reduc) {
678  assert(codegen.redKind != kNoReduc);
679  codegen.redVal = merger.exp(codegen.redExp).val = reduc;
680 }
681 
682 /// Extracts identity from custom reduce.
684  return dyn_cast<sparse_tensor::ReduceOp>(op).getIdentity();
685 }
686 
687 //===----------------------------------------------------------------------===//
688 // Sparse compiler synthesis methods (statements and expressions).
689 //===----------------------------------------------------------------------===//
690 
691 /// Generates loop boundary statements (entering/exiting loops). The function
692 /// passes and updates the reduction value.
694  CodeGen &codegen, Merger &merger,
696  callback) {
697  SmallVector<Value> reduc;
698  if (codegen.redVal)
699  reduc.push_back(codegen.redVal);
700  if (codegen.expValues)
701  reduc.push_back(codegen.expCount);
702  if (codegen.insChain)
703  reduc.push_back(codegen.insChain);
704 
705  auto r = callback(reduc);
706 
707  // Callback should do in-place update on reduction value vector.
708  unsigned i = 0;
709  if (codegen.redVal)
710  updateReduc(merger, codegen, reduc[i++]);
711  if (codegen.expValues)
712  codegen.expCount = reduc[i++];
713  if (codegen.insChain)
714  codegen.insChain = reduc[i];
715 
716  return r;
717 }
718 
719 /// Local bufferization of all dense and sparse data structures.
720 static void genBuffers(Merger &merger, CodeGen &codegen, OpBuilder &builder,
721  linalg::GenericOp op) {
722  Location loc = op.getLoc();
723  assert(op.getNumOperands() == op.getNumDpsInputs() + 1);
724 
725  codegen.loopEmitter.initializeLoopEmit(
726  builder, loc,
727  /// Generates buffer for the output tensor.
728  /// Note that all sparse kernels assume that when all elements are written
729  /// to (viz. x(i) = y(i) * z(i)), the output buffer is already initialized
730  /// to all zeroes and only nonzeroes values are computed and written out.
731  /// For updates (viz. x(i) += y(i) * z(i)), only nonzeroes values are used
732  /// for the updates and no assumption on the original contents of the
733  /// output buffer is necessary.
734  [&op](OpBuilder &builder, Location loc, Value memref,
735  Value tensor) -> Value {
736  // Must not be a sparse tensor.
737  assert(!getSparseTensorEncoding(tensor.getType()));
738  OpOperand *lhs = op.getDpsInitOperand(0);
739  // Two output tensors references should pointed to the same object.
740  assert(lhs->get() == tensor);
741  bool isInit = op.isInitTensor(lhs);
742  // An output tensor can simply materialize from the buffer of the tensor
743  // that appears in the outs() clause. For updates, this has the
744  // advantage that only the nonzero value are involved in the
745  // computation, keeping the operation O(nnz). In all other cases, we are
746  // forced to zero out the buffer to enforce the assumption above, which
747  // may negatively impact running complexity (viz. O(n^2 + nnz) vs.
748  // O(nnz) for matrices).
749  // TODO: use better analysis to avoid zeroing out the buffer?
750  Value init = memref;
751  if (!isInit) {
752  Value zero = constantZero(builder, loc,
753  getElementTypeOrSelf(tensor.getType()));
754  builder.create<linalg::FillOp>(loc, ValueRange{zero},
755  ValueRange{init});
756  }
757  return init;
758  });
759 }
760 
761 /// Generates index for load/store on sparse tensor.
762 static Value genIndex(CodeGen &codegen, linalg::GenericOp op, OpOperand *t) {
763  auto map = op.getMatchingIndexingMap(t);
764  auto enc = getSparseTensorEncoding(t->get().getType());
765  AffineExpr a = map.getResult(toOrigDim(enc, map.getNumResults() - 1));
766  assert(a.getKind() == AffineExprKind::DimId);
767  unsigned idx = a.cast<AffineDimExpr>().getPosition();
768  return codegen.getLoopIdxValue(idx);
769 }
770 
771 /// Generates subscript for load/store on a dense or sparse tensor.
772 static Value genSubscript(CodeGen &codegen, OpBuilder &builder,
773  linalg::GenericOp op, OpOperand *t,
774  SmallVectorImpl<Value> &args) {
775  unsigned tensor = t->getOperandNumber();
776  auto map = op.getMatchingIndexingMap(t);
777  auto enc = getSparseTensorEncoding(t->get().getType());
778  unsigned rank = map.getNumResults();
779  if (enc) {
780  Value pidx = codegen.loopEmitter.getPidxs()[tensor].back();
781  assert(pidx);
782  args.push_back(pidx); // position index
783  } else {
784  for (unsigned d = 0; d < rank; d++) {
785  AffineExpr a = map.getResult(d);
786  args.push_back(codegen.loopEmitter.genAffine(builder, a, op.getLoc()));
787  }
788  }
789  return codegen.loopEmitter.getValBuffer()[tensor];
790 }
791 
792 /// Generates insertion code to implement dynamic tensor load.
793 static Value genInsertionLoad(CodeGen &codegen, OpBuilder &builder,
794  linalg::GenericOp op, OpOperand *t) {
795  Location loc = op.getLoc();
796  // Direct lexicographic index order, tensor loads as zero.
797  if (!codegen.expValues) {
798  Type tp = getElementTypeOrSelf(t->get().getType());
799  return constantZero(builder, loc, tp);
800  }
801  // Load from expanded access pattern.
802  Value index = genIndex(codegen, op, t);
803  return builder.create<memref::LoadOp>(loc, codegen.expValues, index);
804 }
805 
806 /// Generates insertion code to implement dynamic tensor load for reduction.
807 static Value genInsertionLoadReduce(Merger &merger, CodeGen &codegen,
808  OpBuilder &builder, linalg::GenericOp op,
809  OpOperand *t) {
810  Location loc = op.getLoc();
811  Value identity = getCustomRedId(merger.exp(codegen.redCustom).op);
812  // Direct lexicographic index order, tensor loads as identity.
813  if (!codegen.expValues) {
814  return identity;
815  }
816  // Load from expanded access pattern if filled, identity otherwise.
817  Value index = genIndex(codegen, op, t);
818  Value isFilled =
819  builder.create<memref::LoadOp>(loc, codegen.expFilled, index);
820  Value valAtIndex =
821  builder.create<memref::LoadOp>(loc, codegen.expValues, index);
822  return builder.create<arith::SelectOp>(loc, isFilled, valAtIndex, identity);
823 }
824 
825 /// Generates insertion code to implement dynamic tensor store.
826 static void genInsertionStore(CodeGen &codegen, OpBuilder &builder,
827  linalg::GenericOp op, OpOperand *t, Value rhs) {
828  Location loc = op.getLoc();
829  // Direct insertion in lexicographic index order.
830  if (!codegen.expValues) {
831  unsigned rank = op.getRank(t);
832  SmallVector<Value> indices;
833  for (unsigned i = 0; i < rank; i++) {
834  assert(codegen.loopEmitter.getLoopIV(i));
835  indices.push_back(codegen.loopEmitter.getLoopIV(i));
836  }
837  codegen.insChain =
838  builder.create<InsertOp>(loc, rhs, codegen.insChain, indices);
839  return;
840  }
841  // Generates insertion code along expanded access pattern.
842  // if (!expFilled[i]) then
843  // expFilled[i] = true
844  // expAdded[inserts++] = i
845  // endif
846  // values[i] = rhs
847  Value index = genIndex(codegen, op, t);
848  Value fval = constantI1(builder, loc, false);
849  Value tval = constantI1(builder, loc, true);
850  // If statement.
851  Value filled = builder.create<memref::LoadOp>(loc, codegen.expFilled, index);
852  Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
853  filled, fval);
854  scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIndexType(), cond,
855  /*else=*/true);
856  // True branch.
857  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
858  builder.create<memref::StoreOp>(loc, tval, codegen.expFilled, index);
859  builder.create<memref::StoreOp>(loc, index, codegen.expAdded,
860  codegen.expCount);
861  Value one = constantIndex(builder, loc, 1);
862  Value add = builder.create<arith::AddIOp>(loc, codegen.expCount, one);
863  builder.create<scf::YieldOp>(loc, add);
864  // False branch.
865  builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
866  builder.create<scf::YieldOp>(loc, codegen.expCount);
867  builder.setInsertionPointAfter(ifOp);
868  // Value assignment.
869  codegen.expCount = ifOp.getResult(0);
870  builder.create<memref::StoreOp>(loc, rhs, codegen.expValues, index);
871 }
872 
873 /// Generates a load on a dense or sparse tensor.
874 static Value genTensorLoad(Merger &merger, CodeGen &codegen, OpBuilder &builder,
875  linalg::GenericOp op, unsigned exp) {
876  // Test if the load was hoisted to a higher loop nest.
877  Value val = merger.exp(exp).val;
878  if (val)
879  return val;
880 
881  // Load during insertion.
882  OpOperand &t = op->getOpOperand(merger.exp(exp).tensor);
883  if (&t == codegen.sparseOut) {
884  if (codegen.redCustom != -1u)
885  return genInsertionLoadReduce(merger, codegen, builder, op, &t);
886  return genInsertionLoad(codegen, builder, op, &t);
887  }
888  // Actual load.
889  SmallVector<Value> args;
890  Value ptr = genSubscript(codegen, builder, op, &t, args);
891  return builder.create<memref::LoadOp>(op.getLoc(), ptr, args);
892 }
893 
894 /// Generates a store on a dense or sparse tensor.
895 static void genTensorStore(Merger &merger, CodeGen &codegen, OpBuilder &builder,
896  linalg::GenericOp op, unsigned exp, Value rhs) {
897  Location loc = op.getLoc();
898  // Test if this is a scalarized reduction.
899  if (codegen.redVal) {
900  updateReduc(merger, codegen, rhs);
901  return;
902  }
903  // Store during insertion.
904  OpOperand *t = op.getDpsInitOperand(0);
905  if (t == codegen.sparseOut) {
906  if (!rhs) {
907  // Only unary and binary are allowed to return uninitialized rhs
908  // to indicate missing output.
909  assert(merger.exp(exp).kind == kUnary || merger.exp(exp).kind == kBinary);
910  } else if (merger.exp(exp).kind == kSelect) {
911  // Select operation insertion.
912  Value insChain = codegen.insChain;
913  assert(insChain);
914  scf::IfOp ifOp = builder.create<scf::IfOp>(loc, insChain.getType(), rhs,
915  /*else=*/true);
916  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
917  // Existing value was preserved to be used here.
918  assert(merger.exp(exp).val);
919  Value v0 = merger.exp(exp).val;
920  genInsertionStore(codegen, builder, op, t, v0);
921  merger.exp(exp).val = Value();
922  // Yield modified insertion chain along true branch.
923  builder.create<scf::YieldOp>(op.getLoc(), codegen.insChain);
924  // Yield original insertion chain along false branch.
925  builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
926  builder.create<scf::YieldOp>(loc, insChain);
927  // Done with if statement.
928  codegen.insChain = ifOp->getResult(0);
929  builder.setInsertionPointAfter(ifOp);
930  } else {
931  genInsertionStore(codegen, builder, op, t, rhs);
932  }
933  return;
934  }
935  // Actual store.
936  SmallVector<Value> args;
937  Value ptr = genSubscript(codegen, builder, op, t, args);
938  builder.create<memref::StoreOp>(loc, rhs, ptr, args);
939 }
940 
941 /// Generates an invariant value.
942 inline static Value genInvariantValue(Merger &merger, CodeGen &codegen,
943  OpBuilder &builder, unsigned exp) {
944  return merger.exp(exp).val;
945 }
946 
947 /// Generates an index value.
948 inline static Value genIndexValue(CodeGen &codegen, OpBuilder &builder,
949  unsigned idx) {
950  return codegen.getLoopIdxValue(idx);
951 }
952 
953 /// Semi-ring branches are simply inlined by the sparse compiler. Prior
954 /// analysis has verified that all computations are "local" to the inlined
955 /// branch or otherwise invariantly defined outside the loop nest, with the
956 /// exception of index computations, which need to be relinked to actual
957 /// inlined cloned code.
958 static Value relinkBranch(CodeGen &codegen, RewriterBase &rewriter,
959  Block *block, Value e, unsigned ldx) {
960  if (Operation *def = e.getDefiningOp()) {
961  if (auto indexOp = dyn_cast<linalg::IndexOp>(def))
962  return genIndexValue(codegen, rewriter, indexOp.getDim());
963  if (def->getBlock() == block) {
964  for (unsigned i = 0, n = def->getNumOperands(); i < n; i++)
965  def->setOperand(
966  i, relinkBranch(codegen, rewriter, block, def->getOperand(i), ldx));
967  }
968  }
969  return e;
970 }
971 
972 /// Recursively generates tensor expression.
973 static Value genExp(Merger &merger, CodeGen &codegen, RewriterBase &rewriter,
974  linalg::GenericOp op, unsigned exp, unsigned ldx) {
975  Location loc = op.getLoc();
976  if (exp == -1u)
977  return Value();
978  if (merger.exp(exp).kind == Kind::kTensor)
979  return genTensorLoad(merger, codegen, rewriter, op, exp);
980  if (merger.exp(exp).kind == Kind::kInvariant)
981  return genInvariantValue(merger, codegen, rewriter, exp);
982  if (merger.exp(exp).kind == Kind::kIndex)
983  return genIndexValue(codegen, rewriter, merger.exp(exp).index);
984 
985  if (merger.exp(exp).kind == Kind::kReduce) {
986  // Make custom reduction identity accessible for expanded access pattern.
987  assert(codegen.redCustom == -1u);
988  codegen.redCustom = exp;
989  }
990 
991  Value v0 =
992  genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0, ldx);
993  Value v1 =
994  genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e1, ldx);
995  Value ee = merger.buildExp(rewriter, loc, exp, v0, v1);
996  if (ee && (merger.exp(exp).kind == Kind::kUnary ||
997  merger.exp(exp).kind == Kind::kBinary ||
998  merger.exp(exp).kind == Kind::kBinaryBranch ||
999  merger.exp(exp).kind == Kind::kReduce ||
1000  merger.exp(exp).kind == Kind::kSelect))
1001  ee = relinkBranch(codegen, rewriter, ee.getParentBlock(), ee, ldx);
1002 
1003  if (merger.exp(exp).kind == kSelect) {
1004  assert(!merger.exp(exp).val);
1005  merger.exp(exp).val = v0; // Preserve value for later use.
1006  }
1007 
1008  if (merger.exp(exp).kind == Kind::kReduce) {
1009  assert(codegen.redCustom != -1u);
1010  codegen.redCustom = -1u;
1011  }
1012 
1013  return ee;
1014 }
1015 
1016 /// Hoists loop invariant tensor loads for which indices have been exhausted.
1017 static void genInvariants(Merger &merger, CodeGen &codegen, OpBuilder &builder,
1018  linalg::GenericOp op, unsigned exp, unsigned ldx,
1019  bool atStart, unsigned last = -1u) {
1020  if (exp == -1u)
1021  return;
1022  if (merger.exp(exp).kind == Kind::kTensor) {
1023  // Inspect tensor indices.
1024  bool atLevel = ldx == -1u;
1025  OpOperand &t = op->getOpOperand(merger.exp(exp).tensor);
1026  auto map = op.getMatchingIndexingMap(&t);
1027  auto enc = getSparseTensorEncoding(t.get().getType());
1028  for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
1029  AffineExpr a = map.getResult(toOrigDim(enc, d));
1030  Optional<unsigned> sldx = merger.getLoopIdx(t.getOperandNumber(), d);
1031  if (sldx && merger.isFilterLoop(sldx.value())) {
1032  if (!codegen.getLoopIdxValue(sldx.value()))
1033  // The filter loops has not been constructed.
1034  return;
1035  if (sldx.value() == ldx)
1036  atLevel = true;
1037  } else if (!isInvariantAffine(codegen, a, ldx, atLevel))
1038  return; // still in play
1039  }
1040  // All exhausted at this level (atLevel denotes exactly at this level).
1041  if (!atLevel)
1042  return;
1043  OpOperand *lhs = op.getDpsInitOperand(0);
1044  if (lhs == &t) {
1045  // Start or end a scalarized reduction
1046  if (atStart) {
1047  Kind kind = merger.exp(last).kind;
1048  Value load = kind == Kind::kReduce
1049  ? getCustomRedId(merger.exp(last).op)
1050  : genTensorLoad(merger, codegen, builder, op, exp);
1051  codegen.redKind = getReduction(kind);
1052  codegen.redExp = exp;
1053  updateReduc(merger, codegen, load);
1054  } else {
1055  Value redVal = codegen.redVal;
1056  updateReduc(merger, codegen, Value());
1057  codegen.redExp = -1u;
1058  codegen.redKind = kNoReduc;
1059  genTensorStore(merger, codegen, builder, op, exp, redVal);
1060  }
1061  } else {
1062  // Start or end loop invariant hoisting of a tensor load.
1063  merger.exp(exp).val =
1064  atStart ? genTensorLoad(merger, codegen, builder, op, exp) : Value();
1065  }
1066  } else if (merger.exp(exp).kind != Kind::kInvariant &&
1067  merger.exp(exp).kind != Kind::kIndex) {
1068  // Traverse into the binary operations. Note that we only hoist
1069  // tensor loads, since subsequent MLIR/LLVM passes know how to
1070  // deal with all other kinds of derived loop invariants.
1071  unsigned e0 = merger.exp(exp).children.e0;
1072  unsigned e1 = merger.exp(exp).children.e1;
1073  genInvariants(merger, codegen, builder, op, e0, ldx, atStart, exp);
1074  genInvariants(merger, codegen, builder, op, e1, ldx, atStart, exp);
1075  }
1076 }
1077 
1078 /// Generates an expanded access pattern in innermost dimension.
1079 static void genExpansion(Merger &merger, CodeGen &codegen, OpBuilder &builder,
1080  linalg::GenericOp op, unsigned at, bool atStart) {
1081  OpOperand *lhs = codegen.sparseOut;
1082  if (!lhs || codegen.outerParNest != op.getRank(lhs) - 1 ||
1083  at != codegen.outerParNest)
1084  return; // not needed at this level
1085  assert(codegen.redVal == nullptr);
1086  // Generate start or end of an expanded access pattern. Note that because
1087  // an expension does not rely on the ongoing contents of the sparse storage
1088  // scheme, we can use the original tensor as incoming SSA value (which
1089  // simplifies codegen a bit). If expansion on the actual contents is ever
1090  // needed, we will need to use the SSA value in the insertion chain instead.
1091  Value tensor = lhs->get();
1092  Location loc = op.getLoc();
1093  if (atStart) {
1094  auto dynShape = {ShapedType::kDynamic};
1095  Type etp = tensor.getType().cast<ShapedType>().getElementType();
1096  Type t1 = MemRefType::get(dynShape, etp);
1097  Type t2 = MemRefType::get(dynShape, builder.getI1Type());
1098  Type t3 = MemRefType::get(dynShape, builder.getIndexType());
1099  Type t4 = builder.getIndexType();
1100  auto res =
1101  builder.create<ExpandOp>(loc, TypeRange({t1, t2, t3, t4}), tensor);
1102  assert(res.getNumResults() == 4);
1103  assert(!codegen.expValues);
1104  codegen.expValues = res.getResult(0);
1105  codegen.expFilled = res.getResult(1);
1106  codegen.expAdded = res.getResult(2);
1107  codegen.expCount = res.getResult(3);
1108  } else {
1109  assert(codegen.expValues);
1110  SmallVector<Value> indices;
1111  for (unsigned i = 0; i < at; i++) {
1112  assert(codegen.loopEmitter.getLoopIV(i));
1113  indices.push_back(codegen.loopEmitter.getLoopIV(i));
1114  }
1115  codegen.insChain = builder.create<CompressOp>(
1116  loc, codegen.expValues, codegen.expFilled, codegen.expAdded,
1117  codegen.expCount, codegen.insChain, indices);
1118  codegen.expValues = codegen.expFilled = codegen.expAdded =
1119  codegen.expCount = Value();
1120  }
1121 }
1122 
1123 /// Returns parallelization strategy. Any implicit loop in the Linalg
1124 /// operation that is marked "parallel" is a candidate. Whether it is actually
1125 /// converted to a parallel operation depends on the requested strategy.
1126 static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isSparse) {
1127  // Reject parallelization of sparse output.
1128  if (codegen.sparseOut)
1129  return false;
1130  // Parallel loops on tensor expansion can cause data races.
1131  if (codegen.expCount)
1132  return false;
1133  // Inspect strategy.
1134  switch (codegen.options.parallelizationStrategy) {
1136  return false;
1138  return isOuter && !isSparse;
1140  return isOuter;
1142  return !isSparse;
1144  return true;
1145  }
1146  llvm_unreachable("unexpected parallelization strategy");
1147 }
1148 
1149 /// Generates a for-loop on a single index.
1150 static Operation *genFor(Merger &merger, CodeGen &codegen, OpBuilder &builder,
1151  linalg::GenericOp op, bool isOuter, bool isInner,
1152  unsigned idx, size_t tid, size_t dim,
1153  ArrayRef<size_t> extraTids,
1154  ArrayRef<size_t> extraDims) {
1155  Location loc = op.getLoc();
1156  bool isSparse = isCompressedDLT(merger.getDimLevelType(tid, idx)) ||
1157  isSingletonDLT(merger.getDimLevelType(tid, idx));
1158  bool isParallel = isParallelFor(codegen, isOuter, isSparse);
1159 
1160  Operation *loop =
1161  genLoopBoundary(codegen, merger, [&](MutableArrayRef<Value> reduc) {
1162  if (merger.isFilterLoop(idx)) {
1163  // extraTids/extraDims must be empty because filter loops only
1164  // corresponding to the one and only sparse tensor level.
1165  assert(isSparse && extraTids.empty() && extraDims.empty());
1166  OpOperand *t = &op->getOpOperand(tid);
1167  auto enc = getSparseTensorEncoding(t->get().getType());
1168  // Retrieves the affine expression for the filter loop.
1169  AffineExpr a =
1170  op.getMatchingIndexingMap(t).getResult(toOrigDim(enc, dim));
1171  return codegen.loopEmitter.enterFilterLoopOverTensorAtDim(
1172  builder, loc, tid, dim, a, reduc);
1173  }
1174  return codegen.loopEmitter.enterLoopOverTensorAtDim(
1175  builder, loc, tid, dim, reduc, isParallel, extraTids, extraDims);
1176  }).value();
1177  assert(loop);
1178  return loop;
1179 }
1180 
1181 /// Emit a while-loop for co-iteration over multiple indices.
1182 static Operation *genWhile(Merger &merger, CodeGen &codegen, OpBuilder &builder,
1183  linalg::GenericOp op, unsigned idx, bool needsUniv,
1184  ArrayRef<size_t> condTids, ArrayRef<size_t> condDims,
1185  ArrayRef<size_t> extraTids,
1186  ArrayRef<size_t> extraDims) {
1187 
1188  Operation *loop =
1189  genLoopBoundary(codegen, merger, [&](MutableArrayRef<Value> reduc) {
1190  // Construct the while-loop with a parameter for each index.
1191  return codegen.loopEmitter.enterCoIterationOverTensorsAtDims(
1192  builder, op.getLoc(), condTids, condDims, needsUniv, reduc,
1193  extraTids, extraDims);
1194  }).value();
1195  assert(loop);
1196  return loop;
1197 }
1198 
1199 /// Generates a for-loop or a while-loop, depending on whether it implements
1200 /// singleton iteration or co-iteration over the given conjunction.
1201 static Operation *genLoop(Merger &merger, CodeGen &codegen, OpBuilder &builder,
1202  linalg::GenericOp op, unsigned at, bool needsUniv,
1203  ArrayRef<size_t> condTids, ArrayRef<size_t> condDims,
1204  ArrayRef<size_t> extraTids,
1205  ArrayRef<size_t> extraDims) {
1206  assert(condTids.size() == condDims.size());
1207  assert(extraTids.size() == extraDims.size());
1208  unsigned idx = codegen.topSort[at];
1209  if (condTids.size() == 1) {
1210  bool isOuter = at == 0;
1211  bool isInner = at == codegen.topSort.size() - 1;
1212  return genFor(merger, codegen, builder, op, isOuter, isInner, idx,
1213  condTids.front(), condDims.front(), extraTids, extraDims);
1214  }
1215  return genWhile(merger, codegen, builder, op, idx, needsUniv, condTids,
1216  condDims, extraTids, extraDims);
1217 }
1218 
1219 /// Generates the induction structure for a while-loop.
1220 static void finalizeWhileOp(Merger &merger, CodeGen &codegen,
1221  OpBuilder &builder, linalg::GenericOp op,
1222  unsigned idx, bool needsUniv, BitVector &induction,
1223  scf::WhileOp whileOp) {
1224  Location loc = op.getLoc();
1225  // Finalize each else branch of all if statements.
1226  if (codegen.redVal || codegen.expValues || codegen.insChain) {
1227  while (auto ifOp = dyn_cast_or_null<scf::IfOp>(
1228  builder.getInsertionBlock()->getParentOp())) {
1229  unsigned y = 0;
1230  SmallVector<Value> yields;
1231  if (codegen.redVal) {
1232  yields.push_back(codegen.redVal);
1233  updateReduc(merger, codegen, ifOp.getResult(y++));
1234  }
1235  if (codegen.expValues) {
1236  yields.push_back(codegen.expCount);
1237  codegen.expCount = ifOp->getResult(y++);
1238  }
1239  if (codegen.insChain) {
1240  yields.push_back(codegen.insChain);
1241  codegen.insChain = ifOp->getResult(y++);
1242  }
1243  assert(y == yields.size());
1244  builder.create<scf::YieldOp>(loc, yields);
1245  builder.setInsertionPointAfter(ifOp);
1246  }
1247  }
1248  builder.setInsertionPointToEnd(&whileOp.getAfter().front());
1249 }
1250 
1251 /// Generates a single if-statement within a while-loop.
1252 static scf::IfOp genIf(Merger &merger, CodeGen &codegen, OpBuilder &builder,
1253  linalg::GenericOp op, unsigned idx,
1254  BitVector &conditions) {
1255  Location loc = op.getLoc();
1256  SmallVector<Type> types;
1257  Value cond;
1258  for (unsigned b = 0, be = conditions.size(); b < be; b++) {
1259  if (!conditions[b])
1260  continue;
1261  unsigned tensor = merger.tensor(b);
1262  assert(idx == merger.index(b));
1263  Value clause;
1264  if (isCompressedDLT(merger.getDimLevelType(b)) ||
1265  isSingletonDLT(merger.getDimLevelType(b))) {
1266  auto dim = merger.getDimNum(tensor, idx).value();
1267  Value op1 = codegen.loopEmitter.getCoord()[tensor][dim];
1268  Value op2 = codegen.getLoopIdxValue(idx);
1269  clause = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, op1,
1270  op2);
1271  } else {
1272  assert(isDenseDLT(merger.getDimLevelType(b)) ||
1273  isUndefDLT(merger.getDimLevelType(b)));
1274  clause = constantI1(builder, loc, true);
1275  }
1276  cond = cond ? builder.create<arith::AndIOp>(loc, cond, clause) : clause;
1277  }
1278  if (codegen.redVal)
1279  types.push_back(codegen.redVal.getType());
1280  if (codegen.expValues)
1281  types.push_back(builder.getIndexType());
1282  if (codegen.insChain)
1283  types.push_back(codegen.insChain.getType());
1284  scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true);
1285  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
1286  return ifOp;
1287 }
1288 
1289 /// Generates end of true branch of if-statement within a while-loop.
1290 static void endIf(Merger &merger, CodeGen &codegen, OpBuilder &builder,
1291  linalg::GenericOp op, scf::IfOp ifOp, Operation *loop,
1292  Value redInput, Value cntInput, Value insInput) {
1293  SmallVector<Value> operands;
1294  if (codegen.redVal) {
1295  operands.push_back(codegen.redVal);
1296  updateReduc(merger, codegen, redInput);
1297  }
1298  if (codegen.expValues) {
1299  operands.push_back(codegen.expCount);
1300  codegen.expCount = cntInput;
1301  }
1302  if (codegen.insChain) {
1303  operands.push_back(codegen.insChain);
1304  codegen.insChain = insInput;
1305  }
1306  if (!operands.empty())
1307  builder.create<scf::YieldOp>(op.getLoc(), operands);
1308  builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
1309 }
1310 
1311 //===----------------------------------------------------------------------===//
1312 // Sparse compiler synthesis methods (loop sequence).
1313 //===----------------------------------------------------------------------===//
1314 
1315 /// Starts a loop sequence at given level. Returns true if
1316 /// the universal loop index must be maintained at this level.
1317 static bool startLoopSeq(Merger &merger, CodeGen &codegen, OpBuilder &builder,
1318  linalg::GenericOp op, unsigned exp, unsigned at,
1319  unsigned idx, unsigned ldx, unsigned lts) {
1320  assert(!codegen.getLoopIdxValue(idx));
1321  // Emit invariants at this loop sequence level.
1322  genInvariants(merger, codegen, builder, op, exp, ldx, /*atStart=*/true);
1323  // Emit access pattern expansion for sparse tensor output.
1324  genExpansion(merger, codegen, builder, op, at, /*atStart=*/true);
1325  // Emit further intitialization at this loop sequence level.
1326  unsigned l0 = merger.set(lts)[0];
1327  bool needsUniv = false;
1328 
1329  SmallVector<size_t> tids;
1330  SmallVector<size_t> dims;
1331  merger.foreachTidDimPairInBits(
1332  merger.lat(l0).bits,
1333  [&](unsigned b, unsigned tid, Optional<unsigned> dim, DimLevelType dlt) {
1334  assert(merger.index(b) == idx);
1335  if (isDenseDLT(dlt) || isUndefDLT(dlt)) {
1336  needsUniv = true;
1337  } else {
1338  // sparse/singleton dim levels.
1339  tids.push_back(tid);
1340  dims.push_back(dim.value());
1341  }
1342  });
1343 
1344  codegen.loopEmitter.enterNewLoopSeq(builder, op.getLoc(), tids, dims);
1345 
1346  // Maintain the universal index only if it is actually
1347  // consumed by a subsequent lattice point.
1348  if (needsUniv) {
1349  unsigned lsize = merger.set(lts).size();
1350  for (unsigned i = 1; i < lsize; i++) {
1351  unsigned li = merger.set(lts)[i];
1352  if (!merger.hasAnySparse(merger.lat(li).simple))
1353  return true;
1354  }
1355  }
1356  return false;
1357 }
1358 
1359 static void genConstantDenseAddressFromLevel(CodeGen &codegen,
1360  OpBuilder &builder,
1361  linalg::GenericOp op, unsigned tid,
1362  unsigned lvl) {
1363  // TODO: Handle affine expression on output tensor.
1364  assert(tid < op.getNumDpsInputs());
1365 
1366  OpOperand *input = op.getDpsInputOperands()[tid];
1367  ArrayRef<AffineExpr> affines = op.getMatchingIndexingMap(input).getResults();
1368  auto enc = getSparseTensorEncoding(input->get().getType());
1369  if (enc) {
1370  for (unsigned i = lvl, e = affines.size(); i < e; i++) {
1371  AffineExpr affine = affines[toOrigDim(enc, i)];
1372  if (isDenseDLT(getDimLevelType(enc, i)) &&
1373  affine.isa<AffineConstantExpr>()) {
1374  codegen.loopEmitter.genDenseAffineAddressAtCurLevel(
1375  builder, op.getLoc(), input->getOperandNumber(), i, affine);
1376  } else {
1377  // Breaks on first non-dense non-constant level.
1378  return;
1379  }
1380  }
1381  }
1382 }
1383 
1384 static void genInitConstantDenseAddress(CodeGen &codegen,
1385  RewriterBase &rewriter,
1386  linalg::GenericOp op) {
1387  // We can generates address for constant affine expression before any loops
1388  // starting from the first level as they do not depend on any thing.
1389  // E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two
1390  // levels can be determined before loops.
1391  for (unsigned tid = 0, e = op.getNumDpsInputs(); tid < e; tid++)
1392  genConstantDenseAddressFromLevel(codegen, rewriter, op, tid, 0);
1393 }
1394 
1396  Merger &merger, CodeGen &codegen, linalg::GenericOp op, unsigned li,
1397  unsigned idx, SmallVectorImpl<size_t> &condTids,
1398  SmallVectorImpl<size_t> &condDims, SmallVectorImpl<size_t> &extraTids,
1399  SmallVectorImpl<size_t> &extraDims, SmallVectorImpl<size_t> &affineTids,
1401 
1402  const BitVector &all = merger.lat(li).bits;
1403  const BitVector &simple = merger.lat(li).simple;
1404 
1405  // Converts bits to array + dim pair
1406  merger.foreachTidDimPairInBits(all, [&, idx](unsigned b, unsigned tid,
1407  Optional<unsigned> dim,
1408  DimLevelType dlt) {
1409  if (simple.test(b)) {
1410  if (isUndefDLT(dlt)) {
1411  // An undefined dlt in the lattices, we probably mean to iterate based
1412  // on the dim of output tensor.
1413  // E.g., this could be a synthetic tensor (for invariants and sparse
1414  // output tensor).
1415  // out[i][j] = invariant; or a broadcast
1416  // out[i][j] = in[i] (j is undef for input)
1417  tid = merger.getOutTensorID();
1418  dim = merger.getDimNum(tid, idx);
1419  // Skips invalid dim (e.g., when this is a zero ranked tensor).
1420  if (!dim)
1421  return;
1422  }
1423  condTids.push_back(tid);
1424  condDims.push_back(dim.value());
1425  } else if (isDenseDLT(dlt)) {
1426  // TODO: get rid of extraTids and extraDims.
1427  extraTids.push_back(tid);
1428  extraDims.push_back(dim.value());
1429  } else {
1430  assert(isUndefDLT(dlt));
1431  if (tid >= op.getNumDpsInputs())
1432  // We only handle affine expression on input tensors (for now).
1433  return;
1434  OpOperand *operand = &op->getOpOperand(tid);
1435  auto enc = getSparseTensorEncoding(operand->get().getType());
1436  // Non-annotated dense tensors requires no special handling.
1437  if (!enc)
1438  return;
1439 
1440  ArrayRef<AffineExpr> affines =
1441  op.getMatchingIndexingMap(operand).getResults();
1442  assert(affines.size() == enc.getDimLevelType().size());
1443  for (unsigned i = 0, e = affines.size(); i < e; i++) {
1444  AffineExpr exp = affines[toOrigDim(enc, i)];
1445  // Skip simple affine expression and non dense dimensions (which has
1446  // it own filter loop).
1447  if (exp.isa<AffineDimExpr>() || !isDenseDLT(getDimLevelType(enc, i)))
1448  continue;
1449 
1450  // Constant affine expression are handled in genLoop
1451  if (!exp.isa<AffineConstantExpr>()) {
1452  bool atLevel = false;
1453  if (isInvariantAffine(codegen, exp, idx, atLevel) && atLevel) {
1454  // If the compound affine is invariant and we are right at the
1455  // level. We need to generate the address according to the affine
1456  // expression. This is also the best place we can do it to avoid
1457  // putting it inside inner loops.
1458  // NOTE: It assumes that the levels of the input tensor are
1459  // initialized in order (and it is also currently guaranteed by
1460  // computeIterationGraph), another more admissible approach might be
1461  // accepting out-of-order access between consecutive dense levels.
1462  affineTids.push_back(tid);
1463  affineDims.push_back(i);
1464  exps.push_back(exp);
1465  }
1466  }
1467  }
1468  }
1469  });
1470 
1471  if (isDenseDLT(merger.getDimLevelType(merger.getOutTensorID(), idx))) {
1472  // Note that we generate dense indices of the output tensor
1473  // unconditionally, since they may not appear in the lattice, but may be
1474  // needed for linearized codegen.
1475  auto dim = merger.getDimNum(merger.getOutTensorID(), idx).value();
1476  extraTids.push_back(merger.getOutTensorID());
1477  extraDims.push_back(dim);
1478  }
1479 }
1480 
1481 /// Starts a single loop in current sequence.
1482 static Operation *startLoop(Merger &merger, CodeGen &codegen,
1483  OpBuilder &builder, linalg::GenericOp op,
1484  unsigned at, unsigned li, bool needsUniv) {
1485  // The set of tensors + dims to generate loops on
1486  SmallVector<size_t> condTids, condDims;
1487  // The set of (dense) tensors that is optimized from condition, yet still
1488  // need extra locals to iterate on them.
1489  SmallVector<size_t> extraTids, extraDims;
1490  // The set of dense tensors with non-trivial affine expression that just
1491  // becomes invariant and the address shall now be generated at the current
1492  // level.
1493  SmallVector<size_t> affineTids, affineDims;
1494  SmallVector<AffineExpr> affines;
1495 
1496  translateBitsToTidDimPairs(merger, codegen, op, li, codegen.topSort[at],
1497  condTids, condDims, extraTids, extraDims,
1498  affineTids, affineDims, affines);
1499  // Emit the for/while-loop control.
1500  Operation *loop = genLoop(merger, codegen, builder, op, at, needsUniv,
1501  condTids, condDims, extraTids, extraDims);
1502 
1503  for (auto [tid, dim, exp] : llvm::zip(affineTids, affineDims, affines)) {
1504  codegen.loopEmitter.genDenseAffineAddressAtCurLevel(builder, op.getLoc(),
1505  tid, dim, exp);
1506  }
1507 
1508  // Until now, we have entered every <tid, dim> pair in {cond, extra,
1509  // affine}Tids/Dims. The addresses of the upcoming levels which are dependent
1510  // on constant affines expression may now be determined.
1511  auto allTids = llvm::concat<size_t>(condTids, extraTids, affineTids);
1512  auto allDims = llvm::concat<size_t>(condDims, extraDims, affineDims);
1513  for (auto [tid, dim] : llvm::zip(allTids, allDims)) {
1514  if (tid != merger.getOutTensorID())
1515  genConstantDenseAddressFromLevel(codegen, builder, op, tid, dim + 1);
1516  }
1517 
1518  return loop;
1519 }
1520 
1521 /// Ends a single loop in current sequence. Returns new values for needsUniv.
1522 static bool endLoop(Merger &merger, CodeGen &codegen, RewriterBase &rewriter,
1523  linalg::GenericOp op, Operation *loop, unsigned idx,
1524  unsigned li, bool needsUniv) {
1525  // End a while-loop.
1526  if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) {
1527  finalizeWhileOp(merger, codegen, rewriter, op, idx, needsUniv,
1528  merger.lat(li).bits, whileOp);
1529  } else {
1530  needsUniv = false;
1531  }
1532 
1533  genLoopBoundary(codegen, merger, [&](MutableArrayRef<Value> reduc) {
1534  codegen.loopEmitter.exitCurrentLoop(rewriter, op.getLoc(), reduc);
1535  return llvm::None;
1536  });
1537 
1538  return needsUniv;
1539 }
1540 
1541 /// Ends a loop sequence at given level.
1542 static void endLoopSeq(Merger &merger, CodeGen &codegen, OpBuilder &builder,
1543  linalg::GenericOp op, unsigned exp, unsigned at,
1544  unsigned idx, unsigned ldx) {
1545  assert(codegen.getLoopIdxValue(idx) == nullptr);
1546  codegen.loopEmitter.exitCurrentLoopSeq();
1547  // Unmark bookkeeping of invariants and loop index.
1548  genInvariants(merger, codegen, builder, op, exp, ldx, /*atStart=*/false);
1549  // Finalize access pattern expansion for sparse tensor output.
1550  genExpansion(merger, codegen, builder, op, at, /*atStart=*/false);
1551 }
1552 
1553 /// Recursively generates code while computing iteration lattices in order
1554 /// to manage the complexity of implementing co-iteration over unions
1555 /// and intersections of sparse iterations spaces.
1556 static void genStmt(Merger &merger, CodeGen &codegen, RewriterBase &rewriter,
1557  linalg::GenericOp op, unsigned exp, unsigned at) {
1558  // At each leaf, assign remaining tensor (sub)expression to output tensor.
1559  if (at == codegen.topSort.size()) {
1560  unsigned ldx = codegen.topSort[at - 1];
1561  Value rhs = genExp(merger, codegen, rewriter, op, exp, ldx);
1562  genTensorStore(merger, codegen, rewriter, op, exp, rhs);
1563  return;
1564  }
1565 
1566  // Construct iteration lattices for current loop index, with L0 at top.
1567  unsigned idx = codegen.topSort[at];
1568  unsigned ldx = at == 0 ? -1u : codegen.topSort[at - 1];
1569  unsigned lts = merger.optimizeSet(merger.buildLattices(exp, idx));
1570 
1571  // TODO: sort
1572  // TODO: dedup
1573 
1574  // Start a loop sequence.
1575  bool needsUniv =
1576  startLoopSeq(merger, codegen, rewriter, op, exp, at, idx, ldx, lts);
1577 
1578  // Emit a loop for every lattice point L0 >= Li in this loop sequence.
1579  unsigned lsize = merger.set(lts).size();
1580  for (unsigned i = 0; i < lsize; i++) {
1581  // Start a loop.
1582  unsigned li = merger.set(lts)[i];
1583  Operation *loop =
1584  startLoop(merger, codegen, rewriter, op, at, li, needsUniv);
1585 
1586  // Visit all lattices points with Li >= Lj to generate the
1587  // loop-body, possibly with if statements for coiteration.
1588  Value redInput = codegen.redVal;
1589  Value cntInput = codegen.expCount;
1590  Value insInput = codegen.insChain;
1591  bool isWhile = dyn_cast<scf::WhileOp>(loop) != nullptr;
1592  for (unsigned j = 0; j < lsize; j++) {
1593  unsigned lj = merger.set(lts)[j];
1594  unsigned ej = merger.lat(lj).exp;
1595  if (li == lj || merger.latGT(li, lj)) {
1596  // Recurse into body of each branch.
1597  if (isWhile) {
1598  scf::IfOp ifOp =
1599  genIf(merger, codegen, rewriter, op, idx, merger.lat(lj).simple);
1600  genStmt(merger, codegen, rewriter, op, ej, at + 1);
1601  endIf(merger, codegen, rewriter, op, ifOp, loop, redInput, cntInput,
1602  insInput);
1603  } else {
1604  genStmt(merger, codegen, rewriter, op, ej, at + 1);
1605  }
1606  }
1607  }
1608 
1609  // End a loop.
1610  needsUniv =
1611  endLoop(merger, codegen, rewriter, op, loop, idx, li, needsUniv);
1612  }
1613 
1614  // End a loop sequence.
1615  endLoopSeq(merger, codegen, rewriter, op, exp, at, idx, ldx);
1616 }
1617 
1618 /// Converts the result computed by the sparse kernel into the required form.
1619 static void genResult(Merger &merger, CodeGen &codegen, RewriterBase &rewriter,
1620  linalg::GenericOp op) {
1621  OpOperand *lhs = op.getDpsInitOperand(0);
1622  Value tensor = lhs->get();
1623  Type resType = tensor.getType();
1624  if (getSparseTensorEncoding(resType)) {
1625  // The sparse tensor rematerializes from the original sparse tensor's
1626  // underlying sparse storage format. For an insertion chain, the
1627  // tensor materializes from the chain with 'hasInserts' enabled.
1628  bool hasInserts = codegen.sparseOut == lhs;
1629  if (hasInserts)
1630  tensor = codegen.insChain;
1631  rewriter.replaceOpWithNewOp<LoadOp>(op, resType, tensor, hasInserts);
1632  } else {
1633  // To rematerialize an non-annotated tensor, simply load it
1634  // from the bufferized value.
1635  Value val = codegen.loopEmitter.getValBuffer().back(); // value array
1636  rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, resType, val);
1637  }
1638 }
1639 
1640 //===----------------------------------------------------------------------===//
1641 // Sparse compiler rewriting methods.
1642 //===----------------------------------------------------------------------===//
1643 
1644 namespace {
1645 /// Sparse rewriting rule for generic Lingalg operation.
1646 struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
1647 public:
1648  GenericOpSparsifier(MLIRContext *context, SparsificationOptions o)
1649  : OpRewritePattern<linalg::GenericOp>(context), options(o) {}
1650 
1651  LogicalResult matchAndRewrite(linalg::GenericOp op,
1652  PatternRewriter &rewriter) const override {
1653  // Detects sparse annotations and translate the per-dimension sparsity
1654  // information for all tensors to loop indices in the kernel.
1655  if (op.getNumDpsInits() != 1)
1656  return failure();
1657  unsigned numTensors = op->getNumOperands();
1658  unsigned numLoops = op.getNumLoops();
1659  unsigned numFilterLoops = getNumCompoundAffineOnSparseDims(op);
1660  Merger merger(numTensors, numLoops, numFilterLoops);
1661  if (!findSparseAnnotations(merger, op))
1662  return failure();
1663 
1664  // Builds the tensor expression for the Linalg operation in SSA form.
1665  Optional<unsigned> optExp = merger.buildTensorExpFromLinalg(op);
1666  if (!optExp.has_value())
1667  return failure();
1668 
1669  unsigned exp = optExp.value();
1670  OpOperand *sparseOut = nullptr;
1671  unsigned outerParNest = 0;
1672  // Computes a topologically sorted iteration graph to ensure tensors
1673  // are visited in natural index order. Gradually relaxes the considered
1674  // constraints until an acyclic iteration graph results, such that sparse
1675  // code generation can proceed. As a last resort, an attempt is made
1676  // to resolve cycles by inserting a conversion.
1677  std::vector<unsigned> topSort;
1678  // Whether the current GenericOp is admissible.
1679  bool isAdmissible = false;
1680  bool hasCycle = true;
1681  // An const list of all masks that we used for interation graph
1682  // computation. Must be ordered from strict -> loose.
1683  const auto allMask = {SortMask::kIncludeAll, SortMask::kIncludeUndef,
1684  SortMask::kIncludeDense, SortMask::kSparseOnly};
1685  for (auto mask : allMask)
1686  if (computeIterationGraph(merger, op, topSort, mask)) {
1687  hasCycle = false;
1688  if (isAdmissibleTensorExp(merger, op, topSort, exp, &sparseOut,
1689  outerParNest)) {
1690  isAdmissible = true;
1691  break;
1692  }
1693  // else try a set of less strict constraints.
1694  }
1695 
1696  if (hasCycle)
1697  // Give it one last shot to resolve the cycle.
1698  return resolveCycle(merger, rewriter, op);
1699  if (!isAdmissible)
1700  // Inadmissible expression, reject.
1701  return failure();
1702 
1703  merger.setHasSparseOut(sparseOut != nullptr);
1704 
1705  SmallVector<Value> tensors;
1706  for (OpOperand &t : op->getOpOperands())
1707  tensors.push_back(t.get());
1708 
1709  // Recursively generates code if admissible.
1710  CodeGen codegen(options, op.getContext(), tensors, numTensors, numLoops,
1711  sparseOut, outerParNest, topSort);
1712  genBuffers(merger, codegen, rewriter, op);
1713  genInitConstantDenseAddress(codegen, rewriter, op);
1714  genStmt(merger, codegen, rewriter, op, exp, 0);
1715  genResult(merger, codegen, rewriter, op);
1716  return success();
1717  }
1718 
1719 private:
1720  // Last resort cycle resolution.
1721  LogicalResult resolveCycle(Merger &merger, PatternRewriter &rewriter,
1722  linalg::GenericOp op) const {
1723  // Compute topological sort while leaving out every
1724  // sparse input tensor in succession until an acylic
1725  // iteration graph results.
1726  std::vector<unsigned> topSort;
1727  for (OpOperand *t : op.getDpsInputOperands()) {
1728  unsigned tensor = t->getOperandNumber();
1729  Value tval = t->get();
1730  auto srcEnc = getSparseTensorEncoding(tval.getType());
1731  if (!srcEnc ||
1732  !computeIterationGraph(merger, op, topSort, SortMask::kSparseOnly, t))
1733  continue;
1734  // Found an input tensor that resolves the cycle by inserting a
1735  // conversion into a sparse tensor that adheres to the iteration
1736  // graph order. Also releases the temporary sparse tensor.
1737  //
1738  // TODO: investigate fusing the conversion with computation,
1739  // especially if it is a direct yield!
1740  //
1741  auto srcTp = tval.getType().cast<RankedTensorType>();
1742  auto dstEnc = SparseTensorEncodingAttr::get(
1743  op->getContext(), srcEnc.getDimLevelType(),
1744  permute(merger, getContext(), op.getMatchingIndexingMap(t),
1745  topSort), // new order
1746  srcEnc.getHigherOrdering(), srcEnc.getPointerBitWidth(),
1747  srcEnc.getIndexBitWidth());
1748  auto dstTp = RankedTensorType::get(srcTp.getShape(),
1749  srcTp.getElementType(), dstEnc);
1750  auto convert = rewriter.create<ConvertOp>(tval.getLoc(), dstTp, tval);
1751  op->setOperand(tensor, convert);
1752  rewriter.setInsertionPointAfter(op);
1753  rewriter.create<bufferization::DeallocTensorOp>(tval.getLoc(), convert);
1754  return success();
1755  }
1756  // Cannot be resolved with a single conversion.
1757  // TODO: convert more than one?
1758  return failure();
1759  }
1760 
1761  /// Options to control sparse code generation.
1763 };
1764 
1765 } // namespace
1766 
1767 /// Populates the given patterns list with rewriting rules required for
1768 /// the sparsification of linear algebra operations.
1770  RewritePatternSet &patterns, const SparsificationOptions &options) {
1771  patterns.add<GenericOpSparsifier>(patterns.getContext(), options);
1772 }
static constexpr const bool value
static llvm::ManagedStatic< PassManagerOptions > options
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:696
static bool startLoopSeq(Merger &merger, CodeGen &codegen, OpBuilder &builder, linalg::GenericOp op, unsigned exp, unsigned at, unsigned idx, unsigned ldx, unsigned lts)
Starts a loop sequence at given level.
static void genExpansion(Merger &merger, CodeGen &codegen, OpBuilder &builder, linalg::GenericOp op, unsigned at, bool atStart)
Generates an expanded access pattern in innermost dimension.
static Operation * genWhile(Merger &merger, CodeGen &codegen, OpBuilder &builder, linalg::GenericOp op, unsigned idx, bool needsUniv, ArrayRef< size_t > condTids, ArrayRef< size_t > condDims, ArrayRef< size_t > extraTids, ArrayRef< size_t > extraDims)
Emit a while-loop for co-iteration over multiple indices.
static void updateReduc(Merger &merger, CodeGen &codegen, Value reduc)
Updates scalarized reduction value.
static void genBuffers(Merger &merger, CodeGen &codegen, OpBuilder &builder, linalg::GenericOp op)
Local bufferization of all dense and sparse data structures.
static void addAffineOrderings(std::vector< std::vector< bool >> &adjM, std::vector< unsigned > &inDegree, AffineExpr a, AffineExpr b, Optional< unsigned > fidx, Optional< unsigned > tidx)
Helper method to add all constraints from the indices in one affine expression before all indices in ...
static bool computeIterationGraph(Merger &merger, linalg::GenericOp op, std::vector< unsigned > &topSort, unsigned mask, OpOperand *skip=nullptr)
Computes a topologically sorted iteration graph for the linalg operation.
static void genTensorStore(Merger &merger, CodeGen &codegen, OpBuilder &builder, linalg::GenericOp op, unsigned exp, Value rhs)
Generates a store on a dense or sparse tensor.
static Value genInsertionLoad(CodeGen &codegen, OpBuilder &builder, linalg::GenericOp op, OpOperand *t)
Generates insertion code to implement dynamic tensor load.
static Operation * genFor(Merger &merger, CodeGen &codegen, OpBuilder &builder, linalg::GenericOp op, bool isOuter, bool isInner, unsigned idx, size_t tid, size_t dim, ArrayRef< size_t > extraTids, ArrayRef< size_t > extraDims)
Generates a for-loop on a single index.
static void genInsertionStore(CodeGen &codegen, OpBuilder &builder, linalg::GenericOp op, OpOperand *t, Value rhs)
Generates insertion code to implement dynamic tensor store.
static Value genExp(Merger &merger, CodeGen &codegen, RewriterBase &rewriter, linalg::GenericOp op, unsigned exp, unsigned ldx)
Recursively generates tensor expression.
static void genStmt(Merger &merger, CodeGen &codegen, RewriterBase &rewriter, linalg::GenericOp op, unsigned exp, unsigned at)
Recursively generates code while computing iteration lattices in order to manage the complexity of im...
static Value relinkBranch(CodeGen &codegen, RewriterBase &rewriter, Block *block, Value e, unsigned ldx)
Semi-ring branches are simply inlined by the sparse compiler.
static bool isMaterializing(Value val)
Returns true if tensor materializes uninitialized into the computation.
static void genInitConstantDenseAddress(CodeGen &codegen, RewriterBase &rewriter, linalg::GenericOp op)
static Value genIndex(CodeGen &codegen, linalg::GenericOp op, OpOperand *t)
Generates index for load/store on sparse tensor.
static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op)
Helper method to inspect sparse encodings in the tensor types.
static Value getCustomRedId(Operation *op)
Extracts identity from custom reduce.
static Value genInsertionLoadReduce(Merger &merger, CodeGen &codegen, OpBuilder &builder, linalg::GenericOp op, OpOperand *t)
Generates insertion code to implement dynamic tensor load for reduction.
static Value genInvariantValue(Merger &merger, CodeGen &codegen, OpBuilder &builder, unsigned exp)
Generates an invariant value.
static bool findAffine(Merger &merger, unsigned tensor, unsigned dim, AffineExpr a, DimLevelType dlt, unsigned &filterLdx, bool setLvlFormat=true)
Helper method to inspect affine expressions.
static Operation * genLoop(Merger &merger, CodeGen &codegen, OpBuilder &builder, linalg::GenericOp op, unsigned at, bool needsUniv, ArrayRef< size_t > condTids, ArrayRef< size_t > condDims, ArrayRef< size_t > extraTids, ArrayRef< size_t > extraDims)
Generates a for-loop or a while-loop, depending on whether it implements singleton iteration or co-it...
static Value genIndexValue(CodeGen &codegen, OpBuilder &builder, unsigned idx)
Generates an index value.
static AffineMap permute(const Merger &merger, MLIRContext *context, AffineMap m, ArrayRef< unsigned > topSort)
Helper method to construct a permuted dimension ordering that adheres to the given topological sort.
static void genConstantDenseAddressFromLevel(CodeGen &codegen, OpBuilder &builder, linalg::GenericOp op, unsigned tid, unsigned lvl)
static void endLoopSeq(Merger &merger, CodeGen &codegen, OpBuilder &builder, linalg::GenericOp op, unsigned exp, unsigned at, unsigned idx, unsigned ldx)
Ends a loop sequence at given level.
static Value genSubscript(CodeGen &codegen, OpBuilder &builder, linalg::GenericOp op, OpOperand *t, SmallVectorImpl< Value > &args)
Generates subscript for load/store on a dense or sparse tensor.
static Value genTensorLoad(Merger &merger, CodeGen &codegen, OpBuilder &builder, linalg::GenericOp op, unsigned exp)
Generates a load on a dense or sparse tensor.
static void genInvariants(Merger &merger, CodeGen &codegen, OpBuilder &builder, linalg::GenericOp op, unsigned exp, unsigned ldx, bool atStart, unsigned last=-1u)
Hoists loop invariant tensor loads for which indices have been exhausted.
static void tryLoosenAffineDenseConstraints(linalg::GenericOp op, Optional< unsigned > &fldx, AffineExpr &fa, Optional< unsigned > &tldx, AffineExpr &ta)
static bool isInvariantAffine(AffineExpr a, ArrayRef< unsigned > loopStack, unsigned ldx, bool &atLevel)
Determines if affine expression is invariant.
static Reduction getReduction(Kind kind)
Maps operation to reduction.
static bool isAdmissibleTensorExp(Merger &merger, linalg::GenericOp op, std::vector< unsigned > &topSort, unsigned exp, OpOperand **sparseOut, unsigned &outerParNest)
Returns true when the tensor expression is admissible for codegen.
static unsigned getNumCompoundAffineOnSparseDims(AffineMap affineMap, Value tensor)
Get the total number of compound affine expressions in affineMap that are attached to the given tenso...
static Operation * startLoop(Merger &merger, CodeGen &codegen, OpBuilder &builder, linalg::GenericOp op, unsigned at, unsigned li, bool needsUniv)
Starts a single loop in current sequence.
static void finalizeWhileOp(Merger &merger, CodeGen &codegen, OpBuilder &builder, linalg::GenericOp op, unsigned idx, bool needsUniv, BitVector &induction, scf::WhileOp whileOp)
Generates the induction structure for a while-loop.
static void translateBitsToTidDimPairs(Merger &merger, CodeGen &codegen, linalg::GenericOp op, unsigned li, unsigned idx, SmallVectorImpl< size_t > &condTids, SmallVectorImpl< size_t > &condDims, SmallVectorImpl< size_t > &extraTids, SmallVectorImpl< size_t > &extraDims, SmallVectorImpl< size_t > &affineTids, SmallVectorImpl< size_t > &affineDims, SmallVectorImpl< AffineExpr > &exps)
static bool endLoop(Merger &merger, CodeGen &codegen, RewriterBase &rewriter, linalg::GenericOp op, Operation *loop, unsigned idx, unsigned li, bool needsUniv)
Ends a single loop in current sequence. Returns new values for needsUniv.
static bool topSortOptimal(unsigned n, ArrayRef< utils::IteratorType > iteratorTypes, const Merger &merger, std::vector< unsigned > &topSort, std::vector< unsigned > &inDegree, std::vector< std::vector< bool >> &adjM)
A helper to compute a topological sort.
static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isSparse)
Returns parallelization strategy.
static void genResult(Merger &merger, CodeGen &codegen, RewriterBase &rewriter, linalg::GenericOp op)
Converts the result computed by the sparse kernel into the required form.
static void endIf(Merger &merger, CodeGen &codegen, OpBuilder &builder, linalg::GenericOp op, scf::IfOp ifOp, Operation *loop, Value redInput, Value cntInput, Value insInput)
Generates end of true branch of if-statement within a while-loop.
static Optional< Operation * > genLoopBoundary(CodeGen &codegen, Merger &merger, function_ref< Optional< Operation * >(MutableArrayRef< Value > reduc)> callback)
Generates loop boundary statements (entering/exiting loops).
static scf::IfOp genIf(Merger &merger, CodeGen &codegen, OpBuilder &builder, linalg::GenericOp op, unsigned idx, BitVector &conditions)
Generates a single if-statement within a while-loop.
Affine binary operation expression.
Definition: AffineExpr.h:207
An integer constant appearing in affine expression.
Definition: AffineExpr.h:232
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:216
unsigned getPosition() const
Definition: AffineExpr.cpp:311
Base class for AffineExpr visitors/walkers.
Base type for affine expression.
Definition: AffineExpr.h:68
U cast() const
Definition: AffineExpr.h:291
AffineExprKind getKind() const
Return the classification for this type.
Definition: AffineExpr.cpp:26
bool isa() const
Definition: AffineExpr.h:270
U dyn_cast() const
Definition: AffineExpr.h:281
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:42
unsigned getNumDims() const
Definition: AffineMap.cpp:306
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:319
unsigned getNumResults() const
Definition: AffineMap.cpp:314
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:323
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:206
Block represents an ordered list of Operations.
Definition: Block.h:30
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
IntegerType getI1Type()
Definition: Builders.cpp:58
IndexType getIndexType()
Definition: Builders.cpp:56
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:137
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:64
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:56
This class helps build Operations.
Definition: Builders.h:198
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:383
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:388
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:422
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:364
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:394
This class represents an operand of an operation.
Definition: Value.h:247
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:212
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:31
void setOperand(unsigned idx, Value value)
Definition: Operation.h:268
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:324
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:605
MLIRContext * getContext() const
Return the MLIRContext used to create this pattern.
Definition: PatternMatch.h:132
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:398
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
U cast() const
Definition: Types.h:280
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:349
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
Type getType() const
Return the type of this value.
Definition: Value.h:114
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
A class to handle all iteration lattice operations.
Definition: Merger.h:149
void setHasSparseOut(bool s)
Definition: Merger.h:333
Optional< unsigned > getLoopIdx(unsigned t, unsigned dim) const
Definition: Merger.h:294
bool isFilterLoop(unsigned ldx) const
Definition: Merger.h:269
unsigned optimizeSet(unsigned s0)
Optimizes the iteration lattice points in the given set.
Definition: Merger.cpp:226
bool latGT(unsigned i, unsigned j) const
Returns true if Li > Lj.
Definition: Merger.cpp:294
unsigned buildLattices(unsigned e, unsigned i)
Builds the iteration lattices in a bottom-up traversal given the remaining tensor (sub)expression and...
Definition: Merger.cpp:638
unsigned getFilterLoopStartingIdx() const
Get the starting filter loop index.
Definition: Merger.h:256
LatPoint & lat(unsigned l)
Definition: Merger.h:340
unsigned getNumLoops() const
Get the number of total loops (native loops + filter loops).
Definition: Merger.h:250
unsigned getNumFilterLoops() const
Get the number of filter loops.
Definition: Merger.h:254
void foreachTidDimPairInBits(const BitVector &bits, function_ref< void(unsigned b, unsigned tid, Optional< unsigned > dim, DimLevelType dlt)> cb)
Definition: Merger.h:323
DimLevelType getDimLevelType(unsigned t, unsigned i) const
Gets the dimension level type of the tth tensor on ith loop.
Definition: Merger.h:284
TensorExp & exp(unsigned e)
Convenience getters to immediately access the stored nodes.
Definition: Merger.h:339
SmallVector< unsigned > & set(unsigned s)
Definition: Merger.h:341
void setDimAndDimLevelType(unsigned t, unsigned i, unsigned dim, DimLevelType dlt)
Sets the dimension and dimension level type of the tth tensor on ith loop.
Definition: Merger.h:312
unsigned getOutTensorID() const
Gets tensor ID for the output tensor.
Definition: Merger.h:264
unsigned tensor(unsigned b) const
Bit translation (get tensor ID).
Definition: Merger.h:245
Optional< unsigned > getDimNum(unsigned t, unsigned i) const
Gets the dimension number of the the tth tensor on ith loop.
Definition: Merger.h:300
Value buildExp(RewriterBase &rewriter, Location loc, unsigned e, Value v0, Value v1)
Rebuilds SSA format from a tensor expression.
Definition: Merger.cpp:1122
bool hasAnySparse(const BitVector &bits) const
Returns true if any set bit corresponds to sparse dimension level type.
Definition: Merger.cpp:397
unsigned index(unsigned b) const
Bit translation (get loop index).
Definition: Merger.h:247
Optional< unsigned > buildTensorExpFromLinalg(linalg::GenericOp op)
Builds a tensor expression from the given Linalg operation.
Definition: Merger.cpp:834
bool isSingleCondition(unsigned t, unsigned e) const
Returns true if given tensor iterates only in the given tensor expression.
Definition: Merger.cpp:313
Value getLoopIV(size_t level) const
Gets loop induction variable at the given level.
Definition: CodegenUtils.h:437
unsigned getCurrentDepth() const
Gets loop induction variable at the given level.
Definition: CodegenUtils.h:434
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
Definition: Utils.cpp:194
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:245
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
Definition: CodegenUtils.h:223
Kind
Tensor expression kind.
Definition: Merger.h:25
constexpr bool isSingletonDLT(DimLevelType dlt)
Check if the DimLevelType is singleton (regardless of properties).
Definition: Enums.h:191
constexpr bool isDenseDLT(DimLevelType dlt)
Check if the DimLevelType is dense.
Definition: Enums.h:176
Value constantI1(OpBuilder &builder, Location loc, bool b)
Generates a constant of i1 type.
Definition: CodegenUtils.h:265
constexpr bool isUndefDLT(DimLevelType dlt)
Check if the DimLevelType is the special undefined value.
Definition: Enums.h:171
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
DimLevelType
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition: Enums.h:147
DimLevelType getDimLevelType(const SparseTensorEncodingAttr &enc, uint64_t d)
Definition: SparseTensor.h:49
constexpr bool isCompressedDLT(DimLevelType dlt)
Check if the DimLevelType is compressed (regardless of properties).
Definition: Enums.h:185
uint64_t toOrigDim(const SparseTensorEncodingAttr &enc, uint64_t d)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ DimId
Dimensional identifier.
@ Constant
Constant integer.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void populateSparsificationPatterns(RewritePatternSet &patterns, const SparsificationOptions &options=SparsificationOptions())
Sets up sparsification rewriting rules with the given options.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:356
Options for the Sparsification pass.
Definition: Passes.h:51
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
Definition: PatternMatch.h:327
BitVector bits
Conjunction of tensor loop indices as bitvector.
Definition: Merger.h:134
BitVector simple
Simplified conjunction of tensor loop indices as bitvector.
Definition: Merger.h:139
unsigned exp
Index of the tensor expression.
Definition: Merger.h:142
Value val
Direct link to IR for an invariant or the destination value (to infer destination type) of a cast ope...
Definition: Merger.h:116
Children children
Tensor operations hold the indices of their children.
Definition: Merger.h:110
Kind kind
Tensor expression kind.
Definition: Merger.h:100
unsigned index
Indices hold the index number.
Definition: Merger.h:107
unsigned tensor
Expressions representing tensors simply have a tensor number.
Definition: Merger.h:104
Operation * op
Code blocks used by semirings.
Definition: Merger.h:122
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.