16 #include "llvm/Support/Debug.h"
20 namespace sparse_tensor {
98 llvm_unreachable(
"unexpected kind");
107 : kind(k), val(v), op(o) {
220 llvm_unreachable(
"unexpected kind");
225 : outTensor(numInputOutputTensors - 1),
226 syntheticTensor(numInputOutputTensors),
227 numTensors(numInputOutputTensors + 1), numLoops(numLoops),
231 loopToLvl(numTensors,
232 std::vector<std::optional<
Level>>(numLoops, std::nullopt)),
233 lvlToLoop(numTensors,
234 std::vector<std::optional<
LoopId>>(maxLvlRank, std::nullopt)),
235 loopToUnresolvedLvls(numLoops, std::vector<std::optional<
LvlLTPair>>(
236 numTensors, std::nullopt)),
237 levelToDependentLoop(numTensors,
240 loopBounds(numLoops, std::make_pair(numTensors, numLoops)) {}
247 assert(isValidTensorId(t));
248 const ExprId eNew(tensorExps.size());
250 Value(),
nullptr,
nullptr);
255 assert(isValidLoopId(i));
256 const ExprId eNew(tensorExps.size());
258 Value(),
nullptr,
nullptr);
263 const ExprId eNew(tensorExps.size());
270 const ExprId eNew(tensorExps.size());
279 const ExprId eNew(tensorExps.size());
280 tensorExps.emplace_back(k, e0, e1,
Value(), op, attr);
287 const ExprId eNew(tensorExps.size());
294 const unsigned size = numLoops * numTensors;
296 latPoints.emplace_back(size, e);
297 latPoints[pNew].bits.set(b);
302 assert(bits.size() == numLoops * numTensors);
304 latPoints.emplace_back(bits, e);
309 const LatSetId sNew(latSets.size());
310 latSets.emplace_back();
319 const auto &point0 =
lat(p0);
320 const auto &point1 =
lat(p1);
321 BitVector bits(point0.bits);
323 const ExprId ne =
addExp(kind, point0.exp, point1.exp, op, attr);
324 latPoints.emplace_back(bits, ne);
330 auto &setNew = latSets[sNew];
333 setNew.push_back(
conjLat(e, p0, p1, op));
342 latSets[sNew].append(latSets[s0]);
352 latSets[sNew].append(latSets[s1]);
366 assert(
exp(e0).kind !=
exp(e1).kind);
375 latSets[sNew].append(latSets[lhsSet]);
376 latSets[sNew].append(latSets[rhsSet]);
389 latSets[sNew].append(latSets[s0]);
395 latSets[sNew].append(latSets[s1]);
405 auto &setNew = latSets[sNew];
407 const auto &point = latPoints[p];
408 setNew.push_back(
addLat(point.bits,
addExp(kind, point.exp, v, op)));
419 auto &setNew = latSets[sNew];
422 const auto &point = latPoints[p];
423 ExprId newExp = lhsZero ?
addExp(kind, zeroExp, point.exp,
nullptr, a)
424 :
addExp(kind, point.exp, zeroExp,
nullptr, a);
425 setNew.push_back(
addLat(point.bits, newExp));
432 auto &setNew = latSets[sNew];
433 const auto &set0 =
set(s0);
434 assert(!set0.empty());
444 assert(!
latGT(p1, p2));
450 assert(!add ||
latGT(p0, p1));
453 setNew.push_back(p1);
463 bool isSingleton =
true;
465 if (p0 != p1 &&
latGT(p0, p1)) {
471 BitVector simple(latPoints[p0].bits);
478 for (
unsigned b = 0; b < be; b++) {
487 for (
unsigned b = be - 1 - offset, i = 0; i < be;
488 b = b == 0 ? be - 1 : b - 1, i++) {
492 if (!lt.hasSparseSemantic()) {
503 const BitVector &bitsi =
lat(i).
bits;
504 const BitVector &bitsj =
lat(
j).
bits;
505 assert(bitsi.size() == bitsj.size());
506 if (bitsi.count() > bitsj.count()) {
507 for (
TensorLoopId b = 0, be = bitsj.size(); b < be; b++)
508 if (bitsj[b] && !bitsi[b])
516 BitVector tmp(latPoints[
j].bits);
517 tmp ^= latPoints[i].bits;
522 const auto &expr =
exp(e);
525 return expr.tensor == t;
531 const ExprId e0 = expr.children.e0;
535 const ExprId e0 = expr.children.e0;
536 const ExprId e1 = expr.children.e1;
540 llvm_unreachable(
"unexpected arity");
544 const auto &expr =
exp(e);
573 llvm_unreachable(
"unexpected kind");
577 assert(isValidTensorId(t));
578 const auto &expr =
exp(e);
582 return expr.tensor == t;
629 assert(!maybeZero(expr.children.e1));
634 assert(isInvariant(expr.children.e1));
643 isInvariant(expr.children.e1);
645 return isInvariant(expr.children.e0);
666 llvm_unreachable(
"unexpected kind");
672 if (lt.hasSparseSemantic())
747 return "binary_branch";
792 llvm_unreachable(
"unexpected kind for symbol");
796 const auto &expr =
exp(e);
800 if (expr.tensor == syntheticTensor)
801 llvm::dbgs() <<
"synthetic_";
802 else if (expr.tensor == outTensor)
803 llvm::dbgs() <<
"output_";
804 llvm::dbgs() <<
"tensor_" << expr.tensor;
807 llvm::dbgs() <<
"invariant";
813 llvm::dbgs() <<
"loopvar_" << expr.loop;
882 llvm::dbgs() <<
"{" << expr.attr <<
"}";
895 const auto &point =
lat(p);
896 llvm::dbgs() <<
"lat(";
898 llvm::dbgs() <<
" :";
900 llvm::dbgs() <<
" : ";
902 llvm::dbgs() <<
" )\n";
906 const auto &ss =
set(s);
907 llvm::dbgs() <<
"{ #" << ss.size() <<
"\n";
912 llvm::dbgs() <<
"}\n";
916 for (
TensorLoopId b = 0, be = bits.size(); b < be; b++) {
920 const auto lt = lvlTypes[t][i];
922 llvm::dbgs() <<
" DEP_" << t <<
"_" << i;
924 llvm::dbgs() <<
" i_" << t <<
"_" << i <<
"_" <<
toMLIRString(lt);
940 const auto &expr =
exp(e);
957 if (hasSparseOut && t == outTensor)
960 latSets[s].push_back(
addLat(t, i, e));
1002 const ExprId e0 = expr.children.e0;
1003 const Value v = expr.val;
1011 const ExprId e0 = expr.children.e0;
1022 const ExprId e0 = expr.children.e0;
1023 UnaryOp unop = cast<UnaryOp>(expr.op);
1025 Region &absentRegion = unop.getAbsentRegion();
1026 if (absentRegion.
empty()) {
1033 YieldOp absentYield = cast<YieldOp>(absentBlock.
getTerminator());
1034 const Value absentVal = absentYield.getSingleResult();
1053 const ExprId e0 = expr.children.e0;
1054 const ExprId e1 = expr.children.e1;
1075 const ExprId e0 = expr.children.e0;
1076 const ExprId e1 = expr.children.e1;
1077 assert(!maybeZero(e1));
1096 const ExprId e0 = expr.children.e0;
1097 const ExprId e1 = expr.children.e1;
1110 const ExprId e0 = expr.children.e0;
1111 const ExprId e1 = expr.children.e1;
1121 const ExprId e0 = expr.children.e0;
1122 const ExprId e1 = expr.children.e1;
1123 assert(isInvariant(e1));
1134 const ExprId e0 = expr.children.e0;
1135 const ExprId e1 = expr.children.e1;
1136 BinaryOp binop = cast<BinaryOp>(expr.op);
1139 Region &leftRegion = binop.getLeftRegion();
1140 Region &rightRegion = binop.getRightRegion();
1143 if (!leftRegion.
empty()) {
1149 if (!rightRegion.
empty()) {
1153 bool includeLeft = binop.getLeftIdentity() || !leftRegion.
empty();
1154 bool includeRight = binop.getRightIdentity() || !rightRegion.
empty();
1155 return combiSet(e, child0, child1, binop, includeLeft,
1162 const ExprId e0 = expr.children.e0;
1163 const ExprId e1 = expr.children.e1;
1172 const ExprId e0 = expr.children.e0;
1177 const ExprId e0 = expr.children.e0;
1178 const ExprId e1 = expr.children.e1;
1183 llvm_unreachable(
"unexpected expression kind");
1188 Operation *yield = op.getRegion().front().getTerminator();
1189 assert(isa<linalg::YieldOp>(yield));
1190 return buildTensorExp(op, yield->
getOperand(0)).first;
1194 bool Merger::maybeZero(
ExprId e)
const {
1195 const auto &expr =
exp(e);
1197 if (
auto c = expr.val.getDefiningOp<complex::ConstantOp>()) {
1198 ArrayAttr arrayAttr = c.getValue();
1199 return cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
1200 cast<FloatAttr>(arrayAttr[1]).getValue().isZero();
1203 return c.value() == 0;
1205 return c.value().isZero();
1210 Type Merger::inferType(
ExprId e, Value src)
const {
1215 if (
auto vtp = dyn_cast<VectorType>(src.getType()))
1216 return VectorType::get(vtp.getNumElements(), dtp, vtp.getScalableDims());
1223 if (isa<BlockArgument>(v))
1227 if (isa<linalg::IndexOp>(def))
1231 return def->
getBlock() != op->getBlock();
1246 assert(isa<YieldOp>(yield));
1250 std::pair<std::optional<ExprId>,
bool>
1251 Merger::buildTensorExp(linalg::GenericOp op,
Value v) {
1253 if (
auto arg = dyn_cast<BlockArgument>(v)) {
1258 if (arg.getOwner()->getParentOp() == op) {
1261 if (!op.isScalar(&t))
1271 if (def->getBlock() != &op.getRegion().front())
1274 if (def->getNumOperands() == 0) {
1275 if (
auto indexOp = dyn_cast<linalg::IndexOp>(def))
1280 if (def->getNumOperands() == 1) {
1281 const auto [x, hasSpDep] = buildTensorExp(op, def->getOperand(0));
1282 if (x.has_value()) {
1284 if (isa<math::AbsFOp>(def))
1286 if (isa<complex::AbsOp>(def))
1288 if (isa<math::AbsIOp>(def))
1290 if (isa<math::CeilOp>(def))
1292 if (isa<math::FloorOp>(def))
1294 if (isa<math::SqrtOp>(def))
1296 if (isa<complex::SqrtOp>(def))
1298 if (isa<math::ExpM1Op>(def))
1300 if (isa<complex::Expm1Op>(def))
1302 if (isa<math::Log1pOp>(def))
1304 if (isa<complex::Log1pOp>(def))
1306 if (isa<math::SinOp>(def))
1308 if (isa<complex::SinOp>(def))
1310 if (isa<math::TanhOp>(def))
1312 if (isa<complex::TanhOp>(def))
1314 if (isa<arith::NegFOp>(def))
1316 if (isa<complex::NegOp>(def))
1318 if (isa<arith::TruncFOp>(def))
1320 if (isa<arith::ExtFOp>(def))
1322 if (isa<arith::FPToSIOp>(def))
1324 if (isa<arith::FPToUIOp>(def))
1326 if (isa<arith::SIToFPOp>(def))
1328 if (isa<arith::UIToFPOp>(def))
1330 if (isa<arith::ExtSIOp>(def))
1332 if (isa<arith::ExtUIOp>(def))
1334 if (isa<arith::IndexCastOp>(def))
1336 if (isa<arith::TruncIOp>(def))
1338 if (isa<complex::ImOp>(def))
1340 if (isa<complex::ReOp>(def))
1342 if (isa<arith::BitcastOp>(def))
1344 if (
auto unop = dyn_cast<sparse_tensor::UnaryOp>(def)) {
1349 if (
auto selop = dyn_cast<sparse_tensor::SelectOp>(def)) {
1358 if (def->getNumOperands() == 2) {
1359 const auto [x, xDepSp] = buildTensorExp(op, def->getOperand(0));
1360 const auto [y, yDepSp] = buildTensorExp(op, def->getOperand(1));
1361 bool hasSpDep = xDepSp || yDepSp;
1362 if (x.has_value() && y.has_value()) {
1365 if (isa<arith::MulFOp>(def))
1367 if (isa<complex::MulOp>(def))
1369 if (isa<arith::MulIOp>(def))
1371 if (isa<arith::DivFOp>(def) && !maybeZero(e1))
1373 if (isa<complex::DivOp>(def) && !maybeZero(e1))
1375 if (isa<arith::DivSIOp>(def) && !maybeZero(e1))
1377 if (isa<arith::DivUIOp>(def) && !maybeZero(e1))
1379 if (isa<arith::AddFOp>(def))
1381 if (isa<complex::AddOp>(def))
1383 if (isa<arith::AddIOp>(def))
1385 if (isa<arith::SubFOp>(def))
1387 if (isa<complex::SubOp>(def))
1389 if (isa<arith::SubIOp>(def))
1391 if (isa<arith::AndIOp>(def))
1393 if (isa<arith::OrIOp>(def))
1395 if (isa<arith::XOrIOp>(def))
1397 if (isa<arith::ShRSIOp>(def) && isInvariant(e1))
1399 if (isa<arith::ShRUIOp>(def) && isInvariant(e1))
1401 if (isa<arith::ShLIOp>(def) && isInvariant(e1))
1403 if (
auto ci = dyn_cast<arith::CmpIOp>(def)) {
1404 if (ci.getPredicate() == arith::CmpIPredicate::eq &&
1405 ci.getPredicate() == arith::CmpIPredicate::sle &&
1406 ci.getPredicate() == arith::CmpIPredicate::sge &&
1407 ci.getPredicate() == arith::CmpIPredicate::ule &&
1408 ci.getPredicate() == arith::CmpIPredicate::uge) {
1411 return {std::nullopt,
false};
1415 ci.getPredicateAttr());
1416 return {e, hasSpDep};
1418 if (
auto cf = dyn_cast<arith::CmpFOp>(def)) {
1419 if (cf.getPredicate() == arith::CmpFPredicate::OEQ &&
1420 cf.getPredicate() == arith::CmpFPredicate::OGE &&
1421 cf.getPredicate() == arith::CmpFPredicate::OLE &&
1422 cf.getPredicate() == arith::CmpFPredicate::ONE &&
1423 cf.getPredicate() == arith::CmpFPredicate::UEQ &&
1424 cf.getPredicate() == arith::CmpFPredicate::UGE &&
1425 cf.getPredicate() == arith::CmpFPredicate::ULE &&
1426 cf.getPredicate() == arith::CmpFPredicate::ORD &&
1427 cf.getPredicate() == arith::CmpFPredicate::UNO) {
1430 return {std::nullopt,
false};
1433 cf.getPredicateAttr());
1434 return {e, hasSpDep};
1436 if (
auto binop = dyn_cast<sparse_tensor::BinaryOp>(def)) {
1438 (binop.getLeftIdentity() ||
1440 (binop.getRightIdentity() ||
1447 if (def->getNumOperands() == 3) {
1448 const auto [x, xDepSp] = buildTensorExp(op, def->getOperand(0));
1449 const auto [y, yDepSp] = buildTensorExp(op, def->getOperand(1));
1450 const auto [z, zDepSp] = buildTensorExp(op, def->getOperand(2));
1451 bool hasSpDep = xDepSp || yDepSp || zDepSp;
1452 if (x.has_value() && y.has_value() && z.has_value()) {
1455 if (
auto redop = dyn_cast<sparse_tensor::ReduceOp>(def)) {
1466 if (def->getNumResults() != 1)
1467 return {std::nullopt,
false};
1469 SmallVector<std::pair<std::optional<ExprId>,
bool>, 2> subExp;
1471 for (Value operand : def->getOperands())
1472 subExp.push_back(buildTensorExp(op, operand));
1474 if (llvm::all_of(subExp,
1475 [](
auto e) {
return e.first.has_value() && !e.second; })) {
1477 if (subExp.size() == 2) {
1479 *subExp[1].first, def);
1482 if (subExp.size() == 1) {
1489 return {std::nullopt,
false};
1499 YieldOp clonedYield = cast<YieldOp>(clonedBlock.
getTerminator());
1503 Value val = clonedYield.getSingleResult();
1504 rewriter.
eraseOp(clonedYield);
1505 rewriter.
eraseOp(placeholder);
1514 UnaryOp unop = cast<UnaryOp>(op);
1515 Region &presentRegion = unop.getPresentRegion();
1516 if (presentRegion.
empty())
1528 BinaryOp binop = cast<BinaryOp>(op);
1529 Region &overlapRegion = binop.getOverlapRegion();
1530 if (overlapRegion.
empty())
1534 return insertYieldOp(rewriter, loc, overlapRegion, {v0, v1});
1539 const auto &expr =
exp(e);
1540 switch (expr.kind) {
1546 llvm_unreachable(
"unexpected non-op");
1549 return rewriter.
create<math::AbsFOp>(loc, v0);
1551 auto type = cast<ComplexType>(v0.
getType());
1552 auto eltType = cast<FloatType>(type.getElementType());
1553 return rewriter.
create<complex::AbsOp>(loc, eltType, v0);
1556 return rewriter.
create<math::AbsIOp>(loc, v0);
1558 return rewriter.
create<math::CeilOp>(loc, v0);
1560 return rewriter.
create<math::FloorOp>(loc, v0);
1562 return rewriter.
create<math::SqrtOp>(loc, v0);
1564 return rewriter.
create<complex::SqrtOp>(loc, v0);
1566 return rewriter.
create<math::ExpM1Op>(loc, v0);
1568 return rewriter.
create<complex::Expm1Op>(loc, v0);
1570 return rewriter.
create<math::Log1pOp>(loc, v0);
1572 return rewriter.
create<complex::Log1pOp>(loc, v0);
1574 return rewriter.
create<math::SinOp>(loc, v0);
1576 return rewriter.
create<complex::SinOp>(loc, v0);
1578 return rewriter.
create<math::TanhOp>(loc, v0);
1580 return rewriter.
create<complex::TanhOp>(loc, v0);
1582 return rewriter.
create<arith::NegFOp>(loc, v0);
1584 return rewriter.
create<complex::NegOp>(loc, v0);
1586 return rewriter.
create<arith::SubIOp>(
1592 return rewriter.
create<arith::TruncFOp>(loc, inferType(e, v0), v0);
1594 return rewriter.
create<arith::ExtFOp>(loc, inferType(e, v0), v0);
1596 return rewriter.
create<arith::FPToSIOp>(loc, inferType(e, v0), v0);
1598 return rewriter.
create<arith::FPToUIOp>(loc, inferType(e, v0), v0);
1600 return rewriter.
create<arith::SIToFPOp>(loc, inferType(e, v0), v0);
1602 return rewriter.
create<arith::UIToFPOp>(loc, inferType(e, v0), v0);
1604 return rewriter.
create<arith::ExtSIOp>(loc, inferType(e, v0), v0);
1606 return rewriter.
create<arith::ExtUIOp>(loc, inferType(e, v0), v0);
1608 return rewriter.
create<arith::IndexCastOp>(loc, inferType(e, v0), v0);
1610 return rewriter.
create<arith::TruncIOp>(loc, inferType(e, v0), v0);
1612 auto type = cast<ComplexType>(v0.
getType());
1613 auto eltType = cast<FloatType>(type.getElementType());
1614 return rewriter.
create<complex::ImOp>(loc, eltType, v0);
1617 auto type = cast<ComplexType>(v0.
getType());
1618 auto eltType = cast<FloatType>(type.getElementType());
1619 return rewriter.
create<complex::ReOp>(loc, eltType, v0);
1622 return rewriter.
create<arith::BitcastOp>(loc, inferType(e, v0), v0);
1625 return rewriter.
create<arith::MulFOp>(loc, v0, v1);
1627 return rewriter.
create<complex::MulOp>(loc, v0, v1);
1629 return rewriter.
create<arith::MulIOp>(loc, v0, v1);
1631 return rewriter.
create<arith::DivFOp>(loc, v0, v1);
1633 return rewriter.
create<complex::DivOp>(loc, v0, v1);
1635 return rewriter.
create<arith::DivSIOp>(loc, v0, v1);
1637 return rewriter.
create<arith::DivUIOp>(loc, v0, v1);
1639 return rewriter.
create<arith::AddFOp>(loc, v0, v1);
1641 return rewriter.
create<complex::AddOp>(loc, v0, v1);
1643 return rewriter.
create<arith::AddIOp>(loc, v0, v1);
1645 return rewriter.
create<arith::SubFOp>(loc, v0, v1);
1647 return rewriter.
create<complex::SubOp>(loc, v0, v1);
1649 return rewriter.
create<arith::SubIOp>(loc, v0, v1);
1651 return rewriter.
create<arith::AndIOp>(loc, v0, v1);
1653 return rewriter.
create<arith::OrIOp>(loc, v0, v1);
1655 return rewriter.
create<arith::XOrIOp>(loc, v0, v1);
1657 return rewriter.
create<arith::ShRSIOp>(loc, v0, v1);
1659 return rewriter.
create<arith::ShRUIOp>(loc, v0, v1);
1661 return rewriter.
create<arith::ShLIOp>(loc, v0, v1);
1663 auto predicate = llvm::cast<arith::CmpIPredicateAttr>(expr.attr);
1664 return rewriter.
create<arith::CmpIOp>(loc, predicate, v0, v1);
1667 auto predicate = llvm::cast<arith::CmpFPredicateAttr>(expr.attr);
1668 return rewriter.
create<arith::CmpFOp>(loc, predicate, v0, v1);
1676 return insertYieldOp(rewriter, loc, cast<SelectOp>(expr.op).getRegion(),
1681 ReduceOp redOp = cast<ReduceOp>(expr.op);
1682 return insertYieldOp(rewriter, loc, redOp.getRegion(), {v0, v1});
1693 llvm_unreachable(
"unexpected expression kind in build");
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
TypedAttr getZeroAttr(Type type)
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Block * getBlock() const
Returns the current block of the builder.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumOperands()
Block * getBlock()
Returns the operation block that contains this operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
void cloneInto(Region *dest, IRMapping &mapper)
Clone the internal blocks from this region into dest.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Specialization of arith.constant op that returns a floating point value.
Specialization of arith.constant op that returns an integer of index type.
Specialization of arith.constant op that returns an integer value.
LatPointId conjLat(ExprId e, LatPointId p0, LatPointId p1, Operation *op=nullptr)
Computes a single conjunction of two lattice points by taking the "union" of LoopId (effectively cons...
LatSetId disjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op=nullptr)
Disjunctive merge of two lattice sets: (s0 /\_op s1, s0, s1).
bool isSingleCondition(TensorId t, ExprId e) const
Returns true if given tensor iterates only in the given tensor expression.
bool hasSparseIdxReduction(const BitVector &bits) const
Returns true if bits contains a dependent index reduction condition on sparse levels.
bool expContainsTensor(ExprId e, TensorId t) const
Returns true if the expression contains the tensor as an operand.
LatSetId mapBinWithSynZeroSet(ExprId e, LatSetId s, bool lhsZero)
Maps the binary operator to the same operation but with one of its operand set to zero,...
bool isSparseLvlWithNonTrivialIdxExp(TensorLoopId b) const
Checks whether the TensorLoopId represents a sparse tensor level contains non-trivial index expressio...
void dumpBits(const BitVector &bits) const
LatSetId addSet()
Constructs a new (initially empty) set, and returns its identifier.
BitVector simplifyCond(LatSetId s, LatPointId p)
Simplifies the conditions in a conjunction of a given lattice point within the given set using just t...
bool hasNegateOnOut(ExprId e) const
Returns true if the expression contains a negation on output tensor.
bool isLvlWithNonTrivialIdxExp(TensorLoopId b) const
Checks whether the TensorLoopId represents a tensor level contains non-trivial index expression.
LatSetId disjSetWithZero(ExprId e, LatSetId s0, LatSetId s1)
Disjunctive merge of two lattice sets and also set one of the operand to zero: (s0 /\_op s1 (e0 op e1...
void dumpSet(LatSetId s) const
void dumpLat(LatPointId p) const
LatSetId combiSet(ExprId e, LatSetId s0, LatSetId s1, Operation *orig, bool includeLeft, TensorExp::Kind ltrans, Operation *opleft, bool includeRight, TensorExp::Kind rtrans, Operation *opright)
Disjunctive merge of two lattice sets with custom handling of the overlap, left, and right regions.
ExprId addTensorExp(TensorId t)
Constructs a new tensor expression, and returns its identifier.
LatSetId buildLattices(ExprId e, LoopId i)
Builds the iteration lattices in a bottom-up traversal given the remaining tensor (sub)expression and...
LatSetId conjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op=nullptr)
Conjunctive merge of two lattice sets: (s0 /\_op s1).
ExprId addExp(TensorExp::Kind k, ExprId e0, ExprId e1=detail::kInvalidId, Operation *op=nullptr, Attribute attr=nullptr)
Constructs a new unary or binary expression, and returns its identifier.
ExprId addSynZeroExp()
Constructs a new synthetic zero expression.
constexpr LoopId makeLoopId(unsigned i) const
Safely converts the argument to a loop identifier.
std::optional< ExprId > buildTensorExpFromLinalg(linalg::GenericOp op)
Builds a tensor expression from the given Linalg operation.
ArrayRef< LatPointId > set(LatSetId s) const
LatSetId optimizeSet(LatSetId s)
Optimizes the iteration lattice points in the given set.
constexpr TensorId tensor(TensorLoopId b) const
Gets the tensor-identifier of the TensorLoopId.
void dumpExp(ExprId e) const
Print methods (for debugging).
Merger(unsigned numInputOutputTensors, unsigned numLoops, unsigned maxLvlRank)
Constructs a merger for the given number of tensors and loops.
bool hasAnySparse(const BitVector &bits) const
Returns true if any TensorLoopId in the bitvector corresponds to sparse level-type.
LatPointId addLat(TensorId t, LoopId i, ExprId e)
Constructs a new iteration lattice point, and returns its identifier.
ExprId addLoopVarExp(LoopId i)
Constructs a new loop-variable expression, and returns its identifier.
bool latGT(LatPointId p0, LatPointId p1) const
Returns true if p0 > p1.
const TensorExp & exp(ExprId e) const
Convenience getters to immediately access the stored nodes.
constexpr LoopId loop(TensorLoopId b) const
Gets the loop-identifier of the TensorLoopId.
const LatPoint & lat(LatPointId p) const
bool onlyDenseDiff(LatPointId p0, LatPointId p1) const
Returns true if p0 and p1 only differ in dense.
ExprId addInvariantExp(Value v)
Constructs a new invariant expression, and returns its identifier.
constexpr TensorId makeTensorId(unsigned t) const
Safely converts the argument to a tensor identifier.
LevelType getLvlType(TensorId t, LoopId i) const
Gets the level-type of the tth tensor on ith loop.
Value buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0, Value v1) const
Rebuilds SSA format from a tensor expression.
constexpr TensorLoopId makeTensorLoopId(unsigned t, unsigned i) const
Safely converts the arguments to a pair of (tensor,loop) identifiers.
bool expIsTensor(ExprId e, TensorId t) const
Returns true if the expression is (kTensor t).
LatSetId mapSet(TensorExp::Kind kind, LatSetId s, Value v=Value(), Operation *op=nullptr)
Maps the unary operator over the lattice set of the operand, i.e.
@ Type
An inlay hint that for a type annotation.
static constexpr unsigned kInvalidId
A constant serving as the canonically invalid identifier, regardless of the identifier type.
LevelFormat
This enum defines all supported storage format without the level properties.
static bool isAdmissibleBranchExp(Operation *op, Block *block, Value v)
Ensures that the sparsifier can generate code for expression.
unsigned LatSetId
LatSet identifiers.
static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region ®ion, ValueRange vals)
std::string toMLIRString(LevelType lt)
std::pair< Level, LevelType > LvlLTPair
A pair of level and its corresponding LevelType of a tensor.
unsigned TensorLoopId
A compressed representation of std::pair<TensorId, LoopId>.
static Value buildUnaryPresent(RewriterBase &rewriter, Location loc, Operation *op, Value v0)
uint64_t Level
The type of level identifiers and level-ranks.
static Value buildBinaryOverlap(RewriterBase &rewriter, Location loc, Operation *op, Value v0, Value v1)
static bool isAdmissibleBranch(Operation *op, Region ®ion)
Ensures that the sparsifier can generate code for branch.
unsigned LoopId
Loop identifiers.
static const char * kindToOpSymbol(TensorExp::Kind kind)
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
unsigned ExprId
TensorExp identifiers.
static ExpArity getExpArity(TensorExp::Kind k)
unsigned LatPointId
LatPoint identifiers.
std::pair< LoopId, unsigned > LoopCoeffPair
A pair of loop id and its coefficients.
unsigned TensorId
Tensor identifiers, chosen to be the BlockArgument::getArgNumber of the value passed to Merger::build...
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
BitVector bits
Conjunction of all TensorLoopIds involved in the tensor expression.
This enum defines all the sparse representations supportable by the SparseTensor dialect.
LoopId loop
kLoopVar expressions simply have a loop identifier.
Value val
Direct link to IR for an invariant or the destination value (to infer destination type) of a cast ope...
Kind
Tensor expression kind.
Children children
All other expressions hold the ExprIds of their children.
Attribute attr
An optional attribute that is required to determine the semantics of the operations.
TensorId tensor
kTensor expressions simply have a tensor identifier.
Kind kind
Tensor expression kind.
TensorExp(Kind k, unsigned x, ExprId y, Value v, Operation *op, Attribute a)
The x parameter has different types depending on the value of the k parameter.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.