MLIR  19.0.0git
Merger.cpp
Go to the documentation of this file.
1 //===- Merger.cpp - Implementation of iteration lattices ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
14 
15 #include "mlir/IR/Operation.h"
16 #include "llvm/Support/Debug.h"
17 #include <optional>
18 
19 namespace mlir {
20 namespace sparse_tensor {
21 
22 enum class ExpArity {
23  kNullary,
24  kUnary,
25  kBinary,
26 };
27 
29  switch (k) {
30  // Leaf.
35  return ExpArity::kNullary;
70  return ExpArity::kUnary;
71  // Binary operations.
95  case TensorExp::Kind::kDenseOp: // kDenseOp can *at most* have two operands
96  return ExpArity::kBinary;
97  }
98  llvm_unreachable("unexpected kind");
99 }
100 
101 //===----------------------------------------------------------------------===//
102 // Constructors.
103 //===----------------------------------------------------------------------===//
104 
106  Operation *o, Attribute a)
107  : kind(k), val(v), op(o) {
108  switch (kind) {
109  // Leaf.
111  assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o);
112  tensor = x;
113  return;
115  assert(x == detail::kInvalidId && y == detail::kInvalidId && !v && !o);
116  return;
118  assert(x == detail::kInvalidId && y == detail::kInvalidId && v && !o);
119  return;
121  assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o);
122  loop = x;
123  return;
124  // Unary operations.
145  assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o);
146  children.e0 = x;
147  children.e1 = y;
148  return;
160  assert(x != detail::kInvalidId && y == detail::kInvalidId && v && !o);
161  children.e0 = x;
162  children.e1 = y;
163  return;
166  assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && o);
167  children.e0 = x;
168  children.e1 = y;
169  return;
171  // No assertion on y can be made, as the branching paths involve both
172  // a unary (`mapSet`) and binary (`disjSet`) pathway.
173  assert(x != detail::kInvalidId && !v && o);
174  children.e0 = x;
175  children.e1 = y;
176  return;
177  // Binary operations.
197  assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && !o);
198  children.e0 = x;
199  children.e1 = y;
200  return;
203  assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && !o);
204  attr = a;
205  children.e0 = x;
206  children.e1 = y;
207  return;
210  assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && o);
211  children.e0 = x;
212  children.e1 = y;
213  return;
215  assert(x != detail::kInvalidId && !v && o);
216  children.e0 = x;
217  children.e1 = y;
218  return;
219  }
220  llvm_unreachable("unexpected kind");
221 }
222 
223 Merger::Merger(unsigned numInputOutputTensors, unsigned numLoops,
224  unsigned maxLvlRank)
225  : outTensor(numInputOutputTensors - 1),
226  syntheticTensor(numInputOutputTensors),
227  numTensors(numInputOutputTensors + 1), numLoops(numLoops),
228  hasSparseOut(false),
229  lvlTypes(numTensors,
230  std::vector<LevelType>(numLoops, LevelFormat::Undef)),
231  loopToLvl(numTensors,
232  std::vector<std::optional<Level>>(numLoops, std::nullopt)),
233  lvlToLoop(numTensors,
234  std::vector<std::optional<LoopId>>(maxLvlRank, std::nullopt)),
235  loopToUnresolvedLvls(numLoops, std::vector<std::optional<LvlLTPair>>(
236  numTensors, std::nullopt)),
237  levelToDependentLoop(numTensors,
238  std::vector<std::vector<LoopCoeffPair>>(
239  maxLvlRank, std::vector<LoopCoeffPair>())),
240  loopBounds(numLoops, std::make_pair(numTensors, numLoops)) {}
241 
242 //===----------------------------------------------------------------------===//
243 // Lattice methods.
244 //===----------------------------------------------------------------------===//
245 
247  assert(isValidTensorId(t));
248  const ExprId eNew(tensorExps.size());
249  tensorExps.emplace_back(TensorExp::Kind::kTensor, t, detail::kInvalidId,
250  Value(), nullptr, nullptr);
251  return eNew;
252 }
253 
255  assert(isValidLoopId(i));
256  const ExprId eNew(tensorExps.size());
257  tensorExps.emplace_back(TensorExp::Kind::kLoopVar, i, detail::kInvalidId,
258  Value(), nullptr, nullptr);
259  return eNew;
260 }
261 
263  const ExprId eNew(tensorExps.size());
264  tensorExps.emplace_back(TensorExp::Kind::kInvariant, detail::kInvalidId,
265  detail::kInvalidId, v, nullptr, nullptr);
266  return eNew;
267 }
268 
270  const ExprId eNew(tensorExps.size());
271  tensorExps.emplace_back(TensorExp::Kind::kSynZero, detail::kInvalidId,
272  detail::kInvalidId, Value(), nullptr, nullptr);
273  return eNew;
274 }
275 
277  Attribute attr) {
278  assert(k > TensorExp::Kind::kLoopVar);
279  const ExprId eNew(tensorExps.size());
280  tensorExps.emplace_back(k, e0, e1, Value(), op, attr);
281  return eNew;
282 }
283 
285  Attribute attr) {
286  assert(k > TensorExp::Kind::kLoopVar);
287  const ExprId eNew(tensorExps.size());
288  tensorExps.emplace_back(k, e, detail::kInvalidId, v, op, attr);
289  return eNew;
290 }
291 
293  const LatPointId pNew(latPoints.size());
294  const unsigned size = numLoops * numTensors;
295  const TensorLoopId b = makeTensorLoopId(t, i);
296  latPoints.emplace_back(size, e);
297  latPoints[pNew].bits.set(b);
298  return pNew;
299 }
300 
301 LatPointId Merger::addLat(const BitVector &bits, ExprId e) {
302  assert(bits.size() == numLoops * numTensors);
303  const LatPointId pNew(latPoints.size());
304  latPoints.emplace_back(bits, e);
305  return pNew;
306 }
307 
309  const LatSetId sNew(latSets.size());
310  latSets.emplace_back();
311  return sNew;
312 }
313 
315  Operation *op) {
316  TensorExp::Kind kind = exp(e).kind;
317  Attribute attr = exp(e).attr;
318  const LatPointId pNew(latPoints.size());
319  const auto &point0 = lat(p0);
320  const auto &point1 = lat(p1);
321  BitVector bits(point0.bits);
322  bits |= point1.bits;
323  const ExprId ne = addExp(kind, point0.exp, point1.exp, op, attr);
324  latPoints.emplace_back(bits, ne);
325  return pNew;
326 }
327 
329  const LatSetId sNew = addSet();
330  auto &setNew = latSets[sNew];
331  for (const LatPointId p0 : set(s0))
332  for (const LatPointId p1 : set(s1))
333  setNew.push_back(conjLat(e, p0, p1, op));
334  return sNew;
335 }
336 
338  const LatSetId sNew = conjSet(e, s0, s1, op);
339  TensorExp::Kind kind = exp(e).kind;
340 
341  // Followed by all in s0.
342  latSets[sNew].append(latSets[s0]);
343  // Map binary 0-y to unary -y.
344  // TODO: move this if-else logic into buildLattices
345  if (kind == TensorExp::Kind::kSubF)
346  s1 = mapSet(TensorExp::Kind::kNegF, s1);
347  else if (kind == TensorExp::Kind::kSubC)
348  s1 = mapSet(TensorExp::Kind::kNegC, s1);
349  else if (kind == TensorExp::Kind::kSubI)
350  s1 = mapSet(TensorExp::Kind::kNegI, s1);
351  // Followed by all in s1.
352  latSets[sNew].append(latSets[s1]);
353  return sNew;
354 }
355 
357  assert(exp(e).kind == TensorExp::Kind::kCmpI ||
358  exp(e).kind == TensorExp::Kind::kCmpF);
359  const LatSetId sNew = conjSet(e, s0, s1, nullptr);
360 
361  ExprId e0 = exp(e).children.e0;
362  ExprId e1 = exp(e).children.e1;
363  if (exp(e0).kind == TensorExp::Kind::kSynZero ||
364  exp(e1).kind == TensorExp::Kind::kSynZero) {
365  // lhs and rhs can't be synthetic zero at the same time.
366  assert(exp(e0).kind != exp(e1).kind);
367  // If one of the operands has already been assigned to zero (the
368  // element is absent in the corresponding operand), then we do not
369  // need to build disjunctive set for it.
370  return sNew;
371  }
372 
373  auto lhsSet = mapBinWithSynZeroSet(e, s0, false);
374  auto rhsSet = mapBinWithSynZeroSet(e, s1, true);
375  latSets[sNew].append(latSets[lhsSet]);
376  latSets[sNew].append(latSets[rhsSet]);
377  return sNew;
378 }
379 
381  bool includeLeft, TensorExp::Kind ltrans,
382  Operation *opleft, bool includeRight,
383  TensorExp::Kind rtrans, Operation *opright) {
384  const LatSetId sNew = conjSet(e, s0, s1, orig);
385  // Left Region.
386  if (includeLeft) {
387  if (opleft)
388  s0 = mapSet(ltrans, s0, Value(), opleft);
389  latSets[sNew].append(latSets[s0]);
390  }
391  // Right Region.
392  if (includeRight) {
393  if (opright)
394  s1 = mapSet(rtrans, s1, Value(), opright);
395  latSets[sNew].append(latSets[s1]);
396  }
397  return sNew;
398 }
399 
401  Operation *op) {
402  assert((TensorExp::Kind::kAbsF <= kind && kind <= TensorExp::Kind::kSelect) ||
403  TensorExp::Kind::kDenseOp == kind);
404  const LatSetId sNew = addSet();
405  auto &setNew = latSets[sNew];
406  for (const LatPointId p : set(s0)) {
407  const auto &point = latPoints[p];
408  setNew.push_back(addLat(point.bits, addExp(kind, point.exp, v, op)));
409  }
410  return sNew;
411 }
412 
414  TensorExp::Kind kind = exp(e).kind;
415  Attribute a = exp(e).attr;
416  assert(TensorExp::Kind::kMulF <= kind && kind <= TensorExp::Kind::kShlI);
417  // Must be a binary operation.
418  const LatSetId sNew = addSet();
419  auto &setNew = latSets[sNew];
420  const ExprId zeroExp = addSynZeroExp();
421  for (const LatPointId p : set(s0)) {
422  const auto &point = latPoints[p];
423  ExprId newExp = lhsZero ? addExp(kind, zeroExp, point.exp, nullptr, a)
424  : addExp(kind, point.exp, zeroExp, nullptr, a);
425  setNew.push_back(addLat(point.bits, newExp));
426  }
427  return sNew;
428 }
429 
431  const LatSetId sNew = addSet();
432  auto &setNew = latSets[sNew];
433  const auto &set0 = set(s0);
434  assert(!set0.empty());
435  const LatPointId p0 = set0[0];
436  for (const LatPointId p1 : set0) {
437  bool add = true;
438  if (p0 != p1) {
439  // Check whether this is a straightforward copy.
440  if (expIsTensor(latPoints[p1].exp, outTensor))
441  continue;
442  // Check whether this conjunction is already covered.
443  for (const LatPointId p2 : setNew) {
444  assert(!latGT(p1, p2)); // Lj => Li would be bad
445  if (onlyDenseDiff(p2, p1)) {
446  add = false;
447  break;
448  }
449  }
450  assert(!add || latGT(p0, p1));
451  }
452  if (add)
453  setNew.push_back(p1);
454  }
455  for (const LatPointId p : setNew)
456  latPoints[p].simple = simplifyCond(sNew, p);
457  return sNew;
458 }
459 
461  // First determine if this lattice point is a *singleton*, i.e.,
462  // the last point in a lattice, no other is less than this one.
463  bool isSingleton = true;
464  for (const LatPointId p1 : set(s0)) {
465  if (p0 != p1 && latGT(p0, p1)) {
466  isSingleton = false;
467  break;
468  }
469  }
470 
471  BitVector simple(latPoints[p0].bits);
472  bool reset = isSingleton && hasAnySparse(simple);
473  const TensorLoopId be = simple.size();
474  TensorLoopId offset = 0; // relative to the end
475  if (!reset)
476  // Starts resetting from a dense level, so that the first bit (if kept)
477  // is not undefined level-type.
478  for (unsigned b = 0; b < be; b++) {
479  if (simple[b] && getLvlType(TensorLoopId{b}).hasDenseSemantic()) {
480  offset = be - b - 1; // relative to the end
481  break;
482  }
483  }
484 
485  // Now apply the two basic rules. We also iterate the bits reversely to always
486  // keep the rightmost bit (which could possibly be a synthetic tensor).
487  for (unsigned b = be - 1 - offset, i = 0; i < be;
488  b = b == 0 ? be - 1 : b - 1, i++) {
489  // Slice on dense level has `locate` property as well, and can be optimized.
490  if (simple[b] && !isSparseLvlWithNonTrivialIdxExp(b)) {
491  const auto lt = getLvlType(b);
492  if (!lt.hasSparseSemantic()) {
493  if (reset)
494  simple.reset(b);
495  reset = true;
496  }
497  }
498  }
499  return simple;
500 }
501 
503  const BitVector &bitsi = lat(i).bits;
504  const BitVector &bitsj = lat(j).bits;
505  assert(bitsi.size() == bitsj.size());
506  if (bitsi.count() > bitsj.count()) {
507  for (TensorLoopId b = 0, be = bitsj.size(); b < be; b++)
508  if (bitsj[b] && !bitsi[b])
509  return false;
510  return true;
511  }
512  return false;
513 }
514 
516  BitVector tmp(latPoints[j].bits);
517  tmp ^= latPoints[i].bits;
518  return !hasAnySparse(tmp);
519 }
520 
522  const auto &expr = exp(e);
523  // First we check `expIsTensor`.
524  if (expr.kind == TensorExp::Kind::kTensor)
525  return expr.tensor == t;
526 
527  switch (getExpArity(expr.kind)) {
528  case ExpArity::kNullary:
529  return false;
530  case ExpArity::kUnary: {
531  const ExprId e0 = expr.children.e0;
532  return expContainsTensor(e0, t);
533  }
534  case ExpArity::kBinary: {
535  const ExprId e0 = expr.children.e0;
536  const ExprId e1 = expr.children.e1;
537  return expContainsTensor(e0, t) || expContainsTensor(e1, t);
538  }
539  }
540  llvm_unreachable("unexpected arity");
541 }
542 
544  const auto &expr = exp(e);
545  switch (expr.kind) {
549  return expContainsTensor(expr.children.e0, outTensor);
553  return expContainsTensor(expr.children.e1, outTensor) ||
554  hasNegateOnOut(expr.children.e0);
556  bool lhsNeg = hasNegateOnOut(expr.children.e0);
557  if (!lhsNeg && expr.children.e1 != detail::kInvalidId)
558  return hasNegateOnOut(expr.children.e1);
559  return lhsNeg;
560  }
561  default: {
562  switch (getExpArity(expr.kind)) {
563  case ExpArity::kNullary:
564  return false;
565  case ExpArity::kUnary:
566  return hasNegateOnOut(expr.children.e0);
567  case ExpArity::kBinary:
568  return hasNegateOnOut(expr.children.e0) ||
569  hasNegateOnOut(expr.children.e1);
570  }
571  }
572  }
573  llvm_unreachable("unexpected kind");
574 }
575 
577  assert(isValidTensorId(t));
578  const auto &expr = exp(e);
579  switch (expr.kind) {
580  // Leaf.
582  return expr.tensor == t;
586  return false;
587  // Unary operations.
620  return isSingleCondition(t, expr.children.e0);
623  return false;
624  // Binary operations.
625  case TensorExp::Kind::kDivF: // note: x / c only
629  assert(!maybeZero(expr.children.e1));
630  return isSingleCondition(t, expr.children.e0);
631  case TensorExp::Kind::kShrS: // note: x >> inv only
634  assert(isInvariant(expr.children.e1));
635  return isSingleCondition(t, expr.children.e0);
641  if (isSingleCondition(t, expr.children.e0))
642  return isSingleCondition(t, expr.children.e1) ||
643  isInvariant(expr.children.e1);
644  if (isSingleCondition(t, expr.children.e1))
645  return isInvariant(expr.children.e0);
646  return false;
650  return isSingleCondition(t, expr.children.e0) &&
651  isSingleCondition(t, expr.children.e1);
660  return false;
662  // Since Merger guarantees all the operands of the kDenseOp to be dense, the
663  // operation must be single-condition.
664  return true;
665  }
666  llvm_unreachable("unexpected kind");
667 }
668 
669 bool Merger::hasAnySparse(const BitVector &bits) const {
670  for (TensorLoopId b : bits.set_bits()) {
671  const auto lt = getLvlType(b);
672  if (lt.hasSparseSemantic())
673  return true;
674  }
675  return hasSparseIdxReduction(bits);
676 }
677 
678 bool Merger::hasSparseIdxReduction(const BitVector &bits) const {
679  for (TensorLoopId b : bits.set_bits())
681  return true;
682  return false;
683 }
684 
685 #ifndef NDEBUG
686 
687 //===----------------------------------------------------------------------===//
688 // Print methods (for debugging).
689 //===----------------------------------------------------------------------===//
690 
691 static const char *kindToOpSymbol(TensorExp::Kind kind) {
692  switch (kind) {
693  // Leaf.
695  return "tensor";
697  return "invariant";
699  return "index";
701  return "0";
702  // Unary operations.
706  return "abs";
708  return "ceil";
710  return "floor";
713  return "sqrt";
716  return "expm1";
719  return "log1p";
722  return "sin";
725  return "tanh";
729  return "-";
741  return "complex.im";
743  return "complex.re";
745  return "cast";
747  return "binary_branch";
749  return "unary";
751  return "select";
752  // Binary operations.
756  return "*";
761  return "/";
765  return "+";
769  return "-";
771  return "&";
773  return "|";
775  return "^";
777  return "a>>";
779  return ">>";
781  return "<<";
784  return "cmp";
786  return "binary";
788  return "reduce";
790  return "dense";
791  }
792  llvm_unreachable("unexpected kind for symbol");
793 }
794 
795 void Merger::dumpExp(ExprId e) const {
796  const auto &expr = exp(e);
797  switch (expr.kind) {
798  // Leaf.
800  if (expr.tensor == syntheticTensor)
801  llvm::dbgs() << "synthetic_";
802  else if (expr.tensor == outTensor)
803  llvm::dbgs() << "output_";
804  llvm::dbgs() << "tensor_" << expr.tensor;
805  break;
807  llvm::dbgs() << "invariant";
808  break;
810  llvm::dbgs() << "0";
811  break;
813  llvm::dbgs() << "loopvar_" << expr.loop;
814  break;
815  // Unary operations.
850  llvm::dbgs() << kindToOpSymbol(expr.kind) << " ";
851  dumpExp(expr.children.e0);
852  break;
853  // Binary operations.
878  llvm::dbgs() << "(";
879  dumpExp(expr.children.e0);
880  llvm::dbgs() << " " << kindToOpSymbol(expr.kind);
881  if (expr.attr)
882  llvm::dbgs() << "{" << expr.attr << "}";
883  if (expr.children.e1 != detail::kInvalidId) {
884  llvm::dbgs() << " ";
885  dumpExp(expr.children.e1);
886  llvm::dbgs() << ")";
887  } else {
888  assert(expr.kind == TensorExp::Kind::kDenseOp);
889  }
890  break;
891  }
892 }
893 
895  const auto &point = lat(p);
896  llvm::dbgs() << "lat(";
897  dumpBits(point.bits);
898  llvm::dbgs() << " :";
899  dumpBits(point.simple);
900  llvm::dbgs() << " : ";
901  dumpExp(point.exp);
902  llvm::dbgs() << " )\n";
903 }
904 
905 void Merger::dumpSet(LatSetId s) const {
906  const auto &ss = set(s);
907  llvm::dbgs() << "{ #" << ss.size() << "\n";
908  for (const LatPointId p : ss) {
909  llvm::dbgs() << " ";
910  dumpLat(p);
911  }
912  llvm::dbgs() << "}\n";
913 }
914 
915 void Merger::dumpBits(const BitVector &bits) const {
916  for (TensorLoopId b = 0, be = bits.size(); b < be; b++) {
917  if (bits[b]) {
918  const TensorId t = tensor(b);
919  const LoopId i = loop(b);
920  const auto lt = lvlTypes[t][i];
922  llvm::dbgs() << " DEP_" << t << "_" << i;
923  else
924  llvm::dbgs() << " i_" << t << "_" << i << "_" << toMLIRString(lt);
925  }
926  }
927 }
928 
929 #endif // NDEBUG
930 
931 //===----------------------------------------------------------------------===//
932 // Builder methods.
933 //===----------------------------------------------------------------------===//
934 
936  // NOTE: The `expr` reference will be invalidated by recursive calls
937  // (and any other method that may add new expressions); therefore, the
938  // code below must make sure to copy fields of `expr` into local variables
939  // before making any recursive calls.
940  const auto &expr = exp(e);
941  const TensorExp::Kind kind = expr.kind;
942  switch (kind) {
943  // Leaf.
948  // Either the loop-var is really used in the tensor expression, or it is
949  // set to the undefined loop-var in that level. An invariant expression,
950  // a proper index value, and a truly dynamic sparse output tensor are set
951  // to a synthetic tensor with undefined indices only to ensure the
952  // iteration space is not skipped as a result of their contents.
953  const LatSetId s = addSet();
954  TensorId t = syntheticTensor;
955  if (kind == TensorExp::Kind::kTensor) {
956  t = expr.tensor;
957  if (hasSparseOut && t == outTensor)
958  t = syntheticTensor;
959  }
960  latSets[s].push_back(addLat(t, i, e));
961  return s;
962  }
963  // Unary operations.
995  // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the
996  // lattice set of the operand through the operator into a new set.
997  //
998  // -y|!y | y |
999  // --+---+---+
1000  // | 0 |-y |
1001  {
1002  const ExprId e0 = expr.children.e0;
1003  const Value v = expr.val;
1004  return mapSet(kind, buildLattices(e0, i), v);
1005  }
1008  // The left or right half of a binary operation which has already
1009  // been split into separate operations for each region.
1010  {
1011  const ExprId e0 = expr.children.e0;
1012  Operation *const op = expr.op;
1013  return mapSet(kind, buildLattices(e0, i), Value(), op);
1014  }
1016  // A custom unary operation.
1017  //
1018  // op y| !y | y |
1019  // ----+----------+------------+
1020  // | absent() | present(y) |
1021  {
1022  const ExprId e0 = expr.children.e0;
1023  UnaryOp unop = cast<UnaryOp>(expr.op);
1024  const LatSetId child0 = buildLattices(e0, i);
1025  Region &absentRegion = unop.getAbsentRegion();
1026  if (absentRegion.empty()) {
1027  // Simple mapping over existing values.
1028  return mapSet(kind, child0, Value(), unop);
1029  }
1030  // Use a disjunction with `unop` on the left and the absent value as an
1031  // invariant on the right.
1032  Block &absentBlock = absentRegion.front();
1033  YieldOp absentYield = cast<YieldOp>(absentBlock.getTerminator());
1034  const Value absentVal = absentYield.getSingleResult();
1035  const ExprId rhs = addInvariantExp(absentVal);
1036  return disjSet(e, child0, buildLattices(rhs, i), unop);
1037  }
1038  // Binary operations.
1043  // A multiplicative operation only needs to be performed
1044  // for the conjunction of sparse iteration spaces.
1045  //
1046  // x*y|!y | y |
1047  // ---+---+---+
1048  // !x | 0 | 0 |
1049  // x | 0 |x*y|
1050  //
1051  // Note even here, 0*NaN=NaN and 0*Inf=NaN, but that is ignored.
1052  {
1053  const ExprId e0 = expr.children.e0;
1054  const ExprId e1 = expr.children.e1;
1055  return conjSet(e, buildLattices(e0, i), buildLattices(e1, i));
1056  }
1061  // A division is tricky, since 0/0, 0/c, c/0 all have
1062  // specific outcomes for floating-point and integers.
1063  // Thus, we need to traverse the full iteration space.
1064  //
1065  // x/y|!y | y |
1066  // ---+---+---+
1067  // !x |0/0|0/y| FP: 0/0=NaN,c/0=Inf,0/c=0 with c true nonzero
1068  // x |x/0|x/y| INT: x/0=exception for any x
1069  //
1070  // TODO: for now we "fixed" this by only accepting x/c cases
1071  // during expression building, so that the conjunction
1072  // rules applies (viz. x/c = x*(1/c) as far as lattice
1073  // construction is concerned).
1074  {
1075  const ExprId e0 = expr.children.e0;
1076  const ExprId e1 = expr.children.e1;
1077  assert(!maybeZero(e1));
1078  return conjSet(e, buildLattices(e0, i), buildLattices(e1, i));
1079  }
1086  case TensorExp::Kind::kOrI:
1088  // An additive operation needs to be performed
1089  // for the disjunction of sparse iteration spaces.
1090  //
1091  // x+y|!y | y | x-y|!y | y |
1092  // ---+---+---+ ---+---+---+
1093  // !x | 0 | y | !x | 0 |-y |
1094  // x | x |x+y| x | x |x-y|
1095  {
1096  const ExprId e0 = expr.children.e0;
1097  const ExprId e1 = expr.children.e1;
1098  return disjSet(e, buildLattices(e0, i), buildLattices(e1, i));
1099  }
1102  // A comparison operation needs to be performed
1103  // for the disjunction of sparse iteration spaces.
1104  //
1105  // x < y | !y | y |
1106  // -------+-------+-------+
1107  // !x | 0 | 0 < y |
1108  // x | x < 0 | x < y |
1109  {
1110  const ExprId e0 = expr.children.e0;
1111  const ExprId e1 = expr.children.e1;
1112  return disjSetWithZero(e, buildLattices(e0, i), buildLattices(e1, i));
1113  }
1117  // A shift operation by an invariant amount (viz. tensor expressions
1118  // can only occur at the left-hand-side of the operator) can be handled
1119  // with the conjunction rule.
1120  {
1121  const ExprId e0 = expr.children.e0;
1122  const ExprId e1 = expr.children.e1;
1123  assert(isInvariant(e1));
1124  return conjSet(e, buildLattices(e0, i), buildLattices(e1, i));
1125  }
1127  // A custom binary operation.
1128  //
1129  // x op y| !y | y |
1130  // ------+---------+--------------+
1131  // !x | empty | right(y) |
1132  // x | left(x) | overlap(x,y) |
1133  {
1134  const ExprId e0 = expr.children.e0;
1135  const ExprId e1 = expr.children.e1;
1136  BinaryOp binop = cast<BinaryOp>(expr.op);
1137  const LatSetId child0 = buildLattices(e0, i);
1138  const LatSetId child1 = buildLattices(e1, i);
1139  Region &leftRegion = binop.getLeftRegion();
1140  Region &rightRegion = binop.getRightRegion();
1141  // Left Region.
1142  Operation *leftYield = nullptr;
1143  if (!leftRegion.empty()) {
1144  Block &leftBlock = leftRegion.front();
1145  leftYield = leftBlock.getTerminator();
1146  }
1147  // Right Region.
1148  Operation *rightYield = nullptr;
1149  if (!rightRegion.empty()) {
1150  Block &rightBlock = rightRegion.front();
1151  rightYield = rightBlock.getTerminator();
1152  }
1153  bool includeLeft = binop.getLeftIdentity() || !leftRegion.empty();
1154  bool includeRight = binop.getRightIdentity() || !rightRegion.empty();
1155  return combiSet(e, child0, child1, binop, includeLeft,
1156  TensorExp::Kind::kBinaryBranch, leftYield, includeRight,
1157  TensorExp::Kind::kBinaryBranch, rightYield);
1158  }
1160  // A custom reduce operation.
1161  {
1162  const ExprId e0 = expr.children.e0;
1163  const ExprId e1 = expr.children.e1;
1164  Operation *const op = expr.op;
1165  return conjSet(e, buildLattices(e0, i), buildLattices(e1, i), op);
1166  }
1168  // It does not really matter whether we use conjunctive/disjunctive set
1169  // here, as all the operands of kDenseOp must be dense, the disjunctive set
1170  // will be optimized into conjunctive set eventually.
1171  if (expr.children.e1 == detail::kInvalidId) {
1172  const ExprId e0 = expr.children.e0;
1173  Operation *const op = expr.op;
1174  return mapSet(kind, buildLattices(e0, i), Value(), op);
1175  }
1176 
1177  const ExprId e0 = expr.children.e0;
1178  const ExprId e1 = expr.children.e1;
1179  Operation *const op = expr.op;
1180  return conjSet(e, buildLattices(e0, i), buildLattices(e1, i), op);
1181  }
1182  }
1183  llvm_unreachable("unexpected expression kind");
1184 }
1185 
1186 std::optional<ExprId> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
1187  // Build the linalg semantics backward from yield.
1188  Operation *yield = op.getRegion().front().getTerminator();
1189  assert(isa<linalg::YieldOp>(yield));
1190  return buildTensorExp(op, yield->getOperand(0)).first;
1191 }
1192 
1193 /// Only returns false if we are certain this is a nonzero.
1194 bool Merger::maybeZero(ExprId e) const {
1195  const auto &expr = exp(e);
1196  if (expr.kind == TensorExp::Kind::kInvariant) {
1197  if (auto c = expr.val.getDefiningOp<complex::ConstantOp>()) {
1198  ArrayAttr arrayAttr = c.getValue();
1199  return cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
1200  cast<FloatAttr>(arrayAttr[1]).getValue().isZero();
1201  }
1202  if (auto c = expr.val.getDefiningOp<arith::ConstantIntOp>())
1203  return c.value() == 0;
1204  if (auto c = expr.val.getDefiningOp<arith::ConstantFloatOp>())
1205  return c.value().isZero();
1206  }
1207  return true;
1208 }
1209 
1210 Type Merger::inferType(ExprId e, Value src) const {
1211  // Obtain the destination type from the cast node.
1212  Type dtp = exp(e).val.getType();
1213  // Inspect source type. For vector types, apply the same
1214  // vectorization to the destination type.
1215  if (auto vtp = dyn_cast<VectorType>(src.getType()))
1216  return VectorType::get(vtp.getNumElements(), dtp, vtp.getScalableDims());
1217  return dtp;
1218 }
1219 
1220 /// Ensures that the sparsifier can generate code for expression.
1221 static bool isAdmissibleBranchExp(Operation *op, Block *block, Value v) {
1222  // Arguments are always admissible.
1223  if (isa<BlockArgument>(v))
1224  return true;
1225  // Accept index anywhere.
1226  Operation *def = v.getDefiningOp();
1227  if (isa<linalg::IndexOp>(def))
1228  return true;
1229  // Operation defined outside branch.
1230  if (def->getBlock() != block)
1231  return def->getBlock() != op->getBlock(); // invariant?
1232  // Operation defined within branch. Anything is accepted,
1233  // as long as all subexpressions are admissible.
1234  for (unsigned i = 0, n = def->getNumOperands(); i < n; i++)
1235  if (!isAdmissibleBranchExp(op, block, def->getOperand(i)))
1236  return false;
1237  return true;
1238 }
1239 
1240 /// Ensures that the sparsifier can generate code for branch.
1241 static bool isAdmissibleBranch(Operation *op, Region &region) {
1242  if (region.empty())
1243  return true;
1244  // Build the semi-ring branch semantics backward from yield.
1245  Operation *yield = region.front().getTerminator();
1246  assert(isa<YieldOp>(yield));
1247  return isAdmissibleBranchExp(op, &region.front(), yield->getOperand(0));
1248 }
1249 
1250 std::pair<std::optional<ExprId>, bool>
1251 Merger::buildTensorExp(linalg::GenericOp op, Value v) {
1252  // Recursion leaves.
1253  if (auto arg = dyn_cast<BlockArgument>(v)) {
1254  const TensorId tid = makeTensorId(arg.getArgNumber());
1255  // Any argument of the generic op that is not marked as a scalar
1256  // argument is considered a tensor, indexed by the implicit loop
1257  // bounds. This includes rank-0 tensor arguments.
1258  if (arg.getOwner()->getParentOp() == op) {
1259  OpOperand &t = op->getOpOperand(tid);
1260  bool hasSpDep = getSparseTensorEncoding(t.get().getType()) != nullptr;
1261  if (!op.isScalar(&t))
1262  return {addTensorExp(tid), hasSpDep};
1263  v = t.get(); // get scalar value
1264  }
1265  // Any other argument (marked as scalar argument for the generic op
1266  // or belonging to an enveloping op) is considered invariant.
1267  return {addInvariantExp(v), /*hasSpDep=*/false};
1268  }
1269  // Something defined outside is invariant.
1270  Operation *def = v.getDefiningOp();
1271  if (def->getBlock() != &op.getRegion().front())
1272  return {addInvariantExp(v), /*hasSpDep=*/false};
1273  // Construct index operations.
1274  if (def->getNumOperands() == 0) {
1275  if (auto indexOp = dyn_cast<linalg::IndexOp>(def))
1276  return {addLoopVarExp(makeLoopId(indexOp.getDim())), /*hasSpDep=*/false};
1277  }
1278 
1279  // Construct unary operations if subexpression can be built.
1280  if (def->getNumOperands() == 1) {
1281  const auto [x, hasSpDep] = buildTensorExp(op, def->getOperand(0));
1282  if (x.has_value()) {
1283  const ExprId e = *x;
1284  if (isa<math::AbsFOp>(def))
1285  return {addExp(TensorExp::Kind::kAbsF, e), hasSpDep};
1286  if (isa<complex::AbsOp>(def))
1287  return {addExp(TensorExp::Kind::kAbsC, e), hasSpDep};
1288  if (isa<math::AbsIOp>(def))
1289  return {addExp(TensorExp::Kind::kAbsI, e), hasSpDep};
1290  if (isa<math::CeilOp>(def))
1291  return {addExp(TensorExp::Kind::kCeilF, e), hasSpDep};
1292  if (isa<math::FloorOp>(def))
1293  return {addExp(TensorExp::Kind::kFloorF, e), hasSpDep};
1294  if (isa<math::SqrtOp>(def))
1295  return {addExp(TensorExp::Kind::kSqrtF, e), hasSpDep};
1296  if (isa<complex::SqrtOp>(def))
1297  return {addExp(TensorExp::Kind::kSqrtC, e), hasSpDep};
1298  if (isa<math::ExpM1Op>(def))
1299  return {addExp(TensorExp::Kind::kExpm1F, e), hasSpDep};
1300  if (isa<complex::Expm1Op>(def))
1301  return {addExp(TensorExp::Kind::kExpm1C, e), hasSpDep};
1302  if (isa<math::Log1pOp>(def))
1303  return {addExp(TensorExp::Kind::kLog1pF, e), hasSpDep};
1304  if (isa<complex::Log1pOp>(def))
1305  return {addExp(TensorExp::Kind::kLog1pC, e), hasSpDep};
1306  if (isa<math::SinOp>(def))
1307  return {addExp(TensorExp::Kind::kSinF, e), hasSpDep};
1308  if (isa<complex::SinOp>(def))
1309  return {addExp(TensorExp::Kind::kSinC, e), hasSpDep};
1310  if (isa<math::TanhOp>(def))
1311  return {addExp(TensorExp::Kind::kTanhF, e), hasSpDep};
1312  if (isa<complex::TanhOp>(def))
1313  return {addExp(TensorExp::Kind::kTanhC, e), hasSpDep};
1314  if (isa<arith::NegFOp>(def))
1315  return {addExp(TensorExp::Kind::kNegF, e), hasSpDep}; // no negi in std
1316  if (isa<complex::NegOp>(def))
1317  return {addExp(TensorExp::Kind::kNegC, e), hasSpDep};
1318  if (isa<arith::TruncFOp>(def))
1319  return {addExp(TensorExp::Kind::kTruncF, e, v), hasSpDep};
1320  if (isa<arith::ExtFOp>(def))
1321  return {addExp(TensorExp::Kind::kExtF, e, v), hasSpDep};
1322  if (isa<arith::FPToSIOp>(def))
1323  return {addExp(TensorExp::Kind::kCastFS, e, v), hasSpDep};
1324  if (isa<arith::FPToUIOp>(def))
1325  return {addExp(TensorExp::Kind::kCastFU, e, v), hasSpDep};
1326  if (isa<arith::SIToFPOp>(def))
1327  return {addExp(TensorExp::Kind::kCastSF, e, v), hasSpDep};
1328  if (isa<arith::UIToFPOp>(def))
1329  return {addExp(TensorExp::Kind::kCastUF, e, v), hasSpDep};
1330  if (isa<arith::ExtSIOp>(def))
1331  return {addExp(TensorExp::Kind::kCastS, e, v), hasSpDep};
1332  if (isa<arith::ExtUIOp>(def))
1333  return {addExp(TensorExp::Kind::kCastU, e, v), hasSpDep};
1334  if (isa<arith::IndexCastOp>(def))
1335  return {addExp(TensorExp::Kind::kCastIdx, e, v), hasSpDep};
1336  if (isa<arith::TruncIOp>(def))
1337  return {addExp(TensorExp::Kind::kTruncI, e, v), hasSpDep};
1338  if (isa<complex::ImOp>(def))
1339  return {addExp(TensorExp::Kind::kCIm, e), hasSpDep};
1340  if (isa<complex::ReOp>(def))
1341  return {addExp(TensorExp::Kind::kCRe, e), hasSpDep};
1342  if (isa<arith::BitcastOp>(def))
1343  return {addExp(TensorExp::Kind::kBitCast, e, v), hasSpDep};
1344  if (auto unop = dyn_cast<sparse_tensor::UnaryOp>(def)) {
1345  if (isAdmissibleBranch(unop, unop.getPresentRegion()) &&
1346  isAdmissibleBranch(unop, unop.getAbsentRegion()))
1347  return {addExp(TensorExp::Kind::kUnary, e, Value(), def), hasSpDep};
1348  }
1349  if (auto selop = dyn_cast<sparse_tensor::SelectOp>(def)) {
1350  if (isAdmissibleBranch(selop, selop.getRegion()))
1351  return {addExp(TensorExp::Kind::kSelect, e, Value(), def), hasSpDep};
1352  }
1353  }
1354  }
1355  // Construct binary operations if subexpressions can be built.
1356  // See buildLattices() for an explanation of rejecting certain
1357  // division and shift operations.
1358  if (def->getNumOperands() == 2) {
1359  const auto [x, xDepSp] = buildTensorExp(op, def->getOperand(0));
1360  const auto [y, yDepSp] = buildTensorExp(op, def->getOperand(1));
1361  bool hasSpDep = xDepSp || yDepSp;
1362  if (x.has_value() && y.has_value()) {
1363  const ExprId e0 = *x;
1364  const ExprId e1 = *y;
1365  if (isa<arith::MulFOp>(def))
1366  return {addExp(TensorExp::Kind::kMulF, e0, e1), hasSpDep};
1367  if (isa<complex::MulOp>(def))
1368  return {addExp(TensorExp::Kind::kMulC, e0, e1), hasSpDep};
1369  if (isa<arith::MulIOp>(def))
1370  return {addExp(TensorExp::Kind::kMulI, e0, e1), hasSpDep};
1371  if (isa<arith::DivFOp>(def) && !maybeZero(e1))
1372  return {addExp(TensorExp::Kind::kDivF, e0, e1), hasSpDep};
1373  if (isa<complex::DivOp>(def) && !maybeZero(e1))
1374  return {addExp(TensorExp::Kind::kDivC, e0, e1), hasSpDep};
1375  if (isa<arith::DivSIOp>(def) && !maybeZero(e1))
1376  return {addExp(TensorExp::Kind::kDivS, e0, e1), hasSpDep};
1377  if (isa<arith::DivUIOp>(def) && !maybeZero(e1))
1378  return {addExp(TensorExp::Kind::kDivU, e0, e1), hasSpDep};
1379  if (isa<arith::AddFOp>(def))
1380  return {addExp(TensorExp::Kind::kAddF, e0, e1), hasSpDep};
1381  if (isa<complex::AddOp>(def))
1382  return {addExp(TensorExp::Kind::kAddC, e0, e1), hasSpDep};
1383  if (isa<arith::AddIOp>(def))
1384  return {addExp(TensorExp::Kind::kAddI, e0, e1), hasSpDep};
1385  if (isa<arith::SubFOp>(def))
1386  return {addExp(TensorExp::Kind::kSubF, e0, e1), hasSpDep};
1387  if (isa<complex::SubOp>(def))
1388  return {addExp(TensorExp::Kind::kSubC, e0, e1), hasSpDep};
1389  if (isa<arith::SubIOp>(def))
1390  return {addExp(TensorExp::Kind::kSubI, e0, e1), hasSpDep};
1391  if (isa<arith::AndIOp>(def))
1392  return {addExp(TensorExp::Kind::kAndI, e0, e1), hasSpDep};
1393  if (isa<arith::OrIOp>(def))
1394  return {addExp(TensorExp::Kind::kOrI, e0, e1), hasSpDep};
1395  if (isa<arith::XOrIOp>(def))
1396  return {addExp(TensorExp::Kind::kXorI, e0, e1), hasSpDep};
1397  if (isa<arith::ShRSIOp>(def) && isInvariant(e1))
1398  return {addExp(TensorExp::Kind::kShrS, e0, e1), hasSpDep};
1399  if (isa<arith::ShRUIOp>(def) && isInvariant(e1))
1400  return {addExp(TensorExp::Kind::kShrU, e0, e1), hasSpDep};
1401  if (isa<arith::ShLIOp>(def) && isInvariant(e1))
1402  return {addExp(TensorExp::Kind::kShlI, e0, e1), hasSpDep};
1403  if (auto ci = dyn_cast<arith::CmpIOp>(def)) {
1404  if (ci.getPredicate() == arith::CmpIPredicate::eq &&
1405  ci.getPredicate() == arith::CmpIPredicate::sle &&
1406  ci.getPredicate() == arith::CmpIPredicate::sge &&
1407  ci.getPredicate() == arith::CmpIPredicate::ule &&
1408  ci.getPredicate() == arith::CmpIPredicate::uge) {
1409  // We can not sparsify comparison with equal, this is because 0 <= 0
1410  // yields true, and thus densifies the result.
1411  return {std::nullopt, false};
1412  }
1413 
1414  auto e = addExp(TensorExp::Kind::kCmpI, e0, e1, nullptr,
1415  ci.getPredicateAttr());
1416  return {e, hasSpDep};
1417  }
1418  if (auto cf = dyn_cast<arith::CmpFOp>(def)) {
1419  if (cf.getPredicate() == arith::CmpFPredicate::OEQ &&
1420  cf.getPredicate() == arith::CmpFPredicate::OGE &&
1421  cf.getPredicate() == arith::CmpFPredicate::OLE &&
1422  cf.getPredicate() == arith::CmpFPredicate::ONE &&
1423  cf.getPredicate() == arith::CmpFPredicate::UEQ &&
1424  cf.getPredicate() == arith::CmpFPredicate::UGE &&
1425  cf.getPredicate() == arith::CmpFPredicate::ULE &&
1426  cf.getPredicate() == arith::CmpFPredicate::ORD &&
1427  cf.getPredicate() == arith::CmpFPredicate::UNO) {
1428  // We can not sparsify comparison with equal, this is because 0 <= 0
1429  // yields true, and thus densifies the result.
1430  return {std::nullopt, false};
1431  }
1432  auto e = addExp(TensorExp::Kind::kCmpF, e0, e1, nullptr,
1433  cf.getPredicateAttr());
1434  return {e, hasSpDep};
1435  }
1436  if (auto binop = dyn_cast<sparse_tensor::BinaryOp>(def)) {
1437  if (isAdmissibleBranch(binop, binop.getOverlapRegion()) &&
1438  (binop.getLeftIdentity() ||
1439  isAdmissibleBranch(binop, binop.getLeftRegion())) &&
1440  (binop.getRightIdentity() ||
1441  isAdmissibleBranch(binop, binop.getRightRegion())))
1442  return {addExp(TensorExp::Kind::kBinary, e0, e1, def), hasSpDep};
1443  }
1444  }
1445  }
1446  // Construct ternary operations if subexpressions can be built.
1447  if (def->getNumOperands() == 3) {
1448  const auto [x, xDepSp] = buildTensorExp(op, def->getOperand(0));
1449  const auto [y, yDepSp] = buildTensorExp(op, def->getOperand(1));
1450  const auto [z, zDepSp] = buildTensorExp(op, def->getOperand(2));
1451  bool hasSpDep = xDepSp || yDepSp || zDepSp;
1452  if (x.has_value() && y.has_value() && z.has_value()) {
1453  const ExprId e0 = *x;
1454  const ExprId e1 = *y;
1455  if (auto redop = dyn_cast<sparse_tensor::ReduceOp>(def)) {
1456  if (isAdmissibleBranch(redop, redop.getRegion()))
1457  return {addExp(TensorExp::Kind::kReduce, e0, e1, def), hasSpDep};
1458  }
1459  }
1460  }
1461 
1462  // If we reach here, we are dealing with an operation that is not currently
1463  // sparsifiable. We can still generate code for it if all its operands only
1464  // have dense dependencies (i.e., all the values are loaded from dense
1465  // tensors).
1466  if (def->getNumResults() != 1) // only handle single result operation.
1467  return {std::nullopt, false};
1468 
1469  SmallVector<std::pair<std::optional<ExprId>, bool>, 2> subExp;
1470  // Builds all the sub-expressions
1471  for (Value operand : def->getOperands())
1472  subExp.push_back(buildTensorExp(op, operand));
1473 
1474  if (llvm::all_of(subExp,
1475  [](auto e) { return e.first.has_value() && !e.second; })) {
1476  // All the subexpressions can be built and has *no* sparse dependencies.
1477  if (subExp.size() == 2) {
1478  auto e = addExp(TensorExp::Kind::kDenseOp, *subExp[0].first,
1479  *subExp[1].first, def);
1480  return {e, false};
1481  }
1482  if (subExp.size() == 1) {
1483  auto e = addExp(TensorExp::Kind::kDenseOp, *subExp[0].first,
1484  detail::kInvalidId, def);
1485  return {e, false};
1486  }
1487  }
1488  // Cannot build.
1489  return {std::nullopt, false};
1490 }
1491 
1492 static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region &region,
1493  ValueRange vals) {
1494  // Make a clone of overlap region.
1495  Region tmpRegion;
1496  IRMapping mapper;
1497  region.cloneInto(&tmpRegion, tmpRegion.begin(), mapper);
1498  Block &clonedBlock = tmpRegion.front();
1499  YieldOp clonedYield = cast<YieldOp>(clonedBlock.getTerminator());
1500  // Merge cloned block and return yield value.
1501  Operation *placeholder = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1502  rewriter.inlineBlockBefore(&tmpRegion.front(), placeholder, vals);
1503  Value val = clonedYield.getSingleResult();
1504  rewriter.eraseOp(clonedYield);
1505  rewriter.eraseOp(placeholder);
1506  return val;
1507 }
1508 
1510  Operation *op, Value v0) {
1511  if (!v0)
1512  // Empty input value must be propagated.
1513  return Value();
1514  UnaryOp unop = cast<UnaryOp>(op);
1515  Region &presentRegion = unop.getPresentRegion();
1516  if (presentRegion.empty())
1517  // Uninitialized Value() will be interpreted as missing data in the
1518  // output.
1519  return Value();
1520  return insertYieldOp(rewriter, loc, presentRegion, {v0});
1521 }
1522 
1524  Operation *op, Value v0, Value v1) {
1525  if (!v0 || !v1)
1526  // Empty input values must be propagated.
1527  return Value();
1528  BinaryOp binop = cast<BinaryOp>(op);
1529  Region &overlapRegion = binop.getOverlapRegion();
1530  if (overlapRegion.empty())
1531  // Uninitialized Value() will be interpreted as missing data in the
1532  // output.
1533  return Value();
1534  return insertYieldOp(rewriter, loc, overlapRegion, {v0, v1});
1535 }
1536 
1538  Value v1) const {
1539  const auto &expr = exp(e);
1540  switch (expr.kind) {
1541  // Leaf.
1546  llvm_unreachable("unexpected non-op");
1547  // Unary operations.
1549  return rewriter.create<math::AbsFOp>(loc, v0);
1550  case TensorExp::Kind::kAbsC: {
1551  auto type = cast<ComplexType>(v0.getType());
1552  auto eltType = cast<FloatType>(type.getElementType());
1553  return rewriter.create<complex::AbsOp>(loc, eltType, v0);
1554  }
1556  return rewriter.create<math::AbsIOp>(loc, v0);
1558  return rewriter.create<math::CeilOp>(loc, v0);
1560  return rewriter.create<math::FloorOp>(loc, v0);
1562  return rewriter.create<math::SqrtOp>(loc, v0);
1564  return rewriter.create<complex::SqrtOp>(loc, v0);
1566  return rewriter.create<math::ExpM1Op>(loc, v0);
1568  return rewriter.create<complex::Expm1Op>(loc, v0);
1570  return rewriter.create<math::Log1pOp>(loc, v0);
1572  return rewriter.create<complex::Log1pOp>(loc, v0);
1574  return rewriter.create<math::SinOp>(loc, v0);
1576  return rewriter.create<complex::SinOp>(loc, v0);
1578  return rewriter.create<math::TanhOp>(loc, v0);
1580  return rewriter.create<complex::TanhOp>(loc, v0);
1582  return rewriter.create<arith::NegFOp>(loc, v0);
1584  return rewriter.create<complex::NegOp>(loc, v0);
1585  case TensorExp::Kind::kNegI: // no negi in std
1586  return rewriter.create<arith::SubIOp>(
1587  loc,
1588  rewriter.create<arith::ConstantOp>(loc, v0.getType(),
1589  rewriter.getZeroAttr(v0.getType())),
1590  v0);
1592  return rewriter.create<arith::TruncFOp>(loc, inferType(e, v0), v0);
1594  return rewriter.create<arith::ExtFOp>(loc, inferType(e, v0), v0);
1596  return rewriter.create<arith::FPToSIOp>(loc, inferType(e, v0), v0);
1598  return rewriter.create<arith::FPToUIOp>(loc, inferType(e, v0), v0);
1600  return rewriter.create<arith::SIToFPOp>(loc, inferType(e, v0), v0);
1602  return rewriter.create<arith::UIToFPOp>(loc, inferType(e, v0), v0);
1604  return rewriter.create<arith::ExtSIOp>(loc, inferType(e, v0), v0);
1606  return rewriter.create<arith::ExtUIOp>(loc, inferType(e, v0), v0);
1608  return rewriter.create<arith::IndexCastOp>(loc, inferType(e, v0), v0);
1610  return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0);
1611  case TensorExp::Kind::kCIm: {
1612  auto type = cast<ComplexType>(v0.getType());
1613  auto eltType = cast<FloatType>(type.getElementType());
1614  return rewriter.create<complex::ImOp>(loc, eltType, v0);
1615  }
1616  case TensorExp::Kind::kCRe: {
1617  auto type = cast<ComplexType>(v0.getType());
1618  auto eltType = cast<FloatType>(type.getElementType());
1619  return rewriter.create<complex::ReOp>(loc, eltType, v0);
1620  }
1622  return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0);
1623  // Binary operations.
1625  return rewriter.create<arith::MulFOp>(loc, v0, v1);
1627  return rewriter.create<complex::MulOp>(loc, v0, v1);
1629  return rewriter.create<arith::MulIOp>(loc, v0, v1);
1631  return rewriter.create<arith::DivFOp>(loc, v0, v1);
1633  return rewriter.create<complex::DivOp>(loc, v0, v1);
1635  return rewriter.create<arith::DivSIOp>(loc, v0, v1);
1637  return rewriter.create<arith::DivUIOp>(loc, v0, v1);
1639  return rewriter.create<arith::AddFOp>(loc, v0, v1);
1641  return rewriter.create<complex::AddOp>(loc, v0, v1);
1643  return rewriter.create<arith::AddIOp>(loc, v0, v1);
1645  return rewriter.create<arith::SubFOp>(loc, v0, v1);
1647  return rewriter.create<complex::SubOp>(loc, v0, v1);
1649  return rewriter.create<arith::SubIOp>(loc, v0, v1);
1651  return rewriter.create<arith::AndIOp>(loc, v0, v1);
1652  case TensorExp::Kind::kOrI:
1653  return rewriter.create<arith::OrIOp>(loc, v0, v1);
1655  return rewriter.create<arith::XOrIOp>(loc, v0, v1);
1657  return rewriter.create<arith::ShRSIOp>(loc, v0, v1);
1659  return rewriter.create<arith::ShRUIOp>(loc, v0, v1);
1661  return rewriter.create<arith::ShLIOp>(loc, v0, v1);
1662  case TensorExp::Kind::kCmpI: {
1663  auto predicate = llvm::cast<arith::CmpIPredicateAttr>(expr.attr);
1664  return rewriter.create<arith::CmpIOp>(loc, predicate, v0, v1);
1665  }
1666  case TensorExp::Kind::kCmpF: {
1667  auto predicate = llvm::cast<arith::CmpFPredicateAttr>(expr.attr);
1668  return rewriter.create<arith::CmpFOp>(loc, predicate, v0, v1);
1669  }
1670  case TensorExp::Kind::kBinaryBranch: // semi-ring ops with custom logic.
1671  return insertYieldOp(rewriter, loc, *expr.op->getBlock()->getParent(),
1672  {v0});
1674  return buildUnaryPresent(rewriter, loc, expr.op, v0);
1676  return insertYieldOp(rewriter, loc, cast<SelectOp>(expr.op).getRegion(),
1677  {v0});
1679  return buildBinaryOverlap(rewriter, loc, expr.op, v0, v1);
1680  case TensorExp::Kind::kReduce: {
1681  ReduceOp redOp = cast<ReduceOp>(expr.op);
1682  return insertYieldOp(rewriter, loc, redOp.getRegion(), {v0, v1});
1683  }
1685  Operation *actualOp = expr.op;
1686  IRMapping mapping;
1687  mapping.map(actualOp->getOperand(0), v0);
1688  if (actualOp->getNumOperands() == 2)
1689  mapping.map(actualOp->getOperand(1), v1);
1690  return rewriter.clone(*actualOp, mapping)->getResult(0);
1691  }
1692  }
1693  llvm_unreachable("unexpected expression kind in build");
1694 }
1695 
1696 } // namespace sparse_tensor
1697 } // namespace mlir
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:30
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:26
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:553
Block * getBlock() const
Returns the current block of the builder.
Definition: Builders.h:450
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
This class represents an operand of an operation.
Definition: Value.h:263
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
unsigned getNumOperands()
Definition: Operation.h:341
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
void cloneInto(Region *dest, IRMapping &mapper)
Clone the internal blocks from this region into dest.
Definition: Region.cpp:70
iterator begin()
Definition: Region.h:55
Block & front()
Definition: Region.h:65
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Specialization of arith.constant op that returns a floating point value.
Definition: Arith.h:75
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:92
Specialization of arith.constant op that returns an integer value.
Definition: Arith.h:53
LatPointId conjLat(ExprId e, LatPointId p0, LatPointId p1, Operation *op=nullptr)
Computes a single conjunction of two lattice points by taking the "union" of LoopId (effectively cons...
Definition: Merger.cpp:314
LatSetId disjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op=nullptr)
Disjunctive merge of two lattice sets: (s0 /\_op s1, s0, s1).
Definition: Merger.cpp:337
bool isSingleCondition(TensorId t, ExprId e) const
Returns true if given tensor iterates only in the given tensor expression.
Definition: Merger.cpp:576
bool hasSparseIdxReduction(const BitVector &bits) const
Returns true if bits contains a dependent index reduction condition on sparse levels.
Definition: Merger.cpp:678
bool expContainsTensor(ExprId e, TensorId t) const
Returns true if the expression contains the tensor as an operand.
Definition: Merger.cpp:521
LatSetId mapBinWithSynZeroSet(ExprId e, LatSetId s, bool lhsZero)
Maps the binary operator to the same operation but with one of its operand set to zero,...
Definition: Merger.cpp:413
bool isSparseLvlWithNonTrivialIdxExp(TensorLoopId b) const
Checks whether the TensorLoopId represents a sparse tensor level contains non-trivial index expressio...
Definition: Merger.h:509
void dumpBits(const BitVector &bits) const
Definition: Merger.cpp:915
LatSetId addSet()
Constructs a new (initially empty) set, and returns its identifier.
Definition: Merger.cpp:308
BitVector simplifyCond(LatSetId s, LatPointId p)
Simplifies the conditions in a conjunction of a given lattice point within the given set using just t...
Definition: Merger.cpp:460
bool hasNegateOnOut(ExprId e) const
Returns true if the expression contains a negation on output tensor.
Definition: Merger.cpp:543
bool isLvlWithNonTrivialIdxExp(TensorLoopId b) const
Checks whether the TensorLoopId represents a tensor level contains non-trivial index expression.
Definition: Merger.h:500
LatSetId disjSetWithZero(ExprId e, LatSetId s0, LatSetId s1)
Disjunctive merge of two lattice sets and also set one of the operand to zero: (s0 /\_op s1 (e0 op e1...
Definition: Merger.cpp:356
void dumpSet(LatSetId s) const
Definition: Merger.cpp:905
void dumpLat(LatPointId p) const
Definition: Merger.cpp:894
LatSetId combiSet(ExprId e, LatSetId s0, LatSetId s1, Operation *orig, bool includeLeft, TensorExp::Kind ltrans, Operation *opleft, bool includeRight, TensorExp::Kind rtrans, Operation *opright)
Disjunctive merge of two lattice sets with custom handling of the overlap, left, and right regions.
Definition: Merger.cpp:380
ExprId addTensorExp(TensorId t)
Constructs a new tensor expression, and returns its identifier.
Definition: Merger.cpp:246
LatSetId buildLattices(ExprId e, LoopId i)
Builds the iteration lattices in a bottom-up traversal given the remaining tensor (sub)expression and...
Definition: Merger.cpp:935
LatSetId conjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op=nullptr)
Conjunctive merge of two lattice sets: (s0 /\_op s1).
Definition: Merger.cpp:328
ExprId addExp(TensorExp::Kind k, ExprId e0, ExprId e1=detail::kInvalidId, Operation *op=nullptr, Attribute attr=nullptr)
Constructs a new unary or binary expression, and returns its identifier.
Definition: Merger.cpp:276
ExprId addSynZeroExp()
Constructs a new synthetic zero expression.
Definition: Merger.cpp:269
constexpr LoopId makeLoopId(unsigned i) const
Safely converts the argument to a loop identifier.
Definition: Merger.h:248
std::optional< ExprId > buildTensorExpFromLinalg(linalg::GenericOp op)
Builds a tensor expression from the given Linalg operation.
Definition: Merger.cpp:1186
ArrayRef< LatPointId > set(LatSetId s) const
Definition: Merger.h:548
LatSetId optimizeSet(LatSetId s)
Optimizes the iteration lattice points in the given set.
Definition: Merger.cpp:430
constexpr TensorId tensor(TensorLoopId b) const
Gets the tensor-identifier of the TensorLoopId.
Definition: Merger.h:345
void dumpExp(ExprId e) const
Print methods (for debugging).
Definition: Merger.cpp:795
Merger(unsigned numInputOutputTensors, unsigned numLoops, unsigned maxLvlRank)
Constructs a merger for the given number of tensors and loops.
Definition: Merger.cpp:223
bool hasAnySparse(const BitVector &bits) const
Returns true if any TensorLoopId in the bitvector corresponds to sparse level-type.
Definition: Merger.cpp:669
LatPointId addLat(TensorId t, LoopId i, ExprId e)
Constructs a new iteration lattice point, and returns its identifier.
Definition: Merger.cpp:292
ExprId addLoopVarExp(LoopId i)
Constructs a new loop-variable expression, and returns its identifier.
Definition: Merger.cpp:254
bool latGT(LatPointId p0, LatPointId p1) const
Returns true if p0 > p1.
Definition: Merger.cpp:502
const TensorExp & exp(ExprId e) const
Convenience getters to immediately access the stored nodes.
Definition: Merger.h:540
constexpr LoopId loop(TensorLoopId b) const
Gets the loop-identifier of the TensorLoopId.
Definition: Merger.h:347
const LatPoint & lat(LatPointId p) const
Definition: Merger.h:544
bool onlyDenseDiff(LatPointId p0, LatPointId p1) const
Returns true if p0 and p1 only differ in dense.
Definition: Merger.cpp:515
ExprId addInvariantExp(Value v)
Constructs a new invariant expression, and returns its identifier.
Definition: Merger.cpp:262
constexpr TensorId makeTensorId(unsigned t) const
Safely converts the argument to a tensor identifier.
Definition: Merger.h:242
LevelType getLvlType(TensorId t, LoopId i) const
Gets the level-type of the tth tensor on ith loop.
Definition: Merger.h:398
Value buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0, Value v1) const
Rebuilds SSA format from a tensor expression.
Definition: Merger.cpp:1537
constexpr TensorLoopId makeTensorLoopId(unsigned t, unsigned i) const
Safely converts the arguments to a pair of (tensor,loop) identifiers.
Definition: Merger.h:254
bool expIsTensor(ExprId e, TensorId t) const
Returns true if the expression is (kTensor t).
Definition: Merger.h:369
LatSetId mapSet(TensorExp::Kind kind, LatSetId s, Value v=Value(), Operation *op=nullptr)
Maps the unary operator over the lattice set of the operand, i.e.
Definition: Merger.cpp:400
@ Type
An inlay hint that for a type annotation.
static constexpr unsigned kInvalidId
A constant serving as the canonically invalid identifier, regardless of the identifier type.
Definition: Merger.h:30
LevelFormat
This enum defines all supported storage format without the level properties.
Definition: Enums.h:154
static bool isAdmissibleBranchExp(Operation *op, Block *block, Value v)
Ensures that the sparsifier can generate code for expression.
Definition: Merger.cpp:1221
unsigned LatSetId
LatSet identifiers.
Definition: Merger.h:57
static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region &region, ValueRange vals)
Definition: Merger.cpp:1492
std::string toMLIRString(LevelType lt)
Definition: Enums.h:443
std::pair< Level, LevelType > LvlLTPair
A pair of level and its corresponding LevelType of a tensor.
Definition: Merger.h:60
unsigned TensorLoopId
A compressed representation of std::pair<TensorId, LoopId>.
Definition: Merger.h:44
static Value buildUnaryPresent(RewriterBase &rewriter, Location loc, Operation *op, Value v0)
Definition: Merger.cpp:1509
uint64_t Level
The type of level identifiers and level-ranks.
Definition: SparseTensor.h:38
static Value buildBinaryOverlap(RewriterBase &rewriter, Location loc, Operation *op, Value v0, Value v1)
Definition: Merger.cpp:1523
static bool isAdmissibleBranch(Operation *op, Region &region)
Ensures that the sparsifier can generate code for branch.
Definition: Merger.cpp:1241
unsigned LoopId
Loop identifiers.
Definition: Merger.h:38
static const char * kindToOpSymbol(TensorExp::Kind kind)
Definition: Merger.cpp:691
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
unsigned ExprId
TensorExp identifiers.
Definition: Merger.h:48
static ExpArity getExpArity(TensorExp::Kind k)
Definition: Merger.cpp:28
unsigned LatPointId
LatPoint identifiers.
Definition: Merger.h:52
std::pair< LoopId, unsigned > LoopCoeffPair
A pair of loop id and its coefficients.
Definition: Merger.h:64
unsigned TensorId
Tensor identifiers, chosen to be the BlockArgument::getArgNumber of the value passed to Merger::build...
Definition: Merger.h:35
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
BitVector bits
Conjunction of all TensorLoopIds involved in the tensor expression.
Definition: Merger.h:209
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition: Enums.h:238
LoopId loop
kLoopVar expressions simply have a loop identifier.
Definition: Merger.h:96
Value val
Direct link to IR for an invariant or the destination value (to infer destination type) of a cast ope...
Definition: Merger.h:105
Kind
Tensor expression kind.
Definition: Merger.h:129
Children children
All other expressions hold the ExprIds of their children.
Definition: Merger.h:99
Attribute attr
An optional attribute that is required to determine the semantics of the operations.
Definition: Merger.h:118
TensorId tensor
kTensor expressions simply have a tensor identifier.
Definition: Merger.h:93
Kind kind
Tensor expression kind.
Definition: Merger.h:89
TensorExp(Kind k, unsigned x, ExprId y, Value v, Operation *op, Attribute a)
The x parameter has different types depending on the value of the k parameter.
Definition: Merger.cpp:105
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.