MLIR  15.0.0git
Merger.cpp
Go to the documentation of this file.
1 //===- Merger.cpp - Implementation of iteration lattices ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
14 
15 #include "mlir/IR/Operation.h"
16 #include "llvm/Support/Debug.h"
17 
18 namespace mlir {
19 namespace sparse_tensor {
20 
21 //===----------------------------------------------------------------------===//
22 // Constructors.
23 //===----------------------------------------------------------------------===//
24 
25 TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o)
26  : kind(k), val(v), op(o) {
27  switch (kind) {
28  // Leaf.
29  case kTensor:
30  assert(x != -1u && y == -1u && !v && !o);
31  tensor = x;
32  break;
33  case kInvariant:
34  assert(x == -1u && y == -1u && v && !o);
35  break;
36  case kIndex:
37  assert(x != -1u && y == -1u && !v && !o);
38  index = x;
39  break;
40  // Unary operations.
41  case kAbsF:
42  case kAbsC:
43  case kCeilF:
44  case kFloorF:
45  case kSqrtF:
46  case kSqrtC:
47  case kExpm1F:
48  case kExpm1C:
49  case kLog1pF:
50  case kLog1pC:
51  case kSinF:
52  case kSinC:
53  case kTanhF:
54  case kTanhC:
55  case kNegF:
56  case kNegC:
57  case kNegI:
58  case kCIm:
59  case kCRe:
60  assert(x != -1u && y == -1u && !v && !o);
61  children.e0 = x;
62  children.e1 = y;
63  break;
64  case kTruncF:
65  case kExtF:
66  case kCastFS:
67  case kCastFU:
68  case kCastSF:
69  case kCastUF:
70  case kCastS:
71  case kCastU:
72  case kCastIdx:
73  case kTruncI:
74  case kBitCast:
75  assert(x != -1u && y == -1u && v && !o);
76  children.e0 = x;
77  children.e1 = y;
78  break;
79  case kBinaryBranch:
80  assert(x != -1u && y == -1u && !v && o);
81  children.e0 = x;
82  children.e1 = y;
83  break;
84  case kUnary:
85  // No assertion on y can be made, as the branching paths involve both
86  // a unary (mapSet) and binary (takeDisj) pathway.
87  assert(x != -1u && !v && o);
88  children.e0 = x;
89  children.e1 = y;
90  break;
91  // Binary operations.
92  case kMulF:
93  case kMulC:
94  case kMulI:
95  case kDivF:
96  case kDivC:
97  case kDivS:
98  case kDivU:
99  case kAddF:
100  case kAddC:
101  case kAddI:
102  case kSubF:
103  case kSubC:
104  case kSubI:
105  case kAndI:
106  case kOrI:
107  case kXorI:
108  case kShrS:
109  case kShrU:
110  case kShlI:
111  assert(x != -1u && y != -1u && !v && !o);
112  children.e0 = x;
113  children.e1 = y;
114  break;
115  case kBinary:
116  assert(x != -1u && y != -1u && !v && o);
117  children.e0 = x;
118  children.e1 = y;
119  break;
120  }
121 }
122 
123 LatPoint::LatPoint(unsigned n, unsigned e, unsigned b)
124  : bits(n, false), simple(), exp(e) {
125  bits.set(b);
126 }
127 
128 LatPoint::LatPoint(const BitVector &b, unsigned e)
129  : bits(b), simple(), exp(e) {}
130 
131 //===----------------------------------------------------------------------===//
132 // Lattice methods.
133 //===----------------------------------------------------------------------===//
134 
135 unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v,
136  Operation *op) {
137  unsigned e = tensorExps.size();
138  tensorExps.push_back(TensorExp(k, e0, e1, v, op));
139  return e;
140 }
141 
142 unsigned Merger::addLat(unsigned t, unsigned i, unsigned e) {
143  assert(t < numTensors && i < numLoops);
144  unsigned p = latPoints.size();
145  latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t));
146  return p;
147 }
148 
149 unsigned Merger::addSet() {
150  unsigned s = latSets.size();
151  latSets.emplace_back(SmallVector<unsigned, 16>());
152  return s;
153 }
154 
155 unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1,
156  Operation *op) {
157  unsigned p = latPoints.size();
158  BitVector nb = BitVector(latPoints[p0].bits);
159  nb |= latPoints[p1].bits;
160  unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp, Value(), op);
161  latPoints.push_back(LatPoint(nb, e));
162  return p;
163 }
164 
165 unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1, Operation *op) {
166  unsigned s = addSet();
167  for (unsigned p0 : latSets[s0])
168  for (unsigned p1 : latSets[s1])
169  latSets[s].push_back(conjLatPoint(kind, p0, p1, op));
170  return s;
171 }
172 
173 unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1, Operation *op) {
174  unsigned s = takeConj(kind, s0, s1, op);
175  // Followed by all in s0.
176  for (unsigned p : latSets[s0])
177  latSets[s].push_back(p);
178  // Map binary 0-y to unary -y.
179  // TODO: move this if-else logic into buildLattices
180  if (kind == kSubF)
181  s1 = mapSet(kNegF, s1);
182  else if (kind == kSubC)
183  s1 = mapSet(kNegC, s1);
184  else if (kind == kSubI)
185  s1 = mapSet(kNegI, s1);
186  // Followed by all in s1.
187  for (unsigned p : latSets[s1])
188  latSets[s].push_back(p);
189  return s;
190 }
191 
192 unsigned Merger::takeCombi(Kind kind, unsigned s0, unsigned s1, Operation *orig,
193  bool includeLeft, Kind ltrans, Operation *opleft,
194  bool includeRight, Kind rtrans, Operation *opright) {
195  unsigned s = takeConj(kind, s0, s1, orig);
196  // Left Region.
197  if (includeLeft) {
198  if (opleft)
199  s0 = mapSet(ltrans, s0, Value(), opleft);
200  for (unsigned p : latSets[s0])
201  latSets[s].push_back(p);
202  }
203  // Right Region.
204  if (includeRight) {
205  if (opright)
206  s1 = mapSet(rtrans, s1, Value(), opright);
207  for (unsigned p : latSets[s1])
208  latSets[s].push_back(p);
209  }
210  return s;
211 }
212 
213 unsigned Merger::mapSet(Kind kind, unsigned s0, Value v, Operation *op) {
214  assert(kAbsF <= kind && kind <= kUnary);
215  unsigned s = addSet();
216  for (unsigned p : latSets[s0]) {
217  unsigned e = addExp(kind, latPoints[p].exp, v, op);
218  latPoints.push_back(LatPoint(latPoints[p].bits, e));
219  latSets[s].push_back(latPoints.size() - 1);
220  }
221  return s;
222 }
223 
224 unsigned Merger::optimizeSet(unsigned s0) {
225  unsigned s = addSet();
226  assert(!latSets[s0].empty());
227  unsigned p0 = latSets[s0][0];
228  for (unsigned p1 : latSets[s0]) {
229  bool add = true;
230  if (p0 != p1) {
231  // Is this a straightforward copy?
232  unsigned e = latPoints[p1].exp;
233  if (tensorExps[e].kind == kTensor && tensorExps[e].tensor == outTensor)
234  continue;
235  // Conjunction already covered?
236  for (unsigned p2 : latSets[s]) {
237  assert(!latGT(p1, p2)); // Lj => Li would be bad
238  if (onlyDenseDiff(p2, p1)) {
239  add = false;
240  break;
241  }
242  }
243  assert(!add || latGT(p0, p1));
244  }
245  if (add)
246  latSets[s].push_back(p1);
247  }
248  for (unsigned p : latSets[s])
249  latPoints[p].simple = simplifyCond(s, p);
250  return s;
251 }
252 
253 BitVector Merger::simplifyCond(unsigned s0, unsigned p0) {
254  // First determine if this lattice point is a *singleton*, i.e.,
255  // the last point in a lattice, no other is less than this one.
256  bool isSingleton = true;
257  for (unsigned p1 : latSets[s0]) {
258  if (p0 != p1 && latGT(p0, p1)) {
259  isSingleton = false;
260  break;
261  }
262  }
263  // Now apply the two basic rules.
264  BitVector simple = latPoints[p0].bits;
265  bool reset = isSingleton && hasAnyDimOf(simple, kSparse);
266  for (unsigned b = 0, be = simple.size(); b < be; b++) {
267  if (simple[b] && !isDim(b, kSparse)) {
268  if (reset)
269  simple.reset(b);
270  reset = true;
271  }
272  }
273  return simple;
274 }
275 
276 bool Merger::latGT(unsigned i, unsigned j) const {
277  const BitVector &bitsi = latPoints[i].bits;
278  const BitVector &bitsj = latPoints[j].bits;
279  assert(bitsi.size() == bitsj.size());
280  if (bitsi.count() > bitsj.count()) {
281  for (unsigned b = 0, be = bitsj.size(); b < be; b++)
282  if (bitsj[b] && !bitsi[b])
283  return false;
284  return true;
285  }
286  return false;
287 }
288 
289 bool Merger::onlyDenseDiff(unsigned i, unsigned j) {
290  BitVector tmp = latPoints[j].bits;
291  tmp ^= latPoints[i].bits;
292  return !hasAnyDimOf(tmp, kSparse);
293 }
294 
295 bool Merger::hasAnyDimOf(const BitVector &bits, Dim d) const {
296  for (unsigned b = 0, be = bits.size(); b < be; b++)
297  if (bits[b] && isDim(b, d))
298  return true;
299  return false;
300 }
301 
302 bool Merger::isSingleCondition(unsigned t, unsigned e) const {
303  switch (tensorExps[e].kind) {
304  // Leaf.
305  case kTensor:
306  return tensorExps[e].tensor == t;
307  case kInvariant:
308  case kIndex:
309  return false;
310  // Unary operations.
311  case kAbsF:
312  case kAbsC:
313  case kCeilF:
314  case kFloorF:
315  case kSqrtF:
316  case kSqrtC:
317  case kExpm1F:
318  case kExpm1C:
319  case kLog1pF:
320  case kLog1pC:
321  case kSinF:
322  case kSinC:
323  case kTanhF:
324  case kTanhC:
325  case kNegF:
326  case kNegC:
327  case kNegI:
328  case kTruncF:
329  case kExtF:
330  case kCastFS:
331  case kCastFU:
332  case kCastSF:
333  case kCastUF:
334  case kCastS:
335  case kCastU:
336  case kCastIdx:
337  case kTruncI:
338  case kCIm:
339  case kCRe:
340  case kBitCast:
341  return isSingleCondition(t, tensorExps[e].children.e0);
342  case kBinaryBranch:
343  case kUnary:
344  return false;
345  // Binary operations.
346  case kDivF: // note: x / c only
347  case kDivC:
348  case kDivS:
349  case kDivU:
350  assert(!maybeZero(tensorExps[e].children.e1));
351  return isSingleCondition(t, tensorExps[e].children.e0);
352  case kShrS: // note: x >> inv only
353  case kShrU:
354  case kShlI:
355  assert(isInvariant(tensorExps[e].children.e1));
356  return isSingleCondition(t, tensorExps[e].children.e0);
357  case kMulF:
358  case kMulC:
359  case kMulI:
360  case kAndI:
361  if (isSingleCondition(t, tensorExps[e].children.e0))
362  return isSingleCondition(t, tensorExps[e].children.e1) ||
363  isInvariant(tensorExps[e].children.e1);
364  if (isSingleCondition(t, tensorExps[e].children.e1))
365  return isInvariant(tensorExps[e].children.e0);
366  return false;
367  case kAddF:
368  case kAddC:
369  case kAddI:
370  return isSingleCondition(t, tensorExps[e].children.e0) &&
371  isSingleCondition(t, tensorExps[e].children.e1);
372  case kSubF:
373  case kSubC:
374  case kSubI:
375  case kOrI:
376  case kXorI:
377  case kBinary:
378  return false;
379  }
380  llvm_unreachable("unexpected kind");
381 }
382 
383 #ifndef NDEBUG
384 
385 //===----------------------------------------------------------------------===//
386 // Print methods (for debugging).
387 //===----------------------------------------------------------------------===//
388 
389 static const char *kindToOpSymbol(Kind kind) {
390  switch (kind) {
391  // Leaf.
392  case kTensor:
393  return "tensor";
394  case kInvariant:
395  return "invariant";
396  case kIndex:
397  return "index";
398  // Unary operations.
399  case kAbsF:
400  case kAbsC:
401  return "abs";
402  case kCeilF:
403  return "ceil";
404  case kFloorF:
405  return "floor";
406  case kSqrtF:
407  case kSqrtC:
408  return "sqrt";
409  case kExpm1F:
410  case kExpm1C:
411  return "expm1";
412  case kLog1pF:
413  case kLog1pC:
414  return "log1p";
415  case kSinF:
416  case kSinC:
417  return "sin";
418  case kTanhF:
419  case kTanhC:
420  return "tanh";
421  case kNegF:
422  case kNegC:
423  case kNegI:
424  return "-";
425  case kTruncF:
426  case kExtF:
427  case kCastFS:
428  case kCastFU:
429  case kCastSF:
430  case kCastUF:
431  case kCastS:
432  case kCastU:
433  case kCastIdx:
434  case kTruncI:
435  case kCIm:
436  return "complex.im";
437  case kCRe:
438  return "complex.re";
439  case kBitCast:
440  return "cast";
441  case kBinaryBranch:
442  return "binary_branch";
443  case kUnary:
444  return "unary";
445  // Binary operations.
446  case kMulF:
447  case kMulC:
448  case kMulI:
449  return "*";
450  case kDivF:
451  case kDivC:
452  case kDivS:
453  case kDivU:
454  return "/";
455  case kAddF:
456  case kAddC:
457  case kAddI:
458  return "+";
459  case kSubF:
460  case kSubC:
461  case kSubI:
462  return "-";
463  case kAndI:
464  return "&";
465  case kOrI:
466  return "|";
467  case kXorI:
468  return "^";
469  case kShrS:
470  return "a>>";
471  case kShrU:
472  return ">>";
473  case kShlI:
474  return "<<";
475  case kBinary:
476  return "binary";
477  }
478  llvm_unreachable("unexpected kind for symbol");
479 }
480 
481 void Merger::dumpExp(unsigned e) const {
482  switch (tensorExps[e].kind) {
483  // Leaf.
484  case kTensor:
485  if (tensorExps[e].tensor == syntheticTensor)
486  llvm::dbgs() << "synthetic_";
487  else if (tensorExps[e].tensor == outTensor)
488  llvm::dbgs() << "output_";
489  llvm::dbgs() << "tensor_" << tensorExps[e].tensor;
490  break;
491  case kInvariant:
492  llvm::dbgs() << "invariant";
493  break;
494  case kIndex:
495  llvm::dbgs() << "index_" << tensorExps[e].index;
496  break;
497  // Unary operations.
498  case kAbsF:
499  case kAbsC:
500  case kCeilF:
501  case kFloorF:
502  case kSqrtF:
503  case kSqrtC:
504  case kExpm1F:
505  case kExpm1C:
506  case kLog1pF:
507  case kLog1pC:
508  case kSinF:
509  case kSinC:
510  case kTanhF:
511  case kTanhC:
512  case kNegF:
513  case kNegC:
514  case kNegI:
515  case kTruncF:
516  case kExtF:
517  case kCastFS:
518  case kCastFU:
519  case kCastSF:
520  case kCastUF:
521  case kCastS:
522  case kCastU:
523  case kCastIdx:
524  case kTruncI:
525  case kCIm:
526  case kCRe:
527  case kBitCast:
528  case kBinaryBranch:
529  case kUnary:
530  llvm::dbgs() << kindToOpSymbol(tensorExps[e].kind) << " ";
531  dumpExp(tensorExps[e].children.e0);
532  break;
533  // Binary operations.
534  case kMulF:
535  case kMulC:
536  case kMulI:
537  case kDivF:
538  case kDivC:
539  case kDivS:
540  case kDivU:
541  case kAddF:
542  case kAddC:
543  case kAddI:
544  case kSubF:
545  case kSubC:
546  case kSubI:
547  case kAndI:
548  case kOrI:
549  case kXorI:
550  case kShrS:
551  case kShrU:
552  case kShlI:
553  case kBinary:
554  llvm::dbgs() << "(";
555  dumpExp(tensorExps[e].children.e0);
556  llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " ";
557  dumpExp(tensorExps[e].children.e1);
558  llvm::dbgs() << ")";
559  }
560 }
561 
562 void Merger::dumpLat(unsigned p) const {
563  llvm::dbgs() << "lat(";
564  dumpBits(latPoints[p].bits);
565  llvm::dbgs() << " :";
566  dumpBits(latPoints[p].simple);
567  llvm::dbgs() << " : ";
568  dumpExp(latPoints[p].exp);
569  llvm::dbgs() << " )\n";
570 }
571 
572 void Merger::dumpSet(unsigned s) const {
573  llvm::dbgs() << "{ #" << latSets[s].size() << "\n";
574  for (unsigned p : latSets[s]) {
575  llvm::dbgs() << " ";
576  dumpLat(p);
577  }
578  llvm::dbgs() << "}\n";
579 }
580 
581 void Merger::dumpBits(const BitVector &bits) const {
582  for (unsigned b = 0, be = bits.size(); b < be; b++) {
583  if (bits[b]) {
584  unsigned t = tensor(b);
585  unsigned i = index(b);
586  llvm::dbgs() << " i_" << t << "_" << i << "_";
587  switch (dims[t][i]) {
588  case kSparse:
589  llvm::dbgs() << "S";
590  break;
591  case kDense:
592  llvm::dbgs() << "D";
593  break;
594  case kSingle:
595  llvm::dbgs() << "T";
596  break;
597  case kUndef:
598  llvm::dbgs() << "U";
599  break;
600  }
601  }
602  }
603 }
604 
605 #endif // NDEBUG
606 
607 //===----------------------------------------------------------------------===//
608 // Builder methods.
609 //===----------------------------------------------------------------------===//
610 
611 unsigned Merger::buildLattices(unsigned e, unsigned i) {
612  Kind kind = tensorExps[e].kind;
613  switch (kind) {
614  // Leaf.
615  case kTensor:
616  case kInvariant:
617  case kIndex: {
618  // Either the index is really used in the tensor expression, or it is
619  // set to the undefined index in that dimension. An invariant expression,
620  // a proper index value, and a truly dynamic sparse output tensor are set
621  // to a synthetic tensor with undefined indices only to ensure the
622  // iteration space is not skipped as a result of their contents.
623  unsigned s = addSet();
624  unsigned t = syntheticTensor;
625  if (kind == kTensor) {
626  t = tensorExps[e].tensor;
627  if (hasSparseOut && t == outTensor)
628  t = syntheticTensor;
629  }
630  latSets[s].push_back(addLat(t, i, e));
631  return s;
632  }
633  // Unary operations.
634  case kAbsF:
635  case kAbsC:
636  case kCeilF:
637  case kFloorF:
638  case kSqrtF:
639  case kSqrtC:
640  case kExpm1F:
641  case kExpm1C:
642  case kLog1pF:
643  case kLog1pC:
644  case kSinF:
645  case kSinC:
646  case kTanhF:
647  case kTanhC:
648  case kNegF:
649  case kNegC:
650  case kNegI:
651  case kTruncF:
652  case kExtF:
653  case kCastFS:
654  case kCastFU:
655  case kCastSF:
656  case kCastUF:
657  case kCastS:
658  case kCastU:
659  case kCastIdx:
660  case kTruncI:
661  case kCIm:
662  case kCRe:
663  case kBitCast:
664  // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the
665  // lattice set of the operand through the operator into a new set.
666  //
667  // -y|!y | y |
668  // --+---+---+
669  // | 0 |-y |
670  return mapSet(kind, buildLattices(tensorExps[e].children.e0, i),
671  tensorExps[e].val);
672  case kBinaryBranch:
673  // The left or right half of a binary operation which has already
674  // been split into separate operations for each region.
675  return mapSet(kind, buildLattices(tensorExps[e].children.e0, i), Value(),
676  tensorExps[e].op);
677  case kUnary:
678  // A custom unary operation.
679  //
680  // op y| !y | y |
681  // ----+----------+------------+
682  // | absent() | present(y) |
683  {
684  unsigned child0 = buildLattices(tensorExps[e].children.e0, i);
685  UnaryOp unop = cast<UnaryOp>(tensorExps[e].op);
686  Region &absentRegion = unop.getAbsentRegion();
687 
688  if (absentRegion.empty()) {
689  // Simple mapping over existing values.
690  return mapSet(kind, child0, Value(), unop);
691  } // Use a disjunction with `unop` on the left and the absent value as an
692  // invariant on the right.
693  Block &absentBlock = absentRegion.front();
694  YieldOp absentYield = cast<YieldOp>(absentBlock.getTerminator());
695  Value absentVal = absentYield.getResult();
696  unsigned rhs = addExp(kInvariant, absentVal);
697  return takeDisj(kind, child0, buildLattices(rhs, i), unop);
698  }
699  // Binary operations.
700  case kMulF:
701  case kMulC:
702  case kMulI:
703  case kAndI:
704  // A multiplicative operation only needs to be performed
705  // for the conjunction of sparse iteration spaces.
706  //
707  // x*y|!y | y |
708  // ---+---+---+
709  // !x | 0 | 0 |
710  // x | 0 |x*y|
711  //
712  // Note even here, 0*NaN=NaN and 0*Inf=NaN, but that is ignored.
713  return takeConj(kind, // take binary conjunction
714  buildLattices(tensorExps[e].children.e0, i),
715  buildLattices(tensorExps[e].children.e1, i));
716  case kDivF:
717  case kDivC:
718  case kDivS:
719  case kDivU:
720  // A division is tricky, since 0/0, 0/c, c/0 all have
721  // specific outcomes for floating-point and integers.
722  // Thus, we need to traverse the full iteration space.
723  //
724  // x/y|!y | y |
725  // ---+---+---+
726  // !x |0/0|0/y| FP: 0/0=NaN,c/0=Inf,0/c=0 with c true nonzero
727  // x |x/0|x/y| INT: x/0=exception for any x
728  //
729  // TODO: for now we "fixed" this by only accepting x/c cases
730  // during expression building, so that the conjunction
731  // rules applies (viz. x/c = x*(1/c) as far as lattice
732  // construction is concerned).
733  assert(!maybeZero(tensorExps[e].children.e1));
734  return takeConj(kind, // take binary conjunction
735  buildLattices(tensorExps[e].children.e0, i),
736  buildLattices(tensorExps[e].children.e1, i));
737  case kAddF:
738  case kAddC:
739  case kAddI:
740  case kSubF:
741  case kSubC:
742  case kSubI:
743  case kOrI:
744  case kXorI:
745  // An additive operation needs to be performed
746  // for the disjunction of sparse iteration spaces.
747  //
748  // x+y|!y | y | x-y|!y | y |
749  // ---+---+---+ ---+---+---+
750  // !x | 0 | y | !x | 0 |-y |
751  // x | x |x+y| x | x |x-y|
752  return takeDisj(kind, // take binary disjunction
753  buildLattices(tensorExps[e].children.e0, i),
754  buildLattices(tensorExps[e].children.e1, i));
755  case kShrS:
756  case kShrU:
757  case kShlI:
758  // A shift operation by an invariant amount (viz. tensor expressions
759  // can only occur at the left-hand-side of the operator) can be handled
760  // with the conjuction rule.
761  assert(isInvariant(tensorExps[e].children.e1));
762  return takeConj(kind, // take binary conjunction
763  buildLattices(tensorExps[e].children.e0, i),
764  buildLattices(tensorExps[e].children.e1, i));
765  case kBinary:
766  // A custom binary operation.
767  //
768  // x op y| !y | y |
769  // ------+---------+--------------+
770  // !x | empty | right(y) |
771  // x | left(x) | overlap(x,y) |
772  {
773  unsigned child0 = buildLattices(tensorExps[e].children.e0, i);
774  unsigned child1 = buildLattices(tensorExps[e].children.e1, i);
775  BinaryOp binop = cast<BinaryOp>(tensorExps[e].op);
776  Region &leftRegion = binop.getLeftRegion();
777  Region &rightRegion = binop.getRightRegion();
778  // Left Region.
779  Operation *leftYield = nullptr;
780  if (!leftRegion.empty()) {
781  Block &leftBlock = leftRegion.front();
782  leftYield = leftBlock.getTerminator();
783  }
784  // Right Region.
785  Operation *rightYield = nullptr;
786  if (!rightRegion.empty()) {
787  Block &rightBlock = rightRegion.front();
788  rightYield = rightBlock.getTerminator();
789  }
790  bool includeLeft = binop.getLeftIdentity() || !leftRegion.empty();
791  bool includeRight = binop.getRightIdentity() || !rightRegion.empty();
792  return takeCombi(kBinary, child0, child1, binop, includeLeft,
793  kBinaryBranch, leftYield, includeRight, kBinaryBranch,
794  rightYield);
795  }
796  }
797  llvm_unreachable("unexpected expression kind");
798 }
799 
801  // Build the linalg semantics backward from yield.
802  Operation *yield = op.region().front().getTerminator();
803  assert(isa<linalg::YieldOp>(yield));
804  return buildTensorExp(op, yield->getOperand(0));
805 }
806 
807 /// Only returns false if we are certain this is a nonzero.
808 bool Merger::maybeZero(unsigned e) const {
809  if (tensorExps[e].kind == kInvariant) {
810  if (auto c = tensorExps[e].val.getDefiningOp<complex::ConstantOp>()) {
811  ArrayAttr arrayAttr = c.getValue();
812  return arrayAttr[0].cast<FloatAttr>().getValue().isZero() &&
813  arrayAttr[0].cast<FloatAttr>().getValue().isZero();
814  }
815  if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantIntOp>())
816  return c.value() == 0;
817  if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantFloatOp>())
818  return c.value().isZero();
819  }
820  return true;
821 }
822 
823 bool Merger::isInvariant(unsigned e) const {
824  return tensorExps[e].kind == kInvariant;
825 }
826 
827 Type Merger::inferType(unsigned e, Value src) {
828  // Obtain the destination type from the cast node.
829  Type dtp = tensorExps[e].val.getType();
830  // Inspect source type. For vector types, apply the same
831  // vectorization to the destination type.
832  if (auto vtp = src.getType().dyn_cast<VectorType>())
833  return VectorType::get(vtp.getNumElements(), dtp, vtp.getNumScalableDims());
834  return dtp;
835 }
836 
837 /// Ensures that sparse compiler can generate code for expression.
838 static bool isAdmissableBranchExp(Operation *op, Block *block, Value v) {
839  // Arguments are always admissable.
840  if (auto arg = v.dyn_cast<BlockArgument>())
841  return true;
842  // Accept index anywhere.
843  Operation *def = v.getDefiningOp();
844  if (isa<linalg::IndexOp>(def))
845  return true;
846  // Operation defined outside branch.
847  if (def->getBlock() != block) {
848  return def->getBlock() != op->getBlock(); // invariant?
849  }
850  // Operation defined within branch. Anything is accepted,
851  // as long as all subexpressions are admissable.
852  for (unsigned i = 0, n = def->getNumOperands(); i < n; i++)
853  if (!isAdmissableBranchExp(op, block, def->getOperand(i)))
854  return false;
855  return true;
856 }
857 
858 /// Ensures that sparse compiler can generate code for branch.
859 static bool isAdmissableBranch(Operation *op, Region &region) {
860  if (region.empty())
861  return true;
862  // Build the semi-ring branch semantics backward from yield.
863  Operation *yield = region.front().getTerminator();
864  assert(isa<YieldOp>(yield));
865  return isAdmissableBranchExp(op, &region.front(), yield->getOperand(0));
866 }
867 
868 Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
869  if (auto arg = v.dyn_cast<BlockArgument>()) {
870  unsigned argN = arg.getArgNumber();
871  // Any argument of the generic op that is not marked as a scalar
872  // argument is considered a tensor, indexed by the implicit loop
873  // bounds. This includes rank-0 tensor arguments.
874  if (arg.getOwner()->getParentOp() == op) {
875  OpOperand *t = op.getInputAndOutputOperands()[argN];
876  if (!op.isScalar(t))
877  return addExp(kTensor, argN);
878  v = t->get(); // get scalar value
879  }
880  // Any other argument (marked as scalar argument for the generic op
881  // or belonging to an enveloping op) is considered invariant.
882  return addExp(kInvariant, v);
883  }
884  // Something defined outside is invariant.
885  Operation *def = v.getDefiningOp();
886  if (def->getBlock() != &op.region().front())
887  return addExp(kInvariant, v);
888  // Construct index operations.
889  if (def->getNumOperands() == 0) {
890  if (auto indexOp = dyn_cast<linalg::IndexOp>(def))
891  return addExp(kIndex, indexOp.dim());
892  }
893  // Construct unary operations if subexpression can be built.
894  if (def->getNumOperands() == 1) {
895  auto x = buildTensorExp(op, def->getOperand(0));
896  if (x.hasValue()) {
897  unsigned e = x.getValue();
898  if (isa<math::AbsOp>(def))
899  return addExp(kAbsF, e);
900  if (isa<complex::AbsOp>(def))
901  return addExp(kAbsC, e);
902  if (isa<math::CeilOp>(def))
903  return addExp(kCeilF, e);
904  if (isa<math::FloorOp>(def))
905  return addExp(kFloorF, e);
906  if (isa<math::SqrtOp>(def))
907  return addExp(kSqrtF, e);
908  if (isa<complex::SqrtOp>(def))
909  return addExp(kSqrtC, e);
910  if (isa<math::ExpM1Op>(def))
911  return addExp(kExpm1F, e);
912  if (isa<complex::Expm1Op>(def))
913  return addExp(kExpm1C, e);
914  if (isa<math::Log1pOp>(def))
915  return addExp(kLog1pF, e);
916  if (isa<complex::Log1pOp>(def))
917  return addExp(kLog1pC, e);
918  if (isa<math::SinOp>(def))
919  return addExp(kSinF, e);
920  if (isa<complex::SinOp>(def))
921  return addExp(kSinC, e);
922  if (isa<math::TanhOp>(def))
923  return addExp(kTanhF, e);
924  if (isa<complex::TanhOp>(def))
925  return addExp(kTanhC, e);
926  if (isa<arith::NegFOp>(def))
927  return addExp(kNegF, e); // no negi in std
928  if (isa<complex::NegOp>(def))
929  return addExp(kNegC, e);
930  if (isa<arith::TruncFOp>(def))
931  return addExp(kTruncF, e, v);
932  if (isa<arith::ExtFOp>(def))
933  return addExp(kExtF, e, v);
934  if (isa<arith::FPToSIOp>(def))
935  return addExp(kCastFS, e, v);
936  if (isa<arith::FPToUIOp>(def))
937  return addExp(kCastFU, e, v);
938  if (isa<arith::SIToFPOp>(def))
939  return addExp(kCastSF, e, v);
940  if (isa<arith::UIToFPOp>(def))
941  return addExp(kCastUF, e, v);
942  if (isa<arith::ExtSIOp>(def))
943  return addExp(kCastS, e, v);
944  if (isa<arith::ExtUIOp>(def))
945  return addExp(kCastU, e, v);
946  if (isa<arith::IndexCastOp>(def))
947  return addExp(kCastIdx, e, v);
948  if (isa<arith::TruncIOp>(def))
949  return addExp(kTruncI, e, v);
950  if (isa<complex::ImOp>(def))
951  return addExp(kCIm, e);
952  if (isa<complex::ReOp>(def))
953  return addExp(kCRe, e);
954  if (isa<arith::BitcastOp>(def))
955  return addExp(kBitCast, e, v);
956  if (auto unop = dyn_cast<sparse_tensor::UnaryOp>(def)) {
957  if (isAdmissableBranch(unop, unop.getPresentRegion()) &&
958  isAdmissableBranch(unop, unop.getAbsentRegion()))
959  return addExp(kUnary, e, Value(), def);
960  }
961  }
962  }
963  // Construct binary operations if subexpressions can be built.
964  // See buildLattices() for an explanation of rejecting certain
965  // division and shift operations
966  if (def->getNumOperands() == 2) {
967  auto x = buildTensorExp(op, def->getOperand(0));
968  auto y = buildTensorExp(op, def->getOperand(1));
969  if (x.hasValue() && y.hasValue()) {
970  unsigned e0 = x.getValue();
971  unsigned e1 = y.getValue();
972  if (isa<arith::MulFOp>(def))
973  return addExp(kMulF, e0, e1);
974  if (isa<complex::MulOp>(def))
975  return addExp(kMulC, e0, e1);
976  if (isa<arith::MulIOp>(def))
977  return addExp(kMulI, e0, e1);
978  if (isa<arith::DivFOp>(def) && !maybeZero(e1))
979  return addExp(kDivF, e0, e1);
980  if (isa<complex::DivOp>(def) && !maybeZero(e1))
981  return addExp(kDivC, e0, e1);
982  if (isa<arith::DivSIOp>(def) && !maybeZero(e1))
983  return addExp(kDivS, e0, e1);
984  if (isa<arith::DivUIOp>(def) && !maybeZero(e1))
985  return addExp(kDivU, e0, e1);
986  if (isa<arith::AddFOp>(def))
987  return addExp(kAddF, e0, e1);
988  if (isa<complex::AddOp>(def))
989  return addExp(kAddC, e0, e1);
990  if (isa<arith::AddIOp>(def))
991  return addExp(kAddI, e0, e1);
992  if (isa<arith::SubFOp>(def))
993  return addExp(kSubF, e0, e1);
994  if (isa<complex::SubOp>(def))
995  return addExp(kSubC, e0, e1);
996  if (isa<arith::SubIOp>(def))
997  return addExp(kSubI, e0, e1);
998  if (isa<arith::AndIOp>(def))
999  return addExp(kAndI, e0, e1);
1000  if (isa<arith::OrIOp>(def))
1001  return addExp(kOrI, e0, e1);
1002  if (isa<arith::XOrIOp>(def))
1003  return addExp(kXorI, e0, e1);
1004  if (isa<arith::ShRSIOp>(def) && isInvariant(e1))
1005  return addExp(kShrS, e0, e1);
1006  if (isa<arith::ShRUIOp>(def) && isInvariant(e1))
1007  return addExp(kShrU, e0, e1);
1008  if (isa<arith::ShLIOp>(def) && isInvariant(e1))
1009  return addExp(kShlI, e0, e1);
1010  if (auto binop = dyn_cast<sparse_tensor::BinaryOp>(def)) {
1011  if (isAdmissableBranch(binop, binop.getOverlapRegion()) &&
1012  (binop.getLeftIdentity() ||
1013  isAdmissableBranch(binop, binop.getLeftRegion())) &&
1014  (binop.getRightIdentity() ||
1015  isAdmissableBranch(binop, binop.getRightRegion())))
1016  return addExp(kBinary, e0, e1, Value(), def);
1017  }
1018  }
1019  }
1020  // Cannot build.
1021  return None;
1022 }
1023 
1024 static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region &region,
1025  ValueRange vals) {
1026  // Make a clone of overlap region.
1027  Region tmpRegion;
1028  BlockAndValueMapping mapper;
1029  region.cloneInto(&tmpRegion, tmpRegion.begin(), mapper);
1030  Block &clonedBlock = tmpRegion.front();
1031  YieldOp clonedYield = cast<YieldOp>(clonedBlock.getTerminator());
1032  // Merge cloned block and return yield value.
1033  Operation *placeholder = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1034  rewriter.mergeBlockBefore(&tmpRegion.front(), placeholder, vals);
1035  Value val = clonedYield.getResult();
1036  rewriter.eraseOp(clonedYield);
1037  rewriter.eraseOp(placeholder);
1038  return val;
1039 }
1040 
1042  Operation *op, Value v0) {
1043  if (!v0)
1044  // Empty input value must be propagated.
1045  return Value();
1046  UnaryOp unop = cast<UnaryOp>(op);
1047  Region &presentRegion = unop.getPresentRegion();
1048  if (presentRegion.empty())
1049  // Uninitialized Value() will be interpreted as missing data in the
1050  // output.
1051  return Value();
1052  return insertYieldOp(rewriter, loc, presentRegion, {v0});
1053 }
1054 
1056  Operation *op, Value v0, Value v1) {
1057  if (!v0 || !v1)
1058  // Empty input values must be propagated.
1059  return Value();
1060  BinaryOp binop = cast<BinaryOp>(op);
1061  Region &overlapRegion = binop.getOverlapRegion();
1062  if (overlapRegion.empty())
1063  // Uninitialized Value() will be interpreted as missing data in the
1064  // output.
1065  return Value();
1066  return insertYieldOp(rewriter, loc, overlapRegion, {v0, v1});
1067 }
1068 
1069 Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e,
1070  Value v0, Value v1) {
1071  switch (tensorExps[e].kind) {
1072  // Leaf.
1073  case kTensor:
1074  case kInvariant:
1075  case kIndex:
1076  llvm_unreachable("unexpected non-op");
1077  // Unary operations.
1078  case kAbsF:
1079  return rewriter.create<math::AbsOp>(loc, v0);
1080  case kAbsC: {
1081  auto type = v0.getType().cast<ComplexType>();
1082  auto eltType = type.getElementType().cast<FloatType>();
1083  return rewriter.create<complex::AbsOp>(loc, eltType, v0);
1084  }
1085  case kCeilF:
1086  return rewriter.create<math::CeilOp>(loc, v0);
1087  case kFloorF:
1088  return rewriter.create<math::FloorOp>(loc, v0);
1089  case kSqrtF:
1090  return rewriter.create<math::SqrtOp>(loc, v0);
1091  case kSqrtC:
1092  return rewriter.create<complex::SqrtOp>(loc, v0);
1093  case kExpm1F:
1094  return rewriter.create<math::ExpM1Op>(loc, v0);
1095  case kExpm1C:
1096  return rewriter.create<complex::Expm1Op>(loc, v0);
1097  case kLog1pF:
1098  return rewriter.create<math::Log1pOp>(loc, v0);
1099  case kLog1pC:
1100  return rewriter.create<complex::Log1pOp>(loc, v0);
1101  case kSinF:
1102  return rewriter.create<math::SinOp>(loc, v0);
1103  case kSinC:
1104  return rewriter.create<complex::SinOp>(loc, v0);
1105  case kTanhF:
1106  return rewriter.create<math::TanhOp>(loc, v0);
1107  case kTanhC:
1108  return rewriter.create<complex::TanhOp>(loc, v0);
1109  case kNegF:
1110  return rewriter.create<arith::NegFOp>(loc, v0);
1111  case kNegC:
1112  return rewriter.create<complex::NegOp>(loc, v0);
1113  case kNegI: // no negi in std
1114  return rewriter.create<arith::SubIOp>(
1115  loc,
1116  rewriter.create<arith::ConstantOp>(loc, v0.getType(),
1117  rewriter.getZeroAttr(v0.getType())),
1118  v0);
1119  case kTruncF:
1120  return rewriter.create<arith::TruncFOp>(loc, inferType(e, v0), v0);
1121  case kExtF:
1122  return rewriter.create<arith::ExtFOp>(loc, inferType(e, v0), v0);
1123  case kCastFS:
1124  return rewriter.create<arith::FPToSIOp>(loc, inferType(e, v0), v0);
1125  case kCastFU:
1126  return rewriter.create<arith::FPToUIOp>(loc, inferType(e, v0), v0);
1127  case kCastSF:
1128  return rewriter.create<arith::SIToFPOp>(loc, inferType(e, v0), v0);
1129  case kCastUF:
1130  return rewriter.create<arith::UIToFPOp>(loc, inferType(e, v0), v0);
1131  case kCastS:
1132  return rewriter.create<arith::ExtSIOp>(loc, inferType(e, v0), v0);
1133  case kCastU:
1134  return rewriter.create<arith::ExtUIOp>(loc, inferType(e, v0), v0);
1135  case kCastIdx:
1136  return rewriter.create<arith::IndexCastOp>(loc, inferType(e, v0), v0);
1137  case kTruncI:
1138  return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0);
1139  case kCIm: {
1140  auto type = v0.getType().cast<ComplexType>();
1141  auto eltType = type.getElementType().cast<FloatType>();
1142  return rewriter.create<complex::ImOp>(loc, eltType, v0);
1143  }
1144  case kCRe: {
1145  auto type = v0.getType().cast<ComplexType>();
1146  auto eltType = type.getElementType().cast<FloatType>();
1147  return rewriter.create<complex::ReOp>(loc, eltType, v0);
1148  }
1149  case kBitCast:
1150  return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0);
1151  // Binary operations.
1152  case kMulF:
1153  return rewriter.create<arith::MulFOp>(loc, v0, v1);
1154  case kMulC:
1155  return rewriter.create<complex::MulOp>(loc, v0, v1);
1156  case kMulI:
1157  return rewriter.create<arith::MulIOp>(loc, v0, v1);
1158  case kDivF:
1159  return rewriter.create<arith::DivFOp>(loc, v0, v1);
1160  case kDivC:
1161  return rewriter.create<complex::DivOp>(loc, v0, v1);
1162  case kDivS:
1163  return rewriter.create<arith::DivSIOp>(loc, v0, v1);
1164  case kDivU:
1165  return rewriter.create<arith::DivUIOp>(loc, v0, v1);
1166  case kAddF:
1167  return rewriter.create<arith::AddFOp>(loc, v0, v1);
1168  case kAddC:
1169  return rewriter.create<complex::AddOp>(loc, v0, v1);
1170  case kAddI:
1171  return rewriter.create<arith::AddIOp>(loc, v0, v1);
1172  case kSubF:
1173  return rewriter.create<arith::SubFOp>(loc, v0, v1);
1174  case kSubC:
1175  return rewriter.create<complex::SubOp>(loc, v0, v1);
1176  case kSubI:
1177  return rewriter.create<arith::SubIOp>(loc, v0, v1);
1178  case kAndI:
1179  return rewriter.create<arith::AndIOp>(loc, v0, v1);
1180  case kOrI:
1181  return rewriter.create<arith::OrIOp>(loc, v0, v1);
1182  case kXorI:
1183  return rewriter.create<arith::XOrIOp>(loc, v0, v1);
1184  case kShrS:
1185  return rewriter.create<arith::ShRSIOp>(loc, v0, v1);
1186  case kShrU:
1187  return rewriter.create<arith::ShRUIOp>(loc, v0, v1);
1188  case kShlI:
1189  return rewriter.create<arith::ShLIOp>(loc, v0, v1);
1190  case kBinaryBranch: // semi-ring ops with custom logic.
1191  return insertYieldOp(rewriter, loc,
1192  *tensorExps[e].op->getBlock()->getParent(), {v0});
1193  case kUnary:
1194  return buildUnaryPresent(rewriter, loc, tensorExps[e].op, v0);
1195  case kBinary:
1196  return buildBinaryOverlap(rewriter, loc, tensorExps[e].op, v0, v1);
1197  }
1198  llvm_unreachable("unexpected expression kind in build");
1199 }
1200 
1201 } // namespace sparse_tensor
1202 } // namespace mlir
Kind
Tensor expression kind.
Definition: Merger.h:27
TODO: Remove this file when SCCP and integer range analysis have been ported to the new framework...
TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *operation)
Definition: Merger.cpp:25
This class contains a list of basic blocks and a link to the parent operation it is attached to...
Definition: Region.h:26
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
Attribute getZeroAttr(Type type)
Definition: Builders.cpp:264
Specialization of arith.constant op that returns an integer value.
Definition: Arithmetic.h:43
static Value buildUnaryPresent(RewriterBase &rewriter, Location loc, Operation *op, Value v0)
Definition: Merger.cpp:1041
Block represents an ordered list of Operations.
Definition: Block.h:29
Block & front()
Definition: Region.h:65
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Value getOperand(unsigned idx)
Definition: Operation.h:274
unsigned getNumOperands()
Definition: Operation.h:270
unsigned mapSet(Kind kind, unsigned s0, Value v=Value(), Operation *op=nullptr)
Maps the unary operator over the lattice set of the operand, i.e.
Definition: Merger.cpp:213
bool onlyDenseDiff(unsigned i, unsigned j)
Returns true if Li and Lj only differ in dense.
Definition: Merger.cpp:289
unsigned addLat(unsigned t, unsigned i, unsigned e)
Adds an iteration lattice point. Returns its index.
Definition: Merger.cpp:142
Operation & front()
Definition: Block.h:144
BitVector bits
Conjunction of tensor loop indices as bitvector.
Definition: Merger.h:133
static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region &region, ValueRange vals)
Definition: Merger.cpp:1024
static bool isAdmissableBranch(Operation *op, Region &region)
Ensures that sparse compiler can generate code for branch.
Definition: Merger.cpp:859
static Value buildBinaryOverlap(RewriterBase &rewriter, Location loc, Operation *op, Value v0, Value v1)
Definition: Merger.cpp:1055
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:151
bool latGT(unsigned i, unsigned j) const
Returns true if Li > Lj.
Definition: Merger.cpp:276
Region * getParent() const
Provide a &#39;getParent&#39; method for ilist_node_with_parent methods.
Definition: Block.cpp:26
unsigned optimizeSet(unsigned s0)
Optimizes the iteration lattice points in the given set.
Definition: Merger.cpp:224
static bool isAdmissableBranchExp(Operation *op, Block *block, Value v)
Ensures that sparse compiler can generate code for expression.
Definition: Merger.cpp:838
unsigned exp
Index of the tensor expression.
Definition: Merger.h:141
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
unsigned takeDisj(Kind kind, unsigned s0, unsigned s1, Operation *op=nullptr)
Disjunctive merge of two lattice sets L0 and L1 is (L0 /_op L1, L0, L1).
Definition: Merger.cpp:173
Block * getBlock() const
Returns the current block of the builder.
Definition: Builders.h:386
void dumpBits(const BitVector &bits) const
Definition: Merger.cpp:581
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:380
void dumpSet(unsigned s) const
Definition: Merger.cpp:572
bool empty()
Definition: Region.h:60
iterator begin()
Definition: Region.h:55
U dyn_cast() const
Definition: Types.h:256
void dumpExp(unsigned e) const
Print methods (for debugging).
Definition: Merger.cpp:481
U dyn_cast() const
Definition: Value.h:100
unsigned tensor
Expressions representing tensors simply have a tensor number.
Definition: Merger.h:103
unsigned addExp(Kind k, unsigned e0, unsigned e1=-1u, Value v=Value(), Operation *op=nullptr)
Adds a tensor expression. Returns its index.
Definition: Merger.cpp:135
Tensor expression. Represents a MLIR expression in tensor index notation.
Definition: Merger.h:95
Kind kind
Tensor expression kind.
Definition: Merger.h:99
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:133
void mergeBlockBefore(Block *source, Operation *op, ValueRange argValues=llvm::None)
unsigned takeConj(Kind kind, unsigned s0, unsigned s1, Operation *op=nullptr)
Conjunctive merge of two lattice sets L0 and L1 is conjunction of cartesian product.
Definition: Merger.cpp:165
void cloneInto(Region *dest, BlockAndValueMapping &mapper)
Clone the internal blocks from this region into dest.
Definition: Region.cpp:70
This class represents an argument of a Block.
Definition: Value.h:300
unsigned takeCombi(Kind kind, unsigned s0, unsigned s1, Operation *orig, bool includeLeft, Kind ltrans, Operation *opleft, bool includeRight, Kind rtrans, Operation *opright)
Disjunctive merge of two lattice sets L0 and L1 with custom handling of the overlap, left, and right regions.
Definition: Merger.cpp:192
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
BitVector simple
Simplified conjunction of tensor loop indices as bitvector.
Definition: Merger.h:138
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:230
Specialization of arith.constant op that returns a floating point value.
Definition: Arithmetic.h:64
BitVector simplifyCond(unsigned s0, unsigned p0)
Simplifies the conditions in a conjunction of a given lattice point within the given set using just t...
Definition: Merger.cpp:253
Dim
Dimension level type for a tensor (undef means index does not appear).
Definition: Merger.h:24
Type getType() const
Return the type of this value.
Definition: Value.h:118
unsigned conjLatPoint(Kind kind, unsigned p0, unsigned p1, Operation *op=nullptr)
Computes a single conjunction of two lattice points by taking the "union" of loop indices (effectivel...
Definition: Merger.cpp:155
Specialization of arith.constant op that returns an integer of index type.
Definition: Arithmetic.h:80
static const char * kindToOpSymbol(Kind kind)
Definition: Merger.cpp:389
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
This class represents an operand of an operation.
Definition: Value.h:251
unsigned index
Indices hold the index number.
Definition: Merger.h:106
void dumpLat(unsigned p) const
Definition: Merger.cpp:562
unsigned addSet()
Adds a new, initially empty, set. Returns its index.
Definition: Merger.cpp:149
unsigned buildLattices(unsigned e, unsigned i)
Builds the iteration lattices in a bottom-up traversal given the remaining tensor (sub)expression and...
Definition: Merger.cpp:611
bool isSingleCondition(unsigned t, unsigned e) const
Returns true if given tensor iterates only in the given tensor expression.
Definition: Merger.cpp:302
Value buildExp(RewriterBase &rewriter, Location loc, unsigned e, Value v0, Value v1)
Rebuilds SSA format from a tensor expression.
Definition: Merger.cpp:1069
bool hasAnyDimOf(const BitVector &bits, Dim d) const
Returns true if any set bit corresponds to queried dim.
Definition: Merger.cpp:295
LatPoint(unsigned n, unsigned e, unsigned b)
Definition: Merger.cpp:123
This class provides an abstraction over the different types of ranges over Values.
Optional< unsigned > buildTensorExpFromLinalg(linalg::GenericOp op)
Builds a tensor expression from the given Linalg operation.
Definition: Merger.cpp:800
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:398
U cast() const
Definition: Types.h:262
Children children
Tensor operations hold the indices of their children.
Definition: Merger.h:109