MLIR  19.0.0git
Merger.cpp
Go to the documentation of this file.
1 //===- Merger.cpp - Implementation of iteration lattices ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
14 
15 #include "mlir/IR/Operation.h"
16 #include "llvm/Support/Debug.h"
17 #include <optional>
18 
19 namespace mlir {
20 namespace sparse_tensor {
21 
22 enum class ExpArity {
23  kNullary,
24  kUnary,
25  kBinary,
26 };
27 
29  switch (k) {
30  // Leaf.
35  return ExpArity::kNullary;
71  return ExpArity::kUnary;
72  // Binary operations.
96  case TensorExp::Kind::kDenseOp: // kDenseOp can *at most* have two operands
97  return ExpArity::kBinary;
98  }
99  llvm_unreachable("unexpected kind");
100 }
101 
102 //===----------------------------------------------------------------------===//
103 // Constructors.
104 //===----------------------------------------------------------------------===//
105 
107  Operation *o, Attribute a)
108  : kind(k), val(v), op(o), attr(a) {
109  switch (kind) {
110  // Leaf.
112  assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o);
113  tensor = x;
114  return;
116  assert(x == detail::kInvalidId && y == detail::kInvalidId && !v && !o);
117  return;
119  assert(x == detail::kInvalidId && y == detail::kInvalidId && v && !o);
120  return;
122  assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o);
123  loop = x;
124  return;
125  // Unary operations.
147  assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o);
148  children.e0 = x;
149  children.e1 = y;
150  return;
162  assert(x != detail::kInvalidId && y == detail::kInvalidId && v && !o);
163  children.e0 = x;
164  children.e1 = y;
165  return;
168  assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && o);
169  children.e0 = x;
170  children.e1 = y;
171  return;
173  // No assertion on y can be made, as the branching paths involve both
174  // a unary (`mapSet`) and binary (`disjSet`) pathway.
175  assert(x != detail::kInvalidId && !v && o);
176  children.e0 = x;
177  children.e1 = y;
178  return;
179  // Binary operations.
199  assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && !o);
200  children.e0 = x;
201  children.e1 = y;
202  return;
205  assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && !o);
206  children.e0 = x;
207  children.e1 = y;
208  return;
211  assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && o);
212  children.e0 = x;
213  children.e1 = y;
214  return;
216  assert(x != detail::kInvalidId && !v && o);
217  children.e0 = x;
218  children.e1 = y;
219  return;
220  }
221  llvm_unreachable("unexpected kind");
222 }
223 
224 Merger::Merger(unsigned numInputOutputTensors, unsigned numLoops,
225  unsigned maxLvlRank)
226  : outTensor(numInputOutputTensors - 1),
227  syntheticTensor(numInputOutputTensors),
228  numTensors(numInputOutputTensors + 1), numLoops(numLoops),
229  hasSparseOut(false),
230  lvlTypes(numTensors,
231  std::vector<LevelType>(numLoops, LevelFormat::Undef)),
232  loopToLvl(numTensors,
233  std::vector<std::optional<Level>>(numLoops, std::nullopt)),
234  lvlToLoop(numTensors,
235  std::vector<std::optional<LoopId>>(maxLvlRank, std::nullopt)),
236  loopToUnresolvedLvls(numLoops, std::vector<std::optional<LvlLTPair>>(
237  numTensors, std::nullopt)),
238  levelToDependentLoop(numTensors,
239  std::vector<std::vector<LoopCoeffPair>>(
240  maxLvlRank, std::vector<LoopCoeffPair>())),
241  loopBounds(numLoops, std::make_pair(numTensors, numLoops)) {}
242 
243 //===----------------------------------------------------------------------===//
244 // Lattice methods.
245 //===----------------------------------------------------------------------===//
246 
248  assert(isValidTensorId(t));
249  const ExprId eNew(tensorExps.size());
250  tensorExps.emplace_back(TensorExp::Kind::kTensor, t, detail::kInvalidId,
251  Value(), nullptr, nullptr);
252  return eNew;
253 }
254 
256  assert(isValidLoopId(i));
257  const ExprId eNew(tensorExps.size());
258  tensorExps.emplace_back(TensorExp::Kind::kLoopVar, i, detail::kInvalidId,
259  Value(), nullptr, nullptr);
260  return eNew;
261 }
262 
264  const ExprId eNew(tensorExps.size());
265  tensorExps.emplace_back(TensorExp::Kind::kInvariant, detail::kInvalidId,
266  detail::kInvalidId, v, nullptr, nullptr);
267  return eNew;
268 }
269 
271  const ExprId eNew(tensorExps.size());
272  tensorExps.emplace_back(TensorExp::Kind::kSynZero, detail::kInvalidId,
273  detail::kInvalidId, Value(), nullptr, nullptr);
274  return eNew;
275 }
276 
278  Attribute attr) {
279  assert(k > TensorExp::Kind::kLoopVar);
280  const ExprId eNew(tensorExps.size());
281  tensorExps.emplace_back(k, e0, e1, Value(), op, attr);
282  return eNew;
283 }
284 
286  Attribute attr) {
287  assert(k > TensorExp::Kind::kLoopVar);
288  const ExprId eNew(tensorExps.size());
289  tensorExps.emplace_back(k, e, detail::kInvalidId, v, op, attr);
290  return eNew;
291 }
292 
294  const LatPointId pNew(latPoints.size());
295  const unsigned size = numLoops * numTensors;
296  const TensorLoopId b = makeTensorLoopId(t, i);
297  latPoints.emplace_back(size, e);
298  latPoints[pNew].bits.set(b);
299  return pNew;
300 }
301 
302 LatPointId Merger::addLat(const BitVector &bits, ExprId e) {
303  assert(bits.size() == numLoops * numTensors);
304  const LatPointId pNew(latPoints.size());
305  latPoints.emplace_back(bits, e);
306  return pNew;
307 }
308 
310  const LatSetId sNew(latSets.size());
311  latSets.emplace_back();
312  return sNew;
313 }
314 
316  Operation *op) {
317  TensorExp::Kind kind = exp(e).kind;
318  Attribute attr = exp(e).attr;
319  const LatPointId pNew(latPoints.size());
320  const auto &point0 = lat(p0);
321  const auto &point1 = lat(p1);
322  BitVector bits(point0.bits);
323  bits |= point1.bits;
324  const ExprId ne = addExp(kind, point0.exp, point1.exp, op, attr);
325  latPoints.emplace_back(bits, ne);
326  return pNew;
327 }
328 
330  const LatSetId sNew = addSet();
331  auto &setNew = latSets[sNew];
332  for (const LatPointId p0 : set(s0))
333  for (const LatPointId p1 : set(s1))
334  setNew.push_back(conjLat(e, p0, p1, op));
335  return sNew;
336 }
337 
339  const LatSetId sNew = conjSet(e, s0, s1, op);
340  TensorExp::Kind kind = exp(e).kind;
341  // Followed by all in s0.
342  latSets[sNew].append(latSets[s0]);
343  // Map binary 0-y to unary -y.
344  // TODO: move this if-else logic into buildLattices
345  if (kind == TensorExp::Kind::kSubF)
346  s1 = mapSet(TensorExp::Kind::kNegF, s1);
347  else if (kind == TensorExp::Kind::kSubC)
348  s1 = mapSet(TensorExp::Kind::kNegC, s1);
349  else if (kind == TensorExp::Kind::kSubI)
350  s1 = mapSet(TensorExp::Kind::kNegI, s1);
351  // Followed by all in s1.
352  latSets[sNew].append(latSets[s1]);
353  return sNew;
354 }
355 
357  assert(exp(e).kind == TensorExp::Kind::kCmpI ||
358  exp(e).kind == TensorExp::Kind::kCmpF);
359  const LatSetId sNew = conjSet(e, s0, s1, nullptr);
360 
361  ExprId e0 = exp(e).children.e0;
362  ExprId e1 = exp(e).children.e1;
363  if (exp(e0).kind == TensorExp::Kind::kSynZero ||
364  exp(e1).kind == TensorExp::Kind::kSynZero) {
365  // lhs and rhs can't be synthetic zero at the same time.
366  assert(exp(e0).kind != exp(e1).kind);
367  // If one of the operands has already been assigned to zero (the
368  // element is absent in the corresponding operand), then we do not
369  // need to build disjunctive set for it.
370  return sNew;
371  }
372 
373  auto lhsSet = mapBinWithSynZeroSet(e, s0, false);
374  auto rhsSet = mapBinWithSynZeroSet(e, s1, true);
375  latSets[sNew].append(latSets[lhsSet]);
376  latSets[sNew].append(latSets[rhsSet]);
377  return sNew;
378 }
379 
381  bool includeLeft, TensorExp::Kind ltrans,
382  Operation *opleft, bool includeRight,
383  TensorExp::Kind rtrans, Operation *opright) {
384  Attribute a = exp(e).attr;
385  const LatSetId sNew = conjSet(e, s0, s1, orig);
386  // Left Region.
387  if (includeLeft) {
388  if (opleft)
389  s0 = mapSet(ltrans, s0, Value(), opleft, a);
390  latSets[sNew].append(latSets[s0]);
391  }
392  // Right Region.
393  if (includeRight) {
394  if (opright)
395  s1 = mapSet(rtrans, s1, Value(), opright, a);
396  latSets[sNew].append(latSets[s1]);
397  }
398  return sNew;
399 }
400 
402  Operation *op, Attribute a) {
403  assert((TensorExp::Kind::kAbsF <= kind && kind <= TensorExp::Kind::kSelect) ||
404  TensorExp::Kind::kDenseOp == kind);
405  const LatSetId sNew = addSet();
406  auto &setNew = latSets[sNew];
407  for (const LatPointId p : set(s0)) {
408  const auto &point = latPoints[p];
409  setNew.push_back(addLat(point.bits, addExp(kind, point.exp, v, op, a)));
410  }
411  return sNew;
412 }
413 
415  TensorExp::Kind kind = exp(e).kind;
416  Attribute a = exp(e).attr;
417  assert(TensorExp::Kind::kMulF <= kind && kind <= TensorExp::Kind::kShlI);
418  // Must be a binary operation.
419  const LatSetId sNew = addSet();
420  auto &setNew = latSets[sNew];
421  const ExprId zeroExp = addSynZeroExp();
422  for (const LatPointId p : set(s0)) {
423  const auto &point = latPoints[p];
424  ExprId newExp = lhsZero ? addExp(kind, zeroExp, point.exp, nullptr, a)
425  : addExp(kind, point.exp, zeroExp, nullptr, a);
426  setNew.push_back(addLat(point.bits, newExp));
427  }
428  return sNew;
429 }
430 
432  const LatSetId sNew = addSet();
433  auto &setNew = latSets[sNew];
434  const auto &set0 = set(s0);
435  assert(!set0.empty());
436  const LatPointId p0 = set0[0];
437  for (const LatPointId p1 : set0) {
438  bool add = true;
439  if (p0 != p1) {
440  // Check whether this is a straightforward copy.
441  if (expIsTensor(latPoints[p1].exp, outTensor))
442  continue;
443  // Check whether this conjunction is already covered.
444  for (const LatPointId p2 : setNew) {
445  assert(!latGT(p1, p2)); // Lj => Li would be bad
446  if (onlyDenseDiff(p2, p1)) {
447  add = false;
448  break;
449  }
450  }
451  assert(!add || latGT(p0, p1));
452  }
453  if (add)
454  setNew.push_back(p1);
455  }
456  for (const LatPointId p : setNew)
457  latPoints[p].simple = simplifyCond(sNew, p);
458  return sNew;
459 }
460 
462  // First determine if this lattice point is a *singleton*, i.e.,
463  // the last point in a lattice, no other is less than this one.
464  bool isSingleton = true;
465  for (const LatPointId p1 : set(s0)) {
466  if (p0 != p1 && latGT(p0, p1)) {
467  isSingleton = false;
468  break;
469  }
470  }
471 
472  BitVector simple(latPoints[p0].bits);
473  bool reset = isSingleton && hasAnySparse(simple);
474  const TensorLoopId be = simple.size();
475  TensorLoopId offset = 0; // relative to the end
476  if (!reset)
477  // Starts resetting from a dense level, so that the first bit (if kept)
478  // is not undefined level-type.
479  for (unsigned b = 0; b < be; b++) {
480  if (simple[b] && getLvlType(TensorLoopId{b}).hasDenseSemantic()) {
481  offset = be - b - 1; // relative to the end
482  break;
483  }
484  }
485 
486  // Now apply the two basic rules. We also iterate the bits reversely to always
487  // keep the rightmost bit (which could possibly be a synthetic tensor).
488  for (unsigned b = be - 1 - offset, i = 0; i < be;
489  b = b == 0 ? be - 1 : b - 1, i++) {
490  // Slice on dense level has `locate` property as well, and can be optimized.
491  if (simple[b] && !isSparseLvlWithNonTrivialIdxExp(b)) {
492  const auto lt = getLvlType(b);
493  if (!lt.hasSparseSemantic()) {
494  if (reset)
495  simple.reset(b);
496  reset = true;
497  }
498  }
499  }
500  return simple;
501 }
502 
504  const BitVector &bitsi = lat(i).bits;
505  const BitVector &bitsj = lat(j).bits;
506  assert(bitsi.size() == bitsj.size());
507  if (bitsi.count() > bitsj.count()) {
508  for (TensorLoopId b = 0, be = bitsj.size(); b < be; b++)
509  if (bitsj[b] && !bitsi[b])
510  return false;
511  return true;
512  }
513  return false;
514 }
515 
517  BitVector tmp(latPoints[j].bits);
518  tmp ^= latPoints[i].bits;
519  return !hasAnySparse(tmp);
520 }
521 
523  const auto &expr = exp(e);
524  // First we check `expIsTensor`.
525  if (expr.kind == TensorExp::Kind::kTensor)
526  return expr.tensor == t;
527 
528  switch (getExpArity(expr.kind)) {
529  case ExpArity::kNullary:
530  return false;
531  case ExpArity::kUnary: {
532  const ExprId e0 = expr.children.e0;
533  return expContainsTensor(e0, t);
534  }
535  case ExpArity::kBinary: {
536  const ExprId e0 = expr.children.e0;
537  const ExprId e1 = expr.children.e1;
538  return expContainsTensor(e0, t) || expContainsTensor(e1, t);
539  }
540  }
541  llvm_unreachable("unexpected arity");
542 }
543 
545  const auto &expr = exp(e);
546  switch (expr.kind) {
550  return expContainsTensor(expr.children.e0, outTensor);
554  return expContainsTensor(expr.children.e1, outTensor) ||
555  hasNegateOnOut(expr.children.e0);
557  bool lhsNeg = hasNegateOnOut(expr.children.e0);
558  if (!lhsNeg && expr.children.e1 != detail::kInvalidId)
559  return hasNegateOnOut(expr.children.e1);
560  return lhsNeg;
561  }
562  default: {
563  switch (getExpArity(expr.kind)) {
564  case ExpArity::kNullary:
565  return false;
566  case ExpArity::kUnary:
567  return hasNegateOnOut(expr.children.e0);
568  case ExpArity::kBinary:
569  return hasNegateOnOut(expr.children.e0) ||
570  hasNegateOnOut(expr.children.e1);
571  }
572  }
573  }
574  llvm_unreachable("unexpected kind");
575 }
576 
578  assert(isValidTensorId(t));
579  const auto &expr = exp(e);
580  switch (expr.kind) {
581  // Leaf.
583  return expr.tensor == t;
587  return false;
588  // Unary operations.
622  return isSingleCondition(t, expr.children.e0);
625  return false;
626  // Binary operations.
627  case TensorExp::Kind::kDivF: // note: x / c only
631  assert(!maybeZero(expr.children.e1));
632  return isSingleCondition(t, expr.children.e0);
633  case TensorExp::Kind::kShrS: // note: x >> inv only
636  assert(isInvariant(expr.children.e1));
637  return isSingleCondition(t, expr.children.e0);
643  if (isSingleCondition(t, expr.children.e0))
644  return isSingleCondition(t, expr.children.e1) ||
645  isInvariant(expr.children.e1);
646  if (isSingleCondition(t, expr.children.e1))
647  return isInvariant(expr.children.e0);
648  return false;
652  return isSingleCondition(t, expr.children.e0) &&
653  isSingleCondition(t, expr.children.e1);
662  return false;
664  // Since Merger guarantees all the operands of the kDenseOp to be dense, the
665  // operation must be single-condition.
666  return true;
667  }
668  llvm_unreachable("unexpected kind");
669 }
670 
671 bool Merger::hasAnySparse(const BitVector &bits) const {
672  for (TensorLoopId b : bits.set_bits()) {
673  const auto lt = getLvlType(b);
674  if (lt.hasSparseSemantic())
675  return true;
676  }
677  return hasSparseIdxReduction(bits);
678 }
679 
680 bool Merger::hasSparseIdxReduction(const BitVector &bits) const {
681  for (TensorLoopId b : bits.set_bits())
683  return true;
684  return false;
685 }
686 
687 #ifndef NDEBUG
688 
689 //===----------------------------------------------------------------------===//
690 // Print methods (for debugging).
691 //===----------------------------------------------------------------------===//
692 
693 static const char *kindToOpSymbol(TensorExp::Kind kind) {
694  switch (kind) {
695  // Leaf.
697  return "tensor";
699  return "invariant";
701  return "index";
703  return "0";
704  // Unary operations.
708  return "abs";
710  return "ceil";
712  return "floor";
715  return "sqrt";
718  return "expm1";
721  return "log1p";
723  return "relu";
726  return "sin";
729  return "tanh";
733  return "-";
745  return "complex.im";
747  return "complex.re";
749  return "cast";
751  return "binary_branch";
753  return "unary";
755  return "select";
756  // Binary operations.
760  return "*";
765  return "/";
769  return "+";
773  return "-";
775  return "&";
777  return "|";
779  return "^";
781  return "a>>";
783  return ">>";
785  return "<<";
788  return "cmp";
790  return "binary";
792  return "reduce";
794  return "dense";
795  }
796  llvm_unreachable("unexpected kind for symbol");
797 }
798 
799 void Merger::dumpExp(ExprId e) const {
800  const auto &expr = exp(e);
801  switch (expr.kind) {
802  // Leaf.
804  if (expr.tensor == syntheticTensor)
805  llvm::dbgs() << "synthetic_";
806  else if (expr.tensor == outTensor)
807  llvm::dbgs() << "output_";
808  llvm::dbgs() << "tensor_" << expr.tensor;
809  break;
811  llvm::dbgs() << "invariant";
812  break;
814  llvm::dbgs() << "0";
815  break;
817  llvm::dbgs() << "loopvar_" << expr.loop;
818  break;
819  // Unary operations.
855  llvm::dbgs() << kindToOpSymbol(expr.kind) << " ";
856  dumpExp(expr.children.e0);
857  break;
858  // Binary operations.
883  llvm::dbgs() << "(";
884  dumpExp(expr.children.e0);
885  llvm::dbgs() << " " << kindToOpSymbol(expr.kind);
886  if (expr.attr)
887  llvm::dbgs() << "{" << expr.attr << "}";
888  if (expr.children.e1 != detail::kInvalidId) {
889  llvm::dbgs() << " ";
890  dumpExp(expr.children.e1);
891  llvm::dbgs() << ")";
892  } else {
893  assert(expr.kind == TensorExp::Kind::kDenseOp);
894  }
895  break;
896  }
897 }
898 
900  const auto &point = lat(p);
901  llvm::dbgs() << "lat(";
902  dumpBits(point.bits);
903  llvm::dbgs() << " :";
904  dumpBits(point.simple);
905  llvm::dbgs() << " : ";
906  dumpExp(point.exp);
907  llvm::dbgs() << " )\n";
908 }
909 
910 void Merger::dumpSet(LatSetId s) const {
911  const auto &ss = set(s);
912  llvm::dbgs() << "{ #" << ss.size() << "\n";
913  for (const LatPointId p : ss) {
914  llvm::dbgs() << " ";
915  dumpLat(p);
916  }
917  llvm::dbgs() << "}\n";
918 }
919 
920 void Merger::dumpBits(const BitVector &bits) const {
921  for (TensorLoopId b = 0, be = bits.size(); b < be; b++) {
922  if (bits[b]) {
923  const TensorId t = tensor(b);
924  const LoopId i = loop(b);
925  const auto lt = lvlTypes[t][i];
927  llvm::dbgs() << " DEP_" << t << "_" << i;
928  else
929  llvm::dbgs() << " i_" << t << "_" << i << "_" << toMLIRString(lt);
930  }
931  }
932 }
933 
934 #endif // NDEBUG
935 
936 //===----------------------------------------------------------------------===//
937 // Builder methods.
938 //===----------------------------------------------------------------------===//
939 
941  // NOTE: The `expr` reference will be invalidated by recursive calls
942  // (and any other method that may add new expressions); therefore, the
943  // code below must make sure to copy fields of `expr` into local variables
944  // before making any recursive calls.
945  const auto &expr = exp(e);
946  const TensorExp::Kind kind = expr.kind;
947  switch (kind) {
948  // Leaf.
953  // Either the loop-var is really used in the tensor expression, or it is
954  // set to the undefined loop-var in that level. An invariant expression,
955  // a proper index value, and a truly dynamic sparse output tensor are set
956  // to a synthetic tensor with undefined indices only to ensure the
957  // iteration space is not skipped as a result of their contents.
958  const LatSetId s = addSet();
959  TensorId t = syntheticTensor;
960  if (kind == TensorExp::Kind::kTensor) {
961  t = expr.tensor;
962  if (hasSparseOut && t == outTensor)
963  t = syntheticTensor;
964  }
965  latSets[s].push_back(addLat(t, i, e));
966  return s;
967  }
968  // Unary operations.
1001  // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the
1002  // lattice set of the operand through the operator into a new set.
1003  //
1004  // -y|!y | y |
1005  // --+---+---+
1006  // | 0 |-y |
1007  {
1008  const ExprId e0 = expr.children.e0;
1009  const Value v = expr.val;
1010  Attribute a = expr.attr;
1011  return mapSet(kind, buildLattices(e0, i), v, nullptr, a);
1012  }
1015  // The left or right half of a binary operation which has already
1016  // been split into separate operations for each region.
1017  {
1018  const ExprId e0 = expr.children.e0;
1019  Operation *const op = expr.op;
1020  return mapSet(kind, buildLattices(e0, i), Value(), op);
1021  }
1023  // A custom unary operation.
1024  //
1025  // op y| !y | y |
1026  // ----+----------+------------+
1027  // | absent() | present(y) |
1028  {
1029  const ExprId e0 = expr.children.e0;
1030  UnaryOp unop = cast<UnaryOp>(expr.op);
1031  const LatSetId child0 = buildLattices(e0, i);
1032  Region &absentRegion = unop.getAbsentRegion();
1033  if (absentRegion.empty()) {
1034  // Simple mapping over existing values.
1035  return mapSet(kind, child0, Value(), unop);
1036  }
1037  // Use a disjunction with `unop` on the left and the absent value as an
1038  // invariant on the right.
1039  Block &absentBlock = absentRegion.front();
1040  YieldOp absentYield = cast<YieldOp>(absentBlock.getTerminator());
1041  const Value absentVal = absentYield.getSingleResult();
1042  const ExprId rhs = addInvariantExp(absentVal);
1043  return disjSet(e, child0, buildLattices(rhs, i), unop);
1044  }
1045  // Binary operations.
1050  // A multiplicative operation only needs to be performed
1051  // for the conjunction of sparse iteration spaces.
1052  //
1053  // x*y|!y | y |
1054  // ---+---+---+
1055  // !x | 0 | 0 |
1056  // x | 0 |x*y|
1057  //
1058  // Note even here, 0*NaN=NaN and 0*Inf=NaN, but that is ignored.
1059  {
1060  const ExprId e0 = expr.children.e0;
1061  const ExprId e1 = expr.children.e1;
1062  return conjSet(e, buildLattices(e0, i), buildLattices(e1, i));
1063  }
1068  // A division is tricky, since 0/0, 0/c, c/0 all have
1069  // specific outcomes for floating-point and integers.
1070  // Thus, we need to traverse the full iteration space.
1071  //
1072  // x/y|!y | y |
1073  // ---+---+---+
1074  // !x |0/0|0/y| FP: 0/0=NaN,c/0=Inf,0/c=0 with c true nonzero
1075  // x |x/0|x/y| INT: x/0=exception for any x
1076  //
1077  // TODO: for now we "fixed" this by only accepting x/c cases
1078  // during expression building, so that the conjunction
1079  // rules applies (viz. x/c = x*(1/c) as far as lattice
1080  // construction is concerned).
1081  {
1082  const ExprId e0 = expr.children.e0;
1083  const ExprId e1 = expr.children.e1;
1084  assert(!maybeZero(e1));
1085  return conjSet(e, buildLattices(e0, i), buildLattices(e1, i));
1086  }
1093  case TensorExp::Kind::kOrI:
1095  // An additive operation needs to be performed
1096  // for the disjunction of sparse iteration spaces.
1097  //
1098  // x+y|!y | y | x-y|!y | y |
1099  // ---+---+---+ ---+---+---+
1100  // !x | 0 | y | !x | 0 |-y |
1101  // x | x |x+y| x | x |x-y|
1102  {
1103  const ExprId e0 = expr.children.e0;
1104  const ExprId e1 = expr.children.e1;
1105  return disjSet(e, buildLattices(e0, i), buildLattices(e1, i));
1106  }
1109  // A comparison operation needs to be performed
1110  // for the disjunction of sparse iteration spaces.
1111  //
1112  // x < y | !y | y |
1113  // -------+-------+-------+
1114  // !x | 0 | 0 < y |
1115  // x | x < 0 | x < y |
1116  {
1117  const ExprId e0 = expr.children.e0;
1118  const ExprId e1 = expr.children.e1;
1119  return disjSetWithZero(e, buildLattices(e0, i), buildLattices(e1, i));
1120  }
1124  // A shift operation by an invariant amount (viz. tensor expressions
1125  // can only occur at the left-hand-side of the operator) can be handled
1126  // with the conjunction rule.
1127  {
1128  const ExprId e0 = expr.children.e0;
1129  const ExprId e1 = expr.children.e1;
1130  assert(isInvariant(e1));
1131  return conjSet(e, buildLattices(e0, i), buildLattices(e1, i));
1132  }
1134  // A custom binary operation.
1135  //
1136  // x op y| !y | y |
1137  // ------+---------+--------------+
1138  // !x | empty | right(y) |
1139  // x | left(x) | overlap(x,y) |
1140  {
1141  const ExprId e0 = expr.children.e0;
1142  const ExprId e1 = expr.children.e1;
1143  BinaryOp binop = cast<BinaryOp>(expr.op);
1144  const LatSetId child0 = buildLattices(e0, i);
1145  const LatSetId child1 = buildLattices(e1, i);
1146  Region &leftRegion = binop.getLeftRegion();
1147  Region &rightRegion = binop.getRightRegion();
1148  // Left Region.
1149  Operation *leftYield = nullptr;
1150  if (!leftRegion.empty()) {
1151  Block &leftBlock = leftRegion.front();
1152  leftYield = leftBlock.getTerminator();
1153  }
1154  // Right Region.
1155  Operation *rightYield = nullptr;
1156  if (!rightRegion.empty()) {
1157  Block &rightBlock = rightRegion.front();
1158  rightYield = rightBlock.getTerminator();
1159  }
1160  bool includeLeft = binop.getLeftIdentity() || !leftRegion.empty();
1161  bool includeRight = binop.getRightIdentity() || !rightRegion.empty();
1162  return combiSet(e, child0, child1, binop, includeLeft,
1163  TensorExp::Kind::kBinaryBranch, leftYield, includeRight,
1164  TensorExp::Kind::kBinaryBranch, rightYield);
1165  }
1167  // A custom reduce operation.
1168  {
1169  const ExprId e0 = expr.children.e0;
1170  const ExprId e1 = expr.children.e1;
1171  Operation *const op = expr.op;
1172  return conjSet(e, buildLattices(e0, i), buildLattices(e1, i), op);
1173  }
1175  // It does not really matter whether we use conjunctive/disjunctive set
1176  // here, as all the operands of kDenseOp must be dense, the disjunctive set
1177  // will be optimized into conjunctive set eventually.
1178  if (expr.children.e1 == detail::kInvalidId) {
1179  const ExprId e0 = expr.children.e0;
1180  Operation *const op = expr.op;
1181  return mapSet(kind, buildLattices(e0, i), Value(), op);
1182  }
1183 
1184  const ExprId e0 = expr.children.e0;
1185  const ExprId e1 = expr.children.e1;
1186  Operation *const op = expr.op;
1187  return conjSet(e, buildLattices(e0, i), buildLattices(e1, i), op);
1188  }
1189  }
1190  llvm_unreachable("unexpected expression kind");
1191 }
1192 
1193 std::optional<ExprId> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
1194  // Build the linalg semantics backward from yield.
1195  Operation *yield = op.getRegion().front().getTerminator();
1196  assert(isa<linalg::YieldOp>(yield));
1197  return buildTensorExp(op, yield->getOperand(0)).first;
1198 }
1199 
1200 /// Only returns true if we are certain this is a zero.
1201 static bool isCertainZero(Value val) {
1202  if (auto c = val.getDefiningOp<complex::ConstantOp>()) {
1203  ArrayAttr arrayAttr = c.getValue();
1204  return cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
1205  cast<FloatAttr>(arrayAttr[1]).getValue().isZero();
1206  }
1207  if (auto c = val.getDefiningOp<arith::ConstantIntOp>())
1208  return c.value() == 0;
1209  if (auto c = val.getDefiningOp<arith::ConstantFloatOp>())
1210  return c.value().isZero();
1211  return false;
1212 }
1213 
1214 /// Only returns false if we are certain this is a nonzero.
1215 bool Merger::maybeZero(ExprId e) const {
1216  const auto &expr = exp(e);
1217  if (expr.kind == TensorExp::Kind::kInvariant) {
1218  // Note that this is different from isCertainZero() in a subtle
1219  // way by always returning true for non-constants.
1220  if (auto c = expr.val.getDefiningOp<complex::ConstantOp>()) {
1221  ArrayAttr arrayAttr = c.getValue();
1222  return cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
1223  cast<FloatAttr>(arrayAttr[1]).getValue().isZero();
1224  }
1225  if (auto c = expr.val.getDefiningOp<arith::ConstantIntOp>())
1226  return c.value() == 0;
1227  if (auto c = expr.val.getDefiningOp<arith::ConstantFloatOp>())
1228  return c.value().isZero();
1229  }
1230  return true;
1231 }
1232 
1233 Type Merger::inferType(ExprId e, Value src) const {
1234  // Obtain the destination type from the cast node.
1235  Type dtp = exp(e).val.getType();
1236  // Inspect source type. For vector types, apply the same
1237  // vectorization to the destination type.
1238  if (auto vtp = dyn_cast<VectorType>(src.getType()))
1239  return VectorType::get(vtp.getNumElements(), dtp, vtp.getScalableDims());
1240  return dtp;
1241 }
1242 
1243 /// Ensures that the sparsifier can generate code for expression.
1244 static bool isAdmissibleBranchExp(Operation *op, Block *block, Value v) {
1245  // Arguments are always admissible.
1246  if (isa<BlockArgument>(v))
1247  return true;
1248  // Accept index anywhere.
1249  Operation *def = v.getDefiningOp();
1250  if (isa<linalg::IndexOp>(def))
1251  return true;
1252  // Operation defined outside branch.
1253  if (def->getBlock() != block)
1254  return def->getBlock() != op->getBlock(); // invariant?
1255  // Operation defined within branch. Anything is accepted,
1256  // as long as all subexpressions are admissible.
1257  for (unsigned i = 0, n = def->getNumOperands(); i < n; i++)
1258  if (!isAdmissibleBranchExp(op, block, def->getOperand(i)))
1259  return false;
1260  return true;
1261 }
1262 
1263 /// Ensures that the sparsifier can generate code for branch.
1264 static bool isAdmissibleBranch(Operation *op, Region &region) {
1265  if (region.empty())
1266  return true;
1267  // Build the semi-ring branch semantics backward from yield.
1268  Operation *yield = region.front().getTerminator();
1269  assert(isa<YieldOp>(yield));
1270  return isAdmissibleBranchExp(op, &region.front(), yield->getOperand(0));
1271 }
1272 
1273 // Recognizes a direct GT comparison.
1274 static bool isGreater(TensorExp::Kind kind, Attribute attr) {
1275  if (kind == TensorExp::Kind::kCmpI) {
1276  auto pred = llvm::cast<arith::CmpIPredicateAttr>(attr).getValue();
1277  return pred == arith::CmpIPredicate::ugt ||
1278  pred == arith::CmpIPredicate::sgt;
1279  }
1280  if (kind == TensorExp::Kind::kCmpF) {
1281  auto pred = llvm::cast<arith::CmpFPredicateAttr>(attr).getValue();
1282  return pred == arith::CmpFPredicate::UGT ||
1283  pred == arith::CmpFPredicate::OGT;
1284  }
1285  return false;
1286 }
1287 
1288 std::pair<std::optional<ExprId>, bool>
1289 Merger::buildTensorExp(linalg::GenericOp op, Value v) {
1290  // Recursion leaves.
1291  if (auto arg = dyn_cast<BlockArgument>(v)) {
1292  const TensorId tid = makeTensorId(arg.getArgNumber());
1293  // Any argument of the generic op that is not marked as a scalar
1294  // argument is considered a tensor, indexed by the implicit loop
1295  // bounds. This includes rank-0 tensor arguments.
1296  if (arg.getOwner()->getParentOp() == op) {
1297  OpOperand &t = op->getOpOperand(tid);
1298  bool hasSpDep = getSparseTensorEncoding(t.get().getType()) != nullptr;
1299  if (!op.isScalar(&t))
1300  return {addTensorExp(tid), hasSpDep};
1301  v = t.get(); // get scalar value
1302  }
1303  // Any other argument (marked as scalar argument for the generic op
1304  // or belonging to an enveloping op) is considered invariant.
1305  return {addInvariantExp(v), /*hasSpDep=*/false};
1306  }
1307 
1308  // Something defined outside is invariant.
1309  Operation *def = v.getDefiningOp();
1310  if (def->getBlock() != &op.getRegion().front())
1311  return {addInvariantExp(v), /*hasSpDep=*/false};
1312  // Construct index operations.
1313  if (def->getNumOperands() == 0) {
1314  if (auto indexOp = dyn_cast<linalg::IndexOp>(def))
1315  return {addLoopVarExp(makeLoopId(indexOp.getDim())), /*hasSpDep=*/false};
1316  }
1317 
1318  // Construct unary operations if subexpression can be built.
1319  if (def->getNumOperands() == 1) {
1320  const auto [x, hasSpDep] = buildTensorExp(op, def->getOperand(0));
1321  if (x.has_value()) {
1322  const ExprId e = *x;
1323  if (isa<math::AbsFOp>(def))
1324  return {addExp(TensorExp::Kind::kAbsF, e), hasSpDep};
1325  if (isa<complex::AbsOp>(def))
1326  return {addExp(TensorExp::Kind::kAbsC, e), hasSpDep};
1327  if (isa<math::AbsIOp>(def))
1328  return {addExp(TensorExp::Kind::kAbsI, e), hasSpDep};
1329  if (isa<math::CeilOp>(def))
1330  return {addExp(TensorExp::Kind::kCeilF, e), hasSpDep};
1331  if (isa<math::FloorOp>(def))
1332  return {addExp(TensorExp::Kind::kFloorF, e), hasSpDep};
1333  if (isa<math::SqrtOp>(def))
1334  return {addExp(TensorExp::Kind::kSqrtF, e), hasSpDep};
1335  if (isa<complex::SqrtOp>(def))
1336  return {addExp(TensorExp::Kind::kSqrtC, e), hasSpDep};
1337  if (isa<math::ExpM1Op>(def))
1338  return {addExp(TensorExp::Kind::kExpm1F, e), hasSpDep};
1339  if (isa<complex::Expm1Op>(def))
1340  return {addExp(TensorExp::Kind::kExpm1C, e), hasSpDep};
1341  if (isa<math::Log1pOp>(def))
1342  return {addExp(TensorExp::Kind::kLog1pF, e), hasSpDep};
1343  if (isa<complex::Log1pOp>(def))
1344  return {addExp(TensorExp::Kind::kLog1pC, e), hasSpDep};
1345  if (isa<math::SinOp>(def))
1346  return {addExp(TensorExp::Kind::kSinF, e), hasSpDep};
1347  if (isa<complex::SinOp>(def))
1348  return {addExp(TensorExp::Kind::kSinC, e), hasSpDep};
1349  if (isa<math::TanhOp>(def))
1350  return {addExp(TensorExp::Kind::kTanhF, e), hasSpDep};
1351  if (isa<complex::TanhOp>(def))
1352  return {addExp(TensorExp::Kind::kTanhC, e), hasSpDep};
1353  if (isa<arith::NegFOp>(def))
1354  return {addExp(TensorExp::Kind::kNegF, e), hasSpDep}; // no negi in std
1355  if (isa<complex::NegOp>(def))
1356  return {addExp(TensorExp::Kind::kNegC, e), hasSpDep};
1357  if (isa<arith::TruncFOp>(def))
1358  return {addExp(TensorExp::Kind::kTruncF, e, v), hasSpDep};
1359  if (isa<arith::ExtFOp>(def))
1360  return {addExp(TensorExp::Kind::kExtF, e, v), hasSpDep};
1361  if (isa<arith::FPToSIOp>(def))
1362  return {addExp(TensorExp::Kind::kCastFS, e, v), hasSpDep};
1363  if (isa<arith::FPToUIOp>(def))
1364  return {addExp(TensorExp::Kind::kCastFU, e, v), hasSpDep};
1365  if (isa<arith::SIToFPOp>(def))
1366  return {addExp(TensorExp::Kind::kCastSF, e, v), hasSpDep};
1367  if (isa<arith::UIToFPOp>(def))
1368  return {addExp(TensorExp::Kind::kCastUF, e, v), hasSpDep};
1369  if (isa<arith::ExtSIOp>(def))
1370  return {addExp(TensorExp::Kind::kCastS, e, v), hasSpDep};
1371  if (isa<arith::ExtUIOp>(def))
1372  return {addExp(TensorExp::Kind::kCastU, e, v), hasSpDep};
1373  if (isa<arith::IndexCastOp>(def))
1374  return {addExp(TensorExp::Kind::kCastIdx, e, v), hasSpDep};
1375  if (isa<arith::TruncIOp>(def))
1376  return {addExp(TensorExp::Kind::kTruncI, e, v), hasSpDep};
1377  if (isa<complex::ImOp>(def))
1378  return {addExp(TensorExp::Kind::kCIm, e), hasSpDep};
1379  if (isa<complex::ReOp>(def))
1380  return {addExp(TensorExp::Kind::kCRe, e), hasSpDep};
1381  if (isa<arith::BitcastOp>(def))
1382  return {addExp(TensorExp::Kind::kBitCast, e, v), hasSpDep};
1383  if (auto unop = dyn_cast<sparse_tensor::UnaryOp>(def)) {
1384  if (isAdmissibleBranch(unop, unop.getPresentRegion()) &&
1385  isAdmissibleBranch(unop, unop.getAbsentRegion()))
1386  return {addExp(TensorExp::Kind::kUnary, e, Value(), def), hasSpDep};
1387  }
1388  if (auto selop = dyn_cast<sparse_tensor::SelectOp>(def)) {
1389  if (isAdmissibleBranch(selop, selop.getRegion()))
1390  return {addExp(TensorExp::Kind::kSelect, e, Value(), def), hasSpDep};
1391  }
1392  }
1393  }
1394 
1395  // Construct binary operations if subexpressions can be built.
1396  // See buildLattices() for an explanation of rejecting certain
1397  // division and shift operations.
1398  if (def->getNumOperands() == 2) {
1399  const auto [x, xSpVals] = buildTensorExp(op, def->getOperand(0));
1400  const auto [y, ySpVals] = buildTensorExp(op, def->getOperand(1));
1401  // For a conjunctive operation, it yields a "sparse" result if any operand
1402  // is sparse. For a disjunctive operation, it yields a "sparse" result if
1403  // all operands are sparse.
1404  bool conjSpVals = xSpVals || ySpVals;
1405  bool disjSpVals = xSpVals && ySpVals;
1406  if (x.has_value() && y.has_value()) {
1407  const ExprId e0 = *x;
1408  const ExprId e1 = *y;
1409  if (isa<arith::MulFOp>(def))
1410  return {addExp(TensorExp::Kind::kMulF, e0, e1), conjSpVals};
1411  if (isa<complex::MulOp>(def))
1412  return {addExp(TensorExp::Kind::kMulC, e0, e1), conjSpVals};
1413  if (isa<arith::MulIOp>(def))
1414  return {addExp(TensorExp::Kind::kMulI, e0, e1), conjSpVals};
1415  if (isa<arith::DivFOp>(def) && !maybeZero(e1))
1416  return {addExp(TensorExp::Kind::kDivF, e0, e1), conjSpVals};
1417  if (isa<complex::DivOp>(def) && !maybeZero(e1))
1418  return {addExp(TensorExp::Kind::kDivC, e0, e1), conjSpVals};
1419  if (isa<arith::DivSIOp>(def) && !maybeZero(e1))
1420  return {addExp(TensorExp::Kind::kDivS, e0, e1), conjSpVals};
1421  if (isa<arith::DivUIOp>(def) && !maybeZero(e1))
1422  return {addExp(TensorExp::Kind::kDivU, e0, e1), conjSpVals};
1423  if (isa<arith::AddFOp>(def))
1424  return {addExp(TensorExp::Kind::kAddF, e0, e1), disjSpVals};
1425  if (isa<complex::AddOp>(def))
1426  return {addExp(TensorExp::Kind::kAddC, e0, e1), disjSpVals};
1427  if (isa<arith::AddIOp>(def))
1428  return {addExp(TensorExp::Kind::kAddI, e0, e1), disjSpVals};
1429  if (isa<arith::SubFOp>(def))
1430  return {addExp(TensorExp::Kind::kSubF, e0, e1), disjSpVals};
1431  if (isa<complex::SubOp>(def))
1432  return {addExp(TensorExp::Kind::kSubC, e0, e1), disjSpVals};
1433  if (isa<arith::SubIOp>(def))
1434  return {addExp(TensorExp::Kind::kSubI, e0, e1), disjSpVals};
1435  if (isa<arith::AndIOp>(def))
1436  return {addExp(TensorExp::Kind::kAndI, e0, e1), conjSpVals};
1437  if (isa<arith::OrIOp>(def))
1438  return {addExp(TensorExp::Kind::kOrI, e0, e1), disjSpVals};
1439  if (isa<arith::XOrIOp>(def))
1440  return {addExp(TensorExp::Kind::kXorI, e0, e1), disjSpVals};
1441  if (isa<arith::ShRSIOp>(def) && isInvariant(e1))
1442  return {addExp(TensorExp::Kind::kShrS, e0, e1), conjSpVals};
1443  if (isa<arith::ShRUIOp>(def) && isInvariant(e1))
1444  return {addExp(TensorExp::Kind::kShrU, e0, e1), conjSpVals};
1445  if (isa<arith::ShLIOp>(def) && isInvariant(e1))
1446  return {addExp(TensorExp::Kind::kShlI, e0, e1), conjSpVals};
1447  if (auto ci = dyn_cast<arith::CmpIOp>(def)) {
1448  if (ci.getPredicate() == arith::CmpIPredicate::eq &&
1449  ci.getPredicate() == arith::CmpIPredicate::sle &&
1450  ci.getPredicate() == arith::CmpIPredicate::sge &&
1451  ci.getPredicate() == arith::CmpIPredicate::ule &&
1452  ci.getPredicate() == arith::CmpIPredicate::uge) {
1453  // We can not sparsify comparison with equal, this is because 0 <= 0
1454  // yields true, and thus densifies the result.
1455  return {std::nullopt, false};
1456  }
1457 
1458  auto e = addExp(TensorExp::Kind::kCmpI, e0, e1, nullptr,
1459  ci.getPredicateAttr());
1460  return {e, conjSpVals};
1461  }
1462  if (auto cf = dyn_cast<arith::CmpFOp>(def)) {
1463  if (cf.getPredicate() == arith::CmpFPredicate::OEQ &&
1464  cf.getPredicate() == arith::CmpFPredicate::OGE &&
1465  cf.getPredicate() == arith::CmpFPredicate::OLE &&
1466  cf.getPredicate() == arith::CmpFPredicate::ONE &&
1467  cf.getPredicate() == arith::CmpFPredicate::UEQ &&
1468  cf.getPredicate() == arith::CmpFPredicate::UGE &&
1469  cf.getPredicate() == arith::CmpFPredicate::ULE &&
1470  cf.getPredicate() == arith::CmpFPredicate::ORD &&
1471  cf.getPredicate() == arith::CmpFPredicate::UNO) {
1472  // We can not sparsify comparison with equal, this is because 0 <= 0
1473  // yields true, and thus densifies the result.
1474  return {std::nullopt, false};
1475  }
1476  auto e = addExp(TensorExp::Kind::kCmpF, e0, e1, nullptr,
1477  cf.getPredicateAttr());
1478  return {e, conjSpVals};
1479  }
1480  if (auto binop = dyn_cast<sparse_tensor::BinaryOp>(def)) {
1481  if (isAdmissibleBranch(binop, binop.getOverlapRegion()) &&
1482  (binop.getLeftIdentity() ||
1483  isAdmissibleBranch(binop, binop.getLeftRegion())) &&
1484  (binop.getRightIdentity() ||
1485  isAdmissibleBranch(binop, binop.getRightRegion())))
1486  return {addExp(TensorExp::Kind::kBinary, e0, e1, def), conjSpVals};
1487  }
1488  }
1489  }
1490 
1491  // Construct ternary operations if subexpressions can be built.
1492  if (def->getNumOperands() == 3) {
1493  const auto [x, xDepSp] = buildTensorExp(op, def->getOperand(0));
1494  const auto [y, yDepSp] = buildTensorExp(op, def->getOperand(1));
1495  const auto [z, zDepSp] = buildTensorExp(op, def->getOperand(2));
1496  bool hasSpDep = xDepSp || yDepSp || zDepSp;
1497  if (x.has_value() && y.has_value() && z.has_value()) {
1498  const ExprId e0 = *x;
1499  const ExprId e1 = *y;
1500  if (auto redop = dyn_cast<sparse_tensor::ReduceOp>(def)) {
1501  if (isAdmissibleBranch(redop, redop.getRegion()))
1502  return {addExp(TensorExp::Kind::kReduce, e0, e1, def), hasSpDep};
1503  }
1504  if (auto selop = dyn_cast<arith::SelectOp>(def)) {
1505  // Recognize an integral or floating-point ReLu(x) = Max(x, 0)
1506  // operation inside a very specific ternary select operation.
1507  // TODO: capture MIN/MAX/ABS/RELU structure in a more generic way
1508  const auto &cnd = exp(*x);
1509  if (isGreater(cnd.kind, cnd.attr) &&
1512  isCertainZero(exp(*z).val)) {
1513  const auto &a = exp(cnd.children.e0);
1514  const auto &b = exp(cnd.children.e1);
1515  if (a.kind == TensorExp::Kind::kTensor &&
1516  a.tensor == exp(*y).tensor &&
1517  b.kind == TensorExp::Kind::kInvariant && isCertainZero(b.val)) {
1519  nullptr, cnd.attr),
1520  yDepSp};
1521  }
1522  }
1523  }
1524  }
1525  }
1526 
1527  // If we reach here, we are dealing with an operation that is not currently
1528  // sparsifiable. We can still generate code for it if all its operands only
1529  // have dense dependencies (i.e., all the values are loaded from dense
1530  // tensors).
1531  if (def->getNumResults() != 1) // only handle single result operation.
1532  return {std::nullopt, false};
1533  SmallVector<std::pair<std::optional<ExprId>, bool>, 2> subExp;
1534  // Builds all the sub-expressions
1535  for (Value operand : def->getOperands())
1536  subExp.push_back(buildTensorExp(op, operand));
1537 
1538  if (llvm::all_of(subExp,
1539  [](auto e) { return e.first.has_value() && !e.second; })) {
1540  // All the subexpressions can be built and has *no* sparse dependencies.
1541  if (subExp.size() == 2) {
1542  auto e = addExp(TensorExp::Kind::kDenseOp, *subExp[0].first,
1543  *subExp[1].first, def);
1544  return {e, false};
1545  }
1546  if (subExp.size() == 1) {
1547  auto e = addExp(TensorExp::Kind::kDenseOp, *subExp[0].first,
1548  detail::kInvalidId, def);
1549  return {e, false};
1550  }
1551  }
1552 
1553  // Cannot build.
1554  return {std::nullopt, false};
1555 }
1556 
1557 static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region &region,
1558  ValueRange vals) {
1559  // Make a clone of overlap region.
1560  Region tmpRegion;
1561  IRMapping mapper;
1562  region.cloneInto(&tmpRegion, tmpRegion.begin(), mapper);
1563  Block &clonedBlock = tmpRegion.front();
1564  YieldOp clonedYield = cast<YieldOp>(clonedBlock.getTerminator());
1565  // Merge cloned block and return yield value.
1566  Operation *placeholder = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1567  rewriter.inlineBlockBefore(&tmpRegion.front(), placeholder, vals);
1568  Value val = clonedYield.getSingleResult();
1569  rewriter.eraseOp(clonedYield);
1570  rewriter.eraseOp(placeholder);
1571  return val;
1572 }
1573 
1575  Operation *op, Value v0) {
1576  if (!v0)
1577  // Empty input value must be propagated.
1578  return Value();
1579  UnaryOp unop = cast<UnaryOp>(op);
1580  Region &presentRegion = unop.getPresentRegion();
1581  if (presentRegion.empty())
1582  // Uninitialized Value() will be interpreted as missing data in the
1583  // output.
1584  return Value();
1585  return insertYieldOp(rewriter, loc, presentRegion, {v0});
1586 }
1587 
1589  Operation *op, Value v0, Value v1) {
1590  if (!v0 || !v1)
1591  // Empty input values must be propagated.
1592  return Value();
1593  BinaryOp binop = cast<BinaryOp>(op);
1594  Region &overlapRegion = binop.getOverlapRegion();
1595  if (overlapRegion.empty())
1596  // Uninitialized Value() will be interpreted as missing data in the
1597  // output.
1598  return Value();
1599  return insertYieldOp(rewriter, loc, overlapRegion, {v0, v1});
1600 }
1601 
1602 static Value buildRelu(RewriterBase &rewriter, Location loc, Value v0,
1603  Attribute attr) {
1604  Type tp = v0.getType();
1605  auto zero =
1606  rewriter.create<arith::ConstantOp>(loc, tp, rewriter.getZeroAttr(tp));
1607  Value cmp;
1608  if (isa<FloatType>(tp)) {
1609  auto pred = llvm::cast<arith::CmpFPredicateAttr>(attr);
1610  cmp = rewriter.create<arith::CmpFOp>(loc, pred, v0, zero);
1611  } else {
1612  auto pred = llvm::cast<arith::CmpIPredicateAttr>(attr);
1613  cmp = rewriter.create<arith::CmpIOp>(loc, pred, v0, zero);
1614  }
1615  return rewriter.create<arith::SelectOp>(loc, cmp, v0, zero);
1616 }
1617 
1619  Value v1) const {
1620  const auto &expr = exp(e);
1621  switch (expr.kind) {
1622  // Leaf.
1627  llvm_unreachable("unexpected non-op");
1628  // Unary operations.
1630  return rewriter.create<math::AbsFOp>(loc, v0);
1631  case TensorExp::Kind::kAbsC: {
1632  auto type = cast<ComplexType>(v0.getType());
1633  auto eltType = cast<FloatType>(type.getElementType());
1634  return rewriter.create<complex::AbsOp>(loc, eltType, v0);
1635  }
1637  return rewriter.create<math::AbsIOp>(loc, v0);
1639  return rewriter.create<math::CeilOp>(loc, v0);
1641  return rewriter.create<math::FloorOp>(loc, v0);
1643  return rewriter.create<math::SqrtOp>(loc, v0);
1645  return rewriter.create<complex::SqrtOp>(loc, v0);
1647  return rewriter.create<math::ExpM1Op>(loc, v0);
1649  return rewriter.create<complex::Expm1Op>(loc, v0);
1651  return rewriter.create<math::Log1pOp>(loc, v0);
1653  return rewriter.create<complex::Log1pOp>(loc, v0);
1655  return buildRelu(rewriter, loc, v0, expr.attr);
1657  return rewriter.create<math::SinOp>(loc, v0);
1659  return rewriter.create<complex::SinOp>(loc, v0);
1661  return rewriter.create<math::TanhOp>(loc, v0);
1663  return rewriter.create<complex::TanhOp>(loc, v0);
1665  return rewriter.create<arith::NegFOp>(loc, v0);
1667  return rewriter.create<complex::NegOp>(loc, v0);
1668  case TensorExp::Kind::kNegI: // no negi in std
1669  return rewriter.create<arith::SubIOp>(
1670  loc,
1671  rewriter.create<arith::ConstantOp>(loc, v0.getType(),
1672  rewriter.getZeroAttr(v0.getType())),
1673  v0);
1675  return rewriter.create<arith::TruncFOp>(loc, inferType(e, v0), v0);
1677  return rewriter.create<arith::ExtFOp>(loc, inferType(e, v0), v0);
1679  return rewriter.create<arith::FPToSIOp>(loc, inferType(e, v0), v0);
1681  return rewriter.create<arith::FPToUIOp>(loc, inferType(e, v0), v0);
1683  return rewriter.create<arith::SIToFPOp>(loc, inferType(e, v0), v0);
1685  return rewriter.create<arith::UIToFPOp>(loc, inferType(e, v0), v0);
1687  return rewriter.create<arith::ExtSIOp>(loc, inferType(e, v0), v0);
1689  return rewriter.create<arith::ExtUIOp>(loc, inferType(e, v0), v0);
1691  return rewriter.create<arith::IndexCastOp>(loc, inferType(e, v0), v0);
1693  return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0);
1694  case TensorExp::Kind::kCIm: {
1695  auto type = cast<ComplexType>(v0.getType());
1696  auto eltType = cast<FloatType>(type.getElementType());
1697  return rewriter.create<complex::ImOp>(loc, eltType, v0);
1698  }
1699  case TensorExp::Kind::kCRe: {
1700  auto type = cast<ComplexType>(v0.getType());
1701  auto eltType = cast<FloatType>(type.getElementType());
1702  return rewriter.create<complex::ReOp>(loc, eltType, v0);
1703  }
1705  return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0);
1706  // Binary operations.
1708  return rewriter.create<arith::MulFOp>(loc, v0, v1);
1710  return rewriter.create<complex::MulOp>(loc, v0, v1);
1712  return rewriter.create<arith::MulIOp>(loc, v0, v1);
1714  return rewriter.create<arith::DivFOp>(loc, v0, v1);
1716  return rewriter.create<complex::DivOp>(loc, v0, v1);
1718  return rewriter.create<arith::DivSIOp>(loc, v0, v1);
1720  return rewriter.create<arith::DivUIOp>(loc, v0, v1);
1722  return rewriter.create<arith::AddFOp>(loc, v0, v1);
1724  return rewriter.create<complex::AddOp>(loc, v0, v1);
1726  return rewriter.create<arith::AddIOp>(loc, v0, v1);
1728  return rewriter.create<arith::SubFOp>(loc, v0, v1);
1730  return rewriter.create<complex::SubOp>(loc, v0, v1);
1732  return rewriter.create<arith::SubIOp>(loc, v0, v1);
1734  return rewriter.create<arith::AndIOp>(loc, v0, v1);
1735  case TensorExp::Kind::kOrI:
1736  return rewriter.create<arith::OrIOp>(loc, v0, v1);
1738  return rewriter.create<arith::XOrIOp>(loc, v0, v1);
1740  return rewriter.create<arith::ShRSIOp>(loc, v0, v1);
1742  return rewriter.create<arith::ShRUIOp>(loc, v0, v1);
1744  return rewriter.create<arith::ShLIOp>(loc, v0, v1);
1745  case TensorExp::Kind::kCmpI: {
1746  auto predicate = llvm::cast<arith::CmpIPredicateAttr>(expr.attr);
1747  return rewriter.create<arith::CmpIOp>(loc, predicate, v0, v1);
1748  }
1749  case TensorExp::Kind::kCmpF: {
1750  auto predicate = llvm::cast<arith::CmpFPredicateAttr>(expr.attr);
1751  return rewriter.create<arith::CmpFOp>(loc, predicate, v0, v1);
1752  }
1753  case TensorExp::Kind::kBinaryBranch: // semi-ring ops with custom logic.
1754  return insertYieldOp(rewriter, loc, *expr.op->getBlock()->getParent(),
1755  {v0});
1757  return buildUnaryPresent(rewriter, loc, expr.op, v0);
1759  return insertYieldOp(rewriter, loc,
1760  cast<sparse_tensor::SelectOp>(expr.op).getRegion(),
1761  {v0});
1763  return buildBinaryOverlap(rewriter, loc, expr.op, v0, v1);
1764  case TensorExp::Kind::kReduce: {
1765  ReduceOp redOp = cast<ReduceOp>(expr.op);
1766  return insertYieldOp(rewriter, loc, redOp.getRegion(), {v0, v1});
1767  }
1769  Operation *actualOp = expr.op;
1770  IRMapping mapping;
1771  mapping.map(actualOp->getOperand(0), v0);
1772  if (actualOp->getNumOperands() == 2)
1773  mapping.map(actualOp->getOperand(1), v1);
1774  return rewriter.clone(*actualOp, mapping)->getResult(0);
1775  }
1776  }
1777  llvm_unreachable("unexpected expression kind in build");
1778 }
1779 
1780 } // namespace sparse_tensor
1781 } // namespace mlir
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:31
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:26
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:555
Block * getBlock() const
Returns the current block of the builder.
Definition: Builders.h:450
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
This class represents an operand of an operation.
Definition: Value.h:267
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
unsigned getNumOperands()
Definition: Operation.h:341
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
void cloneInto(Region *dest, IRMapping &mapper)
Clone the internal blocks from this region into dest.
Definition: Region.cpp:70
iterator begin()
Definition: Region.h:55
Block & front()
Definition: Region.h:65
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Specialization of arith.constant op that returns a floating point value.
Definition: Arith.h:75
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:92
Specialization of arith.constant op that returns an integer value.
Definition: Arith.h:53
LatPointId conjLat(ExprId e, LatPointId p0, LatPointId p1, Operation *op=nullptr)
Computes a single conjunction of two lattice points by taking the "union" of LoopId (effectively cons...
Definition: Merger.cpp:315
LatSetId disjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op=nullptr)
Disjunctive merge of two lattice sets: (s0 /\_op s1, s0, s1).
Definition: Merger.cpp:338
bool isSingleCondition(TensorId t, ExprId e) const
Returns true if given tensor iterates only in the given tensor expression.
Definition: Merger.cpp:577
bool hasSparseIdxReduction(const BitVector &bits) const
Returns true if bits contains a dependent index reduction condition on sparse levels.
Definition: Merger.cpp:680
bool expContainsTensor(ExprId e, TensorId t) const
Returns true if the expression contains the tensor as an operand.
Definition: Merger.cpp:522
LatSetId mapBinWithSynZeroSet(ExprId e, LatSetId s, bool lhsZero)
Maps the binary operator to the same operation but with one of its operand set to zero,...
Definition: Merger.cpp:414
bool isSparseLvlWithNonTrivialIdxExp(TensorLoopId b) const
Checks whether the TensorLoopId represents a sparse tensor level contains non-trivial index expressio...
Definition: Merger.h:510
void dumpBits(const BitVector &bits) const
Definition: Merger.cpp:920
LatSetId addSet()
Constructs a new (initially empty) set, and returns its identifier.
Definition: Merger.cpp:309
BitVector simplifyCond(LatSetId s, LatPointId p)
Simplifies the conditions in a conjunction of a given lattice point within the given set using just t...
Definition: Merger.cpp:461
bool hasNegateOnOut(ExprId e) const
Returns true if the expression contains a negation on output tensor.
Definition: Merger.cpp:544
bool isLvlWithNonTrivialIdxExp(TensorLoopId b) const
Checks whether the TensorLoopId represents a tensor level contains non-trivial index expression.
Definition: Merger.h:501
LatSetId disjSetWithZero(ExprId e, LatSetId s0, LatSetId s1)
Disjunctive merge of two lattice sets and also set one of the operand to zero: (s0 /\_op s1 (e0 op e1...
Definition: Merger.cpp:356
void dumpSet(LatSetId s) const
Definition: Merger.cpp:910
void dumpLat(LatPointId p) const
Definition: Merger.cpp:899
LatSetId combiSet(ExprId e, LatSetId s0, LatSetId s1, Operation *orig, bool includeLeft, TensorExp::Kind ltrans, Operation *opleft, bool includeRight, TensorExp::Kind rtrans, Operation *opright)
Disjunctive merge of two lattice sets with custom handling of the overlap, left, and right regions.
Definition: Merger.cpp:380
ExprId addTensorExp(TensorId t)
Constructs a new tensor expression, and returns its identifier.
Definition: Merger.cpp:247
LatSetId buildLattices(ExprId e, LoopId i)
Builds the iteration lattices in a bottom-up traversal given the remaining tensor (sub)expression and...
Definition: Merger.cpp:940
LatSetId conjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op=nullptr)
Conjunctive merge of two lattice sets: (s0 /\_op s1).
Definition: Merger.cpp:329
ExprId addExp(TensorExp::Kind k, ExprId e0, ExprId e1=detail::kInvalidId, Operation *op=nullptr, Attribute attr=nullptr)
Constructs a new unary or binary expression, and returns its identifier.
Definition: Merger.cpp:277
ExprId addSynZeroExp()
Constructs a new synthetic zero expression.
Definition: Merger.cpp:270
constexpr LoopId makeLoopId(unsigned i) const
Safely converts the argument to a loop identifier.
Definition: Merger.h:249
std::optional< ExprId > buildTensorExpFromLinalg(linalg::GenericOp op)
Builds a tensor expression from the given Linalg operation.
Definition: Merger.cpp:1193
LatSetId mapSet(TensorExp::Kind kind, LatSetId s, Value v=Value(), Operation *op=nullptr, Attribute attr=nullptr)
Maps the unary operator over the lattice set of the operand, i.e.
Definition: Merger.cpp:401
ArrayRef< LatPointId > set(LatSetId s) const
Definition: Merger.h:549
LatSetId optimizeSet(LatSetId s)
Optimizes the iteration lattice points in the given set.
Definition: Merger.cpp:431
constexpr TensorId tensor(TensorLoopId b) const
Gets the tensor-identifier of the TensorLoopId.
Definition: Merger.h:346
void dumpExp(ExprId e) const
Print methods (for debugging).
Definition: Merger.cpp:799
Merger(unsigned numInputOutputTensors, unsigned numLoops, unsigned maxLvlRank)
Constructs a merger for the given number of tensors and loops.
Definition: Merger.cpp:224
bool hasAnySparse(const BitVector &bits) const
Returns true if any TensorLoopId in the bitvector corresponds to sparse level-type.
Definition: Merger.cpp:671
LatPointId addLat(TensorId t, LoopId i, ExprId e)
Constructs a new iteration lattice point, and returns its identifier.
Definition: Merger.cpp:293
ExprId addLoopVarExp(LoopId i)
Constructs a new loop-variable expression, and returns its identifier.
Definition: Merger.cpp:255
bool latGT(LatPointId p0, LatPointId p1) const
Returns true if p0 > p1.
Definition: Merger.cpp:503
const TensorExp & exp(ExprId e) const
Convenience getters to immediately access the stored nodes.
Definition: Merger.h:541
constexpr LoopId loop(TensorLoopId b) const
Gets the loop-identifier of the TensorLoopId.
Definition: Merger.h:348
const LatPoint & lat(LatPointId p) const
Definition: Merger.h:545
bool onlyDenseDiff(LatPointId p0, LatPointId p1) const
Returns true if p0 and p1 only differ in dense.
Definition: Merger.cpp:516
ExprId addInvariantExp(Value v)
Constructs a new invariant expression, and returns its identifier.
Definition: Merger.cpp:263
constexpr TensorId makeTensorId(unsigned t) const
Safely converts the argument to a tensor identifier.
Definition: Merger.h:243
LevelType getLvlType(TensorId t, LoopId i) const
Gets the level-type of the tth tensor on ith loop.
Definition: Merger.h:399
Value buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0, Value v1) const
Rebuilds SSA format from a tensor expression.
Definition: Merger.cpp:1618
constexpr TensorLoopId makeTensorLoopId(unsigned t, unsigned i) const
Safely converts the arguments to a pair of (tensor,loop) identifiers.
Definition: Merger.h:255
bool expIsTensor(ExprId e, TensorId t) const
Returns true if the expression is (kTensor t).
Definition: Merger.h:370
@ Type
An inlay hint that for a type annotation.
static constexpr unsigned kInvalidId
A constant serving as the canonically invalid identifier, regardless of the identifier type.
Definition: Merger.h:30
LevelFormat
This enum defines all supported storage format without the level properties.
Definition: Enums.h:154
static bool isCertainZero(Value val)
Only returns true if we are certain this is a zero.
Definition: Merger.cpp:1201
static bool isAdmissibleBranchExp(Operation *op, Block *block, Value v)
Ensures that the sparsifier can generate code for expression.
Definition: Merger.cpp:1244
unsigned LatSetId
LatSet identifiers.
Definition: Merger.h:57
static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region &region, ValueRange vals)
Definition: Merger.cpp:1557
std::string toMLIRString(LevelType lt)
Definition: Enums.h:447
std::pair< Level, LevelType > LvlLTPair
A pair of level and its corresponding LevelType of a tensor.
Definition: Merger.h:60
unsigned TensorLoopId
A compressed representation of std::pair<TensorId, LoopId>.
Definition: Merger.h:44
static Value buildUnaryPresent(RewriterBase &rewriter, Location loc, Operation *op, Value v0)
Definition: Merger.cpp:1574
uint64_t Level
The type of level identifiers and level-ranks.
Definition: SparseTensor.h:42
static Value buildBinaryOverlap(RewriterBase &rewriter, Location loc, Operation *op, Value v0, Value v1)
Definition: Merger.cpp:1588
static bool isAdmissibleBranch(Operation *op, Region &region)
Ensures that the sparsifier can generate code for branch.
Definition: Merger.cpp:1264
unsigned LoopId
Loop identifiers.
Definition: Merger.h:38
static const char * kindToOpSymbol(TensorExp::Kind kind)
Definition: Merger.cpp:693
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
static bool isGreater(TensorExp::Kind kind, Attribute attr)
Definition: Merger.cpp:1274
unsigned ExprId
TensorExp identifiers.
Definition: Merger.h:48
static ExpArity getExpArity(TensorExp::Kind k)
Definition: Merger.cpp:28
static Value buildRelu(RewriterBase &rewriter, Location loc, Value v0, Attribute attr)
Definition: Merger.cpp:1602
unsigned LatPointId
LatPoint identifiers.
Definition: Merger.h:52
std::pair< LoopId, unsigned > LoopCoeffPair
A pair of loop id and its coefficients.
Definition: Merger.h:64
unsigned TensorId
Tensor identifiers, chosen to be the BlockArgument::getArgNumber of the value passed to Merger::build...
Definition: Merger.h:35
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
BitVector bits
Conjunction of all TensorLoopIds involved in the tensor expression.
Definition: Merger.h:210
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition: Enums.h:238
LoopId loop
kLoopVar expressions simply have a loop identifier.
Definition: Merger.h:96
Value val
Direct link to IR for an invariant or the destination value (to infer destination type) of a cast ope...
Definition: Merger.h:105
Kind
Tensor expression kind.
Definition: Merger.h:129
Children children
All other expressions hold the ExprIds of their children.
Definition: Merger.h:99
Attribute attr
An optional attribute that is required to determine the semantics of the operations.
Definition: Merger.h:118
TensorId tensor
kTensor expressions simply have a tensor identifier.
Definition: Merger.h:93
Kind kind
Tensor expression kind.
Definition: Merger.h:89
TensorExp(Kind k, unsigned x, ExprId y, Value v, Operation *op, Attribute a)
The x parameter has different types depending on the value of the k parameter.
Definition: Merger.cpp:106
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.