16 #include "llvm/Support/Debug.h"
20 namespace sparse_tensor {
99 llvm_unreachable(
"unexpected kind");
108 : kind(k), val(v), op(o), attr(a) {
221 llvm_unreachable(
"unexpected kind");
226 : outTensor(numInputOutputTensors - 1),
227 syntheticTensor(numInputOutputTensors),
228 numTensors(numInputOutputTensors + 1), numLoops(numLoops),
232 loopToLvl(numTensors,
233 std::vector<std::optional<
Level>>(numLoops, std::nullopt)),
234 lvlToLoop(numTensors,
235 std::vector<std::optional<
LoopId>>(maxLvlRank, std::nullopt)),
236 loopToUnresolvedLvls(numLoops, std::vector<std::optional<
LvlLTPair>>(
237 numTensors, std::nullopt)),
238 levelToDependentLoop(numTensors,
241 loopBounds(numLoops, std::make_pair(numTensors, numLoops)) {}
248 assert(isValidTensorId(t));
249 const ExprId eNew(tensorExps.size());
251 Value(),
nullptr,
nullptr);
256 assert(isValidLoopId(i));
257 const ExprId eNew(tensorExps.size());
259 Value(),
nullptr,
nullptr);
264 const ExprId eNew(tensorExps.size());
271 const ExprId eNew(tensorExps.size());
280 const ExprId eNew(tensorExps.size());
281 tensorExps.emplace_back(k, e0, e1,
Value(), op, attr);
288 const ExprId eNew(tensorExps.size());
295 const unsigned size = numLoops * numTensors;
297 latPoints.emplace_back(size, e);
298 latPoints[pNew].bits.set(b);
303 assert(bits.size() == numLoops * numTensors);
305 latPoints.emplace_back(bits, e);
310 const LatSetId sNew(latSets.size());
311 latSets.emplace_back();
320 const auto &point0 =
lat(p0);
321 const auto &point1 =
lat(p1);
322 BitVector bits(point0.bits);
324 const ExprId ne =
addExp(kind, point0.exp, point1.exp, op, attr);
325 latPoints.emplace_back(bits, ne);
331 auto &setNew = latSets[sNew];
334 setNew.push_back(
conjLat(e, p0, p1, op));
342 latSets[sNew].append(latSets[s0]);
352 latSets[sNew].append(latSets[s1]);
366 assert(
exp(e0).kind !=
exp(e1).kind);
375 latSets[sNew].append(latSets[lhsSet]);
376 latSets[sNew].append(latSets[rhsSet]);
390 latSets[sNew].append(latSets[s0]);
396 latSets[sNew].append(latSets[s1]);
406 auto &setNew = latSets[sNew];
408 const auto &point = latPoints[p];
409 setNew.push_back(
addLat(point.bits,
addExp(kind, point.exp, v, op, a)));
420 auto &setNew = latSets[sNew];
423 const auto &point = latPoints[p];
424 ExprId newExp = lhsZero ?
addExp(kind, zeroExp, point.exp,
nullptr, a)
425 :
addExp(kind, point.exp, zeroExp,
nullptr, a);
426 setNew.push_back(
addLat(point.bits, newExp));
433 auto &setNew = latSets[sNew];
434 const auto &set0 =
set(s0);
435 assert(!set0.empty());
445 assert(!
latGT(p1, p2));
451 assert(!add ||
latGT(p0, p1));
454 setNew.push_back(p1);
464 bool isSingleton =
true;
466 if (p0 != p1 &&
latGT(p0, p1)) {
472 BitVector simple(latPoints[p0].bits);
479 for (
unsigned b = 0; b < be; b++) {
488 for (
unsigned b = be - 1 - offset, i = 0; i < be;
489 b = b == 0 ? be - 1 : b - 1, i++) {
493 if (!lt.hasSparseSemantic()) {
504 const BitVector &bitsi =
lat(i).
bits;
505 const BitVector &bitsj =
lat(
j).
bits;
506 assert(bitsi.size() == bitsj.size());
507 if (bitsi.count() > bitsj.count()) {
508 for (
TensorLoopId b = 0, be = bitsj.size(); b < be; b++)
509 if (bitsj[b] && !bitsi[b])
517 BitVector tmp(latPoints[
j].bits);
518 tmp ^= latPoints[i].bits;
523 const auto &expr =
exp(e);
526 return expr.tensor == t;
532 const ExprId e0 = expr.children.e0;
536 const ExprId e0 = expr.children.e0;
537 const ExprId e1 = expr.children.e1;
541 llvm_unreachable(
"unexpected arity");
545 const auto &expr =
exp(e);
574 llvm_unreachable(
"unexpected kind");
578 assert(isValidTensorId(t));
579 const auto &expr =
exp(e);
583 return expr.tensor == t;
631 assert(!maybeZero(expr.children.e1));
636 assert(isInvariant(expr.children.e1));
645 isInvariant(expr.children.e1);
647 return isInvariant(expr.children.e0);
668 llvm_unreachable(
"unexpected kind");
674 if (lt.hasSparseSemantic())
751 return "binary_branch";
796 llvm_unreachable(
"unexpected kind for symbol");
800 const auto &expr =
exp(e);
804 if (expr.tensor == syntheticTensor)
805 llvm::dbgs() <<
"synthetic_";
806 else if (expr.tensor == outTensor)
807 llvm::dbgs() <<
"output_";
808 llvm::dbgs() <<
"tensor_" << expr.tensor;
811 llvm::dbgs() <<
"invariant";
817 llvm::dbgs() <<
"loopvar_" << expr.loop;
887 llvm::dbgs() <<
"{" << expr.attr <<
"}";
900 const auto &point =
lat(p);
901 llvm::dbgs() <<
"lat(";
903 llvm::dbgs() <<
" :";
905 llvm::dbgs() <<
" : ";
907 llvm::dbgs() <<
" )\n";
911 const auto &ss =
set(s);
912 llvm::dbgs() <<
"{ #" << ss.size() <<
"\n";
917 llvm::dbgs() <<
"}\n";
921 for (
TensorLoopId b = 0, be = bits.size(); b < be; b++) {
925 const auto lt = lvlTypes[t][i];
927 llvm::dbgs() <<
" DEP_" << t <<
"_" << i;
929 llvm::dbgs() <<
" i_" << t <<
"_" << i <<
"_" <<
toMLIRString(lt);
945 const auto &expr =
exp(e);
962 if (hasSparseOut && t == outTensor)
965 latSets[s].push_back(
addLat(t, i, e));
1008 const ExprId e0 = expr.children.e0;
1009 const Value v = expr.val;
1018 const ExprId e0 = expr.children.e0;
1029 const ExprId e0 = expr.children.e0;
1030 UnaryOp unop = cast<UnaryOp>(expr.op);
1032 Region &absentRegion = unop.getAbsentRegion();
1033 if (absentRegion.
empty()) {
1040 YieldOp absentYield = cast<YieldOp>(absentBlock.
getTerminator());
1041 const Value absentVal = absentYield.getSingleResult();
1060 const ExprId e0 = expr.children.e0;
1061 const ExprId e1 = expr.children.e1;
1082 const ExprId e0 = expr.children.e0;
1083 const ExprId e1 = expr.children.e1;
1084 assert(!maybeZero(e1));
1103 const ExprId e0 = expr.children.e0;
1104 const ExprId e1 = expr.children.e1;
1117 const ExprId e0 = expr.children.e0;
1118 const ExprId e1 = expr.children.e1;
1128 const ExprId e0 = expr.children.e0;
1129 const ExprId e1 = expr.children.e1;
1130 assert(isInvariant(e1));
1141 const ExprId e0 = expr.children.e0;
1142 const ExprId e1 = expr.children.e1;
1143 BinaryOp binop = cast<BinaryOp>(expr.op);
1146 Region &leftRegion = binop.getLeftRegion();
1147 Region &rightRegion = binop.getRightRegion();
1150 if (!leftRegion.
empty()) {
1156 if (!rightRegion.
empty()) {
1160 bool includeLeft = binop.getLeftIdentity() || !leftRegion.
empty();
1161 bool includeRight = binop.getRightIdentity() || !rightRegion.
empty();
1162 return combiSet(e, child0, child1, binop, includeLeft,
1169 const ExprId e0 = expr.children.e0;
1170 const ExprId e1 = expr.children.e1;
1179 const ExprId e0 = expr.children.e0;
1184 const ExprId e0 = expr.children.e0;
1185 const ExprId e1 = expr.children.e1;
1190 llvm_unreachable(
"unexpected expression kind");
1195 Operation *yield = op.getRegion().front().getTerminator();
1196 assert(isa<linalg::YieldOp>(yield));
1197 return buildTensorExp(op, yield->
getOperand(0)).first;
1203 ArrayAttr arrayAttr = c.getValue();
1204 return cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
1205 cast<FloatAttr>(arrayAttr[1]).getValue().isZero();
1208 return c.value() == 0;
1210 return c.value().isZero();
1215 bool Merger::maybeZero(
ExprId e)
const {
1216 const auto &expr =
exp(e);
1220 if (
auto c = expr.val.getDefiningOp<complex::ConstantOp>()) {
1221 ArrayAttr arrayAttr = c.getValue();
1222 return cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
1223 cast<FloatAttr>(arrayAttr[1]).getValue().isZero();
1225 if (
auto c = expr.val.getDefiningOp<arith::ConstantIntOp>())
1226 return c.value() == 0;
1227 if (
auto c = expr.val.getDefiningOp<arith::ConstantFloatOp>())
1228 return c.value().isZero();
1233 Type Merger::inferType(
ExprId e, Value src)
const {
1238 if (
auto vtp = dyn_cast<VectorType>(src.getType()))
1239 return VectorType::get(vtp.getNumElements(), dtp, vtp.getScalableDims());
1246 if (isa<BlockArgument>(v))
1250 if (isa<linalg::IndexOp>(def))
1254 return def->
getBlock() != op->getBlock();
1269 assert(isa<YieldOp>(yield));
1276 auto pred = llvm::cast<arith::CmpIPredicateAttr>(attr).getValue();
1277 return pred == arith::CmpIPredicate::ugt ||
1278 pred == arith::CmpIPredicate::sgt;
1281 auto pred = llvm::cast<arith::CmpFPredicateAttr>(attr).getValue();
1282 return pred == arith::CmpFPredicate::UGT ||
1283 pred == arith::CmpFPredicate::OGT;
1288 std::pair<std::optional<ExprId>,
bool>
1289 Merger::buildTensorExp(linalg::GenericOp op,
Value v) {
1291 if (
auto arg = dyn_cast<BlockArgument>(v)) {
1296 if (arg.getOwner()->getParentOp() == op) {
1299 if (!op.isScalar(&t))
1310 if (def->getBlock() != &op.getRegion().front())
1313 if (def->getNumOperands() == 0) {
1314 if (
auto indexOp = dyn_cast<linalg::IndexOp>(def))
1319 if (def->getNumOperands() == 1) {
1320 const auto [x, hasSpDep] = buildTensorExp(op, def->getOperand(0));
1321 if (x.has_value()) {
1323 if (isa<math::AbsFOp>(def))
1325 if (isa<complex::AbsOp>(def))
1327 if (isa<math::AbsIOp>(def))
1329 if (isa<math::CeilOp>(def))
1331 if (isa<math::FloorOp>(def))
1333 if (isa<math::SqrtOp>(def))
1335 if (isa<complex::SqrtOp>(def))
1337 if (isa<math::ExpM1Op>(def))
1339 if (isa<complex::Expm1Op>(def))
1341 if (isa<math::Log1pOp>(def))
1343 if (isa<complex::Log1pOp>(def))
1345 if (isa<math::SinOp>(def))
1347 if (isa<complex::SinOp>(def))
1349 if (isa<math::TanhOp>(def))
1351 if (isa<complex::TanhOp>(def))
1353 if (isa<arith::NegFOp>(def))
1355 if (isa<complex::NegOp>(def))
1357 if (isa<arith::TruncFOp>(def))
1359 if (isa<arith::ExtFOp>(def))
1361 if (isa<arith::FPToSIOp>(def))
1363 if (isa<arith::FPToUIOp>(def))
1365 if (isa<arith::SIToFPOp>(def))
1367 if (isa<arith::UIToFPOp>(def))
1369 if (isa<arith::ExtSIOp>(def))
1371 if (isa<arith::ExtUIOp>(def))
1373 if (isa<arith::IndexCastOp>(def))
1375 if (isa<arith::TruncIOp>(def))
1377 if (isa<complex::ImOp>(def))
1379 if (isa<complex::ReOp>(def))
1381 if (isa<arith::BitcastOp>(def))
1383 if (
auto unop = dyn_cast<sparse_tensor::UnaryOp>(def)) {
1388 if (
auto selop = dyn_cast<sparse_tensor::SelectOp>(def)) {
1398 if (def->getNumOperands() == 2) {
1399 const auto [x, xSpVals] = buildTensorExp(op, def->getOperand(0));
1400 const auto [y, ySpVals] = buildTensorExp(op, def->getOperand(1));
1404 bool conjSpVals = xSpVals || ySpVals;
1405 bool disjSpVals = xSpVals && ySpVals;
1406 if (x.has_value() && y.has_value()) {
1409 if (isa<arith::MulFOp>(def))
1411 if (isa<complex::MulOp>(def))
1413 if (isa<arith::MulIOp>(def))
1415 if (isa<arith::DivFOp>(def) && !maybeZero(e1))
1417 if (isa<complex::DivOp>(def) && !maybeZero(e1))
1419 if (isa<arith::DivSIOp>(def) && !maybeZero(e1))
1421 if (isa<arith::DivUIOp>(def) && !maybeZero(e1))
1423 if (isa<arith::AddFOp>(def))
1425 if (isa<complex::AddOp>(def))
1427 if (isa<arith::AddIOp>(def))
1429 if (isa<arith::SubFOp>(def))
1431 if (isa<complex::SubOp>(def))
1433 if (isa<arith::SubIOp>(def))
1435 if (isa<arith::AndIOp>(def))
1437 if (isa<arith::OrIOp>(def))
1439 if (isa<arith::XOrIOp>(def))
1441 if (isa<arith::ShRSIOp>(def) && isInvariant(e1))
1443 if (isa<arith::ShRUIOp>(def) && isInvariant(e1))
1445 if (isa<arith::ShLIOp>(def) && isInvariant(e1))
1447 if (
auto ci = dyn_cast<arith::CmpIOp>(def)) {
1448 if (ci.getPredicate() == arith::CmpIPredicate::eq &&
1449 ci.getPredicate() == arith::CmpIPredicate::sle &&
1450 ci.getPredicate() == arith::CmpIPredicate::sge &&
1451 ci.getPredicate() == arith::CmpIPredicate::ule &&
1452 ci.getPredicate() == arith::CmpIPredicate::uge) {
1455 return {std::nullopt,
false};
1459 ci.getPredicateAttr());
1460 return {e, conjSpVals};
1462 if (
auto cf = dyn_cast<arith::CmpFOp>(def)) {
1463 if (cf.getPredicate() == arith::CmpFPredicate::OEQ &&
1464 cf.getPredicate() == arith::CmpFPredicate::OGE &&
1465 cf.getPredicate() == arith::CmpFPredicate::OLE &&
1466 cf.getPredicate() == arith::CmpFPredicate::ONE &&
1467 cf.getPredicate() == arith::CmpFPredicate::UEQ &&
1468 cf.getPredicate() == arith::CmpFPredicate::UGE &&
1469 cf.getPredicate() == arith::CmpFPredicate::ULE &&
1470 cf.getPredicate() == arith::CmpFPredicate::ORD &&
1471 cf.getPredicate() == arith::CmpFPredicate::UNO) {
1474 return {std::nullopt,
false};
1477 cf.getPredicateAttr());
1478 return {e, conjSpVals};
1480 if (
auto binop = dyn_cast<sparse_tensor::BinaryOp>(def)) {
1482 (binop.getLeftIdentity() ||
1484 (binop.getRightIdentity() ||
1492 if (def->getNumOperands() == 3) {
1493 const auto [x, xDepSp] = buildTensorExp(op, def->getOperand(0));
1494 const auto [y, yDepSp] = buildTensorExp(op, def->getOperand(1));
1495 const auto [z, zDepSp] = buildTensorExp(op, def->getOperand(2));
1496 bool hasSpDep = xDepSp || yDepSp || zDepSp;
1497 if (x.has_value() && y.has_value() && z.has_value()) {
1500 if (
auto redop = dyn_cast<sparse_tensor::ReduceOp>(def)) {
1504 if (
auto selop = dyn_cast<arith::SelectOp>(def)) {
1508 const auto &cnd =
exp(*x);
1513 const auto &a =
exp(cnd.children.e0);
1514 const auto &b =
exp(cnd.children.e1);
1531 if (def->getNumResults() != 1)
1532 return {std::nullopt,
false};
1533 SmallVector<std::pair<std::optional<ExprId>,
bool>, 2> subExp;
1535 for (Value operand : def->getOperands())
1536 subExp.push_back(buildTensorExp(op, operand));
1538 if (llvm::all_of(subExp,
1539 [](
auto e) {
return e.first.has_value() && !e.second; })) {
1541 if (subExp.size() == 2) {
1543 *subExp[1].first, def);
1546 if (subExp.size() == 1) {
1554 return {std::nullopt,
false};
1564 YieldOp clonedYield = cast<YieldOp>(clonedBlock.
getTerminator());
1568 Value val = clonedYield.getSingleResult();
1569 rewriter.
eraseOp(clonedYield);
1570 rewriter.
eraseOp(placeholder);
1579 UnaryOp unop = cast<UnaryOp>(op);
1580 Region &presentRegion = unop.getPresentRegion();
1581 if (presentRegion.
empty())
1593 BinaryOp binop = cast<BinaryOp>(op);
1594 Region &overlapRegion = binop.getOverlapRegion();
1595 if (overlapRegion.
empty())
1599 return insertYieldOp(rewriter, loc, overlapRegion, {v0, v1});
1608 if (isa<FloatType>(tp)) {
1609 auto pred = llvm::cast<arith::CmpFPredicateAttr>(attr);
1610 cmp = rewriter.
create<arith::CmpFOp>(loc, pred, v0, zero);
1612 auto pred = llvm::cast<arith::CmpIPredicateAttr>(attr);
1613 cmp = rewriter.
create<arith::CmpIOp>(loc, pred, v0, zero);
1615 return rewriter.
create<arith::SelectOp>(loc, cmp, v0, zero);
1620 const auto &expr =
exp(e);
1621 switch (expr.kind) {
1627 llvm_unreachable(
"unexpected non-op");
1630 return rewriter.
create<math::AbsFOp>(loc, v0);
1632 auto type = cast<ComplexType>(v0.
getType());
1633 auto eltType = cast<FloatType>(type.getElementType());
1634 return rewriter.
create<complex::AbsOp>(loc, eltType, v0);
1637 return rewriter.
create<math::AbsIOp>(loc, v0);
1639 return rewriter.
create<math::CeilOp>(loc, v0);
1641 return rewriter.
create<math::FloorOp>(loc, v0);
1643 return rewriter.
create<math::SqrtOp>(loc, v0);
1645 return rewriter.
create<complex::SqrtOp>(loc, v0);
1647 return rewriter.
create<math::ExpM1Op>(loc, v0);
1649 return rewriter.
create<complex::Expm1Op>(loc, v0);
1651 return rewriter.
create<math::Log1pOp>(loc, v0);
1653 return rewriter.
create<complex::Log1pOp>(loc, v0);
1655 return buildRelu(rewriter, loc, v0, expr.attr);
1657 return rewriter.
create<math::SinOp>(loc, v0);
1659 return rewriter.
create<complex::SinOp>(loc, v0);
1661 return rewriter.
create<math::TanhOp>(loc, v0);
1663 return rewriter.
create<complex::TanhOp>(loc, v0);
1665 return rewriter.
create<arith::NegFOp>(loc, v0);
1667 return rewriter.
create<complex::NegOp>(loc, v0);
1669 return rewriter.
create<arith::SubIOp>(
1675 return rewriter.
create<arith::TruncFOp>(loc, inferType(e, v0), v0);
1677 return rewriter.
create<arith::ExtFOp>(loc, inferType(e, v0), v0);
1679 return rewriter.
create<arith::FPToSIOp>(loc, inferType(e, v0), v0);
1681 return rewriter.
create<arith::FPToUIOp>(loc, inferType(e, v0), v0);
1683 return rewriter.
create<arith::SIToFPOp>(loc, inferType(e, v0), v0);
1685 return rewriter.
create<arith::UIToFPOp>(loc, inferType(e, v0), v0);
1687 return rewriter.
create<arith::ExtSIOp>(loc, inferType(e, v0), v0);
1689 return rewriter.
create<arith::ExtUIOp>(loc, inferType(e, v0), v0);
1691 return rewriter.
create<arith::IndexCastOp>(loc, inferType(e, v0), v0);
1693 return rewriter.
create<arith::TruncIOp>(loc, inferType(e, v0), v0);
1695 auto type = cast<ComplexType>(v0.
getType());
1696 auto eltType = cast<FloatType>(type.getElementType());
1697 return rewriter.
create<complex::ImOp>(loc, eltType, v0);
1700 auto type = cast<ComplexType>(v0.
getType());
1701 auto eltType = cast<FloatType>(type.getElementType());
1702 return rewriter.
create<complex::ReOp>(loc, eltType, v0);
1705 return rewriter.
create<arith::BitcastOp>(loc, inferType(e, v0), v0);
1708 return rewriter.
create<arith::MulFOp>(loc, v0, v1);
1710 return rewriter.
create<complex::MulOp>(loc, v0, v1);
1712 return rewriter.
create<arith::MulIOp>(loc, v0, v1);
1714 return rewriter.
create<arith::DivFOp>(loc, v0, v1);
1716 return rewriter.
create<complex::DivOp>(loc, v0, v1);
1718 return rewriter.
create<arith::DivSIOp>(loc, v0, v1);
1720 return rewriter.
create<arith::DivUIOp>(loc, v0, v1);
1722 return rewriter.
create<arith::AddFOp>(loc, v0, v1);
1724 return rewriter.
create<complex::AddOp>(loc, v0, v1);
1726 return rewriter.
create<arith::AddIOp>(loc, v0, v1);
1728 return rewriter.
create<arith::SubFOp>(loc, v0, v1);
1730 return rewriter.
create<complex::SubOp>(loc, v0, v1);
1732 return rewriter.
create<arith::SubIOp>(loc, v0, v1);
1734 return rewriter.
create<arith::AndIOp>(loc, v0, v1);
1736 return rewriter.
create<arith::OrIOp>(loc, v0, v1);
1738 return rewriter.
create<arith::XOrIOp>(loc, v0, v1);
1740 return rewriter.
create<arith::ShRSIOp>(loc, v0, v1);
1742 return rewriter.
create<arith::ShRUIOp>(loc, v0, v1);
1744 return rewriter.
create<arith::ShLIOp>(loc, v0, v1);
1746 auto predicate = llvm::cast<arith::CmpIPredicateAttr>(expr.attr);
1747 return rewriter.
create<arith::CmpIOp>(loc, predicate, v0, v1);
1750 auto predicate = llvm::cast<arith::CmpFPredicateAttr>(expr.attr);
1751 return rewriter.
create<arith::CmpFOp>(loc, predicate, v0, v1);
1760 cast<sparse_tensor::SelectOp>(expr.op).getRegion(),
1765 ReduceOp redOp = cast<ReduceOp>(expr.op);
1766 return insertYieldOp(rewriter, loc, redOp.getRegion(), {v0, v1});
1777 llvm_unreachable(
"unexpected expression kind in build");
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
TypedAttr getZeroAttr(Type type)
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Block * getBlock() const
Returns the current block of the builder.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumOperands()
Block * getBlock()
Returns the operation block that contains this operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
void cloneInto(Region *dest, IRMapping &mapper)
Clone the internal blocks from this region into dest.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Specialization of arith.constant op that returns a floating point value.
Specialization of arith.constant op that returns an integer of index type.
Specialization of arith.constant op that returns an integer value.
LatPointId conjLat(ExprId e, LatPointId p0, LatPointId p1, Operation *op=nullptr)
Computes a single conjunction of two lattice points by taking the "union" of LoopId (effectively cons...
LatSetId disjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op=nullptr)
Disjunctive merge of two lattice sets: (s0 /\_op s1, s0, s1).
bool isSingleCondition(TensorId t, ExprId e) const
Returns true if given tensor iterates only in the given tensor expression.
bool hasSparseIdxReduction(const BitVector &bits) const
Returns true if bits contains a dependent index reduction condition on sparse levels.
bool expContainsTensor(ExprId e, TensorId t) const
Returns true if the expression contains the tensor as an operand.
LatSetId mapBinWithSynZeroSet(ExprId e, LatSetId s, bool lhsZero)
Maps the binary operator to the same operation but with one of its operand set to zero,...
bool isSparseLvlWithNonTrivialIdxExp(TensorLoopId b) const
Checks whether the TensorLoopId represents a sparse tensor level contains non-trivial index expressio...
void dumpBits(const BitVector &bits) const
LatSetId addSet()
Constructs a new (initially empty) set, and returns its identifier.
BitVector simplifyCond(LatSetId s, LatPointId p)
Simplifies the conditions in a conjunction of a given lattice point within the given set using just t...
bool hasNegateOnOut(ExprId e) const
Returns true if the expression contains a negation on output tensor.
bool isLvlWithNonTrivialIdxExp(TensorLoopId b) const
Checks whether the TensorLoopId represents a tensor level contains non-trivial index expression.
LatSetId disjSetWithZero(ExprId e, LatSetId s0, LatSetId s1)
Disjunctive merge of two lattice sets and also set one of the operand to zero: (s0 /\_op s1 (e0 op e1...
void dumpSet(LatSetId s) const
void dumpLat(LatPointId p) const
LatSetId combiSet(ExprId e, LatSetId s0, LatSetId s1, Operation *orig, bool includeLeft, TensorExp::Kind ltrans, Operation *opleft, bool includeRight, TensorExp::Kind rtrans, Operation *opright)
Disjunctive merge of two lattice sets with custom handling of the overlap, left, and right regions.
ExprId addTensorExp(TensorId t)
Constructs a new tensor expression, and returns its identifier.
LatSetId buildLattices(ExprId e, LoopId i)
Builds the iteration lattices in a bottom-up traversal given the remaining tensor (sub)expression and...
LatSetId conjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op=nullptr)
Conjunctive merge of two lattice sets: (s0 /\_op s1).
ExprId addExp(TensorExp::Kind k, ExprId e0, ExprId e1=detail::kInvalidId, Operation *op=nullptr, Attribute attr=nullptr)
Constructs a new unary or binary expression, and returns its identifier.
ExprId addSynZeroExp()
Constructs a new synthetic zero expression.
constexpr LoopId makeLoopId(unsigned i) const
Safely converts the argument to a loop identifier.
std::optional< ExprId > buildTensorExpFromLinalg(linalg::GenericOp op)
Builds a tensor expression from the given Linalg operation.
LatSetId mapSet(TensorExp::Kind kind, LatSetId s, Value v=Value(), Operation *op=nullptr, Attribute attr=nullptr)
Maps the unary operator over the lattice set of the operand, i.e.
ArrayRef< LatPointId > set(LatSetId s) const
LatSetId optimizeSet(LatSetId s)
Optimizes the iteration lattice points in the given set.
constexpr TensorId tensor(TensorLoopId b) const
Gets the tensor-identifier of the TensorLoopId.
void dumpExp(ExprId e) const
Print methods (for debugging).
Merger(unsigned numInputOutputTensors, unsigned numLoops, unsigned maxLvlRank)
Constructs a merger for the given number of tensors and loops.
bool hasAnySparse(const BitVector &bits) const
Returns true if any TensorLoopId in the bitvector corresponds to sparse level-type.
LatPointId addLat(TensorId t, LoopId i, ExprId e)
Constructs a new iteration lattice point, and returns its identifier.
ExprId addLoopVarExp(LoopId i)
Constructs a new loop-variable expression, and returns its identifier.
bool latGT(LatPointId p0, LatPointId p1) const
Returns true if p0 > p1.
const TensorExp & exp(ExprId e) const
Convenience getters to immediately access the stored nodes.
constexpr LoopId loop(TensorLoopId b) const
Gets the loop-identifier of the TensorLoopId.
const LatPoint & lat(LatPointId p) const
bool onlyDenseDiff(LatPointId p0, LatPointId p1) const
Returns true if p0 and p1 only differ in dense.
ExprId addInvariantExp(Value v)
Constructs a new invariant expression, and returns its identifier.
constexpr TensorId makeTensorId(unsigned t) const
Safely converts the argument to a tensor identifier.
LevelType getLvlType(TensorId t, LoopId i) const
Gets the level-type of the tth tensor on ith loop.
Value buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0, Value v1) const
Rebuilds SSA format from a tensor expression.
constexpr TensorLoopId makeTensorLoopId(unsigned t, unsigned i) const
Safely converts the arguments to a pair of (tensor,loop) identifiers.
bool expIsTensor(ExprId e, TensorId t) const
Returns true if the expression is (kTensor t).
@ Type
An inlay hint that for a type annotation.
static constexpr unsigned kInvalidId
A constant serving as the canonically invalid identifier, regardless of the identifier type.
LevelFormat
This enum defines all supported storage format without the level properties.
static bool isCertainZero(Value val)
Only returns true if we are certain this is a zero.
static bool isAdmissibleBranchExp(Operation *op, Block *block, Value v)
Ensures that the sparsifier can generate code for expression.
unsigned LatSetId
LatSet identifiers.
static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region ®ion, ValueRange vals)
std::string toMLIRString(LevelType lt)
std::pair< Level, LevelType > LvlLTPair
A pair of level and its corresponding LevelType of a tensor.
unsigned TensorLoopId
A compressed representation of std::pair<TensorId, LoopId>.
static Value buildUnaryPresent(RewriterBase &rewriter, Location loc, Operation *op, Value v0)
uint64_t Level
The type of level identifiers and level-ranks.
static Value buildBinaryOverlap(RewriterBase &rewriter, Location loc, Operation *op, Value v0, Value v1)
static bool isAdmissibleBranch(Operation *op, Region ®ion)
Ensures that the sparsifier can generate code for branch.
unsigned LoopId
Loop identifiers.
static const char * kindToOpSymbol(TensorExp::Kind kind)
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
static bool isGreater(TensorExp::Kind kind, Attribute attr)
unsigned ExprId
TensorExp identifiers.
static ExpArity getExpArity(TensorExp::Kind k)
static Value buildRelu(RewriterBase &rewriter, Location loc, Value v0, Attribute attr)
unsigned LatPointId
LatPoint identifiers.
std::pair< LoopId, unsigned > LoopCoeffPair
A pair of loop id and its coefficients.
unsigned TensorId
Tensor identifiers, chosen to be the BlockArgument::getArgNumber of the value passed to Merger::build...
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
BitVector bits
Conjunction of all TensorLoopIds involved in the tensor expression.
This enum defines all the sparse representations supportable by the SparseTensor dialect.
LoopId loop
kLoopVar expressions simply have a loop identifier.
Value val
Direct link to IR for an invariant or the destination value (to infer destination type) of a cast ope...
Kind
Tensor expression kind.
Children children
All other expressions hold the ExprIds of their children.
Attribute attr
An optional attribute that is required to determine the semantics of the operations.
TensorId tensor
kTensor expressions simply have a tensor identifier.
Kind kind
Tensor expression kind.
TensorExp(Kind k, unsigned x, ExprId y, Value v, Operation *op, Attribute a)
The x parameter has different types depending on the value of the k parameter.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.