MLIR  16.0.0git
Merger.cpp
Go to the documentation of this file.
1 //===- Merger.cpp - Implementation of iteration lattices ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
14 
15 #include "mlir/IR/Operation.h"
16 #include "llvm/Support/Debug.h"
17 
18 namespace mlir {
19 namespace sparse_tensor {
20 
21 //===----------------------------------------------------------------------===//
22 // Constructors.
23 //===----------------------------------------------------------------------===//
24 
25 TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o)
26  : kind(k), val(v), op(o) {
27  switch (kind) {
28  // Leaf.
29  case kTensor:
30  assert(x != -1u && y == -1u && !v && !o);
31  tensor = x;
32  break;
33  case kInvariant:
34  assert(x == -1u && y == -1u && v && !o);
35  break;
36  case kIndex:
37  assert(x != -1u && y == -1u && !v && !o);
38  index = x;
39  break;
40  // Unary operations.
41  case kAbsF:
42  case kAbsC:
43  case kAbsI:
44  case kCeilF:
45  case kFloorF:
46  case kSqrtF:
47  case kSqrtC:
48  case kExpm1F:
49  case kExpm1C:
50  case kLog1pF:
51  case kLog1pC:
52  case kSinF:
53  case kSinC:
54  case kTanhF:
55  case kTanhC:
56  case kNegF:
57  case kNegC:
58  case kNegI:
59  case kCIm:
60  case kCRe:
61  assert(x != -1u && y == -1u && !v && !o);
62  children.e0 = x;
63  children.e1 = y;
64  break;
65  case kTruncF:
66  case kExtF:
67  case kCastFS:
68  case kCastFU:
69  case kCastSF:
70  case kCastUF:
71  case kCastS:
72  case kCastU:
73  case kCastIdx:
74  case kTruncI:
75  case kBitCast:
76  assert(x != -1u && y == -1u && v && !o);
77  children.e0 = x;
78  children.e1 = y;
79  break;
80  case kBinaryBranch:
81  case kSelect:
82  assert(x != -1u && y == -1u && !v && o);
83  children.e0 = x;
84  children.e1 = y;
85  break;
86  case kUnary:
87  // No assertion on y can be made, as the branching paths involve both
88  // a unary (mapSet) and binary (takeDisj) pathway.
89  assert(x != -1u && !v && o);
90  children.e0 = x;
91  children.e1 = y;
92  break;
93  // Binary operations.
94  case kMulF:
95  case kMulC:
96  case kMulI:
97  case kDivF:
98  case kDivC:
99  case kDivS:
100  case kDivU:
101  case kAddF:
102  case kAddC:
103  case kAddI:
104  case kSubF:
105  case kSubC:
106  case kSubI:
107  case kAndI:
108  case kOrI:
109  case kXorI:
110  case kShrS:
111  case kShrU:
112  case kShlI:
113  assert(x != -1u && y != -1u && !v && !o);
114  children.e0 = x;
115  children.e1 = y;
116  break;
117  case kBinary:
118  case kReduce:
119  assert(x != -1u && y != -1u && !v && o);
120  children.e0 = x;
121  children.e1 = y;
122  break;
123  }
124 }
125 
126 LatPoint::LatPoint(unsigned n, unsigned e, unsigned b)
127  : bits(n, false), exp(e) {
128  bits.set(b);
129 }
130 
131 LatPoint::LatPoint(const BitVector &b, unsigned e) : bits(b), exp(e) {}
132 
133 //===----------------------------------------------------------------------===//
134 // Lattice methods.
135 //===----------------------------------------------------------------------===//
136 
137 unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v,
138  Operation *op) {
139  unsigned e = tensorExps.size();
140  tensorExps.push_back(TensorExp(k, e0, e1, v, op));
141  return e;
142 }
143 
144 unsigned Merger::addLat(unsigned t, unsigned i, unsigned e) {
145  assert(t < numTensors && i < numLoops);
146  unsigned p = latPoints.size();
147  latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t));
148  return p;
149 }
150 
151 unsigned Merger::addSet() {
152  unsigned s = latSets.size();
153  latSets.emplace_back();
154  return s;
155 }
156 
157 unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1,
158  Operation *op) {
159  unsigned p = latPoints.size();
160  BitVector nb = BitVector(latPoints[p0].bits);
161  nb |= latPoints[p1].bits;
162  unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp, Value(), op);
163  latPoints.push_back(LatPoint(nb, e));
164  return p;
165 }
166 
167 unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1, Operation *op) {
168  unsigned s = addSet();
169  for (unsigned p0 : latSets[s0])
170  for (unsigned p1 : latSets[s1])
171  latSets[s].push_back(conjLatPoint(kind, p0, p1, op));
172  return s;
173 }
174 
175 unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1, Operation *op) {
176  unsigned s = takeConj(kind, s0, s1, op);
177  // Followed by all in s0.
178  for (unsigned p : latSets[s0])
179  latSets[s].push_back(p);
180  // Map binary 0-y to unary -y.
181  // TODO: move this if-else logic into buildLattices
182  if (kind == kSubF)
183  s1 = mapSet(kNegF, s1);
184  else if (kind == kSubC)
185  s1 = mapSet(kNegC, s1);
186  else if (kind == kSubI)
187  s1 = mapSet(kNegI, s1);
188  // Followed by all in s1.
189  for (unsigned p : latSets[s1])
190  latSets[s].push_back(p);
191  return s;
192 }
193 
194 unsigned Merger::takeCombi(Kind kind, unsigned s0, unsigned s1, Operation *orig,
195  bool includeLeft, Kind ltrans, Operation *opleft,
196  bool includeRight, Kind rtrans, Operation *opright) {
197  unsigned s = takeConj(kind, s0, s1, orig);
198  // Left Region.
199  if (includeLeft) {
200  if (opleft)
201  s0 = mapSet(ltrans, s0, Value(), opleft);
202  for (unsigned p : latSets[s0])
203  latSets[s].push_back(p);
204  }
205  // Right Region.
206  if (includeRight) {
207  if (opright)
208  s1 = mapSet(rtrans, s1, Value(), opright);
209  for (unsigned p : latSets[s1])
210  latSets[s].push_back(p);
211  }
212  return s;
213 }
214 
215 unsigned Merger::mapSet(Kind kind, unsigned s0, Value v, Operation *op) {
216  assert(kAbsF <= kind && kind <= kSelect);
217  unsigned s = addSet();
218  for (unsigned p : latSets[s0]) {
219  unsigned e = addExp(kind, latPoints[p].exp, v, op);
220  latPoints.push_back(LatPoint(latPoints[p].bits, e));
221  latSets[s].push_back(latPoints.size() - 1);
222  }
223  return s;
224 }
225 
226 unsigned Merger::optimizeSet(unsigned s0) {
227  unsigned s = addSet();
228  assert(!latSets[s0].empty());
229  unsigned p0 = latSets[s0][0];
230  for (unsigned p1 : latSets[s0]) {
231  bool add = true;
232  if (p0 != p1) {
233  // Is this a straightforward copy?
234  unsigned e = latPoints[p1].exp;
235  if (tensorExps[e].kind == kTensor && tensorExps[e].tensor == outTensor)
236  continue;
237  // Conjunction already covered?
238  for (unsigned p2 : latSets[s]) {
239  assert(!latGT(p1, p2)); // Lj => Li would be bad
240  if (onlyDenseDiff(p2, p1)) {
241  add = false;
242  break;
243  }
244  }
245  assert(!add || latGT(p0, p1));
246  }
247  if (add)
248  latSets[s].push_back(p1);
249  }
250  for (unsigned p : latSets[s])
251  latPoints[p].simple = simplifyCond(s, p);
252  return s;
253 }
254 
255 BitVector Merger::simplifyCond(unsigned s0, unsigned p0) {
256  // First determine if this lattice point is a *singleton*, i.e.,
257  // the last point in a lattice, no other is less than this one.
258  bool isSingleton = true;
259  for (unsigned p1 : latSets[s0]) {
260  if (p0 != p1 && latGT(p0, p1)) {
261  isSingleton = false;
262  break;
263  }
264  }
265 
266  BitVector simple = latPoints[p0].bits;
267  bool reset = isSingleton && hasAnySparse(simple);
268  unsigned be = simple.size();
269  unsigned offset = 0; // relative to the end
270  if (!reset)
271  // Starts resetting from a dense dimension, so that the first bit (if kept)
272  // is not undefined dimension type.
273  for (unsigned b = 0; b < be; b++) {
274  if (simple[b] && isDenseDLT(getDimLevelType(b))) {
275  offset = be - b - 1; // relative to the end
276  break;
277  }
278  }
279 
280  // Now apply the two basic rules. We also iterate the bits reversely to always
281  // keep the rightmost bit (which could possibly be a synthetic tensor).
282  for (unsigned b = be - 1 - offset, i = 0; i < be;
283  b = b == 0 ? be - 1 : b - 1, i++) {
284  if (simple[b] && (!isCompressedDLT(getDimLevelType(b)) &&
286  if (reset)
287  simple.reset(b);
288  reset = true;
289  }
290  }
291  return simple;
292 }
293 
294 bool Merger::latGT(unsigned i, unsigned j) const {
295  const BitVector &bitsi = latPoints[i].bits;
296  const BitVector &bitsj = latPoints[j].bits;
297  assert(bitsi.size() == bitsj.size());
298  if (bitsi.count() > bitsj.count()) {
299  for (unsigned b = 0, be = bitsj.size(); b < be; b++)
300  if (bitsj[b] && !bitsi[b])
301  return false;
302  return true;
303  }
304  return false;
305 }
306 
307 bool Merger::onlyDenseDiff(unsigned i, unsigned j) {
308  BitVector tmp = latPoints[j].bits;
309  tmp ^= latPoints[i].bits;
310  return !hasAnySparse(tmp);
311 }
312 
313 bool Merger::isSingleCondition(unsigned t, unsigned e) const {
314  switch (tensorExps[e].kind) {
315  // Leaf.
316  case kTensor:
317  return tensorExps[e].tensor == t;
318  case kInvariant:
319  case kIndex:
320  return false;
321  // Unary operations.
322  case kAbsF:
323  case kAbsC:
324  case kAbsI:
325  case kCeilF:
326  case kFloorF:
327  case kSqrtF:
328  case kSqrtC:
329  case kExpm1F:
330  case kExpm1C:
331  case kLog1pF:
332  case kLog1pC:
333  case kSinF:
334  case kSinC:
335  case kTanhF:
336  case kTanhC:
337  case kNegF:
338  case kNegC:
339  case kNegI:
340  case kTruncF:
341  case kExtF:
342  case kCastFS:
343  case kCastFU:
344  case kCastSF:
345  case kCastUF:
346  case kCastS:
347  case kCastU:
348  case kCastIdx:
349  case kTruncI:
350  case kCIm:
351  case kCRe:
352  case kBitCast:
353  return isSingleCondition(t, tensorExps[e].children.e0);
354  case kBinaryBranch:
355  case kUnary:
356  case kSelect:
357  return false;
358  // Binary operations.
359  case kDivF: // note: x / c only
360  case kDivC:
361  case kDivS:
362  case kDivU:
363  assert(!maybeZero(tensorExps[e].children.e1));
364  return isSingleCondition(t, tensorExps[e].children.e0);
365  case kShrS: // note: x >> inv only
366  case kShrU:
367  case kShlI:
368  assert(isInvariant(tensorExps[e].children.e1));
369  return isSingleCondition(t, tensorExps[e].children.e0);
370  case kMulF:
371  case kMulC:
372  case kMulI:
373  case kAndI:
374  if (isSingleCondition(t, tensorExps[e].children.e0))
375  return isSingleCondition(t, tensorExps[e].children.e1) ||
376  isInvariant(tensorExps[e].children.e1);
377  if (isSingleCondition(t, tensorExps[e].children.e1))
378  return isInvariant(tensorExps[e].children.e0);
379  return false;
380  case kAddF:
381  case kAddC:
382  case kAddI:
383  return isSingleCondition(t, tensorExps[e].children.e0) &&
384  isSingleCondition(t, tensorExps[e].children.e1);
385  case kSubF:
386  case kSubC:
387  case kSubI:
388  case kOrI:
389  case kXorI:
390  case kBinary:
391  case kReduce:
392  return false;
393  }
394  llvm_unreachable("unexpected kind");
395 }
396 
397 bool Merger::hasAnySparse(const BitVector &bits) const {
398  for (unsigned b = 0, be = bits.size(); b < be; b++)
399  if (bits[b] && (isCompressedDLT(getDimLevelType(b)) ||
401  return true;
402  return false;
403 }
404 
405 #ifndef NDEBUG
406 
407 //===----------------------------------------------------------------------===//
408 // Print methods (for debugging).
409 //===----------------------------------------------------------------------===//
410 
411 static const char *kindToOpSymbol(Kind kind) {
412  switch (kind) {
413  // Leaf.
414  case kTensor:
415  return "tensor";
416  case kInvariant:
417  return "invariant";
418  case kIndex:
419  return "index";
420  // Unary operations.
421  case kAbsF:
422  case kAbsC:
423  case kAbsI:
424  return "abs";
425  case kCeilF:
426  return "ceil";
427  case kFloorF:
428  return "floor";
429  case kSqrtF:
430  case kSqrtC:
431  return "sqrt";
432  case kExpm1F:
433  case kExpm1C:
434  return "expm1";
435  case kLog1pF:
436  case kLog1pC:
437  return "log1p";
438  case kSinF:
439  case kSinC:
440  return "sin";
441  case kTanhF:
442  case kTanhC:
443  return "tanh";
444  case kNegF:
445  case kNegC:
446  case kNegI:
447  return "-";
448  case kTruncF:
449  case kExtF:
450  case kCastFS:
451  case kCastFU:
452  case kCastSF:
453  case kCastUF:
454  case kCastS:
455  case kCastU:
456  case kCastIdx:
457  case kTruncI:
458  case kCIm:
459  return "complex.im";
460  case kCRe:
461  return "complex.re";
462  case kBitCast:
463  return "cast";
464  case kBinaryBranch:
465  return "binary_branch";
466  case kUnary:
467  return "unary";
468  case kSelect:
469  return "select";
470  // Binary operations.
471  case kMulF:
472  case kMulC:
473  case kMulI:
474  return "*";
475  case kDivF:
476  case kDivC:
477  case kDivS:
478  case kDivU:
479  return "/";
480  case kAddF:
481  case kAddC:
482  case kAddI:
483  return "+";
484  case kSubF:
485  case kSubC:
486  case kSubI:
487  return "-";
488  case kAndI:
489  return "&";
490  case kOrI:
491  return "|";
492  case kXorI:
493  return "^";
494  case kShrS:
495  return "a>>";
496  case kShrU:
497  return ">>";
498  case kShlI:
499  return "<<";
500  case kBinary:
501  return "binary";
502  case kReduce:
503  return "reduce";
504  }
505  llvm_unreachable("unexpected kind for symbol");
506 }
507 
508 void Merger::dumpExp(unsigned e) const {
509  switch (tensorExps[e].kind) {
510  // Leaf.
511  case kTensor:
512  if (tensorExps[e].tensor == syntheticTensor)
513  llvm::dbgs() << "synthetic_";
514  else if (tensorExps[e].tensor == outTensor)
515  llvm::dbgs() << "output_";
516  llvm::dbgs() << "tensor_" << tensorExps[e].tensor;
517  break;
518  case kInvariant:
519  llvm::dbgs() << "invariant";
520  break;
521  case kIndex:
522  llvm::dbgs() << "index_" << tensorExps[e].index;
523  break;
524  // Unary operations.
525  case kAbsF:
526  case kAbsC:
527  case kAbsI:
528  case kCeilF:
529  case kFloorF:
530  case kSqrtF:
531  case kSqrtC:
532  case kExpm1F:
533  case kExpm1C:
534  case kLog1pF:
535  case kLog1pC:
536  case kSinF:
537  case kSinC:
538  case kTanhF:
539  case kTanhC:
540  case kNegF:
541  case kNegC:
542  case kNegI:
543  case kTruncF:
544  case kExtF:
545  case kCastFS:
546  case kCastFU:
547  case kCastSF:
548  case kCastUF:
549  case kCastS:
550  case kCastU:
551  case kCastIdx:
552  case kTruncI:
553  case kCIm:
554  case kCRe:
555  case kBitCast:
556  case kBinaryBranch:
557  case kUnary:
558  case kSelect:
559  llvm::dbgs() << kindToOpSymbol(tensorExps[e].kind) << " ";
560  dumpExp(tensorExps[e].children.e0);
561  break;
562  // Binary operations.
563  case kMulF:
564  case kMulC:
565  case kMulI:
566  case kDivF:
567  case kDivC:
568  case kDivS:
569  case kDivU:
570  case kAddF:
571  case kAddC:
572  case kAddI:
573  case kSubF:
574  case kSubC:
575  case kSubI:
576  case kAndI:
577  case kOrI:
578  case kXorI:
579  case kShrS:
580  case kShrU:
581  case kShlI:
582  case kBinary:
583  case kReduce:
584  llvm::dbgs() << "(";
585  dumpExp(tensorExps[e].children.e0);
586  llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " ";
587  dumpExp(tensorExps[e].children.e1);
588  llvm::dbgs() << ")";
589  }
590 }
591 
592 void Merger::dumpLat(unsigned p) const {
593  llvm::dbgs() << "lat(";
594  dumpBits(latPoints[p].bits);
595  llvm::dbgs() << " :";
596  dumpBits(latPoints[p].simple);
597  llvm::dbgs() << " : ";
598  dumpExp(latPoints[p].exp);
599  llvm::dbgs() << " )\n";
600 }
601 
602 void Merger::dumpSet(unsigned s) const {
603  llvm::dbgs() << "{ #" << latSets[s].size() << "\n";
604  for (unsigned p : latSets[s]) {
605  llvm::dbgs() << " ";
606  dumpLat(p);
607  }
608  llvm::dbgs() << "}\n";
609 }
610 
611 void Merger::dumpBits(const BitVector &bits) const {
612  for (unsigned b = 0, be = bits.size(); b < be; b++) {
613  if (bits[b]) {
614  unsigned t = tensor(b);
615  unsigned i = index(b);
616  DimLevelType dlt = dimTypes[t][i];
617  llvm::dbgs() << " i_" << t << "_" << i << "_";
618  if (isDenseDLT(dlt))
619  llvm::dbgs() << "D";
620  else if (isCompressedDLT(dlt))
621  llvm::dbgs() << "C";
622  else if (isSingletonDLT(dlt))
623  llvm::dbgs() << "S";
624  else if (isUndefDLT(dlt))
625  llvm::dbgs() << "U";
626  llvm::dbgs() << "[O=" << isOrderedDLT(dlt) << ",U=" << isUniqueDLT(dlt)
627  << "]";
628  }
629  }
630 }
631 
632 #endif // NDEBUG
633 
634 //===----------------------------------------------------------------------===//
635 // Builder methods.
636 //===----------------------------------------------------------------------===//
637 
638 unsigned Merger::buildLattices(unsigned e, unsigned i) {
639  Kind kind = tensorExps[e].kind;
640  switch (kind) {
641  // Leaf.
642  case kTensor:
643  case kInvariant:
644  case kIndex: {
645  // Either the index is really used in the tensor expression, or it is
646  // set to the undefined index in that dimension. An invariant expression,
647  // a proper index value, and a truly dynamic sparse output tensor are set
648  // to a synthetic tensor with undefined indices only to ensure the
649  // iteration space is not skipped as a result of their contents.
650  unsigned s = addSet();
651  unsigned t = syntheticTensor;
652  if (kind == kTensor) {
653  t = tensorExps[e].tensor;
654  if (hasSparseOut && t == outTensor)
655  t = syntheticTensor;
656  }
657  latSets[s].push_back(addLat(t, i, e));
658  return s;
659  }
660  // Unary operations.
661  case kAbsF:
662  case kAbsC:
663  case kAbsI:
664  case kCeilF:
665  case kFloorF:
666  case kSqrtF:
667  case kSqrtC:
668  case kExpm1F:
669  case kExpm1C:
670  case kLog1pF:
671  case kLog1pC:
672  case kSinF:
673  case kSinC:
674  case kTanhF:
675  case kTanhC:
676  case kNegF:
677  case kNegC:
678  case kNegI:
679  case kTruncF:
680  case kExtF:
681  case kCastFS:
682  case kCastFU:
683  case kCastSF:
684  case kCastUF:
685  case kCastS:
686  case kCastU:
687  case kCastIdx:
688  case kTruncI:
689  case kCIm:
690  case kCRe:
691  case kBitCast:
692  // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the
693  // lattice set of the operand through the operator into a new set.
694  //
695  // -y|!y | y |
696  // --+---+---+
697  // | 0 |-y |
698  return mapSet(kind, buildLattices(tensorExps[e].children.e0, i),
699  tensorExps[e].val);
700  case kBinaryBranch:
701  case kSelect:
702  // The left or right half of a binary operation which has already
703  // been split into separate operations for each region.
704  return mapSet(kind, buildLattices(tensorExps[e].children.e0, i), Value(),
705  tensorExps[e].op);
706  case kUnary:
707  // A custom unary operation.
708  //
709  // op y| !y | y |
710  // ----+----------+------------+
711  // | absent() | present(y) |
712  {
713  unsigned child0 = buildLattices(tensorExps[e].children.e0, i);
714  UnaryOp unop = cast<UnaryOp>(tensorExps[e].op);
715  Region &absentRegion = unop.getAbsentRegion();
716 
717  if (absentRegion.empty()) {
718  // Simple mapping over existing values.
719  return mapSet(kind, child0, Value(), unop);
720  } // Use a disjunction with `unop` on the left and the absent value as an
721  // invariant on the right.
722  Block &absentBlock = absentRegion.front();
723  YieldOp absentYield = cast<YieldOp>(absentBlock.getTerminator());
724  Value absentVal = absentYield.getResult();
725  unsigned rhs = addExp(kInvariant, absentVal);
726  return takeDisj(kind, child0, buildLattices(rhs, i), unop);
727  }
728  // Binary operations.
729  case kMulF:
730  case kMulC:
731  case kMulI:
732  case kAndI:
733  // A multiplicative operation only needs to be performed
734  // for the conjunction of sparse iteration spaces.
735  //
736  // x*y|!y | y |
737  // ---+---+---+
738  // !x | 0 | 0 |
739  // x | 0 |x*y|
740  //
741  // Note even here, 0*NaN=NaN and 0*Inf=NaN, but that is ignored.
742  return takeConj(kind, // take binary conjunction
743  buildLattices(tensorExps[e].children.e0, i),
744  buildLattices(tensorExps[e].children.e1, i));
745  case kDivF:
746  case kDivC:
747  case kDivS:
748  case kDivU:
749  // A division is tricky, since 0/0, 0/c, c/0 all have
750  // specific outcomes for floating-point and integers.
751  // Thus, we need to traverse the full iteration space.
752  //
753  // x/y|!y | y |
754  // ---+---+---+
755  // !x |0/0|0/y| FP: 0/0=NaN,c/0=Inf,0/c=0 with c true nonzero
756  // x |x/0|x/y| INT: x/0=exception for any x
757  //
758  // TODO: for now we "fixed" this by only accepting x/c cases
759  // during expression building, so that the conjunction
760  // rules applies (viz. x/c = x*(1/c) as far as lattice
761  // construction is concerned).
762  assert(!maybeZero(tensorExps[e].children.e1));
763  return takeConj(kind, // take binary conjunction
764  buildLattices(tensorExps[e].children.e0, i),
765  buildLattices(tensorExps[e].children.e1, i));
766  case kAddF:
767  case kAddC:
768  case kAddI:
769  case kSubF:
770  case kSubC:
771  case kSubI:
772  case kOrI:
773  case kXorI:
774  // An additive operation needs to be performed
775  // for the disjunction of sparse iteration spaces.
776  //
777  // x+y|!y | y | x-y|!y | y |
778  // ---+---+---+ ---+---+---+
779  // !x | 0 | y | !x | 0 |-y |
780  // x | x |x+y| x | x |x-y|
781  return takeDisj(kind, // take binary disjunction
782  buildLattices(tensorExps[e].children.e0, i),
783  buildLattices(tensorExps[e].children.e1, i));
784  case kShrS:
785  case kShrU:
786  case kShlI:
787  // A shift operation by an invariant amount (viz. tensor expressions
788  // can only occur at the left-hand-side of the operator) can be handled
789  // with the conjuction rule.
790  assert(isInvariant(tensorExps[e].children.e1));
791  return takeConj(kind, // take binary conjunction
792  buildLattices(tensorExps[e].children.e0, i),
793  buildLattices(tensorExps[e].children.e1, i));
794  case kBinary:
795  // A custom binary operation.
796  //
797  // x op y| !y | y |
798  // ------+---------+--------------+
799  // !x | empty | right(y) |
800  // x | left(x) | overlap(x,y) |
801  {
802  unsigned child0 = buildLattices(tensorExps[e].children.e0, i);
803  unsigned child1 = buildLattices(tensorExps[e].children.e1, i);
804  BinaryOp binop = cast<BinaryOp>(tensorExps[e].op);
805  Region &leftRegion = binop.getLeftRegion();
806  Region &rightRegion = binop.getRightRegion();
807  // Left Region.
808  Operation *leftYield = nullptr;
809  if (!leftRegion.empty()) {
810  Block &leftBlock = leftRegion.front();
811  leftYield = leftBlock.getTerminator();
812  }
813  // Right Region.
814  Operation *rightYield = nullptr;
815  if (!rightRegion.empty()) {
816  Block &rightBlock = rightRegion.front();
817  rightYield = rightBlock.getTerminator();
818  }
819  bool includeLeft = binop.getLeftIdentity() || !leftRegion.empty();
820  bool includeRight = binop.getRightIdentity() || !rightRegion.empty();
821  return takeCombi(kBinary, child0, child1, binop, includeLeft,
822  kBinaryBranch, leftYield, includeRight, kBinaryBranch,
823  rightYield);
824  }
825  case kReduce:
826  // A custom reduce operation.
827  return takeConj(kind, buildLattices(tensorExps[e].children.e0, i),
828  buildLattices(tensorExps[e].children.e1, i),
829  tensorExps[e].op);
830  }
831  llvm_unreachable("unexpected expression kind");
832 }
833 
835  // Build the linalg semantics backward from yield.
836  Operation *yield = op.getRegion().front().getTerminator();
837  assert(isa<linalg::YieldOp>(yield));
838  return buildTensorExp(op, yield->getOperand(0));
839 }
840 
841 /// Only returns false if we are certain this is a nonzero.
842 bool Merger::maybeZero(unsigned e) const {
843  if (tensorExps[e].kind == kInvariant) {
844  if (auto c = tensorExps[e].val.getDefiningOp<complex::ConstantOp>()) {
845  ArrayAttr arrayAttr = c.getValue();
846  return arrayAttr[0].cast<FloatAttr>().getValue().isZero() &&
847  arrayAttr[1].cast<FloatAttr>().getValue().isZero();
848  }
849  if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantIntOp>())
850  return c.value() == 0;
851  if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantFloatOp>())
852  return c.value().isZero();
853  }
854  return true;
855 }
856 
857 bool Merger::isInvariant(unsigned e) const {
858  return tensorExps[e].kind == kInvariant;
859 }
860 
861 Type Merger::inferType(unsigned e, Value src) {
862  // Obtain the destination type from the cast node.
863  Type dtp = tensorExps[e].val.getType();
864  // Inspect source type. For vector types, apply the same
865  // vectorization to the destination type.
866  if (auto vtp = src.getType().dyn_cast<VectorType>())
867  return VectorType::get(vtp.getNumElements(), dtp, vtp.getNumScalableDims());
868  return dtp;
869 }
870 
871 /// Ensures that sparse compiler can generate code for expression.
872 static bool isAdmissibleBranchExp(Operation *op, Block *block, Value v) {
873  // Arguments are always admissible.
874  if (auto arg = v.dyn_cast<BlockArgument>())
875  return true;
876  // Accept index anywhere.
877  Operation *def = v.getDefiningOp();
878  if (isa<linalg::IndexOp>(def))
879  return true;
880  // Operation defined outside branch.
881  if (def->getBlock() != block)
882  return def->getBlock() != op->getBlock(); // invariant?
883  // Operation defined within branch. Anything is accepted,
884  // as long as all subexpressions are admissible.
885  for (unsigned i = 0, n = def->getNumOperands(); i < n; i++)
886  if (!isAdmissibleBranchExp(op, block, def->getOperand(i)))
887  return false;
888  return true;
889 }
890 
891 /// Ensures that sparse compiler can generate code for branch.
892 static bool isAdmissibleBranch(Operation *op, Region &region) {
893  if (region.empty())
894  return true;
895  // Build the semi-ring branch semantics backward from yield.
896  Operation *yield = region.front().getTerminator();
897  assert(isa<YieldOp>(yield));
898  return isAdmissibleBranchExp(op, &region.front(), yield->getOperand(0));
899 }
900 
901 Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
902  if (auto arg = v.dyn_cast<BlockArgument>()) {
903  unsigned argN = arg.getArgNumber();
904  // Any argument of the generic op that is not marked as a scalar
905  // argument is considered a tensor, indexed by the implicit loop
906  // bounds. This includes rank-0 tensor arguments.
907  if (arg.getOwner()->getParentOp() == op) {
908  OpOperand &t = op->getOpOperand(argN);
909  if (!op.isScalar(&t))
910  return addExp(kTensor, argN);
911  v = t.get(); // get scalar value
912  }
913  // Any other argument (marked as scalar argument for the generic op
914  // or belonging to an enveloping op) is considered invariant.
915  return addExp(kInvariant, v);
916  }
917  // Something defined outside is invariant.
918  Operation *def = v.getDefiningOp();
919  if (def->getBlock() != &op.getRegion().front())
920  return addExp(kInvariant, v);
921  // Construct index operations.
922  if (def->getNumOperands() == 0) {
923  if (auto indexOp = dyn_cast<linalg::IndexOp>(def))
924  return addExp(kIndex, indexOp.getDim());
925  }
926  // Construct unary operations if subexpression can be built.
927  if (def->getNumOperands() == 1) {
928  auto x = buildTensorExp(op, def->getOperand(0));
929  if (x.has_value()) {
930  unsigned e = x.value();
931  if (isa<math::AbsFOp>(def))
932  return addExp(kAbsF, e);
933  if (isa<complex::AbsOp>(def))
934  return addExp(kAbsC, e);
935  if (isa<math::AbsIOp>(def))
936  return addExp(kAbsI, e);
937  if (isa<math::CeilOp>(def))
938  return addExp(kCeilF, e);
939  if (isa<math::FloorOp>(def))
940  return addExp(kFloorF, e);
941  if (isa<math::SqrtOp>(def))
942  return addExp(kSqrtF, e);
943  if (isa<complex::SqrtOp>(def))
944  return addExp(kSqrtC, e);
945  if (isa<math::ExpM1Op>(def))
946  return addExp(kExpm1F, e);
947  if (isa<complex::Expm1Op>(def))
948  return addExp(kExpm1C, e);
949  if (isa<math::Log1pOp>(def))
950  return addExp(kLog1pF, e);
951  if (isa<complex::Log1pOp>(def))
952  return addExp(kLog1pC, e);
953  if (isa<math::SinOp>(def))
954  return addExp(kSinF, e);
955  if (isa<complex::SinOp>(def))
956  return addExp(kSinC, e);
957  if (isa<math::TanhOp>(def))
958  return addExp(kTanhF, e);
959  if (isa<complex::TanhOp>(def))
960  return addExp(kTanhC, e);
961  if (isa<arith::NegFOp>(def))
962  return addExp(kNegF, e); // no negi in std
963  if (isa<complex::NegOp>(def))
964  return addExp(kNegC, e);
965  if (isa<arith::TruncFOp>(def))
966  return addExp(kTruncF, e, v);
967  if (isa<arith::ExtFOp>(def))
968  return addExp(kExtF, e, v);
969  if (isa<arith::FPToSIOp>(def))
970  return addExp(kCastFS, e, v);
971  if (isa<arith::FPToUIOp>(def))
972  return addExp(kCastFU, e, v);
973  if (isa<arith::SIToFPOp>(def))
974  return addExp(kCastSF, e, v);
975  if (isa<arith::UIToFPOp>(def))
976  return addExp(kCastUF, e, v);
977  if (isa<arith::ExtSIOp>(def))
978  return addExp(kCastS, e, v);
979  if (isa<arith::ExtUIOp>(def))
980  return addExp(kCastU, e, v);
981  if (isa<arith::IndexCastOp>(def))
982  return addExp(kCastIdx, e, v);
983  if (isa<arith::TruncIOp>(def))
984  return addExp(kTruncI, e, v);
985  if (isa<complex::ImOp>(def))
986  return addExp(kCIm, e);
987  if (isa<complex::ReOp>(def))
988  return addExp(kCRe, e);
989  if (isa<arith::BitcastOp>(def))
990  return addExp(kBitCast, e, v);
991  if (auto unop = dyn_cast<sparse_tensor::UnaryOp>(def)) {
992  if (isAdmissibleBranch(unop, unop.getPresentRegion()) &&
993  isAdmissibleBranch(unop, unop.getAbsentRegion()))
994  return addExp(kUnary, e, Value(), def);
995  }
996  if (auto selop = dyn_cast<sparse_tensor::SelectOp>(def)) {
997  if (isAdmissibleBranch(selop, selop.getRegion()))
998  return addExp(kSelect, e, Value(), def);
999  }
1000  }
1001  }
1002  // Construct binary operations if subexpressions can be built.
1003  // See buildLattices() for an explanation of rejecting certain
1004  // division and shift operations.
1005  if (def->getNumOperands() == 2) {
1006  auto x = buildTensorExp(op, def->getOperand(0));
1007  auto y = buildTensorExp(op, def->getOperand(1));
1008  if (x.has_value() && y.has_value()) {
1009  unsigned e0 = x.value();
1010  unsigned e1 = y.value();
1011  if (isa<arith::MulFOp>(def))
1012  return addExp(kMulF, e0, e1);
1013  if (isa<complex::MulOp>(def))
1014  return addExp(kMulC, e0, e1);
1015  if (isa<arith::MulIOp>(def))
1016  return addExp(kMulI, e0, e1);
1017  if (isa<arith::DivFOp>(def) && !maybeZero(e1))
1018  return addExp(kDivF, e0, e1);
1019  if (isa<complex::DivOp>(def) && !maybeZero(e1))
1020  return addExp(kDivC, e0, e1);
1021  if (isa<arith::DivSIOp>(def) && !maybeZero(e1))
1022  return addExp(kDivS, e0, e1);
1023  if (isa<arith::DivUIOp>(def) && !maybeZero(e1))
1024  return addExp(kDivU, e0, e1);
1025  if (isa<arith::AddFOp>(def))
1026  return addExp(kAddF, e0, e1);
1027  if (isa<complex::AddOp>(def))
1028  return addExp(kAddC, e0, e1);
1029  if (isa<arith::AddIOp>(def))
1030  return addExp(kAddI, e0, e1);
1031  if (isa<arith::SubFOp>(def))
1032  return addExp(kSubF, e0, e1);
1033  if (isa<complex::SubOp>(def))
1034  return addExp(kSubC, e0, e1);
1035  if (isa<arith::SubIOp>(def))
1036  return addExp(kSubI, e0, e1);
1037  if (isa<arith::AndIOp>(def))
1038  return addExp(kAndI, e0, e1);
1039  if (isa<arith::OrIOp>(def))
1040  return addExp(kOrI, e0, e1);
1041  if (isa<arith::XOrIOp>(def))
1042  return addExp(kXorI, e0, e1);
1043  if (isa<arith::ShRSIOp>(def) && isInvariant(e1))
1044  return addExp(kShrS, e0, e1);
1045  if (isa<arith::ShRUIOp>(def) && isInvariant(e1))
1046  return addExp(kShrU, e0, e1);
1047  if (isa<arith::ShLIOp>(def) && isInvariant(e1))
1048  return addExp(kShlI, e0, e1);
1049  if (auto binop = dyn_cast<sparse_tensor::BinaryOp>(def)) {
1050  if (isAdmissibleBranch(binop, binop.getOverlapRegion()) &&
1051  (binop.getLeftIdentity() ||
1052  isAdmissibleBranch(binop, binop.getLeftRegion())) &&
1053  (binop.getRightIdentity() ||
1054  isAdmissibleBranch(binop, binop.getRightRegion())))
1055  return addExp(kBinary, e0, e1, Value(), def);
1056  }
1057  }
1058  }
1059  // Construct ternary operations if subexpressions can be built.
1060  if (def->getNumOperands() == 3) {
1061  auto x = buildTensorExp(op, def->getOperand(0));
1062  auto y = buildTensorExp(op, def->getOperand(1));
1063  auto z = buildTensorExp(op, def->getOperand(2));
1064  if (x.has_value() && y.has_value() && z.has_value()) {
1065  unsigned e0 = x.value();
1066  unsigned e1 = y.value();
1067  if (auto redop = dyn_cast<sparse_tensor::ReduceOp>(def)) {
1068  if (isAdmissibleBranch(redop, redop.getRegion()))
1069  return addExp(kReduce, e0, e1, Value(), def);
1070  }
1071  }
1072  }
1073  // Cannot build.
1074  return None;
1075 }
1076 
1077 static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region &region,
1078  ValueRange vals) {
1079  // Make a clone of overlap region.
1080  Region tmpRegion;
1081  BlockAndValueMapping mapper;
1082  region.cloneInto(&tmpRegion, tmpRegion.begin(), mapper);
1083  Block &clonedBlock = tmpRegion.front();
1084  YieldOp clonedYield = cast<YieldOp>(clonedBlock.getTerminator());
1085  // Merge cloned block and return yield value.
1086  Operation *placeholder = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1087  rewriter.mergeBlockBefore(&tmpRegion.front(), placeholder, vals);
1088  Value val = clonedYield.getResult();
1089  rewriter.eraseOp(clonedYield);
1090  rewriter.eraseOp(placeholder);
1091  return val;
1092 }
1093 
1095  Operation *op, Value v0) {
1096  if (!v0)
1097  // Empty input value must be propagated.
1098  return Value();
1099  UnaryOp unop = cast<UnaryOp>(op);
1100  Region &presentRegion = unop.getPresentRegion();
1101  if (presentRegion.empty())
1102  // Uninitialized Value() will be interpreted as missing data in the
1103  // output.
1104  return Value();
1105  return insertYieldOp(rewriter, loc, presentRegion, {v0});
1106 }
1107 
1109  Operation *op, Value v0, Value v1) {
1110  if (!v0 || !v1)
1111  // Empty input values must be propagated.
1112  return Value();
1113  BinaryOp binop = cast<BinaryOp>(op);
1114  Region &overlapRegion = binop.getOverlapRegion();
1115  if (overlapRegion.empty())
1116  // Uninitialized Value() will be interpreted as missing data in the
1117  // output.
1118  return Value();
1119  return insertYieldOp(rewriter, loc, overlapRegion, {v0, v1});
1120 }
1121 
1122 Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e,
1123  Value v0, Value v1) {
1124  switch (tensorExps[e].kind) {
1125  // Leaf.
1126  case kTensor:
1127  case kInvariant:
1128  case kIndex:
1129  llvm_unreachable("unexpected non-op");
1130  // Unary operations.
1131  case kAbsF:
1132  return rewriter.create<math::AbsFOp>(loc, v0);
1133  case kAbsC: {
1134  auto type = v0.getType().cast<ComplexType>();
1135  auto eltType = type.getElementType().cast<FloatType>();
1136  return rewriter.create<complex::AbsOp>(loc, eltType, v0);
1137  }
1138  case kAbsI:
1139  return rewriter.create<math::AbsIOp>(loc, v0);
1140  case kCeilF:
1141  return rewriter.create<math::CeilOp>(loc, v0);
1142  case kFloorF:
1143  return rewriter.create<math::FloorOp>(loc, v0);
1144  case kSqrtF:
1145  return rewriter.create<math::SqrtOp>(loc, v0);
1146  case kSqrtC:
1147  return rewriter.create<complex::SqrtOp>(loc, v0);
1148  case kExpm1F:
1149  return rewriter.create<math::ExpM1Op>(loc, v0);
1150  case kExpm1C:
1151  return rewriter.create<complex::Expm1Op>(loc, v0);
1152  case kLog1pF:
1153  return rewriter.create<math::Log1pOp>(loc, v0);
1154  case kLog1pC:
1155  return rewriter.create<complex::Log1pOp>(loc, v0);
1156  case kSinF:
1157  return rewriter.create<math::SinOp>(loc, v0);
1158  case kSinC:
1159  return rewriter.create<complex::SinOp>(loc, v0);
1160  case kTanhF:
1161  return rewriter.create<math::TanhOp>(loc, v0);
1162  case kTanhC:
1163  return rewriter.create<complex::TanhOp>(loc, v0);
1164  case kNegF:
1165  return rewriter.create<arith::NegFOp>(loc, v0);
1166  case kNegC:
1167  return rewriter.create<complex::NegOp>(loc, v0);
1168  case kNegI: // no negi in std
1169  return rewriter.create<arith::SubIOp>(
1170  loc,
1171  rewriter.create<arith::ConstantOp>(loc, v0.getType(),
1172  rewriter.getZeroAttr(v0.getType())),
1173  v0);
1174  case kTruncF:
1175  return rewriter.create<arith::TruncFOp>(loc, inferType(e, v0), v0);
1176  case kExtF:
1177  return rewriter.create<arith::ExtFOp>(loc, inferType(e, v0), v0);
1178  case kCastFS:
1179  return rewriter.create<arith::FPToSIOp>(loc, inferType(e, v0), v0);
1180  case kCastFU:
1181  return rewriter.create<arith::FPToUIOp>(loc, inferType(e, v0), v0);
1182  case kCastSF:
1183  return rewriter.create<arith::SIToFPOp>(loc, inferType(e, v0), v0);
1184  case kCastUF:
1185  return rewriter.create<arith::UIToFPOp>(loc, inferType(e, v0), v0);
1186  case kCastS:
1187  return rewriter.create<arith::ExtSIOp>(loc, inferType(e, v0), v0);
1188  case kCastU:
1189  return rewriter.create<arith::ExtUIOp>(loc, inferType(e, v0), v0);
1190  case kCastIdx:
1191  return rewriter.create<arith::IndexCastOp>(loc, inferType(e, v0), v0);
1192  case kTruncI:
1193  return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0);
1194  case kCIm: {
1195  auto type = v0.getType().cast<ComplexType>();
1196  auto eltType = type.getElementType().cast<FloatType>();
1197  return rewriter.create<complex::ImOp>(loc, eltType, v0);
1198  }
1199  case kCRe: {
1200  auto type = v0.getType().cast<ComplexType>();
1201  auto eltType = type.getElementType().cast<FloatType>();
1202  return rewriter.create<complex::ReOp>(loc, eltType, v0);
1203  }
1204  case kBitCast:
1205  return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0);
1206  // Binary operations.
1207  case kMulF:
1208  return rewriter.create<arith::MulFOp>(loc, v0, v1);
1209  case kMulC:
1210  return rewriter.create<complex::MulOp>(loc, v0, v1);
1211  case kMulI:
1212  return rewriter.create<arith::MulIOp>(loc, v0, v1);
1213  case kDivF:
1214  return rewriter.create<arith::DivFOp>(loc, v0, v1);
1215  case kDivC:
1216  return rewriter.create<complex::DivOp>(loc, v0, v1);
1217  case kDivS:
1218  return rewriter.create<arith::DivSIOp>(loc, v0, v1);
1219  case kDivU:
1220  return rewriter.create<arith::DivUIOp>(loc, v0, v1);
1221  case kAddF:
1222  return rewriter.create<arith::AddFOp>(loc, v0, v1);
1223  case kAddC:
1224  return rewriter.create<complex::AddOp>(loc, v0, v1);
1225  case kAddI:
1226  return rewriter.create<arith::AddIOp>(loc, v0, v1);
1227  case kSubF:
1228  return rewriter.create<arith::SubFOp>(loc, v0, v1);
1229  case kSubC:
1230  return rewriter.create<complex::SubOp>(loc, v0, v1);
1231  case kSubI:
1232  return rewriter.create<arith::SubIOp>(loc, v0, v1);
1233  case kAndI:
1234  return rewriter.create<arith::AndIOp>(loc, v0, v1);
1235  case kOrI:
1236  return rewriter.create<arith::OrIOp>(loc, v0, v1);
1237  case kXorI:
1238  return rewriter.create<arith::XOrIOp>(loc, v0, v1);
1239  case kShrS:
1240  return rewriter.create<arith::ShRSIOp>(loc, v0, v1);
1241  case kShrU:
1242  return rewriter.create<arith::ShRUIOp>(loc, v0, v1);
1243  case kShlI:
1244  return rewriter.create<arith::ShLIOp>(loc, v0, v1);
1245  case kBinaryBranch: // semi-ring ops with custom logic.
1246  return insertYieldOp(rewriter, loc,
1247  *tensorExps[e].op->getBlock()->getParent(), {v0});
1248  case kUnary:
1249  return buildUnaryPresent(rewriter, loc, tensorExps[e].op, v0);
1250  case kSelect:
1251  return insertYieldOp(rewriter, loc,
1252  cast<SelectOp>(tensorExps[e].op).getRegion(), {v0});
1253  case kBinary:
1254  return buildBinaryOverlap(rewriter, loc, tensorExps[e].op, v0, v1);
1255  case kReduce: {
1256  ReduceOp redOp = cast<ReduceOp>(tensorExps[e].op);
1257  return insertYieldOp(rewriter, loc, redOp.getRegion(), {v0, v1});
1258  }
1259  }
1260  llvm_unreachable("unexpected expression kind in build");
1261 }
1262 
1263 } // namespace sparse_tensor
1264 } // namespace mlir
@ None
This class represents an argument of a Block.
Definition: Value.h:296
Block represents an ordered list of Operations.
Definition: Block.h:30
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:232
Attribute getZeroAttr(Type type)
Definition: Builders.cpp:306
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:137
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:64
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:422
This class represents an operand of an operation.
Definition: Value.h:247
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:31
Value getOperand(unsigned idx)
Definition: Operation.h:267
unsigned getNumOperands()
Definition: Operation.h:263
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:144
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:486
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
iterator begin()
Definition: Region.h:55
void cloneInto(Region *dest, BlockAndValueMapping &mapper)
Clone the internal blocks from this region into dest.
Definition: Region.cpp:70
Block & front()
Definition: Region.h:65
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:398
void mergeBlockBefore(Block *source, Operation *op, ValueRange argValues=llvm::None)
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
U cast() const
Definition: Types.h:280
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:349
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
Type getType() const
Return the type of this value.
Definition: Value.h:114
U dyn_cast() const
Definition: Value.h:95
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Specialization of arith.constant op that returns a floating point value.
Definition: Arith.h:73
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:89
Specialization of arith.constant op that returns an integer value.
Definition: Arith.h:52
unsigned takeCombi(Kind kind, unsigned s0, unsigned s1, Operation *orig, bool includeLeft, Kind ltrans, Operation *opleft, bool includeRight, Kind rtrans, Operation *opright)
Disjunctive merge of two lattice sets L0 and L1 with custom handling of the overlap,...
Definition: Merger.cpp:194
unsigned optimizeSet(unsigned s0)
Optimizes the iteration lattice points in the given set.
Definition: Merger.cpp:226
bool latGT(unsigned i, unsigned j) const
Returns true if Li > Lj.
Definition: Merger.cpp:294
unsigned buildLattices(unsigned e, unsigned i)
Builds the iteration lattices in a bottom-up traversal given the remaining tensor (sub)expression and...
Definition: Merger.cpp:638
unsigned addExp(Kind k, unsigned e0, unsigned e1=-1u, Value v=Value(), Operation *op=nullptr)
Adds a tensor expression. Returns its index.
Definition: Merger.cpp:137
void dumpBits(const BitVector &bits) const
Definition: Merger.cpp:611
void dumpSet(unsigned s) const
Definition: Merger.cpp:602
void dumpExp(unsigned e) const
Print methods (for debugging).
Definition: Merger.cpp:508
unsigned mapSet(Kind kind, unsigned s0, Value v=Value(), Operation *op=nullptr)
Maps the unary operator over the lattice set of the operand, i.e.
Definition: Merger.cpp:215
DimLevelType getDimLevelType(unsigned t, unsigned i) const
Gets the dimension level type of the tth tensor on ith loop.
Definition: Merger.h:284
TensorExp & exp(unsigned e)
Convenience getters to immediately access the stored nodes.
Definition: Merger.h:339
bool onlyDenseDiff(unsigned i, unsigned j)
Returns true if Li and Lj only differ in dense.
Definition: Merger.cpp:307
unsigned addSet()
Adds a new, initially empty, set. Returns its index.
Definition: Merger.cpp:151
unsigned tensor(unsigned b) const
Bit translation (get tensor ID).
Definition: Merger.h:245
void dumpLat(unsigned p) const
Definition: Merger.cpp:592
unsigned takeConj(Kind kind, unsigned s0, unsigned s1, Operation *op=nullptr)
Conjunctive merge of two lattice sets L0 and L1 is conjunction of cartesian product.
Definition: Merger.cpp:167
unsigned conjLatPoint(Kind kind, unsigned p0, unsigned p1, Operation *op=nullptr)
Computes a single conjunction of two lattice points by taking the "union" of loop indices (effectivel...
Definition: Merger.cpp:157
Value buildExp(RewriterBase &rewriter, Location loc, unsigned e, Value v0, Value v1)
Rebuilds SSA format from a tensor expression.
Definition: Merger.cpp:1122
bool hasAnySparse(const BitVector &bits) const
Returns true if any set bit corresponds to sparse dimension level type.
Definition: Merger.cpp:397
unsigned index(unsigned b) const
Bit translation (get loop index).
Definition: Merger.h:247
Optional< unsigned > buildTensorExpFromLinalg(linalg::GenericOp op)
Builds a tensor expression from the given Linalg operation.
Definition: Merger.cpp:834
unsigned addLat(unsigned t, unsigned i, unsigned e)
Adds an iteration lattice point. Returns its index.
Definition: Merger.cpp:144
bool isSingleCondition(unsigned t, unsigned e) const
Returns true if given tensor iterates only in the given tensor expression.
Definition: Merger.cpp:313
BitVector simplifyCond(unsigned s0, unsigned p0)
Simplifies the conditions in a conjunction of a given lattice point within the given set using just t...
Definition: Merger.cpp:255
unsigned takeDisj(Kind kind, unsigned s0, unsigned s1, Operation *op=nullptr)
Disjunctive merge of two lattice sets L0 and L1 is (L0 /_op L1, L0, L1).
Definition: Merger.cpp:175
@ Type
An inlay hint that for a type annotation.
static bool isAdmissibleBranchExp(Operation *op, Block *block, Value v)
Ensures that sparse compiler can generate code for expression.
Definition: Merger.cpp:872
Kind
Tensor expression kind.
Definition: Merger.h:25
static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region &region, ValueRange vals)
Definition: Merger.cpp:1077
constexpr bool isSingletonDLT(DimLevelType dlt)
Check if the DimLevelType is singleton (regardless of properties).
Definition: Enums.h:191
constexpr bool isUniqueDLT(DimLevelType dlt)
Check if the DimLevelType is unique (regardless of storage format).
Definition: Enums.h:202
static Value buildUnaryPresent(RewriterBase &rewriter, Location loc, Operation *op, Value v0)
Definition: Merger.cpp:1094
constexpr bool isDenseDLT(DimLevelType dlt)
Check if the DimLevelType is dense.
Definition: Enums.h:176
static Value buildBinaryOverlap(RewriterBase &rewriter, Location loc, Operation *op, Value v0, Value v1)
Definition: Merger.cpp:1108
static bool isAdmissibleBranch(Operation *op, Region &region)
Ensures that sparse compiler can generate code for branch.
Definition: Merger.cpp:892
constexpr bool isUndefDLT(DimLevelType dlt)
Check if the DimLevelType is the special undefined value.
Definition: Enums.h:171
DimLevelType
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition: Enums.h:147
constexpr bool isCompressedDLT(DimLevelType dlt)
Check if the DimLevelType is compressed (regardless of properties).
Definition: Enums.h:185
static const char * kindToOpSymbol(Kind kind)
Definition: Merger.cpp:411
constexpr bool isOrderedDLT(DimLevelType dlt)
Check if the DimLevelType is ordered (regardless of storage format).
Definition: Enums.h:197
Include the generated interface declarations.
LatPoint(unsigned n, unsigned e, unsigned b)
Definition: Merger.cpp:126
BitVector bits
Conjunction of tensor loop indices as bitvector.
Definition: Merger.h:134
Tensor expression. Represents a MLIR expression in tensor index notation.
Definition: Merger.h:96
TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *operation)
Definition: Merger.cpp:25
Children children
Tensor operations hold the indices of their children.
Definition: Merger.h:110
Kind kind
Tensor expression kind.
Definition: Merger.h:100
unsigned index
Indices hold the index number.
Definition: Merger.h:107
unsigned tensor
Expressions representing tensors simply have a tensor number.
Definition: Merger.h:104
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.