MLIR  14.0.0git
Merger.cpp
Go to the documentation of this file.
1 //===- Merger.cpp - Implementation of iteration lattices ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 
12 #include "mlir/IR/Operation.h"
13 #include "llvm/Support/Debug.h"
14 
15 namespace mlir {
16 namespace sparse_tensor {
17 
18 //===----------------------------------------------------------------------===//
19 // Constructors.
20 //===----------------------------------------------------------------------===//
21 
22 TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v)
23  : kind(k), val(v) {
24  switch (kind) {
25  case kTensor:
26  assert(x != -1u && y == -1u && !v);
27  tensor = x;
28  break;
29  case kInvariant:
30  assert(x == -1u && y == -1u && v);
31  break;
32  case kAbsF:
33  case kCeilF:
34  case kFloorF:
35  case kNegF:
36  case kNegI:
37  assert(x != -1u && y == -1u && !v);
38  children.e0 = x;
39  children.e1 = y;
40  break;
41  case kTruncF:
42  case kExtF:
43  case kCastFS:
44  case kCastFU:
45  case kCastSF:
46  case kCastUF:
47  case kCastS:
48  case kCastU:
49  case kTruncI:
50  case kBitCast:
51  assert(x != -1u && y == -1u && v);
52  children.e0 = x;
53  children.e1 = y;
54  break;
55  default:
56  assert(x != -1u && y != -1u && !v);
57  children.e0 = x;
58  children.e1 = y;
59  break;
60  }
61 }
62 
63 LatPoint::LatPoint(unsigned n, unsigned e, unsigned b)
64  : bits(n, false), simple(), exp(e) {
65  bits.set(b);
66 }
67 
68 LatPoint::LatPoint(const llvm::BitVector &b, unsigned e)
69  : bits(b), simple(), exp(e) {}
70 
71 //===----------------------------------------------------------------------===//
72 // Lattice methods.
73 //===----------------------------------------------------------------------===//
74 
75 unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v) {
76  unsigned e = tensorExps.size();
77  tensorExps.push_back(TensorExp(k, e0, e1, v));
78  return e;
79 }
80 
81 unsigned Merger::addLat(unsigned t, unsigned i, unsigned e) {
82  assert(t < numTensors && i < numLoops);
83  unsigned p = latPoints.size();
84  latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t));
85  return p;
86 }
87 
88 unsigned Merger::addSet() {
89  unsigned s = latSets.size();
90  latSets.emplace_back(SmallVector<unsigned, 16>());
91  return s;
92 }
93 
94 unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1) {
95  unsigned p = latPoints.size();
96  llvm::BitVector nb = llvm::BitVector(latPoints[p0].bits);
97  nb |= latPoints[p1].bits;
98  unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp);
99  latPoints.push_back(LatPoint(nb, e));
100  return p;
101 }
102 
103 unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1) {
104  unsigned s = addSet();
105  for (unsigned p0 : latSets[s0])
106  for (unsigned p1 : latSets[s1])
107  latSets[s].push_back(conjLatPoint(kind, p0, p1));
108  return s;
109 }
110 
111 unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1) {
112  unsigned s = takeConj(kind, s0, s1);
113  // Followed by all in s0.
114  for (unsigned p : latSets[s0])
115  latSets[s].push_back(p);
116  // Map binary 0-y to unary -y.
117  if (kind == kSubF)
118  s1 = mapSet(kNegF, s1);
119  else if (kind == kSubI)
120  s1 = mapSet(kNegI, s1);
121  // Followed by all in s1.
122  for (unsigned p : latSets[s1])
123  latSets[s].push_back(p);
124  return s;
125 }
126 
127 unsigned Merger::mapSet(Kind kind, unsigned s0, Value v) {
128  assert(kAbsF <= kind && kind <= kBitCast);
129  unsigned s = addSet();
130  for (unsigned p : latSets[s0]) {
131  unsigned e = addExp(kind, latPoints[p].exp, v);
132  latPoints.push_back(LatPoint(latPoints[p].bits, e));
133  latSets[s].push_back(latPoints.size() - 1);
134  }
135  return s;
136 }
137 
138 unsigned Merger::optimizeSet(unsigned s0) {
139  unsigned s = addSet();
140  assert(!latSets[s0].empty());
141  unsigned p0 = latSets[s0][0];
142  for (unsigned p1 : latSets[s0]) {
143  bool add = true;
144  if (p0 != p1) {
145  // Is this a straightforward copy?
146  unsigned e = latPoints[p1].exp;
147  if (tensorExps[e].kind == kTensor && tensorExps[e].tensor == outTensor)
148  continue;
149  // Conjunction already covered?
150  for (unsigned p2 : latSets[s]) {
151  assert(!latGT(p1, p2)); // Lj => Li would be bad
152  if (onlyDenseDiff(p2, p1)) {
153  add = false;
154  break;
155  }
156  }
157  assert(!add || latGT(p0, p1));
158  }
159  if (add)
160  latSets[s].push_back(p1);
161  }
162  for (unsigned p : latSets[s])
163  latPoints[p].simple = simplifyCond(s, p);
164  return s;
165 }
166 
167 llvm::BitVector Merger::simplifyCond(unsigned s0, unsigned p0) {
168  // First determine if this lattice point is a *singleton*, i.e.,
169  // the last point in a lattice, no other is less than this one.
170  bool isSingleton = true;
171  for (unsigned p1 : latSets[s0]) {
172  if (p0 != p1 && latGT(p0, p1)) {
173  isSingleton = false;
174  break;
175  }
176  }
177  // Now apply the two basic rules.
178  llvm::BitVector simple = latPoints[p0].bits;
179  bool reset = isSingleton && hasAnyDimOf(simple, kSparse);
180  for (unsigned b = 0, be = simple.size(); b < be; b++) {
181  if (simple[b] && !isDim(b, kSparse)) {
182  if (reset)
183  simple.reset(b);
184  reset = true;
185  }
186  }
187  return simple;
188 }
189 
190 bool Merger::latGT(unsigned i, unsigned j) const {
191  const llvm::BitVector &bitsi = latPoints[i].bits;
192  const llvm::BitVector &bitsj = latPoints[j].bits;
193  assert(bitsi.size() == bitsj.size());
194  if (bitsi.count() > bitsj.count()) {
195  for (unsigned b = 0, be = bitsj.size(); b < be; b++)
196  if (bitsj[b] && !bitsi[b])
197  return false;
198  return true;
199  }
200  return false;
201 }
202 
203 bool Merger::onlyDenseDiff(unsigned i, unsigned j) {
204  llvm::BitVector tmp = latPoints[j].bits;
205  tmp ^= latPoints[i].bits;
206  return !hasAnyDimOf(tmp, kSparse);
207 }
208 
209 bool Merger::hasAnyDimOf(const llvm::BitVector &bits, Dim d) const {
210  for (unsigned b = 0, be = bits.size(); b < be; b++)
211  if (bits[b] && isDim(b, d))
212  return true;
213  return false;
214 }
215 
216 bool Merger::isSingleCondition(unsigned t, unsigned e) const {
217  switch (tensorExps[e].kind) {
218  case kTensor:
219  return tensorExps[e].tensor == t;
220  case kAbsF:
221  case kCeilF:
222  case kFloorF:
223  case kNegF:
224  case kNegI:
225  case kTruncF:
226  case kExtF:
227  case kCastFS:
228  case kCastFU:
229  case kCastSF:
230  case kCastUF:
231  case kCastS:
232  case kCastU:
233  case kTruncI:
234  case kBitCast:
235  return isSingleCondition(t, tensorExps[e].children.e0);
236  case kDivF: // note: x / c only
237  case kDivS:
238  case kDivU:
239  assert(!maybeZero(tensorExps[e].children.e1));
240  return isSingleCondition(t, tensorExps[e].children.e0);
241  case kShrS: // note: x >> inv only
242  case kShrU:
243  case kShlI:
244  assert(isInvariant(tensorExps[e].children.e1));
245  return isSingleCondition(t, tensorExps[e].children.e0);
246  case kMulF:
247  case kMulI:
248  case kAndI:
249  if (isSingleCondition(t, tensorExps[e].children.e0))
250  return isSingleCondition(t, tensorExps[e].children.e1) ||
251  isInvariant(tensorExps[e].children.e1);
252  if (isSingleCondition(t, tensorExps[e].children.e1))
253  return isInvariant(tensorExps[e].children.e0);
254  return false;
255  case kAddF:
256  case kAddI:
257  return isSingleCondition(t, tensorExps[e].children.e0) &&
258  isSingleCondition(t, tensorExps[e].children.e1);
259  default:
260  return false;
261  }
262 }
263 
264 #ifndef NDEBUG
265 
266 //===----------------------------------------------------------------------===//
267 // Print methods (for debugging).
268 //===----------------------------------------------------------------------===//
269 
270 static const char *kindToOpSymbol(Kind kind) {
271  switch (kind) {
272  case kTensor:
273  return "tensor";
274  case kInvariant:
275  return "invariant";
276  case kAbsF:
277  return "abs";
278  case kCeilF:
279  return "ceil";
280  case kFloorF:
281  return "floor";
282  case kNegF:
283  return "-";
284  case kNegI:
285  return "-";
286  case kTruncF:
287  case kExtF:
288  case kCastFS:
289  case kCastFU:
290  case kCastSF:
291  case kCastUF:
292  case kCastS:
293  case kCastU:
294  case kTruncI:
295  case kBitCast:
296  return "cast";
297  case kMulF:
298  return "*";
299  case kMulI:
300  return "*";
301  case kDivF:
302  return "/";
303  case kDivS:
304  return "/";
305  case kDivU:
306  return "/";
307  case kAddF:
308  return "+";
309  case kAddI:
310  return "+";
311  case kSubF:
312  return "-";
313  case kSubI:
314  return "-";
315  case kAndI:
316  return "&";
317  case kOrI:
318  return "|";
319  case kXorI:
320  return "^";
321  case kShrS:
322  return "a>>";
323  case kShrU:
324  return ">>";
325  case kShlI:
326  return "<<";
327  }
328  llvm_unreachable("unexpected kind for symbol");
329 }
330 
331 void Merger::dumpExp(unsigned e) const {
332  switch (tensorExps[e].kind) {
333  case kTensor:
334  if (tensorExps[e].tensor == syntheticTensor)
335  llvm::dbgs() << "synthetic_";
336  else if (tensorExps[e].tensor == outTensor)
337  llvm::dbgs() << "output_";
338  llvm::dbgs() << "tensor_" << tensorExps[e].tensor;
339  break;
340  case kInvariant:
341  llvm::dbgs() << "invariant";
342  break;
343  case kAbsF:
344  case kCeilF:
345  case kFloorF:
346  case kNegF:
347  case kNegI:
348  case kTruncF:
349  case kExtF:
350  case kCastFS:
351  case kCastFU:
352  case kCastSF:
353  case kCastUF:
354  case kCastS:
355  case kCastU:
356  case kTruncI:
357  case kBitCast:
358  llvm::dbgs() << kindToOpSymbol(tensorExps[e].kind) << " ";
359  dumpExp(tensorExps[e].children.e0);
360  break;
361  default:
362  llvm::dbgs() << "(";
363  dumpExp(tensorExps[e].children.e0);
364  llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " ";
365  dumpExp(tensorExps[e].children.e1);
366  llvm::dbgs() << ")";
367  }
368 }
369 
370 void Merger::dumpLat(unsigned p) const {
371  llvm::dbgs() << "lat(";
372  dumpBits(latPoints[p].bits);
373  llvm::dbgs() << " :";
374  dumpBits(latPoints[p].simple);
375  llvm::dbgs() << " : ";
376  dumpExp(latPoints[p].exp);
377  llvm::dbgs() << " )\n";
378 }
379 
380 void Merger::dumpSet(unsigned s) const {
381  llvm::dbgs() << "{ #" << latSets[s].size() << "\n";
382  for (unsigned p : latSets[s]) {
383  llvm::dbgs() << " ";
384  dumpLat(p);
385  }
386  llvm::dbgs() << "}\n";
387 }
388 
389 void Merger::dumpBits(const llvm::BitVector &bits) const {
390  for (unsigned b = 0, be = bits.size(); b < be; b++) {
391  if (bits[b]) {
392  unsigned t = tensor(b);
393  unsigned i = index(b);
394  llvm::dbgs() << " i_" << t << "_" << i << "_";
395  switch (dims[t][i]) {
396  case kSparse:
397  llvm::dbgs() << "S";
398  break;
399  case kDense:
400  llvm::dbgs() << "D";
401  break;
402  case kSingle:
403  llvm::dbgs() << "T";
404  break;
405  case kUndef:
406  llvm::dbgs() << "U";
407  break;
408  }
409  }
410  }
411 }
412 
413 #endif // NDEBUG
414 
415 //===----------------------------------------------------------------------===//
416 // Builder methods.
417 //===----------------------------------------------------------------------===//
418 
419 unsigned Merger::buildLattices(unsigned e, unsigned i) {
420  Kind kind = tensorExps[e].kind;
421  switch (kind) {
422  case kTensor:
423  case kInvariant: {
424  // Either the index is really used in the tensor expression, or it is
425  // set to the undefined index in that dimension. An invariant expression
426  // and a truly dynamic sparse output tensor are set to a synthetic tensor
427  // with undefined indices only to ensure the iteration space is not
428  // skipped as a result of their contents.
429  unsigned s = addSet();
430  unsigned t = kind == kTensor ? tensorExps[e].tensor : syntheticTensor;
431  if (hasSparseOut && t == outTensor)
432  t = syntheticTensor;
433  latSets[s].push_back(addLat(t, i, e));
434  return s;
435  }
436  case kAbsF:
437  case kCeilF:
438  case kFloorF:
439  case kNegF:
440  case kNegI:
441  case kTruncF:
442  case kExtF:
443  case kCastFS:
444  case kCastFU:
445  case kCastSF:
446  case kCastUF:
447  case kCastS:
448  case kCastU:
449  case kTruncI:
450  case kBitCast:
451  // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the
452  // lattice set of the operand through the operator into a new set.
453  //
454  // -y|!y | y |
455  // --+---+---+
456  // | 0 |-y |
457  return mapSet(kind, buildLattices(tensorExps[e].children.e0, i),
458  tensorExps[e].val);
459  case kMulF:
460  case kMulI:
461  case kAndI:
462  // A multiplicative operation only needs to be performed
463  // for the conjunction of sparse iteration spaces.
464  //
465  // x*y|!y | y |
466  // ---+---+---+
467  // !x | 0 | 0 |
468  // x | 0 |x*y|
469  return takeConj(kind, // take binary conjunction
470  buildLattices(tensorExps[e].children.e0, i),
471  buildLattices(tensorExps[e].children.e1, i));
472  case kDivF:
473  case kDivS:
474  case kDivU:
475  // A division is tricky, since 0/0, 0/c, c/0 all have
476  // specific outcomes for floating-point and integers.
477  // Thus, we need to traverse the full iteration space.
478  //
479  // x/y|!y | y |
480  // ---+---+---+
481  // !x |0/0|0/y| FP: 0/0=NaN,c/0=Inf,0/c=0 with c true nonzero
482  // x |x/0|x/y| INT: x/0=exception for any x
483  //
484  // TODO: for now we "fixed" this by only accepting x/c cases
485  // during expression building, so that the conjunction
486  // rules applies (viz. x/c = x*(1/c) as far as lattice
487  // construction is concerned).
488  assert(!maybeZero(tensorExps[e].children.e1));
489  return takeConj(kind, // take binary conjunction
490  buildLattices(tensorExps[e].children.e0, i),
491  buildLattices(tensorExps[e].children.e1, i));
492  case kAddF:
493  case kAddI:
494  case kSubF:
495  case kSubI:
496  case kOrI:
497  case kXorI:
498  // An additive operation needs to be performed
499  // for the disjunction of sparse iteration spaces.
500  //
501  // x+y|!y | y | x-y|!y | y |
502  // ---+---+---+ ---+---+---+
503  // !x | 0 | y | !x | 0 |-y |
504  // x | x |x+y| x | x |x-y|
505  return takeDisj(kind, // take binary disjunction
506  buildLattices(tensorExps[e].children.e0, i),
507  buildLattices(tensorExps[e].children.e1, i));
508  case kShrS:
509  case kShrU:
510  case kShlI:
511  // A shift operation by an invariant amount (viz. tensor expressions
512  // can only occur at the left-hand-side of the operator) can be handled
513  // with the conjuction rule.
514  assert(isInvariant(tensorExps[e].children.e1));
515  return takeConj(kind, // take binary conjunction
516  buildLattices(tensorExps[e].children.e0, i),
517  buildLattices(tensorExps[e].children.e1, i));
518  }
519  llvm_unreachable("unexpected expression kind");
520 }
521 
523  Operation *yield = op.region().front().getTerminator();
524  return buildTensorExp(op, yield->getOperand(0));
525 }
526 
527 /// Only returns false if we are certain this is a nonzero.
528 bool Merger::maybeZero(unsigned e) const {
529  if (tensorExps[e].kind == kInvariant) {
530  if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantIntOp>())
531  return c.value() == 0;
532  if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantFloatOp>())
533  return c.value().isZero();
534  }
535  return true;
536 }
537 
538 bool Merger::isInvariant(unsigned e) const {
539  return tensorExps[e].kind == kInvariant;
540 }
541 
542 Type Merger::inferType(unsigned e, Value src) {
543  // Obtain the destination type from the cast node.
544  Type dtp = tensorExps[e].val.getType();
545  // Inspect source type. For vector types, apply the same
546  // vectorization to the destination type.
547  if (auto vtp = src.getType().dyn_cast<VectorType>())
548  return VectorType::get(vtp.getNumElements(), dtp);
549  return dtp;
550 }
551 
552 Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
553  if (auto arg = v.dyn_cast<BlockArgument>()) {
554  unsigned argN = arg.getArgNumber();
555  // Any argument of the generic op that is not marked as a scalar
556  // argument is considered a tensor, indexed by the implicit loop
557  // bounds. This includes rank-0 tensor arguments.
558  if (arg.getOwner()->getParentOp() == op) {
559  OpOperand *t = op.getInputAndOutputOperands()[argN];
560  if (!op.isScalar(t))
561  return addExp(kTensor, argN);
562  v = t->get(); // get scalar value
563  }
564  // Any other argument (marked as scalar argument for the generic op
565  // or belonging to an enveloping op) is considered invariant.
566  return addExp(kInvariant, v);
567  }
568  // Something defined outside is invariant.
569  Operation *def = v.getDefiningOp();
570  if (def->getBlock() != &op.region().front())
571  return addExp(kInvariant, v);
572  // Construct unary operations if subexpression can be built.
573  if (def->getNumOperands() == 1) {
574  auto x = buildTensorExp(op, def->getOperand(0));
575  if (x.hasValue()) {
576  unsigned e = x.getValue();
577  if (isa<math::AbsOp>(def))
578  return addExp(kAbsF, e);
579  if (isa<math::CeilOp>(def))
580  return addExp(kCeilF, e);
581  if (isa<math::FloorOp>(def))
582  return addExp(kFloorF, e);
583  if (isa<arith::NegFOp>(def))
584  return addExp(kNegF, e); // no negi in std
585  if (isa<arith::TruncFOp>(def))
586  return addExp(kTruncF, e, v);
587  if (isa<arith::ExtFOp>(def))
588  return addExp(kExtF, e, v);
589  if (isa<arith::FPToSIOp>(def))
590  return addExp(kCastFS, e, v);
591  if (isa<arith::FPToUIOp>(def))
592  return addExp(kCastFU, e, v);
593  if (isa<arith::SIToFPOp>(def))
594  return addExp(kCastSF, e, v);
595  if (isa<arith::UIToFPOp>(def))
596  return addExp(kCastUF, e, v);
597  if (isa<arith::ExtSIOp>(def))
598  return addExp(kCastS, e, v);
599  if (isa<arith::ExtUIOp>(def))
600  return addExp(kCastU, e, v);
601  if (isa<arith::TruncIOp>(def))
602  return addExp(kTruncI, e, v);
603  if (isa<arith::BitcastOp>(def))
604  return addExp(kBitCast, e, v);
605  }
606  }
607  // Construct binary operations if subexpressions can be built.
608  // See buildLattices() for an explanation of rejecting certain
609  // division and shift operations
610  if (def->getNumOperands() == 2) {
611  auto x = buildTensorExp(op, def->getOperand(0));
612  auto y = buildTensorExp(op, def->getOperand(1));
613  if (x.hasValue() && y.hasValue()) {
614  unsigned e0 = x.getValue();
615  unsigned e1 = y.getValue();
616  if (isa<arith::MulFOp>(def))
617  return addExp(kMulF, e0, e1);
618  if (isa<arith::MulIOp>(def))
619  return addExp(kMulI, e0, e1);
620  if (isa<arith::DivFOp>(def) && !maybeZero(e1))
621  return addExp(kDivF, e0, e1);
622  if (isa<arith::DivSIOp>(def) && !maybeZero(e1))
623  return addExp(kDivS, e0, e1);
624  if (isa<arith::DivUIOp>(def) && !maybeZero(e1))
625  return addExp(kDivU, e0, e1);
626  if (isa<arith::AddFOp>(def))
627  return addExp(kAddF, e0, e1);
628  if (isa<arith::AddIOp>(def))
629  return addExp(kAddI, e0, e1);
630  if (isa<arith::SubFOp>(def))
631  return addExp(kSubF, e0, e1);
632  if (isa<arith::SubIOp>(def))
633  return addExp(kSubI, e0, e1);
634  if (isa<arith::AndIOp>(def))
635  return addExp(kAndI, e0, e1);
636  if (isa<arith::OrIOp>(def))
637  return addExp(kOrI, e0, e1);
638  if (isa<arith::XOrIOp>(def))
639  return addExp(kXorI, e0, e1);
640  if (isa<arith::ShRSIOp>(def) && isInvariant(e1))
641  return addExp(kShrS, e0, e1);
642  if (isa<arith::ShRUIOp>(def) && isInvariant(e1))
643  return addExp(kShrU, e0, e1);
644  if (isa<arith::ShLIOp>(def) && isInvariant(e1))
645  return addExp(kShlI, e0, e1);
646  }
647  }
648  // Cannot build.
649  return None;
650 }
651 
652 Value Merger::buildExp(PatternRewriter &rewriter, Location loc, unsigned e,
653  Value v0, Value v1) {
654  switch (tensorExps[e].kind) {
655  case kTensor:
656  case kInvariant:
657  llvm_unreachable("unexpected non-op");
658  // Unary ops.
659  case kAbsF:
660  return rewriter.create<math::AbsOp>(loc, v0);
661  case kCeilF:
662  return rewriter.create<math::CeilOp>(loc, v0);
663  case kFloorF:
664  return rewriter.create<math::FloorOp>(loc, v0);
665  case kNegF:
666  return rewriter.create<arith::NegFOp>(loc, v0);
667  case kNegI: // no negi in std
668  return rewriter.create<arith::SubIOp>(
669  loc,
670  rewriter.create<arith::ConstantOp>(loc, v0.getType(),
671  rewriter.getZeroAttr(v0.getType())),
672  v0);
673  case kTruncF:
674  return rewriter.create<arith::TruncFOp>(loc, v0, inferType(e, v0));
675  case kExtF:
676  return rewriter.create<arith::ExtFOp>(loc, v0, inferType(e, v0));
677  case kCastFS:
678  return rewriter.create<arith::FPToSIOp>(loc, v0, inferType(e, v0));
679  case kCastFU:
680  return rewriter.create<arith::FPToUIOp>(loc, v0, inferType(e, v0));
681  case kCastSF:
682  return rewriter.create<arith::SIToFPOp>(loc, v0, inferType(e, v0));
683  case kCastUF:
684  return rewriter.create<arith::UIToFPOp>(loc, v0, inferType(e, v0));
685  case kCastS:
686  return rewriter.create<arith::ExtSIOp>(loc, v0, inferType(e, v0));
687  case kCastU:
688  return rewriter.create<arith::ExtUIOp>(loc, v0, inferType(e, v0));
689  case kTruncI:
690  return rewriter.create<arith::TruncIOp>(loc, v0, inferType(e, v0));
691  case kBitCast:
692  return rewriter.create<arith::BitcastOp>(loc, v0, inferType(e, v0));
693  // Binary ops.
694  case kMulF:
695  return rewriter.create<arith::MulFOp>(loc, v0, v1);
696  case kMulI:
697  return rewriter.create<arith::MulIOp>(loc, v0, v1);
698  case kDivF:
699  return rewriter.create<arith::DivFOp>(loc, v0, v1);
700  case kDivS:
701  return rewriter.create<arith::DivSIOp>(loc, v0, v1);
702  case kDivU:
703  return rewriter.create<arith::DivUIOp>(loc, v0, v1);
704  case kAddF:
705  return rewriter.create<arith::AddFOp>(loc, v0, v1);
706  case kAddI:
707  return rewriter.create<arith::AddIOp>(loc, v0, v1);
708  case kSubF:
709  return rewriter.create<arith::SubFOp>(loc, v0, v1);
710  case kSubI:
711  return rewriter.create<arith::SubIOp>(loc, v0, v1);
712  case kAndI:
713  return rewriter.create<arith::AndIOp>(loc, v0, v1);
714  case kOrI:
715  return rewriter.create<arith::OrIOp>(loc, v0, v1);
716  case kXorI:
717  return rewriter.create<arith::XOrIOp>(loc, v0, v1);
718  case kShrS:
719  return rewriter.create<arith::ShRSIOp>(loc, v0, v1);
720  case kShrU:
721  return rewriter.create<arith::ShRUIOp>(loc, v0, v1);
722  case kShlI:
723  return rewriter.create<arith::ShLIOp>(loc, v0, v1);
724  }
725  llvm_unreachable("unexpected expression kind in build");
726 }
727 
728 } // namespace sparse_tensor
729 } // namespace mlir
Kind
Tensor expression kind.
Definition: Merger.h:27
Include the generated interface declarations.
llvm::BitVector bits
Conjunction of tensor loop indices as bitvector.
Definition: Merger.h:101
OpTy create(Location location, Args &&...args)
Create an operation of specific op type at the current insertion point.
Definition: Builders.h:430
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:881
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
bool hasAnyDimOf(const llvm::BitVector &bits, Dim d) const
Returns true if any set bit corresponds to queried dim.
Definition: Merger.cpp:209
Attribute getZeroAttr(Type type)
Definition: Builders.cpp:264
Specialization of arith.constant op that returns an integer value.
Definition: Arithmetic.h:41
Value getOperand(unsigned idx)
Definition: Operation.h:219
unsigned getNumOperands()
Definition: Operation.h:215
unsigned mapSet(Kind kind, unsigned s0, Value v=Value())
Maps the unary operator over the lattice set of the operand, i.e.
Definition: Merger.cpp:127
bool onlyDenseDiff(unsigned i, unsigned j)
Returns true if Li and Lj only differ in dense.
Definition: Merger.cpp:203
unsigned addLat(unsigned t, unsigned i, unsigned e)
Adds an iteration lattice point. Returns its index.
Definition: Merger.cpp:81
Operation & front()
Definition: Block.h:144
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:96
bool latGT(unsigned i, unsigned j) const
Returns true if Li > Lj.
Definition: Merger.cpp:190
unsigned optimizeSet(unsigned s0)
Optimizes the iteration lattice points in the given set.
Definition: Merger.cpp:138
void dumpBits(const llvm::BitVector &bits) const
Definition: Merger.cpp:389
unsigned exp
Index of the tensor expresssion.
Definition: Merger.h:109
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
void dumpSet(unsigned s) const
Definition: Merger.cpp:380
unsigned addExp(Kind k, unsigned e0, unsigned e1=-1u, Value v=Value())
Adds a tensor expression. Returns its index.
Definition: Merger.cpp:75
unsigned takeConj(Kind kind, unsigned s0, unsigned s1)
Conjunctive merge of two lattice sets L0 and L1 is conjunction of cartesian product.
Definition: Merger.cpp:103
U dyn_cast() const
Definition: Types.h:244
llvm::BitVector simple
Simplified conjunction of tensor loop indices as bitvector.
Definition: Merger.h:106
void dumpExp(unsigned e) const
Print methods (for debugging).
Definition: Merger.cpp:331
U dyn_cast() const
Definition: Value.h:99
Tensor expression. Represents a MLIR expression in tensor index notation.
Definition: Merger.h:72
unsigned takeDisj(Kind kind, unsigned s0, unsigned s1)
Disjunctive merge of two lattice sets L0 and L1 is (L0 /_op L1, L0, L1).
Definition: Merger.cpp:111
Kind kind
Tensor expression kind.
Definition: Merger.h:76
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:133
Value buildExp(PatternRewriter &rewriter, Location loc, unsigned e, Value v0, Value v1)
Rebuilds SSA format from a tensor expression.
Definition: Merger.cpp:652
This class represents an argument of a Block.
Definition: Value.h:298
Eliminates identifier at the specified position using Fourier-Motzkin variable elimination.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:84
Specialization of arith.constant op that returns a floating point value.
Definition: Arithmetic.h:62
Dim
Dimension level type for a tensor (undef means index does not appear).
Definition: Merger.h:24
Type getType() const
Return the type of this value.
Definition: Value.h:117
llvm::BitVector simplifyCond(unsigned s0, unsigned p0)
Simplifies the conditions in a conjunction of a given lattice point within the given set using just t...
Definition: Merger.cpp:167
unsigned conjLatPoint(Kind kind, unsigned p0, unsigned p1)
Computes a single conjunction of two lattice points by taking the "union" of loop indices (effectivel...
Definition: Merger.cpp:94
static const char * kindToOpSymbol(Kind kind)
Definition: Merger.cpp:270
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
This class represents an operand of an operation.
Definition: Value.h:249
TensorExp(Kind k, unsigned x, unsigned y, Value v)
Definition: Merger.cpp:22
void dumpLat(unsigned p) const
Definition: Merger.cpp:370
unsigned addSet()
Adds a new, initially empty, set. Returns its index.
Definition: Merger.cpp:88
unsigned buildLattices(unsigned e, unsigned i)
Builds the iteration lattices in a bottom-up traversal given the remaining tensor (sub)expression and...
Definition: Merger.cpp:419
bool isSingleCondition(unsigned t, unsigned e) const
Returns true if given tensor iterates only in the given tensor expression.
Definition: Merger.cpp:216
LatPoint(unsigned n, unsigned e, unsigned b)
Definition: Merger.cpp:63
Optional< unsigned > buildTensorExpFromLinalg(linalg::GenericOp op)
Builds a tensor expression from the given Linalg operation.
Definition: Merger.cpp:522
Children children
Tensor operations hold the indices of their children.
Definition: Merger.h:83