16 #include "llvm/Support/Debug.h"
20 namespace sparse_tensor {
98 llvm_unreachable(
"unexpected kind");
107 : kind(k), val(v), op(o) {
220 llvm_unreachable(
"unexpected kind");
225 : outTensor(numInputOutputTensors - 1),
226 syntheticTensor(numInputOutputTensors),
227 numTensors(numInputOutputTensors + 1), numLoops(numLoops),
230 loopToLvl(numTensors,
231 std::vector<std::optional<
Level>>(numLoops, std::nullopt)),
232 lvlToLoop(numTensors,
233 std::vector<std::optional<
LoopId>>(maxLvlRank, std::nullopt)),
234 loopToUnresolvedLvls(numLoops, std::vector<std::optional<
LvlLTPair>>(
235 numTensors, std::nullopt)),
236 levelToDependentLoop(numTensors,
239 loopBounds(numLoops, std::make_pair(numTensors, numLoops)) {}
246 assert(isValidTensorId(t));
247 const ExprId eNew(tensorExps.size());
249 Value(),
nullptr,
nullptr);
254 assert(isValidLoopId(i));
255 const ExprId eNew(tensorExps.size());
257 Value(),
nullptr,
nullptr);
262 const ExprId eNew(tensorExps.size());
269 const ExprId eNew(tensorExps.size());
278 const ExprId eNew(tensorExps.size());
279 tensorExps.emplace_back(k, e0, e1,
Value(), op, attr);
286 const ExprId eNew(tensorExps.size());
293 const unsigned size = numLoops * numTensors;
295 latPoints.emplace_back(size, e);
296 latPoints[pNew].bits.set(b);
301 assert(bits.size() == numLoops * numTensors);
303 latPoints.emplace_back(bits, e);
308 const LatSetId sNew(latSets.size());
309 latSets.emplace_back();
318 const auto &point0 =
lat(p0);
319 const auto &point1 =
lat(p1);
320 BitVector bits(point0.bits);
322 const ExprId ne =
addExp(kind, point0.exp, point1.exp, op, attr);
323 latPoints.emplace_back(bits, ne);
329 auto &setNew = latSets[sNew];
332 setNew.push_back(
conjLat(e, p0, p1, op));
341 latSets[sNew].append(latSets[s0]);
351 latSets[sNew].append(latSets[s1]);
365 assert(
exp(e0).kind !=
exp(e1).kind);
374 latSets[sNew].append(latSets[lhsSet]);
375 latSets[sNew].append(latSets[rhsSet]);
388 latSets[sNew].append(latSets[s0]);
394 latSets[sNew].append(latSets[s1]);
404 auto &setNew = latSets[sNew];
406 const auto &point = latPoints[p];
407 setNew.push_back(
addLat(point.bits,
addExp(kind, point.exp, v, op)));
418 auto &setNew = latSets[sNew];
421 const auto &point = latPoints[p];
422 ExprId newExp = lhsZero ?
addExp(kind, zeroExp, point.exp,
nullptr, a)
423 :
addExp(kind, point.exp, zeroExp,
nullptr, a);
424 setNew.push_back(
addLat(point.bits, newExp));
431 auto &setNew = latSets[sNew];
432 const auto &set0 =
set(s0);
433 assert(!set0.empty());
443 assert(!
latGT(p1, p2));
449 assert(!add ||
latGT(p0, p1));
452 setNew.push_back(p1);
462 bool isSingleton =
true;
464 if (p0 != p1 &&
latGT(p0, p1)) {
470 BitVector simple(latPoints[p0].bits);
477 for (
unsigned b = 0; b < be; b++) {
486 for (
unsigned b = be - 1 - offset, i = 0; i < be;
487 b = b == 0 ? be - 1 : b - 1, i++) {
503 const BitVector &bitsi =
lat(i).
bits;
504 const BitVector &bitsj =
lat(
j).
bits;
505 assert(bitsi.size() == bitsj.size());
506 if (bitsi.count() > bitsj.count()) {
507 for (
TensorLoopId b = 0, be = bitsj.size(); b < be; b++)
508 if (bitsj[b] && !bitsi[b])
516 BitVector tmp(latPoints[
j].bits);
517 tmp ^= latPoints[i].bits;
522 const auto &expr =
exp(e);
525 return expr.tensor == t;
531 const ExprId e0 = expr.children.e0;
535 const ExprId e0 = expr.children.e0;
536 const ExprId e1 = expr.children.e1;
540 llvm_unreachable(
"unexpected arity");
544 const auto &expr =
exp(e);
573 llvm_unreachable(
"unexpected kind");
577 assert(isValidTensorId(t));
578 const auto &expr =
exp(e);
582 return expr.tensor == t;
629 assert(!maybeZero(expr.children.e1));
634 assert(isInvariant(expr.children.e1));
643 isInvariant(expr.children.e1);
645 return isInvariant(expr.children.e0);
666 llvm_unreachable(
"unexpected kind");
748 return "binary_branch";
793 llvm_unreachable(
"unexpected kind for symbol");
797 const auto &expr =
exp(e);
801 if (expr.tensor == syntheticTensor)
802 llvm::dbgs() <<
"synthetic_";
803 else if (expr.tensor == outTensor)
804 llvm::dbgs() <<
"output_";
805 llvm::dbgs() <<
"tensor_" << expr.tensor;
808 llvm::dbgs() <<
"invariant";
814 llvm::dbgs() <<
"loopvar_" << expr.loop;
883 llvm::dbgs() <<
"{" << expr.attr <<
"}";
896 const auto &point =
lat(p);
897 llvm::dbgs() <<
"lat(";
899 llvm::dbgs() <<
" :";
901 llvm::dbgs() <<
" : ";
903 llvm::dbgs() <<
" )\n";
907 const auto &ss =
set(s);
908 llvm::dbgs() <<
"{ #" << ss.size() <<
"\n";
913 llvm::dbgs() <<
"}\n";
917 for (
TensorLoopId b = 0, be = bits.size(); b < be; b++) {
921 const auto lt = lvlTypes[t][i];
923 llvm::dbgs() <<
" DEP_" << t <<
"_" << i;
925 llvm::dbgs() <<
" i_" << t <<
"_" << i <<
"_" <<
toMLIRString(lt);
941 const auto &expr =
exp(e);
958 if (hasSparseOut && t == outTensor)
961 latSets[s].push_back(
addLat(t, i, e));
1003 const ExprId e0 = expr.children.e0;
1004 const Value v = expr.val;
1012 const ExprId e0 = expr.children.e0;
1023 const ExprId e0 = expr.children.e0;
1024 UnaryOp unop = cast<UnaryOp>(expr.op);
1026 Region &absentRegion = unop.getAbsentRegion();
1027 if (absentRegion.
empty()) {
1034 YieldOp absentYield = cast<YieldOp>(absentBlock.
getTerminator());
1035 const Value absentVal = absentYield.getResult();
1054 const ExprId e0 = expr.children.e0;
1055 const ExprId e1 = expr.children.e1;
1076 const ExprId e0 = expr.children.e0;
1077 const ExprId e1 = expr.children.e1;
1078 assert(!maybeZero(e1));
1097 const ExprId e0 = expr.children.e0;
1098 const ExprId e1 = expr.children.e1;
1111 const ExprId e0 = expr.children.e0;
1112 const ExprId e1 = expr.children.e1;
1122 const ExprId e0 = expr.children.e0;
1123 const ExprId e1 = expr.children.e1;
1124 assert(isInvariant(e1));
1135 const ExprId e0 = expr.children.e0;
1136 const ExprId e1 = expr.children.e1;
1137 BinaryOp binop = cast<BinaryOp>(expr.op);
1140 Region &leftRegion = binop.getLeftRegion();
1141 Region &rightRegion = binop.getRightRegion();
1144 if (!leftRegion.
empty()) {
1150 if (!rightRegion.
empty()) {
1154 bool includeLeft = binop.getLeftIdentity() || !leftRegion.
empty();
1155 bool includeRight = binop.getRightIdentity() || !rightRegion.
empty();
1156 return combiSet(e, child0, child1, binop, includeLeft,
1163 const ExprId e0 = expr.children.e0;
1164 const ExprId e1 = expr.children.e1;
1173 const ExprId e0 = expr.children.e0;
1178 const ExprId e0 = expr.children.e0;
1179 const ExprId e1 = expr.children.e1;
1184 llvm_unreachable(
"unexpected expression kind");
1189 Operation *yield = op.getRegion().front().getTerminator();
1190 assert(isa<linalg::YieldOp>(yield));
1191 return buildTensorExp(op, yield->
getOperand(0)).first;
1195 bool Merger::maybeZero(
ExprId e)
const {
1196 const auto &expr =
exp(e);
1198 if (
auto c = expr.val.getDefiningOp<complex::ConstantOp>()) {
1199 ArrayAttr arrayAttr = c.getValue();
1200 return cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
1201 cast<FloatAttr>(arrayAttr[1]).getValue().isZero();
1204 return c.value() == 0;
1206 return c.value().isZero();
1211 Type Merger::inferType(
ExprId e, Value src)
const {
1216 if (
auto vtp = dyn_cast<VectorType>(src.getType()))
1217 return VectorType::get(vtp.getNumElements(), dtp, vtp.getScalableDims());
1224 if (isa<BlockArgument>(v))
1228 if (isa<linalg::IndexOp>(def))
1232 return def->
getBlock() != op->getBlock();
1247 assert(isa<YieldOp>(yield));
1251 std::pair<std::optional<ExprId>,
bool>
1252 Merger::buildTensorExp(linalg::GenericOp op,
Value v) {
1254 if (
auto arg = dyn_cast<BlockArgument>(v)) {
1259 if (arg.getOwner()->getParentOp() == op) {
1262 if (!op.isScalar(&t))
1272 if (def->getBlock() != &op.getRegion().front())
1275 if (def->getNumOperands() == 0) {
1276 if (
auto indexOp = dyn_cast<linalg::IndexOp>(def))
1281 if (def->getNumOperands() == 1) {
1282 const auto [x, hasSpDep] = buildTensorExp(op, def->getOperand(0));
1283 if (x.has_value()) {
1285 if (isa<math::AbsFOp>(def))
1287 if (isa<complex::AbsOp>(def))
1289 if (isa<math::AbsIOp>(def))
1291 if (isa<math::CeilOp>(def))
1293 if (isa<math::FloorOp>(def))
1295 if (isa<math::SqrtOp>(def))
1297 if (isa<complex::SqrtOp>(def))
1299 if (isa<math::ExpM1Op>(def))
1301 if (isa<complex::Expm1Op>(def))
1303 if (isa<math::Log1pOp>(def))
1305 if (isa<complex::Log1pOp>(def))
1307 if (isa<math::SinOp>(def))
1309 if (isa<complex::SinOp>(def))
1311 if (isa<math::TanhOp>(def))
1313 if (isa<complex::TanhOp>(def))
1315 if (isa<arith::NegFOp>(def))
1317 if (isa<complex::NegOp>(def))
1319 if (isa<arith::TruncFOp>(def))
1321 if (isa<arith::ExtFOp>(def))
1323 if (isa<arith::FPToSIOp>(def))
1325 if (isa<arith::FPToUIOp>(def))
1327 if (isa<arith::SIToFPOp>(def))
1329 if (isa<arith::UIToFPOp>(def))
1331 if (isa<arith::ExtSIOp>(def))
1333 if (isa<arith::ExtUIOp>(def))
1335 if (isa<arith::IndexCastOp>(def))
1337 if (isa<arith::TruncIOp>(def))
1339 if (isa<complex::ImOp>(def))
1341 if (isa<complex::ReOp>(def))
1343 if (isa<arith::BitcastOp>(def))
1345 if (
auto unop = dyn_cast<sparse_tensor::UnaryOp>(def)) {
1350 if (
auto selop = dyn_cast<sparse_tensor::SelectOp>(def)) {
1359 if (def->getNumOperands() == 2) {
1360 const auto [x, xDepSp] = buildTensorExp(op, def->getOperand(0));
1361 const auto [y, yDepSp] = buildTensorExp(op, def->getOperand(1));
1362 bool hasSpDep = xDepSp || yDepSp;
1363 if (x.has_value() && y.has_value()) {
1366 if (isa<arith::MulFOp>(def))
1368 if (isa<complex::MulOp>(def))
1370 if (isa<arith::MulIOp>(def))
1372 if (isa<arith::DivFOp>(def) && !maybeZero(e1))
1374 if (isa<complex::DivOp>(def) && !maybeZero(e1))
1376 if (isa<arith::DivSIOp>(def) && !maybeZero(e1))
1378 if (isa<arith::DivUIOp>(def) && !maybeZero(e1))
1380 if (isa<arith::AddFOp>(def))
1382 if (isa<complex::AddOp>(def))
1384 if (isa<arith::AddIOp>(def))
1386 if (isa<arith::SubFOp>(def))
1388 if (isa<complex::SubOp>(def))
1390 if (isa<arith::SubIOp>(def))
1392 if (isa<arith::AndIOp>(def))
1394 if (isa<arith::OrIOp>(def))
1396 if (isa<arith::XOrIOp>(def))
1398 if (isa<arith::ShRSIOp>(def) && isInvariant(e1))
1400 if (isa<arith::ShRUIOp>(def) && isInvariant(e1))
1402 if (isa<arith::ShLIOp>(def) && isInvariant(e1))
1404 if (
auto ci = dyn_cast<arith::CmpIOp>(def)) {
1405 if (ci.getPredicate() == arith::CmpIPredicate::eq &&
1406 ci.getPredicate() == arith::CmpIPredicate::sle &&
1407 ci.getPredicate() == arith::CmpIPredicate::sge &&
1408 ci.getPredicate() == arith::CmpIPredicate::ule &&
1409 ci.getPredicate() == arith::CmpIPredicate::uge) {
1412 return {std::nullopt,
false};
1416 ci.getPredicateAttr());
1417 return {e, hasSpDep};
1419 if (
auto cf = dyn_cast<arith::CmpFOp>(def)) {
1420 if (cf.getPredicate() == arith::CmpFPredicate::OEQ &&
1421 cf.getPredicate() == arith::CmpFPredicate::OGE &&
1422 cf.getPredicate() == arith::CmpFPredicate::OLE &&
1423 cf.getPredicate() == arith::CmpFPredicate::ONE &&
1424 cf.getPredicate() == arith::CmpFPredicate::UEQ &&
1425 cf.getPredicate() == arith::CmpFPredicate::UGE &&
1426 cf.getPredicate() == arith::CmpFPredicate::ULE &&
1427 cf.getPredicate() == arith::CmpFPredicate::ORD &&
1428 cf.getPredicate() == arith::CmpFPredicate::UNO) {
1431 return {std::nullopt,
false};
1434 cf.getPredicateAttr());
1435 return {e, hasSpDep};
1437 if (
auto binop = dyn_cast<sparse_tensor::BinaryOp>(def)) {
1439 (binop.getLeftIdentity() ||
1441 (binop.getRightIdentity() ||
1448 if (def->getNumOperands() == 3) {
1449 const auto [x, xDepSp] = buildTensorExp(op, def->getOperand(0));
1450 const auto [y, yDepSp] = buildTensorExp(op, def->getOperand(1));
1451 const auto [z, zDepSp] = buildTensorExp(op, def->getOperand(2));
1452 bool hasSpDep = xDepSp || yDepSp || zDepSp;
1453 if (x.has_value() && y.has_value() && z.has_value()) {
1456 if (
auto redop = dyn_cast<sparse_tensor::ReduceOp>(def)) {
1467 if (def->getNumResults() != 1)
1468 return {std::nullopt,
false};
1470 SmallVector<std::pair<std::optional<ExprId>,
bool>, 2> subExp;
1472 for (Value operand : def->getOperands())
1473 subExp.push_back(buildTensorExp(op, operand));
1475 if (llvm::all_of(subExp,
1476 [](
auto e) {
return e.first.has_value() && !e.second; })) {
1478 if (subExp.size() == 2) {
1480 *subExp[1].first, def);
1483 if (subExp.size() == 1) {
1490 return {std::nullopt,
false};
1500 YieldOp clonedYield = cast<YieldOp>(clonedBlock.
getTerminator());
1504 Value val = clonedYield.getResult();
1505 rewriter.
eraseOp(clonedYield);
1506 rewriter.
eraseOp(placeholder);
1515 UnaryOp unop = cast<UnaryOp>(op);
1516 Region &presentRegion = unop.getPresentRegion();
1517 if (presentRegion.
empty())
1529 BinaryOp binop = cast<BinaryOp>(op);
1530 Region &overlapRegion = binop.getOverlapRegion();
1531 if (overlapRegion.
empty())
1535 return insertYieldOp(rewriter, loc, overlapRegion, {v0, v1});
1540 const auto &expr =
exp(e);
1541 switch (expr.kind) {
1547 llvm_unreachable(
"unexpected non-op");
1550 return rewriter.
create<math::AbsFOp>(loc, v0);
1552 auto type = cast<ComplexType>(v0.
getType());
1553 auto eltType = cast<FloatType>(type.getElementType());
1554 return rewriter.
create<complex::AbsOp>(loc, eltType, v0);
1557 return rewriter.
create<math::AbsIOp>(loc, v0);
1559 return rewriter.
create<math::CeilOp>(loc, v0);
1561 return rewriter.
create<math::FloorOp>(loc, v0);
1563 return rewriter.
create<math::SqrtOp>(loc, v0);
1565 return rewriter.
create<complex::SqrtOp>(loc, v0);
1567 return rewriter.
create<math::ExpM1Op>(loc, v0);
1569 return rewriter.
create<complex::Expm1Op>(loc, v0);
1571 return rewriter.
create<math::Log1pOp>(loc, v0);
1573 return rewriter.
create<complex::Log1pOp>(loc, v0);
1575 return rewriter.
create<math::SinOp>(loc, v0);
1577 return rewriter.
create<complex::SinOp>(loc, v0);
1579 return rewriter.
create<math::TanhOp>(loc, v0);
1581 return rewriter.
create<complex::TanhOp>(loc, v0);
1583 return rewriter.
create<arith::NegFOp>(loc, v0);
1585 return rewriter.
create<complex::NegOp>(loc, v0);
1587 return rewriter.
create<arith::SubIOp>(
1593 return rewriter.
create<arith::TruncFOp>(loc, inferType(e, v0), v0);
1595 return rewriter.
create<arith::ExtFOp>(loc, inferType(e, v0), v0);
1597 return rewriter.
create<arith::FPToSIOp>(loc, inferType(e, v0), v0);
1599 return rewriter.
create<arith::FPToUIOp>(loc, inferType(e, v0), v0);
1601 return rewriter.
create<arith::SIToFPOp>(loc, inferType(e, v0), v0);
1603 return rewriter.
create<arith::UIToFPOp>(loc, inferType(e, v0), v0);
1605 return rewriter.
create<arith::ExtSIOp>(loc, inferType(e, v0), v0);
1607 return rewriter.
create<arith::ExtUIOp>(loc, inferType(e, v0), v0);
1609 return rewriter.
create<arith::IndexCastOp>(loc, inferType(e, v0), v0);
1611 return rewriter.
create<arith::TruncIOp>(loc, inferType(e, v0), v0);
1613 auto type = cast<ComplexType>(v0.
getType());
1614 auto eltType = cast<FloatType>(type.getElementType());
1615 return rewriter.
create<complex::ImOp>(loc, eltType, v0);
1618 auto type = cast<ComplexType>(v0.
getType());
1619 auto eltType = cast<FloatType>(type.getElementType());
1620 return rewriter.
create<complex::ReOp>(loc, eltType, v0);
1623 return rewriter.
create<arith::BitcastOp>(loc, inferType(e, v0), v0);
1626 return rewriter.
create<arith::MulFOp>(loc, v0, v1);
1628 return rewriter.
create<complex::MulOp>(loc, v0, v1);
1630 return rewriter.
create<arith::MulIOp>(loc, v0, v1);
1632 return rewriter.
create<arith::DivFOp>(loc, v0, v1);
1634 return rewriter.
create<complex::DivOp>(loc, v0, v1);
1636 return rewriter.
create<arith::DivSIOp>(loc, v0, v1);
1638 return rewriter.
create<arith::DivUIOp>(loc, v0, v1);
1640 return rewriter.
create<arith::AddFOp>(loc, v0, v1);
1642 return rewriter.
create<complex::AddOp>(loc, v0, v1);
1644 return rewriter.
create<arith::AddIOp>(loc, v0, v1);
1646 return rewriter.
create<arith::SubFOp>(loc, v0, v1);
1648 return rewriter.
create<complex::SubOp>(loc, v0, v1);
1650 return rewriter.
create<arith::SubIOp>(loc, v0, v1);
1652 return rewriter.
create<arith::AndIOp>(loc, v0, v1);
1654 return rewriter.
create<arith::OrIOp>(loc, v0, v1);
1656 return rewriter.
create<arith::XOrIOp>(loc, v0, v1);
1658 return rewriter.
create<arith::ShRSIOp>(loc, v0, v1);
1660 return rewriter.
create<arith::ShRUIOp>(loc, v0, v1);
1662 return rewriter.
create<arith::ShLIOp>(loc, v0, v1);
1664 auto predicate = llvm::cast<arith::CmpIPredicateAttr>(expr.attr);
1665 return rewriter.
create<arith::CmpIOp>(loc, predicate, v0, v1);
1668 auto predicate = llvm::cast<arith::CmpFPredicateAttr>(expr.attr);
1669 return rewriter.
create<arith::CmpFOp>(loc, predicate, v0, v1);
1677 return insertYieldOp(rewriter, loc, cast<SelectOp>(expr.op).getRegion(),
1682 ReduceOp redOp = cast<ReduceOp>(expr.op);
1683 return insertYieldOp(rewriter, loc, redOp.getRegion(), {v0, v1});
1694 llvm_unreachable(
"unexpected expression kind in build");
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
TypedAttr getZeroAttr(Type type)
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Block * getBlock() const
Returns the current block of the builder.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumOperands()
Block * getBlock()
Returns the operation block that contains this operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
void cloneInto(Region *dest, IRMapping &mapper)
Clone the internal blocks from this region into dest.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Specialization of arith.constant op that returns a floating point value.
Specialization of arith.constant op that returns an integer of index type.
Specialization of arith.constant op that returns an integer value.
LatPointId conjLat(ExprId e, LatPointId p0, LatPointId p1, Operation *op=nullptr)
Computes a single conjunction of two lattice points by taking the "union" of LoopId (effectively cons...
LatSetId disjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op=nullptr)
Disjunctive merge of two lattice sets: (s0 /\_op s1, s0, s1).
bool isSingleCondition(TensorId t, ExprId e) const
Returns true if given tensor iterates only in the given tensor expression.
bool hasSparseIdxReduction(const BitVector &bits) const
Returns true if bits contains a dependent index reduction condition on sparse levels.
bool expContainsTensor(ExprId e, TensorId t) const
Returns true if the expression contains the tensor as an operand.
LatSetId mapBinWithSynZeroSet(ExprId e, LatSetId s, bool lhsZero)
Maps the binary operator to the same operation but with one of its operand set to zero,...
bool isSparseLvlWithNonTrivialIdxExp(TensorLoopId b) const
Checks whether the TensorLoopId represents a sparse tensor level contains non-trivial index expressio...
void dumpBits(const BitVector &bits) const
LatSetId addSet()
Constructs a new (initially empty) set, and returns its identifier.
BitVector simplifyCond(LatSetId s, LatPointId p)
Simplifies the conditions in a conjunction of a given lattice point within the given set using just t...
bool hasNegateOnOut(ExprId e) const
Returns true if the expression contains a negation on output tensor.
bool isLvlWithNonTrivialIdxExp(TensorLoopId b) const
Checks whether the TensorLoopId represents a tensor level contains non-trivial index expression.
LatSetId disjSetWithZero(ExprId e, LatSetId s0, LatSetId s1)
Disjunctive merge of two lattice sets and also set one of the operand to zero: (s0 /\_op s1 (e0 op e1...
void dumpSet(LatSetId s) const
void dumpLat(LatPointId p) const
LatSetId combiSet(ExprId e, LatSetId s0, LatSetId s1, Operation *orig, bool includeLeft, TensorExp::Kind ltrans, Operation *opleft, bool includeRight, TensorExp::Kind rtrans, Operation *opright)
Disjunctive merge of two lattice sets with custom handling of the overlap, left, and right regions.
ExprId addTensorExp(TensorId t)
Constructs a new tensor expression, and returns its identifier.
LatSetId buildLattices(ExprId e, LoopId i)
Builds the iteration lattices in a bottom-up traversal given the remaining tensor (sub)expression and...
LatSetId conjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op=nullptr)
Conjunctive merge of two lattice sets: (s0 /\_op s1).
ExprId addExp(TensorExp::Kind k, ExprId e0, ExprId e1=detail::kInvalidId, Operation *op=nullptr, Attribute attr=nullptr)
Constructs a new unary or binary expression, and returns its identifier.
ExprId addSynZeroExp()
Constructs a new synthetic zero expression.
constexpr LoopId makeLoopId(unsigned i) const
Safely converts the argument to a loop identifier.
std::optional< ExprId > buildTensorExpFromLinalg(linalg::GenericOp op)
Builds a tensor expression from the given Linalg operation.
ArrayRef< LatPointId > set(LatSetId s) const
LatSetId optimizeSet(LatSetId s)
Optimizes the iteration lattice points in the given set.
constexpr TensorId tensor(TensorLoopId b) const
Gets the tensor-identifier of the TensorLoopId.
void dumpExp(ExprId e) const
Print methods (for debugging).
Merger(unsigned numInputOutputTensors, unsigned numLoops, unsigned maxLvlRank)
Constructs a merger for the given number of tensors and loops.
bool hasAnySparse(const BitVector &bits) const
Returns true if any TensorLoopId in the bitvector corresponds to sparse level-type.
LatPointId addLat(TensorId t, LoopId i, ExprId e)
Constructs a new iteration lattice point, and returns its identifier.
ExprId addLoopVarExp(LoopId i)
Constructs a new loop-variable expression, and returns its identifier.
bool latGT(LatPointId p0, LatPointId p1) const
Returns true if p0 > p1.
const TensorExp & exp(ExprId e) const
Convenience getters to immediately access the stored nodes.
constexpr LoopId loop(TensorLoopId b) const
Gets the loop-identifier of the TensorLoopId.
const LatPoint & lat(LatPointId p) const
bool onlyDenseDiff(LatPointId p0, LatPointId p1) const
Returns true if p0 and p1 only differ in dense.
ExprId addInvariantExp(Value v)
Constructs a new invariant expression, and returns its identifier.
constexpr TensorId makeTensorId(unsigned t) const
Safely converts the argument to a tensor identifier.
LevelType getLvlType(TensorId t, LoopId i) const
Gets the level-type of the tth tensor on ith loop.
Value buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0, Value v1) const
Rebuilds SSA format from a tensor expression.
constexpr TensorLoopId makeTensorLoopId(unsigned t, unsigned i) const
Safely converts the arguments to a pair of (tensor,loop) identifiers.
bool expIsTensor(ExprId e, TensorId t) const
Returns true if the expression is (kTensor t).
LatSetId mapSet(TensorExp::Kind kind, LatSetId s, Value v=Value(), Operation *op=nullptr)
Maps the unary operator over the lattice set of the operand, i.e.
@ Type
An inlay hint that for a type annotation.
static constexpr unsigned kInvalidId
A constant serving as the canonically invalid identifier, regardless of the identifier type.
static bool isAdmissibleBranchExp(Operation *op, Block *block, Value v)
Ensures that the sparsifier can generate code for expression.
constexpr const char * toMLIRString(LevelType lt)
Returns string representation of the given dimension level type.
unsigned LatSetId
LatSet identifiers.
static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region ®ion, ValueRange vals)
constexpr bool isLooseCompressedLT(LevelType lt)
Check if the LevelType is loose compressed (regardless of properties).
std::pair< Level, LevelType > LvlLTPair
A pair of level and its corresponding LevelType of a tensor.
unsigned TensorLoopId
A compressed representation of std::pair<TensorId, LoopId>.
static Value buildUnaryPresent(RewriterBase &rewriter, Location loc, Operation *op, Value v0)
uint64_t Level
The type of level identifiers and level-ranks.
static Value buildBinaryOverlap(RewriterBase &rewriter, Location loc, Operation *op, Value v0, Value v1)
static bool isAdmissibleBranch(Operation *op, Region ®ion)
Ensures that the sparsifier can generate code for branch.
unsigned LoopId
Loop identifiers.
constexpr bool is2OutOf4LT(LevelType lt)
Check if the LevelType is 2OutOf4 (regardless of properties).
constexpr bool isDenseLT(LevelType lt)
Check if the LevelType is dense (regardless of properties).
static const char * kindToOpSymbol(TensorExp::Kind kind)
constexpr bool isSingletonLT(LevelType lt)
Check if the LevelType is singleton (regardless of properties).
LevelType
This enum defines all the sparse representations supportable by the SparseTensor dialect.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
constexpr bool isCompressedLT(LevelType lt)
Check if the LevelType is compressed (regardless of properties).
unsigned ExprId
TensorExp identifiers.
static ExpArity getExpArity(TensorExp::Kind k)
unsigned LatPointId
LatPoint identifiers.
std::pair< LoopId, unsigned > LoopCoeffPair
A pair of loop id and its coefficients.
unsigned TensorId
Tensor identifiers, chosen to be the BlockArgument::getArgNumber of the value passed to Merger::build...
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
BitVector bits
Conjunction of all TensorLoopIds involved in the tensor expression.
LoopId loop
kLoopVar expressions simply have a loop identifier.
Value val
Direct link to IR for an invariant or the destination value (to infer destination type) of a cast ope...
Kind
Tensor expression kind.
Children children
All other expressions hold the ExprIds of their children.
Attribute attr
An optional attribute that is required to determine the semantics of the operations.
TensorId tensor
kTensor expressions simply have a tensor identifier.
Kind kind
Tensor expression kind.
TensorExp(Kind k, unsigned x, ExprId y, Value v, Operation *op, Attribute a)
The x parameter has different types depending on the value of the k parameter.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.