MLIR 22.0.0git
Merger.cpp
Go to the documentation of this file.
1//===- Merger.cpp - Implementation of iteration lattices ------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
14
15#include "mlir/IR/Operation.h"
16#include "llvm/Support/Debug.h"
17#include <optional>
18
19namespace mlir {
20namespace sparse_tensor {
21
27
29 switch (k) {
30 // Leaf.
35 return ExpArity::kNullary;
71 return ExpArity::kUnary;
72 // Binary operations.
96 case TensorExp::Kind::kDenseOp: // kDenseOp can *at most* have two operands
97 return ExpArity::kBinary;
98 }
99 llvm_unreachable("unexpected kind");
100}
101
102//===----------------------------------------------------------------------===//
103// Constructors.
104//===----------------------------------------------------------------------===//
105
107 Operation *o, Attribute a)
108 : kind(k), val(v), op(o), attr(a) {
109 switch (kind) {
110 // Leaf.
112 assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o);
113 tensor = x;
114 return;
116 assert(x == detail::kInvalidId && y == detail::kInvalidId && !v && !o);
117 return;
119 assert(x == detail::kInvalidId && y == detail::kInvalidId && v && !o);
120 return;
122 assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o);
123 loop = x;
124 return;
125 // Unary operations.
147 assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o);
148 children.e0 = x;
149 children.e1 = y;
150 return;
162 assert(x != detail::kInvalidId && y == detail::kInvalidId && v && !o);
163 children.e0 = x;
164 children.e1 = y;
165 return;
168 assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && o);
169 children.e0 = x;
170 children.e1 = y;
171 return;
173 // No assertion on y can be made, as the branching paths involve both
174 // a unary (`mapSet`) and binary (`disjSet`) pathway.
175 assert(x != detail::kInvalidId && !v && o);
176 children.e0 = x;
177 children.e1 = y;
178 return;
179 // Binary operations.
199 assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && !o);
200 children.e0 = x;
201 children.e1 = y;
202 return;
205 assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && !o);
206 children.e0 = x;
207 children.e1 = y;
208 return;
211 assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && o);
212 children.e0 = x;
213 children.e1 = y;
214 return;
216 assert(x != detail::kInvalidId && !v && o);
217 children.e0 = x;
218 children.e1 = y;
219 return;
220 }
221 llvm_unreachable("unexpected kind");
222}
223
224Merger::Merger(unsigned numInputOutputTensors, unsigned numLoops,
225 unsigned maxLvlRank)
226 : outTensor(numInputOutputTensors - 1),
227 syntheticTensor(numInputOutputTensors),
228 numTensors(numInputOutputTensors + 1), numLoops(numLoops),
229 hasSparseOut(false),
230 lvlTypes(numTensors,
231 std::vector<LevelType>(numLoops, LevelFormat::Undef)),
232 loopToLvl(numTensors,
233 std::vector<std::optional<Level>>(numLoops, std::nullopt)),
234 lvlToLoop(numTensors,
235 std::vector<std::optional<LoopId>>(maxLvlRank, std::nullopt)),
236 loopToUnresolvedLvls(numLoops, std::vector<std::optional<LvlLTPair>>(
237 numTensors, std::nullopt)),
238 levelToDependentLoop(numTensors,
239 std::vector<std::vector<LoopCoeffPair>>(
240 maxLvlRank, std::vector<LoopCoeffPair>())),
241 loopBounds(numLoops, std::make_pair(numTensors, numLoops)) {}
242
243//===----------------------------------------------------------------------===//
244// Lattice methods.
245//===----------------------------------------------------------------------===//
246
248 assert(isValidTensorId(t));
249 const ExprId eNew(tensorExps.size());
250 tensorExps.emplace_back(TensorExp::Kind::kTensor, t, detail::kInvalidId,
251 Value(), nullptr, nullptr);
252 return eNew;
253}
254
256 assert(isValidLoopId(i));
257 const ExprId eNew(tensorExps.size());
258 tensorExps.emplace_back(TensorExp::Kind::kLoopVar, i, detail::kInvalidId,
259 Value(), nullptr, nullptr);
260 return eNew;
261}
262
264 const ExprId eNew(tensorExps.size());
265 tensorExps.emplace_back(TensorExp::Kind::kInvariant, detail::kInvalidId,
266 detail::kInvalidId, v, nullptr, nullptr);
267 return eNew;
268}
269
271 const ExprId eNew(tensorExps.size());
272 tensorExps.emplace_back(TensorExp::Kind::kSynZero, detail::kInvalidId,
273 detail::kInvalidId, Value(), nullptr, nullptr);
274 return eNew;
275}
276
278 Attribute attr) {
279 assert(k > TensorExp::Kind::kLoopVar);
280 const ExprId eNew(tensorExps.size());
281 tensorExps.emplace_back(k, e0, e1, Value(), op, attr);
282 return eNew;
283}
284
286 Attribute attr) {
287 assert(k > TensorExp::Kind::kLoopVar);
288 const ExprId eNew(tensorExps.size());
289 tensorExps.emplace_back(k, e, detail::kInvalidId, v, op, attr);
290 return eNew;
291}
292
294 const LatPointId pNew(latPoints.size());
295 const unsigned size = numLoops * numTensors;
296 const TensorLoopId b = makeTensorLoopId(t, i);
297 latPoints.emplace_back(size, e);
298 latPoints[pNew].bits.set(b);
299 return pNew;
300}
301
302LatPointId Merger::addLat(const BitVector &bits, ExprId e) {
303 assert(bits.size() == numLoops * numTensors);
304 const LatPointId pNew(latPoints.size());
305 latPoints.emplace_back(bits, e);
306 return pNew;
307}
308
310 const LatSetId sNew(latSets.size());
311 latSets.emplace_back();
312 return sNew;
313}
314
316 Operation *op) {
317 TensorExp::Kind kind = exp(e).kind;
318 Attribute attr = exp(e).attr;
319 const LatPointId pNew(latPoints.size());
320 const auto &point0 = lat(p0);
321 const auto &point1 = lat(p1);
322 BitVector bits(point0.bits);
323 bits |= point1.bits;
324 const ExprId ne = addExp(kind, point0.exp, point1.exp, op, attr);
325 latPoints.emplace_back(bits, ne);
326 return pNew;
327}
328
330 const LatSetId sNew = addSet();
331 auto &setNew = latSets[sNew];
332 for (const LatPointId p0 : set(s0))
333 for (const LatPointId p1 : set(s1))
334 setNew.push_back(conjLat(e, p0, p1, op));
335 return sNew;
336}
337
339 const LatSetId sNew = conjSet(e, s0, s1, op);
340 TensorExp::Kind kind = exp(e).kind;
341 // Followed by all in s0.
342 latSets[sNew].append(latSets[s0]);
343 // Map binary 0-y to unary -y.
344 // TODO: move this if-else logic into buildLattices
345 if (kind == TensorExp::Kind::kSubF)
347 else if (kind == TensorExp::Kind::kSubC)
349 else if (kind == TensorExp::Kind::kSubI)
351 // Followed by all in s1.
352 latSets[sNew].append(latSets[s1]);
353 return sNew;
354}
355
357 assert(exp(e).kind == TensorExp::Kind::kCmpI ||
358 exp(e).kind == TensorExp::Kind::kCmpF);
359 const LatSetId sNew = conjSet(e, s0, s1, nullptr);
360
361 ExprId e0 = exp(e).children.e0;
362 ExprId e1 = exp(e).children.e1;
363 if (exp(e0).kind == TensorExp::Kind::kSynZero ||
364 exp(e1).kind == TensorExp::Kind::kSynZero) {
365 // lhs and rhs can't be synthetic zero at the same time.
366 assert(exp(e0).kind != exp(e1).kind);
367 // If one of the operands has already been assigned to zero (the
368 // element is absent in the corresponding operand), then we do not
369 // need to build disjunctive set for it.
370 return sNew;
371 }
372
373 auto lhsSet = mapBinWithSynZeroSet(e, s0, false);
374 auto rhsSet = mapBinWithSynZeroSet(e, s1, true);
375 latSets[sNew].append(latSets[lhsSet]);
376 latSets[sNew].append(latSets[rhsSet]);
377 return sNew;
378}
379
381 bool includeLeft, TensorExp::Kind ltrans,
382 Operation *opleft, bool includeRight,
383 TensorExp::Kind rtrans, Operation *opright) {
384 Attribute a = exp(e).attr;
385 const LatSetId sNew = conjSet(e, s0, s1, orig);
386 // Left Region.
387 if (includeLeft) {
388 if (opleft)
389 s0 = mapSet(ltrans, s0, Value(), opleft, a);
390 latSets[sNew].append(latSets[s0]);
391 }
392 // Right Region.
393 if (includeRight) {
394 if (opright)
395 s1 = mapSet(rtrans, s1, Value(), opright, a);
396 latSets[sNew].append(latSets[s1]);
397 }
398 return sNew;
399}
400
402 Operation *op, Attribute a) {
403 assert((TensorExp::Kind::kAbsF <= kind && kind <= TensorExp::Kind::kSelect) ||
405 const LatSetId sNew = addSet();
406 auto &setNew = latSets[sNew];
407 for (const LatPointId p : set(s0)) {
408 const auto &point = latPoints[p];
409 setNew.push_back(addLat(point.bits, addExp(kind, point.exp, v, op, a)));
410 }
411 return sNew;
412}
413
415 TensorExp::Kind kind = exp(e).kind;
416 Attribute a = exp(e).attr;
417 assert(TensorExp::Kind::kMulF <= kind && kind <= TensorExp::Kind::kShlI);
418 // Must be a binary operation.
419 const LatSetId sNew = addSet();
420 auto &setNew = latSets[sNew];
421 const ExprId zeroExp = addSynZeroExp();
422 for (const LatPointId p : set(s0)) {
423 const auto &point = latPoints[p];
424 ExprId newExp = lhsZero ? addExp(kind, zeroExp, point.exp, nullptr, a)
425 : addExp(kind, point.exp, zeroExp, nullptr, a);
426 setNew.push_back(addLat(point.bits, newExp));
427 }
428 return sNew;
429}
430
432 const LatSetId sNew = addSet();
433 auto &setNew = latSets[sNew];
434 const auto &set0 = set(s0);
435 assert(!set0.empty());
436 const LatPointId p0 = set0[0];
437 for (const LatPointId p1 : set0) {
438 bool add = true;
439 if (p0 != p1) {
440 // Check whether this is a straightforward copy.
441 if (expIsTensor(latPoints[p1].exp, outTensor))
442 continue;
443 // Check whether this conjunction is already covered.
444 for (const LatPointId p2 : setNew) {
445 assert(!latGT(p1, p2)); // Lj => Li would be bad
446 if (onlyDenseDiff(p2, p1)) {
447 add = false;
448 break;
449 }
450 }
451 assert(!add || latGT(p0, p1));
452 }
453 if (add)
454 setNew.push_back(p1);
455 }
456 for (const LatPointId p : setNew)
457 latPoints[p].simple = simplifyCond(sNew, p);
458 return sNew;
459}
460
462 // First determine if this lattice point is a *singleton*, i.e.,
463 // the last point in a lattice, no other is less than this one.
464 bool isSingleton = true;
465 for (const LatPointId p1 : set(s0)) {
466 if (p0 != p1 && latGT(p0, p1)) {
467 isSingleton = false;
468 break;
469 }
470 }
471
472 BitVector simple(latPoints[p0].bits);
473 bool reset = isSingleton && hasAnySparse(simple);
474 const TensorLoopId be = simple.size();
475 TensorLoopId offset = 0; // relative to the end
476 if (!reset)
477 // Starts resetting from a dense level, so that the first bit (if kept)
478 // is not undefined level-type.
479 for (unsigned b = 0; b < be; b++) {
480 if (simple[b] && getLvlType(TensorLoopId{b}).hasDenseSemantic()) {
481 offset = be - b - 1; // relative to the end
482 break;
483 }
484 }
485
486 // Now apply the two basic rules. We also iterate the bits reversely to always
487 // keep the rightmost bit (which could possibly be a synthetic tensor).
488 for (unsigned b = be - 1 - offset, i = 0; i < be;
489 b = b == 0 ? be - 1 : b - 1, i++) {
490 // Slice on dense level has `locate` property as well, and can be optimized.
491 if (simple[b] && !isSparseLvlWithNonTrivialIdxExp(b)) {
492 const auto lt = getLvlType(b);
493 if (!lt.hasSparseSemantic()) {
494 if (reset)
495 simple.reset(b);
496 reset = true;
497 }
498 }
499 }
500 return simple;
501}
502
504 const BitVector &bitsi = lat(i).bits;
505 const BitVector &bitsj = lat(j).bits;
506 assert(bitsi.size() == bitsj.size());
507 if (bitsi.count() > bitsj.count()) {
508 for (TensorLoopId b = 0, be = bitsj.size(); b < be; b++)
509 if (bitsj[b] && !bitsi[b])
510 return false;
511 return true;
512 }
513 return false;
514}
515
517 BitVector tmp(latPoints[j].bits);
518 tmp ^= latPoints[i].bits;
519 return !hasAnySparse(tmp);
520}
521
523 const auto &expr = exp(e);
524 // First we check `expIsTensor`.
525 if (expr.kind == TensorExp::Kind::kTensor)
526 return expr.tensor == t;
527
528 switch (getExpArity(expr.kind)) {
530 return false;
531 case ExpArity::kUnary: {
532 const ExprId e0 = expr.children.e0;
533 return expContainsTensor(e0, t);
534 }
535 case ExpArity::kBinary: {
536 const ExprId e0 = expr.children.e0;
537 const ExprId e1 = expr.children.e1;
538 return expContainsTensor(e0, t) || expContainsTensor(e1, t);
539 }
540 }
541 llvm_unreachable("unexpected arity");
542}
543
545 const auto &expr = exp(e);
546 switch (expr.kind) {
550 return expContainsTensor(expr.children.e0, outTensor);
554 return expContainsTensor(expr.children.e1, outTensor) ||
555 hasNegateOnOut(expr.children.e0);
557 bool lhsNeg = hasNegateOnOut(expr.children.e0);
558 if (!lhsNeg && expr.children.e1 != detail::kInvalidId)
559 return hasNegateOnOut(expr.children.e1);
560 return lhsNeg;
561 }
562 default: {
563 switch (getExpArity(expr.kind)) {
565 return false;
566 case ExpArity::kUnary:
567 return hasNegateOnOut(expr.children.e0);
569 return hasNegateOnOut(expr.children.e0) ||
570 hasNegateOnOut(expr.children.e1);
571 }
572 }
573 }
574 llvm_unreachable("unexpected kind");
575}
576
578 assert(isValidTensorId(t));
579 const auto &expr = exp(e);
580 switch (expr.kind) {
581 // Leaf.
583 return expr.tensor == t;
587 return false;
588 // Unary operations.
622 return isSingleCondition(t, expr.children.e0);
625 return false;
626 // Binary operations.
627 case TensorExp::Kind::kDivF: // note: x / c only
631 assert(!maybeZero(expr.children.e1));
632 return isSingleCondition(t, expr.children.e0);
633 case TensorExp::Kind::kShrS: // note: x >> inv only
636 assert(isInvariant(expr.children.e1));
637 return isSingleCondition(t, expr.children.e0);
643 if (isSingleCondition(t, expr.children.e0))
644 return isSingleCondition(t, expr.children.e1) ||
645 isInvariant(expr.children.e1);
646 if (isSingleCondition(t, expr.children.e1))
647 return isInvariant(expr.children.e0);
648 return false;
652 return isSingleCondition(t, expr.children.e0) &&
653 isSingleCondition(t, expr.children.e1);
662 return false;
664 // Since Merger guarantees all the operands of the kDenseOp to be dense, the
665 // operation must be single-condition.
666 return true;
667 }
668 llvm_unreachable("unexpected kind");
669}
670
671bool Merger::hasAnySparse(const BitVector &bits) const {
672 for (TensorLoopId b : bits.set_bits()) {
673 const auto lt = getLvlType(b);
674 if (lt.hasSparseSemantic())
675 return true;
676 }
677 return hasSparseIdxReduction(bits);
678}
679
680bool Merger::hasSparseIdxReduction(const BitVector &bits) const {
681 for (TensorLoopId b : bits.set_bits())
683 return true;
684 return false;
685}
686
687#ifndef NDEBUG
688
689//===----------------------------------------------------------------------===//
690// Print methods (for debugging).
691//===----------------------------------------------------------------------===//
692
693static const char *kindToOpSymbol(TensorExp::Kind kind) {
694 switch (kind) {
695 // Leaf.
697 return "tensor";
699 return "invariant";
701 return "index";
703 return "0";
704 // Unary operations.
708 return "abs";
710 return "ceil";
712 return "floor";
715 return "sqrt";
718 return "expm1";
721 return "log1p";
723 return "relu";
726 return "sin";
729 return "tanh";
733 return "-";
745 return "complex.im";
747 return "complex.re";
749 return "cast";
751 return "binary_branch";
753 return "unary";
755 return "select";
756 // Binary operations.
760 return "*";
765 return "/";
769 return "+";
773 return "-";
775 return "&";
777 return "|";
779 return "^";
781 return "a>>";
783 return ">>";
785 return "<<";
788 return "cmp";
790 return "binary";
792 return "reduce";
794 return "dense";
795 }
796 llvm_unreachable("unexpected kind for symbol");
797}
798
799void Merger::dumpExp(ExprId e) const {
800 const auto &expr = exp(e);
801 switch (expr.kind) {
802 // Leaf.
804 if (expr.tensor == syntheticTensor)
805 llvm::dbgs() << "synthetic_";
806 else if (expr.tensor == outTensor)
807 llvm::dbgs() << "output_";
808 llvm::dbgs() << "tensor_" << expr.tensor;
809 break;
811 llvm::dbgs() << "invariant";
812 break;
814 llvm::dbgs() << "0";
815 break;
817 llvm::dbgs() << "loopvar_" << expr.loop;
818 break;
819 // Unary operations.
855 llvm::dbgs() << kindToOpSymbol(expr.kind) << " ";
856 dumpExp(expr.children.e0);
857 break;
858 // Binary operations.
883 llvm::dbgs() << "(";
884 dumpExp(expr.children.e0);
885 llvm::dbgs() << " " << kindToOpSymbol(expr.kind);
886 if (expr.attr)
887 llvm::dbgs() << "{" << expr.attr << "}";
888 if (expr.children.e1 != detail::kInvalidId) {
889 llvm::dbgs() << " ";
890 dumpExp(expr.children.e1);
891 llvm::dbgs() << ")";
892 } else {
893 assert(expr.kind == TensorExp::Kind::kDenseOp);
894 }
895 break;
896 }
897}
898
900 const auto &point = lat(p);
901 llvm::dbgs() << "lat(";
902 dumpBits(point.bits);
903 llvm::dbgs() << " :";
904 dumpBits(point.simple);
905 llvm::dbgs() << " : ";
906 dumpExp(point.exp);
907 llvm::dbgs() << " )\n";
908}
909
911 const auto &ss = set(s);
912 llvm::dbgs() << "{ #" << ss.size() << "\n";
913 for (const LatPointId p : ss) {
914 llvm::dbgs() << " ";
915 dumpLat(p);
916 }
917 llvm::dbgs() << "}\n";
918}
919
920void Merger::dumpBits(const BitVector &bits) const {
921 for (TensorLoopId b = 0, be = bits.size(); b < be; b++) {
922 if (bits[b]) {
923 const TensorId t = tensor(b);
924 const LoopId i = loop(b);
925 const auto lt = lvlTypes[t][i];
927 llvm::dbgs() << " DEP_" << t << "_" << i;
928 else
929 llvm::dbgs() << " i_" << t << "_" << i << "_" << toMLIRString(lt);
930 }
931 }
932}
933
934#endif // NDEBUG
935
936//===----------------------------------------------------------------------===//
937// Builder methods.
938//===----------------------------------------------------------------------===//
939
941 // NOTE: The `expr` reference will be invalidated by recursive calls
942 // (and any other method that may add new expressions); therefore, the
943 // code below must make sure to copy fields of `expr` into local variables
944 // before making any recursive calls.
945 const auto &expr = exp(e);
946 const TensorExp::Kind kind = expr.kind;
947 switch (kind) {
948 // Leaf.
953 // Either the loop-var is really used in the tensor expression, or it is
954 // set to the undefined loop-var in that level. An invariant expression,
955 // a proper index value, and a truly dynamic sparse output tensor are set
956 // to a synthetic tensor with undefined indices only to ensure the
957 // iteration space is not skipped as a result of their contents.
958 const LatSetId s = addSet();
959 TensorId t = syntheticTensor;
960 if (kind == TensorExp::Kind::kTensor) {
961 t = expr.tensor;
962 if (hasSparseOut && t == outTensor)
963 t = syntheticTensor;
964 }
965 latSets[s].push_back(addLat(t, i, e));
966 return s;
967 }
968 // Unary operations.
1001 // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the
1002 // lattice set of the operand through the operator into a new set.
1003 //
1004 // -y|!y | y |
1005 // --+---+---+
1006 // | 0 |-y |
1007 {
1008 const ExprId e0 = expr.children.e0;
1009 const Value v = expr.val;
1010 Attribute a = expr.attr;
1011 return mapSet(kind, buildLattices(e0, i), v, nullptr, a);
1012 }
1015 // The left or right half of a binary operation which has already
1016 // been split into separate operations for each region.
1017 {
1018 const ExprId e0 = expr.children.e0;
1019 Operation *const op = expr.op;
1020 return mapSet(kind, buildLattices(e0, i), Value(), op);
1021 }
1023 // A custom unary operation.
1024 //
1025 // op y| !y | y |
1026 // ----+----------+------------+
1027 // | absent() | present(y) |
1028 {
1029 const ExprId e0 = expr.children.e0;
1030 UnaryOp unop = cast<UnaryOp>(expr.op);
1031 const LatSetId child0 = buildLattices(e0, i);
1032 Region &absentRegion = unop.getAbsentRegion();
1033 if (absentRegion.empty()) {
1034 // Simple mapping over existing values.
1035 return mapSet(kind, child0, Value(), unop);
1036 }
1037 // Use a disjunction with `unop` on the left and the absent value as an
1038 // invariant on the right.
1039 Block &absentBlock = absentRegion.front();
1040 YieldOp absentYield = cast<YieldOp>(absentBlock.getTerminator());
1041 const Value absentVal = absentYield.getSingleResult();
1042 const ExprId rhs = addInvariantExp(absentVal);
1043 return disjSet(e, child0, buildLattices(rhs, i), unop);
1044 }
1045 // Binary operations.
1050 // A multiplicative operation only needs to be performed
1051 // for the conjunction of sparse iteration spaces.
1052 //
1053 // x*y|!y | y |
1054 // ---+---+---+
1055 // !x | 0 | 0 |
1056 // x | 0 |x*y|
1057 //
1058 // Note even here, 0*NaN=NaN and 0*Inf=NaN, but that is ignored.
1059 {
1060 const ExprId e0 = expr.children.e0;
1061 const ExprId e1 = expr.children.e1;
1062 return conjSet(e, buildLattices(e0, i), buildLattices(e1, i));
1063 }
1068 // A division is tricky, since 0/0, 0/c, c/0 all have
1069 // specific outcomes for floating-point and integers.
1070 // Thus, we need to traverse the full iteration space.
1071 //
1072 // x/y|!y | y |
1073 // ---+---+---+
1074 // !x |0/0|0/y| FP: 0/0=NaN,c/0=Inf,0/c=0 with c true nonzero
1075 // x |x/0|x/y| INT: x/0=exception for any x
1076 //
1077 // TODO: for now we "fixed" this by only accepting x/c cases
1078 // during expression building, so that the conjunction
1079 // rules applies (viz. x/c = x*(1/c) as far as lattice
1080 // construction is concerned).
1081 {
1082 const ExprId e0 = expr.children.e0;
1083 const ExprId e1 = expr.children.e1;
1084 assert(!maybeZero(e1));
1085 return conjSet(e, buildLattices(e0, i), buildLattices(e1, i));
1086 }
1095 // An additive operation needs to be performed
1096 // for the disjunction of sparse iteration spaces.
1097 //
1098 // x+y|!y | y | x-y|!y | y |
1099 // ---+---+---+ ---+---+---+
1100 // !x | 0 | y | !x | 0 |-y |
1101 // x | x |x+y| x | x |x-y|
1102 {
1103 const ExprId e0 = expr.children.e0;
1104 const ExprId e1 = expr.children.e1;
1105 return disjSet(e, buildLattices(e0, i), buildLattices(e1, i));
1106 }
1109 // A comparison operation needs to be performed
1110 // for the disjunction of sparse iteration spaces.
1111 //
1112 // x < y | !y | y |
1113 // -------+-------+-------+
1114 // !x | 0 | 0 < y |
1115 // x | x < 0 | x < y |
1116 {
1117 const ExprId e0 = expr.children.e0;
1118 const ExprId e1 = expr.children.e1;
1119 return disjSetWithZero(e, buildLattices(e0, i), buildLattices(e1, i));
1120 }
1124 // A shift operation by an invariant amount (viz. tensor expressions
1125 // can only occur at the left-hand-side of the operator) can be handled
1126 // with the conjunction rule.
1127 {
1128 const ExprId e0 = expr.children.e0;
1129 const ExprId e1 = expr.children.e1;
1130 assert(isInvariant(e1));
1131 return conjSet(e, buildLattices(e0, i), buildLattices(e1, i));
1132 }
1134 // A custom binary operation.
1135 //
1136 // x op y| !y | y |
1137 // ------+---------+--------------+
1138 // !x | empty | right(y) |
1139 // x | left(x) | overlap(x,y) |
1140 {
1141 const ExprId e0 = expr.children.e0;
1142 const ExprId e1 = expr.children.e1;
1143 BinaryOp binop = cast<BinaryOp>(expr.op);
1144 const LatSetId child0 = buildLattices(e0, i);
1145 const LatSetId child1 = buildLattices(e1, i);
1146 Region &leftRegion = binop.getLeftRegion();
1147 Region &rightRegion = binop.getRightRegion();
1148 // Left Region.
1149 Operation *leftYield = nullptr;
1150 if (!leftRegion.empty()) {
1151 Block &leftBlock = leftRegion.front();
1152 leftYield = leftBlock.getTerminator();
1153 }
1154 // Right Region.
1155 Operation *rightYield = nullptr;
1156 if (!rightRegion.empty()) {
1157 Block &rightBlock = rightRegion.front();
1158 rightYield = rightBlock.getTerminator();
1159 }
1160 bool includeLeft = binop.getLeftIdentity() || !leftRegion.empty();
1161 bool includeRight = binop.getRightIdentity() || !rightRegion.empty();
1162 return combiSet(e, child0, child1, binop, includeLeft,
1163 TensorExp::Kind::kBinaryBranch, leftYield, includeRight,
1164 TensorExp::Kind::kBinaryBranch, rightYield);
1165 }
1167 // A custom reduce operation.
1168 {
1169 const ExprId e0 = expr.children.e0;
1170 const ExprId e1 = expr.children.e1;
1171 Operation *const op = expr.op;
1172 return conjSet(e, buildLattices(e0, i), buildLattices(e1, i), op);
1173 }
1175 // It does not really matter whether we use conjunctive/disjunctive set
1176 // here, as all the operands of kDenseOp must be dense, the disjunctive set
1177 // will be optimized into conjunctive set eventually.
1178 if (expr.children.e1 == detail::kInvalidId) {
1179 const ExprId e0 = expr.children.e0;
1180 Operation *const op = expr.op;
1181 return mapSet(kind, buildLattices(e0, i), Value(), op);
1182 }
1183
1184 const ExprId e0 = expr.children.e0;
1185 const ExprId e1 = expr.children.e1;
1186 Operation *const op = expr.op;
1187 return conjSet(e, buildLattices(e0, i), buildLattices(e1, i), op);
1188 }
1189 }
1190 llvm_unreachable("unexpected expression kind");
1191}
1192
1193std::optional<ExprId> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
1194 // Build the linalg semantics backward from yield.
1195 Operation *yield = op.getRegion().front().getTerminator();
1196 assert(isa<linalg::YieldOp>(yield));
1197 return buildTensorExp(op, yield->getOperand(0)).first;
1198}
1199
1200/// Only returns true if we are certain this is a zero.
1201static bool isCertainZero(Value val) {
1202 if (auto c = val.getDefiningOp<complex::ConstantOp>()) {
1203 ArrayAttr arrayAttr = c.getValue();
1204 return cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
1205 cast<FloatAttr>(arrayAttr[1]).getValue().isZero();
1206 }
1207 if (auto c = val.getDefiningOp<arith::ConstantIntOp>())
1208 return c.value() == 0;
1209 if (auto c = val.getDefiningOp<arith::ConstantFloatOp>())
1210 return c.value().isZero();
1211 return false;
1212}
1213
1214/// Only returns false if we are certain this is a nonzero.
1215bool Merger::maybeZero(ExprId e) const {
1216 const auto &expr = exp(e);
1217 if (expr.kind == TensorExp::Kind::kInvariant) {
1218 // Note that this is different from isCertainZero() in a subtle
1219 // way by always returning true for non-constants.
1220 if (auto c = expr.val.getDefiningOp<complex::ConstantOp>()) {
1221 ArrayAttr arrayAttr = c.getValue();
1222 return cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
1223 cast<FloatAttr>(arrayAttr[1]).getValue().isZero();
1224 }
1225 if (auto c = expr.val.getDefiningOp<arith::ConstantIntOp>())
1226 return c.value() == 0;
1227 if (auto c = expr.val.getDefiningOp<arith::ConstantFloatOp>())
1228 return c.value().isZero();
1229 }
1230 return true;
1231}
1232
1233Type Merger::inferType(ExprId e, Value src) const {
1234 // Obtain the destination type from the cast node.
1235 Type dtp = exp(e).val.getType();
1236 // Inspect source type. For vector types, apply the same
1237 // vectorization to the destination type.
1238 if (auto vtp = dyn_cast<VectorType>(src.getType()))
1239 return VectorType::get(vtp.getNumElements(), dtp, vtp.getScalableDims());
1240 return dtp;
1241}
1242
1243/// Ensures that the sparsifier can generate code for expression.
1244static bool isAdmissibleBranchExp(Operation *op, Block *block, Value v) {
1245 // Arguments are always admissible.
1246 if (isa<BlockArgument>(v))
1247 return true;
1248 // Accept index anywhere.
1249 Operation *def = v.getDefiningOp();
1250 if (isa<linalg::IndexOp>(def))
1251 return true;
1252 // Operation defined outside branch.
1253 if (def->getBlock() != block)
1254 return def->getBlock() != op->getBlock(); // invariant?
1255 // Operation defined within branch. Anything is accepted,
1256 // as long as all subexpressions are admissible.
1257 for (unsigned i = 0, n = def->getNumOperands(); i < n; i++)
1258 if (!isAdmissibleBranchExp(op, block, def->getOperand(i)))
1259 return false;
1260 return true;
1261}
1262
1263/// Ensures that the sparsifier can generate code for branch.
1264static bool isAdmissibleBranch(Operation *op, Region &region) {
1265 if (region.empty())
1266 return true;
1267 // Build the semi-ring branch semantics backward from yield.
1268 Operation *yield = region.front().getTerminator();
1269 assert(isa<YieldOp>(yield));
1270 return isAdmissibleBranchExp(op, &region.front(), yield->getOperand(0));
1271}
1272
1273// Recognizes a direct GT comparison.
1274static bool isGreater(TensorExp::Kind kind, Attribute attr) {
1275 if (kind == TensorExp::Kind::kCmpI) {
1276 auto pred = llvm::cast<arith::CmpIPredicateAttr>(attr).getValue();
1277 return pred == arith::CmpIPredicate::ugt ||
1278 pred == arith::CmpIPredicate::sgt;
1279 }
1280 if (kind == TensorExp::Kind::kCmpF) {
1281 auto pred = llvm::cast<arith::CmpFPredicateAttr>(attr).getValue();
1282 return pred == arith::CmpFPredicate::UGT ||
1283 pred == arith::CmpFPredicate::OGT;
1284 }
1285 return false;
1286}
1287
1288std::pair<std::optional<ExprId>, bool>
1289Merger::buildTensorExp(linalg::GenericOp op, Value v) {
1290 // Recursion leaves.
1291 if (auto arg = dyn_cast<BlockArgument>(v)) {
1292 const TensorId tid = makeTensorId(arg.getArgNumber());
1293 // Any argument of the generic op that is not marked as a scalar
1294 // argument is considered a tensor, indexed by the implicit loop
1295 // bounds. This includes rank-0 tensor arguments.
1296 if (arg.getOwner()->getParentOp() == op) {
1297 OpOperand &t = op->getOpOperand(tid);
1298 bool hasSpDep = getSparseTensorEncoding(t.get().getType()) != nullptr;
1299 if (!op.isScalar(&t))
1300 return {addTensorExp(tid), hasSpDep};
1301 v = t.get(); // get scalar value
1302 }
1303 // Any other argument (marked as scalar argument for the generic op
1304 // or belonging to an enveloping op) is considered invariant.
1305 return {addInvariantExp(v), /*hasSpDep=*/false};
1306 }
1307
1308 // Something defined outside is invariant.
1309 Operation *def = v.getDefiningOp();
1310 if (def->getBlock() != &op.getRegion().front())
1311 return {addInvariantExp(v), /*hasSpDep=*/false};
1312 // Construct index operations.
1313 if (def->getNumOperands() == 0) {
1314 if (auto indexOp = dyn_cast<linalg::IndexOp>(def))
1315 return {addLoopVarExp(makeLoopId(indexOp.getDim())), /*hasSpDep=*/false};
1316 }
1317
1318 // Construct unary operations if subexpression can be built.
1319 if (def->getNumOperands() == 1) {
1320 const auto [x, hasSpDep] = buildTensorExp(op, def->getOperand(0));
1321 if (x.has_value()) {
1322 const ExprId e = *x;
1323 if (isa<math::AbsFOp>(def))
1324 return {addExp(TensorExp::Kind::kAbsF, e), hasSpDep};
1325 if (isa<complex::AbsOp>(def))
1326 return {addExp(TensorExp::Kind::kAbsC, e), hasSpDep};
1327 if (isa<math::AbsIOp>(def))
1328 return {addExp(TensorExp::Kind::kAbsI, e), hasSpDep};
1329 if (isa<math::CeilOp>(def))
1330 return {addExp(TensorExp::Kind::kCeilF, e), hasSpDep};
1331 if (isa<math::FloorOp>(def))
1332 return {addExp(TensorExp::Kind::kFloorF, e), hasSpDep};
1333 if (isa<math::SqrtOp>(def))
1334 return {addExp(TensorExp::Kind::kSqrtF, e), hasSpDep};
1335 if (isa<complex::SqrtOp>(def))
1336 return {addExp(TensorExp::Kind::kSqrtC, e), hasSpDep};
1337 if (isa<math::ExpM1Op>(def))
1338 return {addExp(TensorExp::Kind::kExpm1F, e), hasSpDep};
1339 if (isa<complex::Expm1Op>(def))
1340 return {addExp(TensorExp::Kind::kExpm1C, e), hasSpDep};
1341 if (isa<math::Log1pOp>(def))
1342 return {addExp(TensorExp::Kind::kLog1pF, e), hasSpDep};
1343 if (isa<complex::Log1pOp>(def))
1344 return {addExp(TensorExp::Kind::kLog1pC, e), hasSpDep};
1345 if (isa<math::SinOp>(def))
1346 return {addExp(TensorExp::Kind::kSinF, e), hasSpDep};
1347 if (isa<complex::SinOp>(def))
1348 return {addExp(TensorExp::Kind::kSinC, e), hasSpDep};
1349 if (isa<math::TanhOp>(def))
1350 return {addExp(TensorExp::Kind::kTanhF, e), hasSpDep};
1351 if (isa<complex::TanhOp>(def))
1352 return {addExp(TensorExp::Kind::kTanhC, e), hasSpDep};
1353 if (isa<arith::NegFOp>(def))
1354 return {addExp(TensorExp::Kind::kNegF, e), hasSpDep}; // no negi in std
1355 if (isa<complex::NegOp>(def))
1356 return {addExp(TensorExp::Kind::kNegC, e), hasSpDep};
1357 if (isa<arith::TruncFOp>(def))
1358 return {addExp(TensorExp::Kind::kTruncF, e, v), hasSpDep};
1359 if (isa<arith::ExtFOp>(def))
1360 return {addExp(TensorExp::Kind::kExtF, e, v), hasSpDep};
1361 if (isa<arith::FPToSIOp>(def))
1362 return {addExp(TensorExp::Kind::kCastFS, e, v), hasSpDep};
1363 if (isa<arith::FPToUIOp>(def))
1364 return {addExp(TensorExp::Kind::kCastFU, e, v), hasSpDep};
1365 if (isa<arith::SIToFPOp>(def))
1366 return {addExp(TensorExp::Kind::kCastSF, e, v), hasSpDep};
1367 if (isa<arith::UIToFPOp>(def))
1368 return {addExp(TensorExp::Kind::kCastUF, e, v), hasSpDep};
1369 if (isa<arith::ExtSIOp>(def))
1370 return {addExp(TensorExp::Kind::kCastS, e, v), hasSpDep};
1371 if (isa<arith::ExtUIOp>(def))
1372 return {addExp(TensorExp::Kind::kCastU, e, v), hasSpDep};
1373 if (isa<arith::IndexCastOp>(def))
1374 return {addExp(TensorExp::Kind::kCastIdx, e, v), hasSpDep};
1375 if (isa<arith::TruncIOp>(def))
1376 return {addExp(TensorExp::Kind::kTruncI, e, v), hasSpDep};
1377 if (isa<complex::ImOp>(def))
1378 return {addExp(TensorExp::Kind::kCIm, e), hasSpDep};
1379 if (isa<complex::ReOp>(def))
1380 return {addExp(TensorExp::Kind::kCRe, e), hasSpDep};
1381 if (isa<arith::BitcastOp>(def))
1382 return {addExp(TensorExp::Kind::kBitCast, e, v), hasSpDep};
1383 if (auto unop = dyn_cast<sparse_tensor::UnaryOp>(def)) {
1384 if (isAdmissibleBranch(unop, unop.getPresentRegion()) &&
1385 isAdmissibleBranch(unop, unop.getAbsentRegion()))
1386 return {addExp(TensorExp::Kind::kUnary, e, Value(), def), hasSpDep};
1387 }
1388 if (auto selop = dyn_cast<sparse_tensor::SelectOp>(def)) {
1389 if (isAdmissibleBranch(selop, selop.getRegion()))
1390 return {addExp(TensorExp::Kind::kSelect, e, Value(), def), hasSpDep};
1391 }
1392 }
1393 }
1394
1395 // Construct binary operations if subexpressions can be built.
1396 // See buildLattices() for an explanation of rejecting certain
1397 // division and shift operations.
1398 if (def->getNumOperands() == 2) {
1399 const auto [x, xSpVals] = buildTensorExp(op, def->getOperand(0));
1400 const auto [y, ySpVals] = buildTensorExp(op, def->getOperand(1));
1401 // For a conjunctive operation, it yields a "sparse" result if any operand
1402 // is sparse. For a disjunctive operation, it yields a "sparse" result if
1403 // all operands are sparse.
1404 bool conjSpVals = xSpVals || ySpVals;
1405 bool disjSpVals = xSpVals && ySpVals;
1406 if (x.has_value() && y.has_value()) {
1407 const ExprId e0 = *x;
1408 const ExprId e1 = *y;
1409 if (isa<arith::MulFOp>(def))
1410 return {addExp(TensorExp::Kind::kMulF, e0, e1), conjSpVals};
1411 if (isa<complex::MulOp>(def))
1412 return {addExp(TensorExp::Kind::kMulC, e0, e1), conjSpVals};
1413 if (isa<arith::MulIOp>(def))
1414 return {addExp(TensorExp::Kind::kMulI, e0, e1), conjSpVals};
1415 if (isa<arith::DivFOp>(def) && !maybeZero(e1))
1416 return {addExp(TensorExp::Kind::kDivF, e0, e1), conjSpVals};
1417 if (isa<complex::DivOp>(def) && !maybeZero(e1))
1418 return {addExp(TensorExp::Kind::kDivC, e0, e1), conjSpVals};
1419 if (isa<arith::DivSIOp>(def) && !maybeZero(e1))
1420 return {addExp(TensorExp::Kind::kDivS, e0, e1), conjSpVals};
1421 if (isa<arith::DivUIOp>(def) && !maybeZero(e1))
1422 return {addExp(TensorExp::Kind::kDivU, e0, e1), conjSpVals};
1423 if (isa<arith::AddFOp>(def))
1424 return {addExp(TensorExp::Kind::kAddF, e0, e1), disjSpVals};
1425 if (isa<complex::AddOp>(def))
1426 return {addExp(TensorExp::Kind::kAddC, e0, e1), disjSpVals};
1427 if (isa<arith::AddIOp>(def))
1428 return {addExp(TensorExp::Kind::kAddI, e0, e1), disjSpVals};
1429 if (isa<arith::SubFOp>(def))
1430 return {addExp(TensorExp::Kind::kSubF, e0, e1), disjSpVals};
1431 if (isa<complex::SubOp>(def))
1432 return {addExp(TensorExp::Kind::kSubC, e0, e1), disjSpVals};
1433 if (isa<arith::SubIOp>(def))
1434 return {addExp(TensorExp::Kind::kSubI, e0, e1), disjSpVals};
1435 if (isa<arith::AndIOp>(def))
1436 return {addExp(TensorExp::Kind::kAndI, e0, e1), conjSpVals};
1437 if (isa<arith::OrIOp>(def))
1438 return {addExp(TensorExp::Kind::kOrI, e0, e1), disjSpVals};
1439 if (isa<arith::XOrIOp>(def))
1440 return {addExp(TensorExp::Kind::kXorI, e0, e1), disjSpVals};
1441 if (isa<arith::ShRSIOp>(def) && isInvariant(e1))
1442 return {addExp(TensorExp::Kind::kShrS, e0, e1), conjSpVals};
1443 if (isa<arith::ShRUIOp>(def) && isInvariant(e1))
1444 return {addExp(TensorExp::Kind::kShrU, e0, e1), conjSpVals};
1445 if (isa<arith::ShLIOp>(def) && isInvariant(e1))
1446 return {addExp(TensorExp::Kind::kShlI, e0, e1), conjSpVals};
1447 if (auto ci = dyn_cast<arith::CmpIOp>(def)) {
1448 if (ci.getPredicate() == arith::CmpIPredicate::eq &&
1449 ci.getPredicate() == arith::CmpIPredicate::sle &&
1450 ci.getPredicate() == arith::CmpIPredicate::sge &&
1451 ci.getPredicate() == arith::CmpIPredicate::ule &&
1452 ci.getPredicate() == arith::CmpIPredicate::uge) {
1453 // We can not sparsify comparison with equal, this is because 0 <= 0
1454 // yields true, and thus densifies the result.
1455 return {std::nullopt, false};
1456 }
1457
1458 auto e = addExp(TensorExp::Kind::kCmpI, e0, e1, nullptr,
1459 ci.getPredicateAttr());
1460 return {e, conjSpVals};
1461 }
1462 if (auto cf = dyn_cast<arith::CmpFOp>(def)) {
1463 if (cf.getPredicate() == arith::CmpFPredicate::OEQ &&
1464 cf.getPredicate() == arith::CmpFPredicate::OGE &&
1465 cf.getPredicate() == arith::CmpFPredicate::OLE &&
1466 cf.getPredicate() == arith::CmpFPredicate::ONE &&
1467 cf.getPredicate() == arith::CmpFPredicate::UEQ &&
1468 cf.getPredicate() == arith::CmpFPredicate::UGE &&
1469 cf.getPredicate() == arith::CmpFPredicate::ULE &&
1470 cf.getPredicate() == arith::CmpFPredicate::ORD &&
1471 cf.getPredicate() == arith::CmpFPredicate::UNO) {
1472 // We can not sparsify comparison with equal, this is because 0 <= 0
1473 // yields true, and thus densifies the result.
1474 return {std::nullopt, false};
1475 }
1476 auto e = addExp(TensorExp::Kind::kCmpF, e0, e1, nullptr,
1477 cf.getPredicateAttr());
1478 return {e, conjSpVals};
1479 }
1480 if (auto binop = dyn_cast<sparse_tensor::BinaryOp>(def)) {
1481 if (isAdmissibleBranch(binop, binop.getOverlapRegion()) &&
1482 (binop.getLeftIdentity() ||
1483 isAdmissibleBranch(binop, binop.getLeftRegion())) &&
1484 (binop.getRightIdentity() ||
1485 isAdmissibleBranch(binop, binop.getRightRegion())))
1486 return {addExp(TensorExp::Kind::kBinary, e0, e1, def), conjSpVals};
1487 }
1488 }
1489 }
1490
1491 // Construct ternary operations if subexpressions can be built.
1492 if (def->getNumOperands() == 3) {
1493 const auto [x, xDepSp] = buildTensorExp(op, def->getOperand(0));
1494 const auto [y, yDepSp] = buildTensorExp(op, def->getOperand(1));
1495 const auto [z, zDepSp] = buildTensorExp(op, def->getOperand(2));
1496 bool hasSpDep = xDepSp || yDepSp || zDepSp;
1497 if (x.has_value() && y.has_value() && z.has_value()) {
1498 const ExprId e0 = *x;
1499 const ExprId e1 = *y;
1500 if (auto redop = dyn_cast<sparse_tensor::ReduceOp>(def)) {
1501 if (isAdmissibleBranch(redop, redop.getRegion()))
1502 return {addExp(TensorExp::Kind::kReduce, e0, e1, def), hasSpDep};
1503 }
1504 if (auto selop = dyn_cast<arith::SelectOp>(def)) {
1505 // Recognize an integral or floating-point ReLu(x) = Max(x, 0)
1506 // operation inside a very specific ternary select operation.
1507 // TODO: capture MIN/MAX/ABS/RELU structure in a more generic way
1508 const auto &cnd = exp(*x);
1509 if (isGreater(cnd.kind, cnd.attr) &&
1510 exp(*y).kind == TensorExp::Kind::kTensor &&
1511 exp(*z).kind == TensorExp::Kind::kInvariant &&
1512 isCertainZero(exp(*z).val)) {
1513 const auto &a = exp(cnd.children.e0);
1514 const auto &b = exp(cnd.children.e1);
1515 if (a.kind == TensorExp::Kind::kTensor &&
1516 a.tensor == exp(*y).tensor &&
1517 b.kind == TensorExp::Kind::kInvariant && isCertainZero(b.val)) {
1519 nullptr, cnd.attr),
1520 yDepSp};
1521 }
1522 }
1523 }
1524 }
1525 }
1526
1527 // If we reach here, we are dealing with an operation that is not currently
1528 // sparsifiable. We can still generate code for it if all its operands only
1529 // have dense dependencies (i.e., all the values are loaded from dense
1530 // tensors).
1531 if (def->getNumResults() != 1) // only handle single result operation.
1532 return {std::nullopt, false};
1533 SmallVector<std::pair<std::optional<ExprId>, bool>, 2> subExp;
1534 // Builds all the sub-expressions
1535 for (Value operand : def->getOperands())
1536 subExp.push_back(buildTensorExp(op, operand));
1537
1538 if (llvm::all_of(subExp,
1539 [](auto e) { return e.first.has_value() && !e.second; })) {
1540 // All the subexpressions can be built and has *no* sparse dependencies.
1541 if (subExp.size() == 2) {
1542 auto e = addExp(TensorExp::Kind::kDenseOp, *subExp[0].first,
1543 *subExp[1].first, def);
1544 return {e, false};
1545 }
1546 if (subExp.size() == 1) {
1547 auto e = addExp(TensorExp::Kind::kDenseOp, *subExp[0].first,
1548 detail::kInvalidId, def);
1549 return {e, false};
1550 }
1551 }
1552
1553 // Cannot build.
1554 return {std::nullopt, false};
1555}
1556
1557static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region &region,
1558 ValueRange vals) {
1559 // Make a clone of overlap region.
1560 Region tmpRegion;
1561 IRMapping mapper;
1562 region.cloneInto(&tmpRegion, tmpRegion.begin(), mapper);
1563 Block &clonedBlock = tmpRegion.front();
1564 YieldOp clonedYield = cast<YieldOp>(clonedBlock.getTerminator());
1565 // Merge cloned block and return yield value.
1566 Operation *placeholder = arith::ConstantIndexOp::create(rewriter, loc, 0);
1567 rewriter.inlineBlockBefore(&tmpRegion.front(), placeholder, vals);
1568 Value val = clonedYield.getSingleResult();
1569 rewriter.eraseOp(clonedYield);
1570 rewriter.eraseOp(placeholder);
1571 return val;
1572}
1573
1575 Operation *op, Value v0) {
1576 if (!v0)
1577 // Empty input value must be propagated.
1578 return Value();
1579 UnaryOp unop = cast<UnaryOp>(op);
1580 Region &presentRegion = unop.getPresentRegion();
1581 if (presentRegion.empty())
1582 // Uninitialized Value() will be interpreted as missing data in the
1583 // output.
1584 return Value();
1585 return insertYieldOp(rewriter, loc, presentRegion, {v0});
1586}
1587
1589 Operation *op, Value v0, Value v1) {
1590 if (!v0 || !v1)
1591 // Empty input values must be propagated.
1592 return Value();
1593 BinaryOp binop = cast<BinaryOp>(op);
1594 Region &overlapRegion = binop.getOverlapRegion();
1595 if (overlapRegion.empty())
1596 // Uninitialized Value() will be interpreted as missing data in the
1597 // output.
1598 return Value();
1599 return insertYieldOp(rewriter, loc, overlapRegion, {v0, v1});
1600}
1601
1602static Value buildRelu(RewriterBase &rewriter, Location loc, Value v0,
1603 Attribute attr) {
1604 Type tp = v0.getType();
1605 auto zero =
1606 arith::ConstantOp::create(rewriter, loc, tp, rewriter.getZeroAttr(tp));
1607 Value cmp;
1608 if (isa<FloatType>(tp)) {
1609 auto pred = llvm::cast<arith::CmpFPredicateAttr>(attr);
1610 cmp = arith::CmpFOp::create(rewriter, loc, pred, v0, zero);
1611 } else {
1612 auto pred = llvm::cast<arith::CmpIPredicateAttr>(attr);
1613 cmp = arith::CmpIOp::create(rewriter, loc, pred, v0, zero);
1614 }
1615 return arith::SelectOp::create(rewriter, loc, cmp, v0, zero);
1616}
1617
1619 Value v1) const {
1620 const auto &expr = exp(e);
1621 switch (expr.kind) {
1622 // Leaf.
1627 llvm_unreachable("unexpected non-op");
1628 // Unary operations.
1630 return math::AbsFOp::create(rewriter, loc, v0);
1632 auto type = cast<ComplexType>(v0.getType());
1633 auto eltType = cast<FloatType>(type.getElementType());
1634 return complex::AbsOp::create(rewriter, loc, eltType, v0);
1635 }
1637 return math::AbsIOp::create(rewriter, loc, v0);
1639 return math::CeilOp::create(rewriter, loc, v0);
1641 return math::FloorOp::create(rewriter, loc, v0);
1643 return math::SqrtOp::create(rewriter, loc, v0);
1645 return complex::SqrtOp::create(rewriter, loc, v0);
1647 return math::ExpM1Op::create(rewriter, loc, v0);
1649 return complex::Expm1Op::create(rewriter, loc, v0);
1651 return math::Log1pOp::create(rewriter, loc, v0);
1653 return complex::Log1pOp::create(rewriter, loc, v0);
1655 return buildRelu(rewriter, loc, v0, expr.attr);
1657 return math::SinOp::create(rewriter, loc, v0);
1659 return complex::SinOp::create(rewriter, loc, v0);
1661 return math::TanhOp::create(rewriter, loc, v0);
1663 return complex::TanhOp::create(rewriter, loc, v0);
1665 return arith::NegFOp::create(rewriter, loc, v0);
1667 return complex::NegOp::create(rewriter, loc, v0);
1668 case TensorExp::Kind::kNegI: // no negi in std
1669 return arith::SubIOp::create(
1670 rewriter, loc,
1671 arith::ConstantOp::create(rewriter, loc, v0.getType(),
1672 rewriter.getZeroAttr(v0.getType())),
1673 v0);
1675 return arith::TruncFOp::create(rewriter, loc, inferType(e, v0), v0);
1677 return arith::ExtFOp::create(rewriter, loc, inferType(e, v0), v0);
1679 return arith::FPToSIOp::create(rewriter, loc, inferType(e, v0), v0);
1681 return arith::FPToUIOp::create(rewriter, loc, inferType(e, v0), v0);
1683 return arith::SIToFPOp::create(rewriter, loc, inferType(e, v0), v0);
1685 return arith::UIToFPOp::create(rewriter, loc, inferType(e, v0), v0);
1687 return arith::ExtSIOp::create(rewriter, loc, inferType(e, v0), v0);
1689 return arith::ExtUIOp::create(rewriter, loc, inferType(e, v0), v0);
1691 return arith::IndexCastOp::create(rewriter, loc, inferType(e, v0), v0);
1693 return arith::TruncIOp::create(rewriter, loc, inferType(e, v0), v0);
1694 case TensorExp::Kind::kCIm: {
1695 auto type = cast<ComplexType>(v0.getType());
1696 auto eltType = cast<FloatType>(type.getElementType());
1697 return complex::ImOp::create(rewriter, loc, eltType, v0);
1698 }
1699 case TensorExp::Kind::kCRe: {
1700 auto type = cast<ComplexType>(v0.getType());
1701 auto eltType = cast<FloatType>(type.getElementType());
1702 return complex::ReOp::create(rewriter, loc, eltType, v0);
1703 }
1705 return arith::BitcastOp::create(rewriter, loc, inferType(e, v0), v0);
1706 // Binary operations.
1708 return arith::MulFOp::create(rewriter, loc, v0, v1);
1710 return complex::MulOp::create(rewriter, loc, v0, v1);
1712 return arith::MulIOp::create(rewriter, loc, v0, v1);
1714 return arith::DivFOp::create(rewriter, loc, v0, v1);
1716 return complex::DivOp::create(rewriter, loc, v0, v1);
1718 return arith::DivSIOp::create(rewriter, loc, v0, v1);
1720 return arith::DivUIOp::create(rewriter, loc, v0, v1);
1722 return arith::AddFOp::create(rewriter, loc, v0, v1);
1724 return complex::AddOp::create(rewriter, loc, v0, v1);
1726 return arith::AddIOp::create(rewriter, loc, v0, v1);
1728 return arith::SubFOp::create(rewriter, loc, v0, v1);
1730 return complex::SubOp::create(rewriter, loc, v0, v1);
1732 return arith::SubIOp::create(rewriter, loc, v0, v1);
1734 return arith::AndIOp::create(rewriter, loc, v0, v1);
1736 return arith::OrIOp::create(rewriter, loc, v0, v1);
1738 return arith::XOrIOp::create(rewriter, loc, v0, v1);
1740 return arith::ShRSIOp::create(rewriter, loc, v0, v1);
1742 return arith::ShRUIOp::create(rewriter, loc, v0, v1);
1744 return arith::ShLIOp::create(rewriter, loc, v0, v1);
1746 auto predicate = llvm::cast<arith::CmpIPredicateAttr>(expr.attr);
1747 return arith::CmpIOp::create(rewriter, loc, predicate, v0, v1);
1748 }
1750 auto predicate = llvm::cast<arith::CmpFPredicateAttr>(expr.attr);
1751 return arith::CmpFOp::create(rewriter, loc, predicate, v0, v1);
1752 }
1753 case TensorExp::Kind::kBinaryBranch: // semi-ring ops with custom logic.
1754 return insertYieldOp(rewriter, loc, *expr.op->getBlock()->getParent(),
1755 {v0});
1757 return buildUnaryPresent(rewriter, loc, expr.op, v0);
1759 return insertYieldOp(rewriter, loc,
1760 cast<sparse_tensor::SelectOp>(expr.op).getRegion(),
1761 {v0});
1763 return buildBinaryOverlap(rewriter, loc, expr.op, v0, v1);
1765 ReduceOp redOp = cast<ReduceOp>(expr.op);
1766 return insertYieldOp(rewriter, loc, redOp.getRegion(), {v0, v1});
1767 }
1769 Operation *actualOp = expr.op;
1770 IRMapping mapping;
1771 mapping.map(actualOp->getOperand(0), v0);
1772 if (actualOp->getNumOperands() == 2)
1773 mapping.map(actualOp->getOperand(1), v1);
1774 return rewriter.clone(*actualOp, mapping)->getResult(0);
1775 }
1776 }
1777 llvm_unreachable("unexpected expression kind in build");
1778}
1779
1780} // namespace sparse_tensor
1781} // namespace mlir
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
false
Parses a map_entries map type from a string format back into its numeric value.
#define add(a, b)
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:244
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:324
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:562
This class represents an operand of an operation.
Definition Value.h:257
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:686
Value getOperand(unsigned idx)
Definition Operation.h:350
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
unsigned getNumOperands()
Definition Operation.h:346
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
bool empty()
Definition Region.h:60
void cloneInto(Region *dest, IRMapping &mapper)
Clone the internal blocks from this region into dest.
Definition Region.cpp:70
iterator begin()
Definition Region.h:55
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Specialization of arith.constant op that returns a floating point value.
Definition Arith.h:92
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
Specialization of arith.constant op that returns an integer value.
Definition Arith.h:54
LatPointId conjLat(ExprId e, LatPointId p0, LatPointId p1, Operation *op=nullptr)
Computes a single conjunction of two lattice points by taking the "union" of LoopId (effectively cons...
Definition Merger.cpp:315
LatSetId disjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op=nullptr)
Disjunctive merge of two lattice sets: (s0 /\_op s1, s0, s1).
Definition Merger.cpp:338
const LatPoint & lat(LatPointId p) const
Definition Merger.h:545
bool isSingleCondition(TensorId t, ExprId e) const
Returns true if given tensor iterates only in the given tensor expression.
Definition Merger.cpp:577
bool hasSparseIdxReduction(const BitVector &bits) const
Returns true if bits contains a dependent index reduction condition on sparse levels.
Definition Merger.cpp:680
bool expContainsTensor(ExprId e, TensorId t) const
Returns true if the expression contains the tensor as an operand.
Definition Merger.cpp:522
ArrayRef< LatPointId > set(LatSetId s) const
Definition Merger.h:549
const TensorExp & exp(ExprId e) const
Convenience getters to immediately access the stored nodes.
Definition Merger.h:541
LatSetId mapBinWithSynZeroSet(ExprId e, LatSetId s, bool lhsZero)
Maps the binary operator to the same operation but with one of its operand set to zero,...
Definition Merger.cpp:414
bool isSparseLvlWithNonTrivialIdxExp(TensorLoopId b) const
Checks whether the TensorLoopId represents a sparse tensor level contains non-trivial index expressio...
Definition Merger.h:510
void dumpBits(const BitVector &bits) const
Definition Merger.cpp:920
LatSetId addSet()
Constructs a new (initially empty) set, and returns its identifier.
Definition Merger.cpp:309
BitVector simplifyCond(LatSetId s, LatPointId p)
Simplifies the conditions in a conjunction of a given lattice point within the given set using just t...
Definition Merger.cpp:461
bool hasNegateOnOut(ExprId e) const
Returns true if the expression contains a negation on output tensor.
Definition Merger.cpp:544
bool isLvlWithNonTrivialIdxExp(TensorLoopId b) const
Checks whether the TensorLoopId represents a tensor level contains non-trivial index expression.
Definition Merger.h:501
LatSetId disjSetWithZero(ExprId e, LatSetId s0, LatSetId s1)
Disjunctive merge of two lattice sets and also set one of the operand to zero: (s0 /\_op s1 (e0 op e1...
Definition Merger.cpp:356
void dumpSet(LatSetId s) const
Definition Merger.cpp:910
void dumpLat(LatPointId p) const
Definition Merger.cpp:899
LatSetId combiSet(ExprId e, LatSetId s0, LatSetId s1, Operation *orig, bool includeLeft, TensorExp::Kind ltrans, Operation *opleft, bool includeRight, TensorExp::Kind rtrans, Operation *opright)
Disjunctive merge of two lattice sets with custom handling of the overlap, left, and right regions.
Definition Merger.cpp:380
ExprId addTensorExp(TensorId t)
Constructs a new tensor expression, and returns its identifier.
Definition Merger.cpp:247
LatSetId buildLattices(ExprId e, LoopId i)
Builds the iteration lattices in a bottom-up traversal given the remaining tensor (sub)expression and...
Definition Merger.cpp:940
LatSetId conjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op=nullptr)
Conjunctive merge of two lattice sets: (s0 /\_op s1).
Definition Merger.cpp:329
ExprId addExp(TensorExp::Kind k, ExprId e0, ExprId e1=detail::kInvalidId, Operation *op=nullptr, Attribute attr=nullptr)
Constructs a new unary or binary expression, and returns its identifier.
Definition Merger.cpp:277
ExprId addSynZeroExp()
Constructs a new synthetic zero expression.
Definition Merger.cpp:270
constexpr LoopId makeLoopId(unsigned i) const
Safely converts the argument to a loop identifier.
Definition Merger.h:249
std::optional< ExprId > buildTensorExpFromLinalg(linalg::GenericOp op)
Builds a tensor expression from the given Linalg operation.
Definition Merger.cpp:1193
LatSetId mapSet(TensorExp::Kind kind, LatSetId s, Value v=Value(), Operation *op=nullptr, Attribute attr=nullptr)
Maps the unary operator over the lattice set of the operand, i.e.
Definition Merger.cpp:401
LatSetId optimizeSet(LatSetId s)
Optimizes the iteration lattice points in the given set.
Definition Merger.cpp:431
constexpr TensorId tensor(TensorLoopId b) const
Gets the tensor-identifier of the TensorLoopId.
Definition Merger.h:346
void dumpExp(ExprId e) const
Print methods (for debugging).
Definition Merger.cpp:799
Merger(unsigned numInputOutputTensors, unsigned numLoops, unsigned maxLvlRank)
Constructs a merger for the given number of tensors and loops.
Definition Merger.cpp:224
bool hasAnySparse(const BitVector &bits) const
Returns true if any TensorLoopId in the bitvector corresponds to sparse level-type.
Definition Merger.cpp:671
LatPointId addLat(TensorId t, LoopId i, ExprId e)
Constructs a new iteration lattice point, and returns its identifier.
Definition Merger.cpp:293
ExprId addLoopVarExp(LoopId i)
Constructs a new loop-variable expression, and returns its identifier.
Definition Merger.cpp:255
bool latGT(LatPointId p0, LatPointId p1) const
Returns true if p0 > p1.
Definition Merger.cpp:503
constexpr LoopId loop(TensorLoopId b) const
Gets the loop-identifier of the TensorLoopId.
Definition Merger.h:348
bool onlyDenseDiff(LatPointId p0, LatPointId p1) const
Returns true if p0 and p1 only differ in dense.
Definition Merger.cpp:516
ExprId addInvariantExp(Value v)
Constructs a new invariant expression, and returns its identifier.
Definition Merger.cpp:263
constexpr TensorId makeTensorId(unsigned t) const
Safely converts the argument to a tensor identifier.
Definition Merger.h:243
LevelType getLvlType(TensorId t, LoopId i) const
Gets the level-type of the tth tensor on ith loop.
Definition Merger.h:399
Value buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0, Value v1) const
Rebuilds SSA format from a tensor expression.
Definition Merger.cpp:1618
constexpr TensorLoopId makeTensorLoopId(unsigned t, unsigned i) const
Safely converts the arguments to a pair of (tensor,loop) identifiers.
Definition Merger.h:255
bool expIsTensor(ExprId e, TensorId t) const
Returns true if the expression is (kTensor t).
Definition Merger.h:370
static constexpr unsigned kInvalidId
A constant serving as the canonically invalid identifier, regardless of the identifier type.
Definition Merger.h:30
LevelFormat
This enum defines all supported storage format without the level properties.
Definition Enums.h:154
std::pair< LoopId, unsigned > LoopCoeffPair
A pair of loop id and its coefficients.
Definition Merger.h:64
static bool isCertainZero(Value val)
Only returns true if we are certain this is a zero.
Definition Merger.cpp:1201
static bool isAdmissibleBranchExp(Operation *op, Block *block, Value v)
Ensures that the sparsifier can generate code for expression.
Definition Merger.cpp:1244
unsigned LatPointId
LatPoint identifiers.
Definition Merger.h:52
static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region &region, ValueRange vals)
Definition Merger.cpp:1557
std::string toMLIRString(LevelType lt)
Definition Enums.h:447
unsigned ExprId
TensorExp identifiers.
Definition Merger.h:48
static const char * kindToOpSymbol(TensorExp::Kind kind)
Definition Merger.cpp:693
static Value buildUnaryPresent(RewriterBase &rewriter, Location loc, Operation *op, Value v0)
Definition Merger.cpp:1574
std::pair< Level, LevelType > LvlLTPair
A pair of level and its corresponding LevelType of a tensor.
Definition Merger.h:60
static Value buildBinaryOverlap(RewriterBase &rewriter, Location loc, Operation *op, Value v0, Value v1)
Definition Merger.cpp:1588
static bool isAdmissibleBranch(Operation *op, Region &region)
Ensures that the sparsifier can generate code for branch.
Definition Merger.cpp:1264
unsigned TensorId
Tensor identifiers, chosen to be the BlockArgument::getArgNumber of the value passed to Merger::build...
Definition Merger.h:35
uint64_t Level
The type of level identifiers and level-ranks.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
static bool isGreater(TensorExp::Kind kind, Attribute attr)
Definition Merger.cpp:1274
unsigned TensorLoopId
A compressed representation of std::pair<TensorId, LoopId>.
Definition Merger.h:44
unsigned LoopId
Loop identifiers.
Definition Merger.h:38
unsigned LatSetId
LatSet identifiers.
Definition Merger.h:57
static ExpArity getExpArity(TensorExp::Kind k)
Definition Merger.cpp:28
static Value buildRelu(RewriterBase &rewriter, Location loc, Value v0, Attribute attr)
Definition Merger.cpp:1602
Include the generated interface declarations.
BitVector bits
Conjunction of all TensorLoopIds involved in the tensor expression.
Definition Merger.h:210
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition Enums.h:238
constexpr bool hasDenseSemantic() const
Check if the LevelType is considered to be dense-like.
Definition Enums.h:343
LoopId loop
kLoopVar expressions simply have a loop identifier.
Definition Merger.h:96
Value val
Direct link to IR for an invariant or the destination value (to infer destination type) of a cast ope...
Definition Merger.h:105
Kind
Tensor expression kind.
Definition Merger.h:129
Children children
All other expressions hold the ExprIds of their children.
Definition Merger.h:99
Attribute attr
An optional attribute that is required to determine the semantics of the operations.
Definition Merger.h:118
TensorId tensor
kTensor expressions simply have a tensor identifier.
Definition Merger.h:93
Kind kind
Tensor expression kind.
Definition Merger.h:89
Operation * op
Code blocks used by semirings.
Definition Merger.h:114
TensorExp(Kind k, unsigned x, ExprId y, Value v, Operation *op, Attribute a)
The x parameter has different types depending on the value of the k parameter.
Definition Merger.cpp:106
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.