MLIR  18.0.0git
Merger.cpp
Go to the documentation of this file.
1 //===- Merger.cpp - Implementation of iteration lattices ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
14 
15 #include "mlir/IR/Operation.h"
16 #include "llvm/Support/Debug.h"
17 #include <optional>
18 
19 namespace mlir {
20 namespace sparse_tensor {
21 
22 enum class ExpArity {
23  kNullary,
24  kUnary,
25  kBinary,
26 };
27 
29  switch (k) {
30  // Leaf.
35  return ExpArity::kNullary;
70  return ExpArity::kUnary;
71  // Binary operations.
95  case TensorExp::Kind::kDenseOp: // kDenseOp can *at most* have two operands
96  return ExpArity::kBinary;
97  }
98  llvm_unreachable("unexpected kind");
99 }
100 
101 //===----------------------------------------------------------------------===//
102 // Constructors.
103 //===----------------------------------------------------------------------===//
104 
106  Operation *o, Attribute a)
107  : kind(k), val(v), op(o) {
108  switch (kind) {
109  // Leaf.
111  assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o);
112  tensor = x;
113  return;
115  assert(x == detail::kInvalidId && y == detail::kInvalidId && !v && !o);
116  return;
118  assert(x == detail::kInvalidId && y == detail::kInvalidId && v && !o);
119  return;
121  assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o);
122  loop = x;
123  return;
124  // Unary operations.
145  assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o);
146  children.e0 = x;
147  children.e1 = y;
148  return;
160  assert(x != detail::kInvalidId && y == detail::kInvalidId && v && !o);
161  children.e0 = x;
162  children.e1 = y;
163  return;
166  assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && o);
167  children.e0 = x;
168  children.e1 = y;
169  return;
171  // No assertion on y can be made, as the branching paths involve both
172  // a unary (`mapSet`) and binary (`disjSet`) pathway.
173  assert(x != detail::kInvalidId && !v && o);
174  children.e0 = x;
175  children.e1 = y;
176  return;
177  // Binary operations.
197  assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && !o);
198  children.e0 = x;
199  children.e1 = y;
200  return;
203  assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && !o);
204  attr = a;
205  children.e0 = x;
206  children.e1 = y;
207  return;
210  assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && o);
211  children.e0 = x;
212  children.e1 = y;
213  return;
215  assert(x != detail::kInvalidId && !v && o);
216  children.e0 = x;
217  children.e1 = y;
218  return;
219  }
220  llvm_unreachable("unexpected kind");
221 }
222 
223 Merger::Merger(unsigned numInputOutputTensors, unsigned numLoops,
224  unsigned maxLvlRank)
225  : outTensor(numInputOutputTensors - 1),
226  syntheticTensor(numInputOutputTensors),
227  numTensors(numInputOutputTensors + 1), numLoops(numLoops),
228  hasSparseOut(false),
229  lvlTypes(numTensors, std::vector<LevelType>(numLoops, LevelType::Undef)),
230  loopToLvl(numTensors,
231  std::vector<std::optional<Level>>(numLoops, std::nullopt)),
232  lvlToLoop(numTensors,
233  std::vector<std::optional<LoopId>>(maxLvlRank, std::nullopt)),
234  loopToUnresolvedLvls(numLoops, std::vector<std::optional<LvlLTPair>>(
235  numTensors, std::nullopt)),
236  levelToDependentLoop(numTensors,
237  std::vector<std::vector<LoopCoeffPair>>(
238  maxLvlRank, std::vector<LoopCoeffPair>())),
239  loopBounds(numLoops, std::make_pair(numTensors, numLoops)) {}
240 
241 //===----------------------------------------------------------------------===//
242 // Lattice methods.
243 //===----------------------------------------------------------------------===//
244 
246  assert(isValidTensorId(t));
247  const ExprId eNew(tensorExps.size());
248  tensorExps.emplace_back(TensorExp::Kind::kTensor, t, detail::kInvalidId,
249  Value(), nullptr, nullptr);
250  return eNew;
251 }
252 
254  assert(isValidLoopId(i));
255  const ExprId eNew(tensorExps.size());
256  tensorExps.emplace_back(TensorExp::Kind::kLoopVar, i, detail::kInvalidId,
257  Value(), nullptr, nullptr);
258  return eNew;
259 }
260 
262  const ExprId eNew(tensorExps.size());
263  tensorExps.emplace_back(TensorExp::Kind::kInvariant, detail::kInvalidId,
264  detail::kInvalidId, v, nullptr, nullptr);
265  return eNew;
266 }
267 
269  const ExprId eNew(tensorExps.size());
270  tensorExps.emplace_back(TensorExp::Kind::kSynZero, detail::kInvalidId,
271  detail::kInvalidId, Value(), nullptr, nullptr);
272  return eNew;
273 }
274 
276  Attribute attr) {
277  assert(k > TensorExp::Kind::kLoopVar);
278  const ExprId eNew(tensorExps.size());
279  tensorExps.emplace_back(k, e0, e1, Value(), op, attr);
280  return eNew;
281 }
282 
284  Attribute attr) {
285  assert(k > TensorExp::Kind::kLoopVar);
286  const ExprId eNew(tensorExps.size());
287  tensorExps.emplace_back(k, e, detail::kInvalidId, v, op, attr);
288  return eNew;
289 }
290 
292  const LatPointId pNew(latPoints.size());
293  const unsigned size = numLoops * numTensors;
294  const TensorLoopId b = makeTensorLoopId(t, i);
295  latPoints.emplace_back(size, e);
296  latPoints[pNew].bits.set(b);
297  return pNew;
298 }
299 
300 LatPointId Merger::addLat(const BitVector &bits, ExprId e) {
301  assert(bits.size() == numLoops * numTensors);
302  const LatPointId pNew(latPoints.size());
303  latPoints.emplace_back(bits, e);
304  return pNew;
305 }
306 
308  const LatSetId sNew(latSets.size());
309  latSets.emplace_back();
310  return sNew;
311 }
312 
314  Operation *op) {
315  TensorExp::Kind kind = exp(e).kind;
316  Attribute attr = exp(e).attr;
317  const LatPointId pNew(latPoints.size());
318  const auto &point0 = lat(p0);
319  const auto &point1 = lat(p1);
320  BitVector bits(point0.bits);
321  bits |= point1.bits;
322  const ExprId ne = addExp(kind, point0.exp, point1.exp, op, attr);
323  latPoints.emplace_back(bits, ne);
324  return pNew;
325 }
326 
328  const LatSetId sNew = addSet();
329  auto &setNew = latSets[sNew];
330  for (const LatPointId p0 : set(s0))
331  for (const LatPointId p1 : set(s1))
332  setNew.push_back(conjLat(e, p0, p1, op));
333  return sNew;
334 }
335 
337  const LatSetId sNew = conjSet(e, s0, s1, op);
338  TensorExp::Kind kind = exp(e).kind;
339 
340  // Followed by all in s0.
341  latSets[sNew].append(latSets[s0]);
342  // Map binary 0-y to unary -y.
343  // TODO: move this if-else logic into buildLattices
344  if (kind == TensorExp::Kind::kSubF)
345  s1 = mapSet(TensorExp::Kind::kNegF, s1);
346  else if (kind == TensorExp::Kind::kSubC)
347  s1 = mapSet(TensorExp::Kind::kNegC, s1);
348  else if (kind == TensorExp::Kind::kSubI)
349  s1 = mapSet(TensorExp::Kind::kNegI, s1);
350  // Followed by all in s1.
351  latSets[sNew].append(latSets[s1]);
352  return sNew;
353 }
354 
356  assert(exp(e).kind == TensorExp::Kind::kCmpI ||
357  exp(e).kind == TensorExp::Kind::kCmpF);
358  const LatSetId sNew = conjSet(e, s0, s1, nullptr);
359 
360  ExprId e0 = exp(e).children.e0;
361  ExprId e1 = exp(e).children.e1;
362  if (exp(e0).kind == TensorExp::Kind::kSynZero ||
363  exp(e1).kind == TensorExp::Kind::kSynZero) {
364  // lhs and rhs can't be synthetic zero at the same time.
365  assert(exp(e0).kind != exp(e1).kind);
366  // If one of the operands has already been assigned to zero (the
367  // element is absent in the corresponding operand), then we do not
368  // need to build disjunctive set for it.
369  return sNew;
370  }
371 
372  auto lhsSet = mapBinWithSynZeroSet(e, s0, false);
373  auto rhsSet = mapBinWithSynZeroSet(e, s1, true);
374  latSets[sNew].append(latSets[lhsSet]);
375  latSets[sNew].append(latSets[rhsSet]);
376  return sNew;
377 }
378 
380  bool includeLeft, TensorExp::Kind ltrans,
381  Operation *opleft, bool includeRight,
382  TensorExp::Kind rtrans, Operation *opright) {
383  const LatSetId sNew = conjSet(e, s0, s1, orig);
384  // Left Region.
385  if (includeLeft) {
386  if (opleft)
387  s0 = mapSet(ltrans, s0, Value(), opleft);
388  latSets[sNew].append(latSets[s0]);
389  }
390  // Right Region.
391  if (includeRight) {
392  if (opright)
393  s1 = mapSet(rtrans, s1, Value(), opright);
394  latSets[sNew].append(latSets[s1]);
395  }
396  return sNew;
397 }
398 
400  Operation *op) {
401  assert((TensorExp::Kind::kAbsF <= kind && kind <= TensorExp::Kind::kSelect) ||
402  TensorExp::Kind::kDenseOp == kind);
403  const LatSetId sNew = addSet();
404  auto &setNew = latSets[sNew];
405  for (const LatPointId p : set(s0)) {
406  const auto &point = latPoints[p];
407  setNew.push_back(addLat(point.bits, addExp(kind, point.exp, v, op)));
408  }
409  return sNew;
410 }
411 
413  TensorExp::Kind kind = exp(e).kind;
414  Attribute a = exp(e).attr;
415  assert(TensorExp::Kind::kMulF <= kind && kind <= TensorExp::Kind::kShlI);
416  // Must be a binary operation.
417  const LatSetId sNew = addSet();
418  auto &setNew = latSets[sNew];
419  const ExprId zeroExp = addSynZeroExp();
420  for (const LatPointId p : set(s0)) {
421  const auto &point = latPoints[p];
422  ExprId newExp = lhsZero ? addExp(kind, zeroExp, point.exp, nullptr, a)
423  : addExp(kind, point.exp, zeroExp, nullptr, a);
424  setNew.push_back(addLat(point.bits, newExp));
425  }
426  return sNew;
427 }
428 
430  const LatSetId sNew = addSet();
431  auto &setNew = latSets[sNew];
432  const auto &set0 = set(s0);
433  assert(!set0.empty());
434  const LatPointId p0 = set0[0];
435  for (const LatPointId p1 : set0) {
436  bool add = true;
437  if (p0 != p1) {
438  // Check whether this is a straightforward copy.
439  if (expIsTensor(latPoints[p1].exp, outTensor))
440  continue;
441  // Check whether this conjunction is already covered.
442  for (const LatPointId p2 : setNew) {
443  assert(!latGT(p1, p2)); // Lj => Li would be bad
444  if (onlyDenseDiff(p2, p1)) {
445  add = false;
446  break;
447  }
448  }
449  assert(!add || latGT(p0, p1));
450  }
451  if (add)
452  setNew.push_back(p1);
453  }
454  for (const LatPointId p : setNew)
455  latPoints[p].simple = simplifyCond(sNew, p);
456  return sNew;
457 }
458 
460  // First determine if this lattice point is a *singleton*, i.e.,
461  // the last point in a lattice, no other is less than this one.
462  bool isSingleton = true;
463  for (const LatPointId p1 : set(s0)) {
464  if (p0 != p1 && latGT(p0, p1)) {
465  isSingleton = false;
466  break;
467  }
468  }
469 
470  BitVector simple(latPoints[p0].bits);
471  bool reset = isSingleton && hasAnySparse(simple);
472  const TensorLoopId be = simple.size();
473  TensorLoopId offset = 0; // relative to the end
474  if (!reset)
475  // Starts resetting from a dense level, so that the first bit (if kept)
476  // is not undefined level-type.
477  for (unsigned b = 0; b < be; b++) {
478  if (simple[b] && isDenseLT(getLvlType(TensorLoopId{b}))) {
479  offset = be - b - 1; // relative to the end
480  break;
481  }
482  }
483 
484  // Now apply the two basic rules. We also iterate the bits reversely to always
485  // keep the rightmost bit (which could possibly be a synthetic tensor).
486  for (unsigned b = be - 1 - offset, i = 0; i < be;
487  b = b == 0 ? be - 1 : b - 1, i++) {
488  // Slice on dense level has `locate` property as well, and can be optimized.
489  if (simple[b] && !isSparseLvlWithNonTrivialIdxExp(b)) {
490  const auto lt = getLvlType(b);
491  if (!isCompressedLT(lt) && !isSingletonLT(lt) &&
492  !isLooseCompressedLT(lt) && !is2OutOf4LT(lt)) {
493  if (reset)
494  simple.reset(b);
495  reset = true;
496  }
497  }
498  }
499  return simple;
500 }
501 
503  const BitVector &bitsi = lat(i).bits;
504  const BitVector &bitsj = lat(j).bits;
505  assert(bitsi.size() == bitsj.size());
506  if (bitsi.count() > bitsj.count()) {
507  for (TensorLoopId b = 0, be = bitsj.size(); b < be; b++)
508  if (bitsj[b] && !bitsi[b])
509  return false;
510  return true;
511  }
512  return false;
513 }
514 
516  BitVector tmp(latPoints[j].bits);
517  tmp ^= latPoints[i].bits;
518  return !hasAnySparse(tmp);
519 }
520 
522  const auto &expr = exp(e);
523  // First we check `expIsTensor`.
524  if (expr.kind == TensorExp::Kind::kTensor)
525  return expr.tensor == t;
526 
527  switch (getExpArity(expr.kind)) {
528  case ExpArity::kNullary:
529  return false;
530  case ExpArity::kUnary: {
531  const ExprId e0 = expr.children.e0;
532  return expContainsTensor(e0, t);
533  }
534  case ExpArity::kBinary: {
535  const ExprId e0 = expr.children.e0;
536  const ExprId e1 = expr.children.e1;
537  return expContainsTensor(e0, t) || expContainsTensor(e1, t);
538  }
539  }
540  llvm_unreachable("unexpected arity");
541 }
542 
544  const auto &expr = exp(e);
545  switch (expr.kind) {
549  return expContainsTensor(expr.children.e0, outTensor);
553  return expContainsTensor(expr.children.e1, outTensor) ||
554  hasNegateOnOut(expr.children.e0);
556  bool lhsNeg = hasNegateOnOut(expr.children.e0);
557  if (!lhsNeg && expr.children.e1 != detail::kInvalidId)
558  return hasNegateOnOut(expr.children.e1);
559  return lhsNeg;
560  }
561  default: {
562  switch (getExpArity(expr.kind)) {
563  case ExpArity::kNullary:
564  return false;
565  case ExpArity::kUnary:
566  return hasNegateOnOut(expr.children.e0);
567  case ExpArity::kBinary:
568  return hasNegateOnOut(expr.children.e0) ||
569  hasNegateOnOut(expr.children.e1);
570  }
571  }
572  }
573  llvm_unreachable("unexpected kind");
574 }
575 
577  assert(isValidTensorId(t));
578  const auto &expr = exp(e);
579  switch (expr.kind) {
580  // Leaf.
582  return expr.tensor == t;
586  return false;
587  // Unary operations.
620  return isSingleCondition(t, expr.children.e0);
623  return false;
624  // Binary operations.
625  case TensorExp::Kind::kDivF: // note: x / c only
629  assert(!maybeZero(expr.children.e1));
630  return isSingleCondition(t, expr.children.e0);
631  case TensorExp::Kind::kShrS: // note: x >> inv only
634  assert(isInvariant(expr.children.e1));
635  return isSingleCondition(t, expr.children.e0);
641  if (isSingleCondition(t, expr.children.e0))
642  return isSingleCondition(t, expr.children.e1) ||
643  isInvariant(expr.children.e1);
644  if (isSingleCondition(t, expr.children.e1))
645  return isInvariant(expr.children.e0);
646  return false;
650  return isSingleCondition(t, expr.children.e0) &&
651  isSingleCondition(t, expr.children.e1);
660  return false;
662  // Since Merger guarantees all the operands of the kDenseOp to be dense, the
663  // operation must be single-condition.
664  return true;
665  }
666  llvm_unreachable("unexpected kind");
667 }
668 
669 bool Merger::hasAnySparse(const BitVector &bits) const {
670  for (TensorLoopId b : bits.set_bits()) {
671  const auto lt = getLvlType(b);
672  if (isCompressedLT(lt) || isSingletonLT(lt) || isLooseCompressedLT(lt) ||
673  is2OutOf4LT(lt))
674  return true;
675  }
676  return hasSparseIdxReduction(bits);
677 }
678 
679 bool Merger::hasSparseIdxReduction(const BitVector &bits) const {
680  for (TensorLoopId b : bits.set_bits())
682  return true;
683  return false;
684 }
685 
686 #ifndef NDEBUG
687 
688 //===----------------------------------------------------------------------===//
689 // Print methods (for debugging).
690 //===----------------------------------------------------------------------===//
691 
692 static const char *kindToOpSymbol(TensorExp::Kind kind) {
693  switch (kind) {
694  // Leaf.
696  return "tensor";
698  return "invariant";
700  return "index";
702  return "0";
703  // Unary operations.
707  return "abs";
709  return "ceil";
711  return "floor";
714  return "sqrt";
717  return "expm1";
720  return "log1p";
723  return "sin";
726  return "tanh";
730  return "-";
742  return "complex.im";
744  return "complex.re";
746  return "cast";
748  return "binary_branch";
750  return "unary";
752  return "select";
753  // Binary operations.
757  return "*";
762  return "/";
766  return "+";
770  return "-";
772  return "&";
774  return "|";
776  return "^";
778  return "a>>";
780  return ">>";
782  return "<<";
785  return "cmp";
787  return "binary";
789  return "reduce";
791  return "dense";
792  }
793  llvm_unreachable("unexpected kind for symbol");
794 }
795 
796 void Merger::dumpExp(ExprId e) const {
797  const auto &expr = exp(e);
798  switch (expr.kind) {
799  // Leaf.
801  if (expr.tensor == syntheticTensor)
802  llvm::dbgs() << "synthetic_";
803  else if (expr.tensor == outTensor)
804  llvm::dbgs() << "output_";
805  llvm::dbgs() << "tensor_" << expr.tensor;
806  break;
808  llvm::dbgs() << "invariant";
809  break;
811  llvm::dbgs() << "0";
812  break;
814  llvm::dbgs() << "loopvar_" << expr.loop;
815  break;
816  // Unary operations.
851  llvm::dbgs() << kindToOpSymbol(expr.kind) << " ";
852  dumpExp(expr.children.e0);
853  break;
854  // Binary operations.
879  llvm::dbgs() << "(";
880  dumpExp(expr.children.e0);
881  llvm::dbgs() << " " << kindToOpSymbol(expr.kind);
882  if (expr.attr)
883  llvm::dbgs() << "{" << expr.attr << "}";
884  if (expr.children.e1 != detail::kInvalidId) {
885  llvm::dbgs() << " ";
886  dumpExp(expr.children.e1);
887  llvm::dbgs() << ")";
888  } else {
889  assert(expr.kind == TensorExp::Kind::kDenseOp);
890  }
891  break;
892  }
893 }
894 
896  const auto &point = lat(p);
897  llvm::dbgs() << "lat(";
898  dumpBits(point.bits);
899  llvm::dbgs() << " :";
900  dumpBits(point.simple);
901  llvm::dbgs() << " : ";
902  dumpExp(point.exp);
903  llvm::dbgs() << " )\n";
904 }
905 
906 void Merger::dumpSet(LatSetId s) const {
907  const auto &ss = set(s);
908  llvm::dbgs() << "{ #" << ss.size() << "\n";
909  for (const LatPointId p : ss) {
910  llvm::dbgs() << " ";
911  dumpLat(p);
912  }
913  llvm::dbgs() << "}\n";
914 }
915 
916 void Merger::dumpBits(const BitVector &bits) const {
917  for (TensorLoopId b = 0, be = bits.size(); b < be; b++) {
918  if (bits[b]) {
919  const TensorId t = tensor(b);
920  const LoopId i = loop(b);
921  const auto lt = lvlTypes[t][i];
923  llvm::dbgs() << " DEP_" << t << "_" << i;
924  else
925  llvm::dbgs() << " i_" << t << "_" << i << "_" << toMLIRString(lt);
926  }
927  }
928 }
929 
930 #endif // NDEBUG
931 
932 //===----------------------------------------------------------------------===//
933 // Builder methods.
934 //===----------------------------------------------------------------------===//
935 
937  // NOTE: The `expr` reference will be invalidated by recursive calls
938  // (and any other method that may add new expressions); therefore, the
939  // code below must make sure to copy fields of `expr` into local variables
940  // before making any recursive calls.
941  const auto &expr = exp(e);
942  const TensorExp::Kind kind = expr.kind;
943  switch (kind) {
944  // Leaf.
949  // Either the loop-var is really used in the tensor expression, or it is
950  // set to the undefined loop-var in that level. An invariant expression,
951  // a proper index value, and a truly dynamic sparse output tensor are set
952  // to a synthetic tensor with undefined indices only to ensure the
953  // iteration space is not skipped as a result of their contents.
954  const LatSetId s = addSet();
955  TensorId t = syntheticTensor;
956  if (kind == TensorExp::Kind::kTensor) {
957  t = expr.tensor;
958  if (hasSparseOut && t == outTensor)
959  t = syntheticTensor;
960  }
961  latSets[s].push_back(addLat(t, i, e));
962  return s;
963  }
964  // Unary operations.
996  // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the
997  // lattice set of the operand through the operator into a new set.
998  //
999  // -y|!y | y |
1000  // --+---+---+
1001  // | 0 |-y |
1002  {
1003  const ExprId e0 = expr.children.e0;
1004  const Value v = expr.val;
1005  return mapSet(kind, buildLattices(e0, i), v);
1006  }
1009  // The left or right half of a binary operation which has already
1010  // been split into separate operations for each region.
1011  {
1012  const ExprId e0 = expr.children.e0;
1013  Operation *const op = expr.op;
1014  return mapSet(kind, buildLattices(e0, i), Value(), op);
1015  }
1017  // A custom unary operation.
1018  //
1019  // op y| !y | y |
1020  // ----+----------+------------+
1021  // | absent() | present(y) |
1022  {
1023  const ExprId e0 = expr.children.e0;
1024  UnaryOp unop = cast<UnaryOp>(expr.op);
1025  const LatSetId child0 = buildLattices(e0, i);
1026  Region &absentRegion = unop.getAbsentRegion();
1027  if (absentRegion.empty()) {
1028  // Simple mapping over existing values.
1029  return mapSet(kind, child0, Value(), unop);
1030  }
1031  // Use a disjunction with `unop` on the left and the absent value as an
1032  // invariant on the right.
1033  Block &absentBlock = absentRegion.front();
1034  YieldOp absentYield = cast<YieldOp>(absentBlock.getTerminator());
1035  const Value absentVal = absentYield.getResult();
1036  const ExprId rhs = addInvariantExp(absentVal);
1037  return disjSet(e, child0, buildLattices(rhs, i), unop);
1038  }
1039  // Binary operations.
1044  // A multiplicative operation only needs to be performed
1045  // for the conjunction of sparse iteration spaces.
1046  //
1047  // x*y|!y | y |
1048  // ---+---+---+
1049  // !x | 0 | 0 |
1050  // x | 0 |x*y|
1051  //
1052  // Note even here, 0*NaN=NaN and 0*Inf=NaN, but that is ignored.
1053  {
1054  const ExprId e0 = expr.children.e0;
1055  const ExprId e1 = expr.children.e1;
1056  return conjSet(e, buildLattices(e0, i), buildLattices(e1, i));
1057  }
1062  // A division is tricky, since 0/0, 0/c, c/0 all have
1063  // specific outcomes for floating-point and integers.
1064  // Thus, we need to traverse the full iteration space.
1065  //
1066  // x/y|!y | y |
1067  // ---+---+---+
1068  // !x |0/0|0/y| FP: 0/0=NaN,c/0=Inf,0/c=0 with c true nonzero
1069  // x |x/0|x/y| INT: x/0=exception for any x
1070  //
1071  // TODO: for now we "fixed" this by only accepting x/c cases
1072  // during expression building, so that the conjunction
1073  // rules applies (viz. x/c = x*(1/c) as far as lattice
1074  // construction is concerned).
1075  {
1076  const ExprId e0 = expr.children.e0;
1077  const ExprId e1 = expr.children.e1;
1078  assert(!maybeZero(e1));
1079  return conjSet(e, buildLattices(e0, i), buildLattices(e1, i));
1080  }
1087  case TensorExp::Kind::kOrI:
1089  // An additive operation needs to be performed
1090  // for the disjunction of sparse iteration spaces.
1091  //
1092  // x+y|!y | y | x-y|!y | y |
1093  // ---+---+---+ ---+---+---+
1094  // !x | 0 | y | !x | 0 |-y |
1095  // x | x |x+y| x | x |x-y|
1096  {
1097  const ExprId e0 = expr.children.e0;
1098  const ExprId e1 = expr.children.e1;
1099  return disjSet(e, buildLattices(e0, i), buildLattices(e1, i));
1100  }
1103  // A comparison operation needs to be performed
1104  // for the disjunction of sparse iteration spaces.
1105  //
1106  // x < y | !y | y |
1107  // -------+-------+-------+
1108  // !x | 0 | 0 < y |
1109  // x | x < 0 | x < y |
1110  {
1111  const ExprId e0 = expr.children.e0;
1112  const ExprId e1 = expr.children.e1;
1113  return disjSetWithZero(e, buildLattices(e0, i), buildLattices(e1, i));
1114  }
1118  // A shift operation by an invariant amount (viz. tensor expressions
1119  // can only occur at the left-hand-side of the operator) can be handled
1120  // with the conjunction rule.
1121  {
1122  const ExprId e0 = expr.children.e0;
1123  const ExprId e1 = expr.children.e1;
1124  assert(isInvariant(e1));
1125  return conjSet(e, buildLattices(e0, i), buildLattices(e1, i));
1126  }
1128  // A custom binary operation.
1129  //
1130  // x op y| !y | y |
1131  // ------+---------+--------------+
1132  // !x | empty | right(y) |
1133  // x | left(x) | overlap(x,y) |
1134  {
1135  const ExprId e0 = expr.children.e0;
1136  const ExprId e1 = expr.children.e1;
1137  BinaryOp binop = cast<BinaryOp>(expr.op);
1138  const LatSetId child0 = buildLattices(e0, i);
1139  const LatSetId child1 = buildLattices(e1, i);
1140  Region &leftRegion = binop.getLeftRegion();
1141  Region &rightRegion = binop.getRightRegion();
1142  // Left Region.
1143  Operation *leftYield = nullptr;
1144  if (!leftRegion.empty()) {
1145  Block &leftBlock = leftRegion.front();
1146  leftYield = leftBlock.getTerminator();
1147  }
1148  // Right Region.
1149  Operation *rightYield = nullptr;
1150  if (!rightRegion.empty()) {
1151  Block &rightBlock = rightRegion.front();
1152  rightYield = rightBlock.getTerminator();
1153  }
1154  bool includeLeft = binop.getLeftIdentity() || !leftRegion.empty();
1155  bool includeRight = binop.getRightIdentity() || !rightRegion.empty();
1156  return combiSet(e, child0, child1, binop, includeLeft,
1157  TensorExp::Kind::kBinaryBranch, leftYield, includeRight,
1158  TensorExp::Kind::kBinaryBranch, rightYield);
1159  }
1161  // A custom reduce operation.
1162  {
1163  const ExprId e0 = expr.children.e0;
1164  const ExprId e1 = expr.children.e1;
1165  Operation *const op = expr.op;
1166  return conjSet(e, buildLattices(e0, i), buildLattices(e1, i), op);
1167  }
1169  // It does not really matter whether we use conjunctive/disjunctive set
1170  // here, as all the operands of kDenseOp must be dense, the disjunctive set
1171  // will be optimized into conjunctive set eventually.
1172  if (expr.children.e1 == detail::kInvalidId) {
1173  const ExprId e0 = expr.children.e0;
1174  Operation *const op = expr.op;
1175  return mapSet(kind, buildLattices(e0, i), Value(), op);
1176  }
1177 
1178  const ExprId e0 = expr.children.e0;
1179  const ExprId e1 = expr.children.e1;
1180  Operation *const op = expr.op;
1181  return conjSet(e, buildLattices(e0, i), buildLattices(e1, i), op);
1182  }
1183  }
1184  llvm_unreachable("unexpected expression kind");
1185 }
1186 
1187 std::optional<ExprId> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
1188  // Build the linalg semantics backward from yield.
1189  Operation *yield = op.getRegion().front().getTerminator();
1190  assert(isa<linalg::YieldOp>(yield));
1191  return buildTensorExp(op, yield->getOperand(0)).first;
1192 }
1193 
1194 /// Only returns false if we are certain this is a nonzero.
1195 bool Merger::maybeZero(ExprId e) const {
1196  const auto &expr = exp(e);
1197  if (expr.kind == TensorExp::Kind::kInvariant) {
1198  if (auto c = expr.val.getDefiningOp<complex::ConstantOp>()) {
1199  ArrayAttr arrayAttr = c.getValue();
1200  return cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
1201  cast<FloatAttr>(arrayAttr[1]).getValue().isZero();
1202  }
1203  if (auto c = expr.val.getDefiningOp<arith::ConstantIntOp>())
1204  return c.value() == 0;
1205  if (auto c = expr.val.getDefiningOp<arith::ConstantFloatOp>())
1206  return c.value().isZero();
1207  }
1208  return true;
1209 }
1210 
1211 Type Merger::inferType(ExprId e, Value src) const {
1212  // Obtain the destination type from the cast node.
1213  Type dtp = exp(e).val.getType();
1214  // Inspect source type. For vector types, apply the same
1215  // vectorization to the destination type.
1216  if (auto vtp = dyn_cast<VectorType>(src.getType()))
1217  return VectorType::get(vtp.getNumElements(), dtp, vtp.getScalableDims());
1218  return dtp;
1219 }
1220 
1221 /// Ensures that the sparsifier can generate code for expression.
1222 static bool isAdmissibleBranchExp(Operation *op, Block *block, Value v) {
1223  // Arguments are always admissible.
1224  if (isa<BlockArgument>(v))
1225  return true;
1226  // Accept index anywhere.
1227  Operation *def = v.getDefiningOp();
1228  if (isa<linalg::IndexOp>(def))
1229  return true;
1230  // Operation defined outside branch.
1231  if (def->getBlock() != block)
1232  return def->getBlock() != op->getBlock(); // invariant?
1233  // Operation defined within branch. Anything is accepted,
1234  // as long as all subexpressions are admissible.
1235  for (unsigned i = 0, n = def->getNumOperands(); i < n; i++)
1236  if (!isAdmissibleBranchExp(op, block, def->getOperand(i)))
1237  return false;
1238  return true;
1239 }
1240 
1241 /// Ensures that the sparsifier can generate code for branch.
1242 static bool isAdmissibleBranch(Operation *op, Region &region) {
1243  if (region.empty())
1244  return true;
1245  // Build the semi-ring branch semantics backward from yield.
1246  Operation *yield = region.front().getTerminator();
1247  assert(isa<YieldOp>(yield));
1248  return isAdmissibleBranchExp(op, &region.front(), yield->getOperand(0));
1249 }
1250 
1251 std::pair<std::optional<ExprId>, bool>
1252 Merger::buildTensorExp(linalg::GenericOp op, Value v) {
1253  // Recursion leaves.
1254  if (auto arg = dyn_cast<BlockArgument>(v)) {
1255  const TensorId tid = makeTensorId(arg.getArgNumber());
1256  // Any argument of the generic op that is not marked as a scalar
1257  // argument is considered a tensor, indexed by the implicit loop
1258  // bounds. This includes rank-0 tensor arguments.
1259  if (arg.getOwner()->getParentOp() == op) {
1260  OpOperand &t = op->getOpOperand(tid);
1261  bool hasSpDep = getSparseTensorEncoding(t.get().getType()) != nullptr;
1262  if (!op.isScalar(&t))
1263  return {addTensorExp(tid), hasSpDep};
1264  v = t.get(); // get scalar value
1265  }
1266  // Any other argument (marked as scalar argument for the generic op
1267  // or belonging to an enveloping op) is considered invariant.
1268  return {addInvariantExp(v), /*hasSpDep=*/false};
1269  }
1270  // Something defined outside is invariant.
1271  Operation *def = v.getDefiningOp();
1272  if (def->getBlock() != &op.getRegion().front())
1273  return {addInvariantExp(v), /*hasSpDep=*/false};
1274  // Construct index operations.
1275  if (def->getNumOperands() == 0) {
1276  if (auto indexOp = dyn_cast<linalg::IndexOp>(def))
1277  return {addLoopVarExp(makeLoopId(indexOp.getDim())), /*hasSpDep=*/false};
1278  }
1279 
1280  // Construct unary operations if subexpression can be built.
1281  if (def->getNumOperands() == 1) {
1282  const auto [x, hasSpDep] = buildTensorExp(op, def->getOperand(0));
1283  if (x.has_value()) {
1284  const ExprId e = *x;
1285  if (isa<math::AbsFOp>(def))
1286  return {addExp(TensorExp::Kind::kAbsF, e), hasSpDep};
1287  if (isa<complex::AbsOp>(def))
1288  return {addExp(TensorExp::Kind::kAbsC, e), hasSpDep};
1289  if (isa<math::AbsIOp>(def))
1290  return {addExp(TensorExp::Kind::kAbsI, e), hasSpDep};
1291  if (isa<math::CeilOp>(def))
1292  return {addExp(TensorExp::Kind::kCeilF, e), hasSpDep};
1293  if (isa<math::FloorOp>(def))
1294  return {addExp(TensorExp::Kind::kFloorF, e), hasSpDep};
1295  if (isa<math::SqrtOp>(def))
1296  return {addExp(TensorExp::Kind::kSqrtF, e), hasSpDep};
1297  if (isa<complex::SqrtOp>(def))
1298  return {addExp(TensorExp::Kind::kSqrtC, e), hasSpDep};
1299  if (isa<math::ExpM1Op>(def))
1300  return {addExp(TensorExp::Kind::kExpm1F, e), hasSpDep};
1301  if (isa<complex::Expm1Op>(def))
1302  return {addExp(TensorExp::Kind::kExpm1C, e), hasSpDep};
1303  if (isa<math::Log1pOp>(def))
1304  return {addExp(TensorExp::Kind::kLog1pF, e), hasSpDep};
1305  if (isa<complex::Log1pOp>(def))
1306  return {addExp(TensorExp::Kind::kLog1pC, e), hasSpDep};
1307  if (isa<math::SinOp>(def))
1308  return {addExp(TensorExp::Kind::kSinF, e), hasSpDep};
1309  if (isa<complex::SinOp>(def))
1310  return {addExp(TensorExp::Kind::kSinC, e), hasSpDep};
1311  if (isa<math::TanhOp>(def))
1312  return {addExp(TensorExp::Kind::kTanhF, e), hasSpDep};
1313  if (isa<complex::TanhOp>(def))
1314  return {addExp(TensorExp::Kind::kTanhC, e), hasSpDep};
1315  if (isa<arith::NegFOp>(def))
1316  return {addExp(TensorExp::Kind::kNegF, e), hasSpDep}; // no negi in std
1317  if (isa<complex::NegOp>(def))
1318  return {addExp(TensorExp::Kind::kNegC, e), hasSpDep};
1319  if (isa<arith::TruncFOp>(def))
1320  return {addExp(TensorExp::Kind::kTruncF, e, v), hasSpDep};
1321  if (isa<arith::ExtFOp>(def))
1322  return {addExp(TensorExp::Kind::kExtF, e, v), hasSpDep};
1323  if (isa<arith::FPToSIOp>(def))
1324  return {addExp(TensorExp::Kind::kCastFS, e, v), hasSpDep};
1325  if (isa<arith::FPToUIOp>(def))
1326  return {addExp(TensorExp::Kind::kCastFU, e, v), hasSpDep};
1327  if (isa<arith::SIToFPOp>(def))
1328  return {addExp(TensorExp::Kind::kCastSF, e, v), hasSpDep};
1329  if (isa<arith::UIToFPOp>(def))
1330  return {addExp(TensorExp::Kind::kCastUF, e, v), hasSpDep};
1331  if (isa<arith::ExtSIOp>(def))
1332  return {addExp(TensorExp::Kind::kCastS, e, v), hasSpDep};
1333  if (isa<arith::ExtUIOp>(def))
1334  return {addExp(TensorExp::Kind::kCastU, e, v), hasSpDep};
1335  if (isa<arith::IndexCastOp>(def))
1336  return {addExp(TensorExp::Kind::kCastIdx, e, v), hasSpDep};
1337  if (isa<arith::TruncIOp>(def))
1338  return {addExp(TensorExp::Kind::kTruncI, e, v), hasSpDep};
1339  if (isa<complex::ImOp>(def))
1340  return {addExp(TensorExp::Kind::kCIm, e), hasSpDep};
1341  if (isa<complex::ReOp>(def))
1342  return {addExp(TensorExp::Kind::kCRe, e), hasSpDep};
1343  if (isa<arith::BitcastOp>(def))
1344  return {addExp(TensorExp::Kind::kBitCast, e, v), hasSpDep};
1345  if (auto unop = dyn_cast<sparse_tensor::UnaryOp>(def)) {
1346  if (isAdmissibleBranch(unop, unop.getPresentRegion()) &&
1347  isAdmissibleBranch(unop, unop.getAbsentRegion()))
1348  return {addExp(TensorExp::Kind::kUnary, e, Value(), def), hasSpDep};
1349  }
1350  if (auto selop = dyn_cast<sparse_tensor::SelectOp>(def)) {
1351  if (isAdmissibleBranch(selop, selop.getRegion()))
1352  return {addExp(TensorExp::Kind::kSelect, e, Value(), def), hasSpDep};
1353  }
1354  }
1355  }
1356  // Construct binary operations if subexpressions can be built.
1357  // See buildLattices() for an explanation of rejecting certain
1358  // division and shift operations.
1359  if (def->getNumOperands() == 2) {
1360  const auto [x, xDepSp] = buildTensorExp(op, def->getOperand(0));
1361  const auto [y, yDepSp] = buildTensorExp(op, def->getOperand(1));
1362  bool hasSpDep = xDepSp || yDepSp;
1363  if (x.has_value() && y.has_value()) {
1364  const ExprId e0 = *x;
1365  const ExprId e1 = *y;
1366  if (isa<arith::MulFOp>(def))
1367  return {addExp(TensorExp::Kind::kMulF, e0, e1), hasSpDep};
1368  if (isa<complex::MulOp>(def))
1369  return {addExp(TensorExp::Kind::kMulC, e0, e1), hasSpDep};
1370  if (isa<arith::MulIOp>(def))
1371  return {addExp(TensorExp::Kind::kMulI, e0, e1), hasSpDep};
1372  if (isa<arith::DivFOp>(def) && !maybeZero(e1))
1373  return {addExp(TensorExp::Kind::kDivF, e0, e1), hasSpDep};
1374  if (isa<complex::DivOp>(def) && !maybeZero(e1))
1375  return {addExp(TensorExp::Kind::kDivC, e0, e1), hasSpDep};
1376  if (isa<arith::DivSIOp>(def) && !maybeZero(e1))
1377  return {addExp(TensorExp::Kind::kDivS, e0, e1), hasSpDep};
1378  if (isa<arith::DivUIOp>(def) && !maybeZero(e1))
1379  return {addExp(TensorExp::Kind::kDivU, e0, e1), hasSpDep};
1380  if (isa<arith::AddFOp>(def))
1381  return {addExp(TensorExp::Kind::kAddF, e0, e1), hasSpDep};
1382  if (isa<complex::AddOp>(def))
1383  return {addExp(TensorExp::Kind::kAddC, e0, e1), hasSpDep};
1384  if (isa<arith::AddIOp>(def))
1385  return {addExp(TensorExp::Kind::kAddI, e0, e1), hasSpDep};
1386  if (isa<arith::SubFOp>(def))
1387  return {addExp(TensorExp::Kind::kSubF, e0, e1), hasSpDep};
1388  if (isa<complex::SubOp>(def))
1389  return {addExp(TensorExp::Kind::kSubC, e0, e1), hasSpDep};
1390  if (isa<arith::SubIOp>(def))
1391  return {addExp(TensorExp::Kind::kSubI, e0, e1), hasSpDep};
1392  if (isa<arith::AndIOp>(def))
1393  return {addExp(TensorExp::Kind::kAndI, e0, e1), hasSpDep};
1394  if (isa<arith::OrIOp>(def))
1395  return {addExp(TensorExp::Kind::kOrI, e0, e1), hasSpDep};
1396  if (isa<arith::XOrIOp>(def))
1397  return {addExp(TensorExp::Kind::kXorI, e0, e1), hasSpDep};
1398  if (isa<arith::ShRSIOp>(def) && isInvariant(e1))
1399  return {addExp(TensorExp::Kind::kShrS, e0, e1), hasSpDep};
1400  if (isa<arith::ShRUIOp>(def) && isInvariant(e1))
1401  return {addExp(TensorExp::Kind::kShrU, e0, e1), hasSpDep};
1402  if (isa<arith::ShLIOp>(def) && isInvariant(e1))
1403  return {addExp(TensorExp::Kind::kShlI, e0, e1), hasSpDep};
1404  if (auto ci = dyn_cast<arith::CmpIOp>(def)) {
1405  if (ci.getPredicate() == arith::CmpIPredicate::eq &&
1406  ci.getPredicate() == arith::CmpIPredicate::sle &&
1407  ci.getPredicate() == arith::CmpIPredicate::sge &&
1408  ci.getPredicate() == arith::CmpIPredicate::ule &&
1409  ci.getPredicate() == arith::CmpIPredicate::uge) {
1410  // We can not sparsify comparison with equal, this is because 0 <= 0
1411  // yields true, and thus densifies the result.
1412  return {std::nullopt, false};
1413  }
1414 
1415  auto e = addExp(TensorExp::Kind::kCmpI, e0, e1, nullptr,
1416  ci.getPredicateAttr());
1417  return {e, hasSpDep};
1418  }
1419  if (auto cf = dyn_cast<arith::CmpFOp>(def)) {
1420  if (cf.getPredicate() == arith::CmpFPredicate::OEQ &&
1421  cf.getPredicate() == arith::CmpFPredicate::OGE &&
1422  cf.getPredicate() == arith::CmpFPredicate::OLE &&
1423  cf.getPredicate() == arith::CmpFPredicate::ONE &&
1424  cf.getPredicate() == arith::CmpFPredicate::UEQ &&
1425  cf.getPredicate() == arith::CmpFPredicate::UGE &&
1426  cf.getPredicate() == arith::CmpFPredicate::ULE &&
1427  cf.getPredicate() == arith::CmpFPredicate::ORD &&
1428  cf.getPredicate() == arith::CmpFPredicate::UNO) {
1429  // We can not sparsify comparison with equal, this is because 0 <= 0
1430  // yields true, and thus densifies the result.
1431  return {std::nullopt, false};
1432  }
1433  auto e = addExp(TensorExp::Kind::kCmpF, e0, e1, nullptr,
1434  cf.getPredicateAttr());
1435  return {e, hasSpDep};
1436  }
1437  if (auto binop = dyn_cast<sparse_tensor::BinaryOp>(def)) {
1438  if (isAdmissibleBranch(binop, binop.getOverlapRegion()) &&
1439  (binop.getLeftIdentity() ||
1440  isAdmissibleBranch(binop, binop.getLeftRegion())) &&
1441  (binop.getRightIdentity() ||
1442  isAdmissibleBranch(binop, binop.getRightRegion())))
1443  return {addExp(TensorExp::Kind::kBinary, e0, e1, def), hasSpDep};
1444  }
1445  }
1446  }
1447  // Construct ternary operations if subexpressions can be built.
1448  if (def->getNumOperands() == 3) {
1449  const auto [x, xDepSp] = buildTensorExp(op, def->getOperand(0));
1450  const auto [y, yDepSp] = buildTensorExp(op, def->getOperand(1));
1451  const auto [z, zDepSp] = buildTensorExp(op, def->getOperand(2));
1452  bool hasSpDep = xDepSp || yDepSp || zDepSp;
1453  if (x.has_value() && y.has_value() && z.has_value()) {
1454  const ExprId e0 = *x;
1455  const ExprId e1 = *y;
1456  if (auto redop = dyn_cast<sparse_tensor::ReduceOp>(def)) {
1457  if (isAdmissibleBranch(redop, redop.getRegion()))
1458  return {addExp(TensorExp::Kind::kReduce, e0, e1, def), hasSpDep};
1459  }
1460  }
1461  }
1462 
1463  // If we reach here, we are dealing with an operation that is not currently
1464  // sparsifiable. We can still generate code for it if all its operands only
1465  // have dense dependencies (i.e., all the values are loaded from dense
1466  // tensors).
1467  if (def->getNumResults() != 1) // only handle single result operation.
1468  return {std::nullopt, false};
1469 
1470  SmallVector<std::pair<std::optional<ExprId>, bool>, 2> subExp;
1471  // Builds all the sub-expressions
1472  for (Value operand : def->getOperands())
1473  subExp.push_back(buildTensorExp(op, operand));
1474 
1475  if (llvm::all_of(subExp,
1476  [](auto e) { return e.first.has_value() && !e.second; })) {
1477  // All the subexpressions can be built and has *no* sparse dependencies.
1478  if (subExp.size() == 2) {
1479  auto e = addExp(TensorExp::Kind::kDenseOp, *subExp[0].first,
1480  *subExp[1].first, def);
1481  return {e, false};
1482  }
1483  if (subExp.size() == 1) {
1484  auto e = addExp(TensorExp::Kind::kDenseOp, *subExp[0].first,
1485  detail::kInvalidId, def);
1486  return {e, false};
1487  }
1488  }
1489  // Cannot build.
1490  return {std::nullopt, false};
1491 }
1492 
1493 static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region &region,
1494  ValueRange vals) {
1495  // Make a clone of overlap region.
1496  Region tmpRegion;
1497  IRMapping mapper;
1498  region.cloneInto(&tmpRegion, tmpRegion.begin(), mapper);
1499  Block &clonedBlock = tmpRegion.front();
1500  YieldOp clonedYield = cast<YieldOp>(clonedBlock.getTerminator());
1501  // Merge cloned block and return yield value.
1502  Operation *placeholder = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1503  rewriter.inlineBlockBefore(&tmpRegion.front(), placeholder, vals);
1504  Value val = clonedYield.getResult();
1505  rewriter.eraseOp(clonedYield);
1506  rewriter.eraseOp(placeholder);
1507  return val;
1508 }
1509 
1511  Operation *op, Value v0) {
1512  if (!v0)
1513  // Empty input value must be propagated.
1514  return Value();
1515  UnaryOp unop = cast<UnaryOp>(op);
1516  Region &presentRegion = unop.getPresentRegion();
1517  if (presentRegion.empty())
1518  // Uninitialized Value() will be interpreted as missing data in the
1519  // output.
1520  return Value();
1521  return insertYieldOp(rewriter, loc, presentRegion, {v0});
1522 }
1523 
1525  Operation *op, Value v0, Value v1) {
1526  if (!v0 || !v1)
1527  // Empty input values must be propagated.
1528  return Value();
1529  BinaryOp binop = cast<BinaryOp>(op);
1530  Region &overlapRegion = binop.getOverlapRegion();
1531  if (overlapRegion.empty())
1532  // Uninitialized Value() will be interpreted as missing data in the
1533  // output.
1534  return Value();
1535  return insertYieldOp(rewriter, loc, overlapRegion, {v0, v1});
1536 }
1537 
1539  Value v1) const {
1540  const auto &expr = exp(e);
1541  switch (expr.kind) {
1542  // Leaf.
1547  llvm_unreachable("unexpected non-op");
1548  // Unary operations.
1550  return rewriter.create<math::AbsFOp>(loc, v0);
1551  case TensorExp::Kind::kAbsC: {
1552  auto type = cast<ComplexType>(v0.getType());
1553  auto eltType = cast<FloatType>(type.getElementType());
1554  return rewriter.create<complex::AbsOp>(loc, eltType, v0);
1555  }
1557  return rewriter.create<math::AbsIOp>(loc, v0);
1559  return rewriter.create<math::CeilOp>(loc, v0);
1561  return rewriter.create<math::FloorOp>(loc, v0);
1563  return rewriter.create<math::SqrtOp>(loc, v0);
1565  return rewriter.create<complex::SqrtOp>(loc, v0);
1567  return rewriter.create<math::ExpM1Op>(loc, v0);
1569  return rewriter.create<complex::Expm1Op>(loc, v0);
1571  return rewriter.create<math::Log1pOp>(loc, v0);
1573  return rewriter.create<complex::Log1pOp>(loc, v0);
1575  return rewriter.create<math::SinOp>(loc, v0);
1577  return rewriter.create<complex::SinOp>(loc, v0);
1579  return rewriter.create<math::TanhOp>(loc, v0);
1581  return rewriter.create<complex::TanhOp>(loc, v0);
1583  return rewriter.create<arith::NegFOp>(loc, v0);
1585  return rewriter.create<complex::NegOp>(loc, v0);
1586  case TensorExp::Kind::kNegI: // no negi in std
1587  return rewriter.create<arith::SubIOp>(
1588  loc,
1589  rewriter.create<arith::ConstantOp>(loc, v0.getType(),
1590  rewriter.getZeroAttr(v0.getType())),
1591  v0);
1593  return rewriter.create<arith::TruncFOp>(loc, inferType(e, v0), v0);
1595  return rewriter.create<arith::ExtFOp>(loc, inferType(e, v0), v0);
1597  return rewriter.create<arith::FPToSIOp>(loc, inferType(e, v0), v0);
1599  return rewriter.create<arith::FPToUIOp>(loc, inferType(e, v0), v0);
1601  return rewriter.create<arith::SIToFPOp>(loc, inferType(e, v0), v0);
1603  return rewriter.create<arith::UIToFPOp>(loc, inferType(e, v0), v0);
1605  return rewriter.create<arith::ExtSIOp>(loc, inferType(e, v0), v0);
1607  return rewriter.create<arith::ExtUIOp>(loc, inferType(e, v0), v0);
1609  return rewriter.create<arith::IndexCastOp>(loc, inferType(e, v0), v0);
1611  return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0);
1612  case TensorExp::Kind::kCIm: {
1613  auto type = cast<ComplexType>(v0.getType());
1614  auto eltType = cast<FloatType>(type.getElementType());
1615  return rewriter.create<complex::ImOp>(loc, eltType, v0);
1616  }
1617  case TensorExp::Kind::kCRe: {
1618  auto type = cast<ComplexType>(v0.getType());
1619  auto eltType = cast<FloatType>(type.getElementType());
1620  return rewriter.create<complex::ReOp>(loc, eltType, v0);
1621  }
1623  return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0);
1624  // Binary operations.
1626  return rewriter.create<arith::MulFOp>(loc, v0, v1);
1628  return rewriter.create<complex::MulOp>(loc, v0, v1);
1630  return rewriter.create<arith::MulIOp>(loc, v0, v1);
1632  return rewriter.create<arith::DivFOp>(loc, v0, v1);
1634  return rewriter.create<complex::DivOp>(loc, v0, v1);
1636  return rewriter.create<arith::DivSIOp>(loc, v0, v1);
1638  return rewriter.create<arith::DivUIOp>(loc, v0, v1);
1640  return rewriter.create<arith::AddFOp>(loc, v0, v1);
1642  return rewriter.create<complex::AddOp>(loc, v0, v1);
1644  return rewriter.create<arith::AddIOp>(loc, v0, v1);
1646  return rewriter.create<arith::SubFOp>(loc, v0, v1);
1648  return rewriter.create<complex::SubOp>(loc, v0, v1);
1650  return rewriter.create<arith::SubIOp>(loc, v0, v1);
1652  return rewriter.create<arith::AndIOp>(loc, v0, v1);
1653  case TensorExp::Kind::kOrI:
1654  return rewriter.create<arith::OrIOp>(loc, v0, v1);
1656  return rewriter.create<arith::XOrIOp>(loc, v0, v1);
1658  return rewriter.create<arith::ShRSIOp>(loc, v0, v1);
1660  return rewriter.create<arith::ShRUIOp>(loc, v0, v1);
1662  return rewriter.create<arith::ShLIOp>(loc, v0, v1);
1663  case TensorExp::Kind::kCmpI: {
1664  auto predicate = llvm::cast<arith::CmpIPredicateAttr>(expr.attr);
1665  return rewriter.create<arith::CmpIOp>(loc, predicate, v0, v1);
1666  }
1667  case TensorExp::Kind::kCmpF: {
1668  auto predicate = llvm::cast<arith::CmpFPredicateAttr>(expr.attr);
1669  return rewriter.create<arith::CmpFOp>(loc, predicate, v0, v1);
1670  }
1671  case TensorExp::Kind::kBinaryBranch: // semi-ring ops with custom logic.
1672  return insertYieldOp(rewriter, loc, *expr.op->getBlock()->getParent(),
1673  {v0});
1675  return buildUnaryPresent(rewriter, loc, expr.op, v0);
1677  return insertYieldOp(rewriter, loc, cast<SelectOp>(expr.op).getRegion(),
1678  {v0});
1680  return buildBinaryOverlap(rewriter, loc, expr.op, v0, v1);
1681  case TensorExp::Kind::kReduce: {
1682  ReduceOp redOp = cast<ReduceOp>(expr.op);
1683  return insertYieldOp(rewriter, loc, redOp.getRegion(), {v0, v1});
1684  }
1686  Operation *actualOp = expr.op;
1687  IRMapping mapping;
1688  mapping.map(actualOp->getOperand(0), v0);
1689  if (actualOp->getNumOperands() == 2)
1690  mapping.map(actualOp->getOperand(1), v1);
1691  return rewriter.clone(*actualOp, mapping)->getResult(0);
1692  }
1693  }
1694  llvm_unreachable("unexpected expression kind in build");
1695 }
1696 
1697 } // namespace sparse_tensor
1698 } // namespace mlir
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:30
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:26
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:238
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:528
Block * getBlock() const
Returns the current block of the builder.
Definition: Builders.h:433
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
This class represents an operand of an operation.
Definition: Value.h:263
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
unsigned getNumOperands()
Definition: Operation.h:341
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
void cloneInto(Region *dest, IRMapping &mapper)
Clone the internal blocks from this region into dest.
Definition: Region.cpp:70
iterator begin()
Definition: Region.h:55
Block & front()
Definition: Region.h:65
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Specialization of arith.constant op that returns a floating point value.
Definition: Arith.h:74
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:90
Specialization of arith.constant op that returns an integer value.
Definition: Arith.h:53
LatPointId conjLat(ExprId e, LatPointId p0, LatPointId p1, Operation *op=nullptr)
Computes a single conjunction of two lattice points by taking the "union" of LoopId (effectively cons...
Definition: Merger.cpp:313
LatSetId disjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op=nullptr)
Disjunctive merge of two lattice sets: (s0 /\_op s1, s0, s1).
Definition: Merger.cpp:336
bool isSingleCondition(TensorId t, ExprId e) const
Returns true if given tensor iterates only in the given tensor expression.
Definition: Merger.cpp:576
bool hasSparseIdxReduction(const BitVector &bits) const
Returns true if bits contains a dependent index reduction condition on sparse levels.
Definition: Merger.cpp:679
bool expContainsTensor(ExprId e, TensorId t) const
Returns true if the expression contains the tensor as an operand.
Definition: Merger.cpp:521
LatSetId mapBinWithSynZeroSet(ExprId e, LatSetId s, bool lhsZero)
Maps the binary operator to the same operation but with one of its operand set to zero,...
Definition: Merger.cpp:412
bool isSparseLvlWithNonTrivialIdxExp(TensorLoopId b) const
Checks whether the TensorLoopId represents a sparse tensor level contains non-trivial index expressio...
Definition: Merger.h:509
void dumpBits(const BitVector &bits) const
Definition: Merger.cpp:916
LatSetId addSet()
Constructs a new (initially empty) set, and returns its identifier.
Definition: Merger.cpp:307
BitVector simplifyCond(LatSetId s, LatPointId p)
Simplifies the conditions in a conjunction of a given lattice point within the given set using just t...
Definition: Merger.cpp:459
bool hasNegateOnOut(ExprId e) const
Returns true if the expression contains a negation on output tensor.
Definition: Merger.cpp:543
bool isLvlWithNonTrivialIdxExp(TensorLoopId b) const
Checks whether the TensorLoopId represents a tensor level contains non-trivial index expression.
Definition: Merger.h:500
LatSetId disjSetWithZero(ExprId e, LatSetId s0, LatSetId s1)
Disjunctive merge of two lattice sets and also set one of the operand to zero: (s0 /\_op s1 (e0 op e1...
Definition: Merger.cpp:355
void dumpSet(LatSetId s) const
Definition: Merger.cpp:906
void dumpLat(LatPointId p) const
Definition: Merger.cpp:895
LatSetId combiSet(ExprId e, LatSetId s0, LatSetId s1, Operation *orig, bool includeLeft, TensorExp::Kind ltrans, Operation *opleft, bool includeRight, TensorExp::Kind rtrans, Operation *opright)
Disjunctive merge of two lattice sets with custom handling of the overlap, left, and right regions.
Definition: Merger.cpp:379
ExprId addTensorExp(TensorId t)
Constructs a new tensor expression, and returns its identifier.
Definition: Merger.cpp:245
LatSetId buildLattices(ExprId e, LoopId i)
Builds the iteration lattices in a bottom-up traversal given the remaining tensor (sub)expression and...
Definition: Merger.cpp:936
LatSetId conjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op=nullptr)
Conjunctive merge of two lattice sets: (s0 /\_op s1).
Definition: Merger.cpp:327
ExprId addExp(TensorExp::Kind k, ExprId e0, ExprId e1=detail::kInvalidId, Operation *op=nullptr, Attribute attr=nullptr)
Constructs a new unary or binary expression, and returns its identifier.
Definition: Merger.cpp:275
ExprId addSynZeroExp()
Constructs a new synthetic zero expression.
Definition: Merger.cpp:268
constexpr LoopId makeLoopId(unsigned i) const
Safely converts the argument to a loop identifier.
Definition: Merger.h:248
std::optional< ExprId > buildTensorExpFromLinalg(linalg::GenericOp op)
Builds a tensor expression from the given Linalg operation.
Definition: Merger.cpp:1187
ArrayRef< LatPointId > set(LatSetId s) const
Definition: Merger.h:549
LatSetId optimizeSet(LatSetId s)
Optimizes the iteration lattice points in the given set.
Definition: Merger.cpp:429
constexpr TensorId tensor(TensorLoopId b) const
Gets the tensor-identifier of the TensorLoopId.
Definition: Merger.h:345
void dumpExp(ExprId e) const
Print methods (for debugging).
Definition: Merger.cpp:796
Merger(unsigned numInputOutputTensors, unsigned numLoops, unsigned maxLvlRank)
Constructs a merger for the given number of tensors and loops.
Definition: Merger.cpp:223
bool hasAnySparse(const BitVector &bits) const
Returns true if any TensorLoopId in the bitvector corresponds to sparse level-type.
Definition: Merger.cpp:669
LatPointId addLat(TensorId t, LoopId i, ExprId e)
Constructs a new iteration lattice point, and returns its identifier.
Definition: Merger.cpp:291
ExprId addLoopVarExp(LoopId i)
Constructs a new loop-variable expression, and returns its identifier.
Definition: Merger.cpp:253
bool latGT(LatPointId p0, LatPointId p1) const
Returns true if p0 > p1.
Definition: Merger.cpp:502
const TensorExp & exp(ExprId e) const
Convenience getters to immediately access the stored nodes.
Definition: Merger.h:541
constexpr LoopId loop(TensorLoopId b) const
Gets the loop-identifier of the TensorLoopId.
Definition: Merger.h:347
const LatPoint & lat(LatPointId p) const
Definition: Merger.h:545
bool onlyDenseDiff(LatPointId p0, LatPointId p1) const
Returns true if p0 and p1 only differ in dense.
Definition: Merger.cpp:515
ExprId addInvariantExp(Value v)
Constructs a new invariant expression, and returns its identifier.
Definition: Merger.cpp:261
constexpr TensorId makeTensorId(unsigned t) const
Safely converts the argument to a tensor identifier.
Definition: Merger.h:242
LevelType getLvlType(TensorId t, LoopId i) const
Gets the level-type of the tth tensor on ith loop.
Definition: Merger.h:398
Value buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0, Value v1) const
Rebuilds SSA format from a tensor expression.
Definition: Merger.cpp:1538
constexpr TensorLoopId makeTensorLoopId(unsigned t, unsigned i) const
Safely converts the arguments to a pair of (tensor,loop) identifiers.
Definition: Merger.h:254
bool expIsTensor(ExprId e, TensorId t) const
Returns true if the expression is (kTensor t).
Definition: Merger.h:369
LatSetId mapSet(TensorExp::Kind kind, LatSetId s, Value v=Value(), Operation *op=nullptr)
Maps the unary operator over the lattice set of the operand, i.e.
Definition: Merger.cpp:399
@ Type
An inlay hint that for a type annotation.
static constexpr unsigned kInvalidId
A constant serving as the canonically invalid identifier, regardless of the identifier type.
Definition: Merger.h:30
static bool isAdmissibleBranchExp(Operation *op, Block *block, Value v)
Ensures that the sparsifier can generate code for expression.
Definition: Merger.cpp:1222
constexpr const char * toMLIRString(LevelType lt)
Returns string representation of the given dimension level type.
Definition: Enums.h:202
unsigned LatSetId
LatSet identifiers.
Definition: Merger.h:57
static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region &region, ValueRange vals)
Definition: Merger.cpp:1493
constexpr bool isLooseCompressedLT(LevelType lt)
Check if the LevelType is loose compressed (regardless of properties).
Definition: Enums.h:271
std::pair< Level, LevelType > LvlLTPair
A pair of level and its corresponding LevelType of a tensor.
Definition: Merger.h:60
unsigned TensorLoopId
A compressed representation of std::pair<TensorId, LoopId>.
Definition: Merger.h:44
static Value buildUnaryPresent(RewriterBase &rewriter, Location loc, Operation *op, Value v0)
Definition: Merger.cpp:1510
uint64_t Level
The type of level identifiers and level-ranks.
Definition: SparseTensor.h:38
static Value buildBinaryOverlap(RewriterBase &rewriter, Location loc, Operation *op, Value v0, Value v1)
Definition: Merger.cpp:1524
static bool isAdmissibleBranch(Operation *op, Region &region)
Ensures that the sparsifier can generate code for branch.
Definition: Merger.cpp:1242
unsigned LoopId
Loop identifiers.
Definition: Merger.h:38
constexpr bool is2OutOf4LT(LevelType lt)
Check if the LevelType is 2OutOf4 (regardless of properties).
Definition: Enums.h:277
constexpr bool isDenseLT(LevelType lt)
Check if the LevelType is dense (regardless of properties).
Definition: Enums.h:253
static const char * kindToOpSymbol(TensorExp::Kind kind)
Definition: Merger.cpp:692
constexpr bool isSingletonLT(LevelType lt)
Check if the LevelType is singleton (regardless of properties).
Definition: Enums.h:265
LevelType
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition: Enums.h:168
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
constexpr bool isCompressedLT(LevelType lt)
Check if the LevelType is compressed (regardless of properties).
Definition: Enums.h:259
unsigned ExprId
TensorExp identifiers.
Definition: Merger.h:48
static ExpArity getExpArity(TensorExp::Kind k)
Definition: Merger.cpp:28
unsigned LatPointId
LatPoint identifiers.
Definition: Merger.h:52
std::pair< LoopId, unsigned > LoopCoeffPair
A pair of loop id and its coefficients.
Definition: Merger.h:64
unsigned TensorId
Tensor identifiers, chosen to be the BlockArgument::getArgNumber of the value passed to Merger::build...
Definition: Merger.h:35
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
BitVector bits
Conjunction of all TensorLoopIds involved in the tensor expression.
Definition: Merger.h:209
LoopId loop
kLoopVar expressions simply have a loop identifier.
Definition: Merger.h:96
Value val
Direct link to IR for an invariant or the destination value (to infer destination type) of a cast ope...
Definition: Merger.h:105
Kind
Tensor expression kind.
Definition: Merger.h:129
Children children
All other expressions hold the ExprIds of their children.
Definition: Merger.h:99
Attribute attr
An optional attribute that is required to determine the semantics of the operations.
Definition: Merger.h:118
TensorId tensor
kTensor expressions simply have a tensor identifier.
Definition: Merger.h:93
Kind kind
Tensor expression kind.
Definition: Merger.h:89
TensorExp(Kind k, unsigned x, ExprId y, Value v, Operation *op, Attribute a)
The x parameter has different types depending on the value of the k parameter.
Definition: Merger.cpp:105
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.