50 return enc && !llvm::all_of(enc.getLvlTypes(),
51 [](
auto lt) { return lt == LevelFormat::Dense; });
57 Value val = op->get();
75 if (
auto *def = yieldOp.getOperand(0).getDefiningOp()) {
76 if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def)) {
80 return (def->getOperand(0) == s1 && def->getOperand(1) == s2) ||
81 (def->getOperand(1) == s1 && def->getOperand(0) == s2);
89 if (
auto arg = dyn_cast<BlockArgument>(val))
92 if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def))
102 if (
auto *def = yieldOp.getOperand(0).getDefiningOp()) {
103 if (isa<arith::AddFOp>(def) || isa<arith::AddIOp>(def)) {
105 return (def->getOperand(0) == x &&
isMulChain(def->getOperand(1), x)) ||
106 (def->getOperand(1) == x &&
isMulChain(def->getOperand(0), x));
115 if (
auto arg = dyn_cast<BlockArgument>(yieldOp.getOperand(0))) {
116 if (arg.getOwner()->getParentOp() == op) {
127 for (
const auto &d :
enumerate(stp.getShape())) {
129 if (d.value() == ShapedType::kDynamic)
130 dim = builder.
create<tensor::DimOp>(loc, tensor, d.index());
133 sizes.push_back(dim);
148 for (
const auto &d :
enumerate(tp.getShape())) {
149 if (d.value() == ShapedType::kDynamic)
150 dynSizes.push_back(sizes[d.index()]);
156 SparseElementsAttr attr) {
162 rewriter, loc, attr, op.getOrder().value_or(
AffineMap()),
165 args.append(cvs.begin(), cvs.end());
169 auto cloned = cast<ForeachOp>(rewriter.
clone(*op.getOperation()));
170 assert(args.size() == cloned.getBody()->getNumArguments());
171 Operation *yield = cloned.getBody()->getTerminator();
175 reduc = yield->getOperands();
189 auto dstShape = dstTp.getShape();
193 if (dstShape[dim] != ShapedType::kDynamic) {
198 for (
const auto &src : srcs.drop_front()) {
201 sizes[dim] = builder.
create<arith::AddIOp>(loc, sizes[dim], srcSz);
226 struct FuseExtractSliceWithConcat
230 LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractOp,
232 auto concatOp = extractOp.getSource().getDefiningOp<tensor::ConcatOp>();
237 int64_t dim = concatOp.getDim();
238 int64_t rank = extractOp.getResultType().getRank();
247 for (
auto [idx, input] :
250 partialSums.push_back(sum);
251 offsetStrides.push_back(
254 auto partialSumMap =
AffineMap::get(concatOp.getInputs().size(), 0,
258 rewriter, loc, partialSumMap, offsetStrides);
261 for (
auto [l, r] : llvm::zip(lhs, rhs)) {
266 return lhs.size() == rhs.size();
269 for (
auto [i, input, offset] :
273 srcOffsets[dim] = offset;
279 if (allEqual(srcSizes, dstSizes) && allEqual(srcOffsets, dstOffsets) &&
280 allEqual(srcStrides, dstStrides)) {
281 Value operand = concatOp.getOperand(i);
282 if (operand.
getType() == extractOp.getResultType())
299 if (!op.hasPureTensorSemantics() || op.
getNumResults() != 1 ||
307 rewriter.
replaceOp(op, op.getDpsInitOperand(0)->get());
311 if (!outputType.hasStaticShape())
313 Operation *def = op.getDpsInitOperand(0)->get().getDefiningOp();
342 if (!op.hasPureTensorSemantics() || op.getNumDpsInputs() != 2 ||
344 op.getNumParallelLoops() != op.getNumLoops() ||
345 !op.getMatchingIndexingMap(op.getDpsInitOperand(0)).isIdentity() ||
346 !op.getMatchingIndexingMap(op.getDpsInputOperand(0)).isIdentity() ||
347 !op.getMatchingIndexingMap(op.getDpsInputOperand(1)).isIdentity())
359 auto prod = dyn_cast_or_null<GenericOp>(
360 op.getDpsInputOperand(other)->get().getDefiningOp());
361 if (!prod || !prod.hasPureTensorSemantics() || prod.getNumResults() != 1 ||
362 !prod.getResult(0).hasOneUse())
374 inputOps.push_back(op.getDpsInputOperand(1 - other)->get());
375 fusedIndexMaps.push_back(fusedIndexMaps.back());
377 auto fusedOp = rewriter.
create<GenericOp>(
386 for (
unsigned i = 0; i < num - 1; i++)
387 addArg(mapper, fusedBlock, prodBlock.getArgument(i));
388 addArg(mapper, fusedBlock, consBlock.
getArgument(1 - other));
389 addArg(mapper, fusedBlock, prodBlock.getArgument(num - 1));
391 auto *acc = prodBlock.getTerminator()->getOperand(0).getDefiningOp();
394 for (
auto &op : prodBlock.without_terminator())
397 rewriter.
clone(op, mapper);
402 rewriter.
create<linalg::YieldOp>(loc, last);
406 Value init = prod.getDpsInitOperand(0)
411 op.getDpsInitOperand(0)->get().getDefiningOp<AllocTensorOp>();
412 rewriter.
modifyOpInPlace(a, [&]() { a.getCopyMutable().assign(init); });
416 rewriter.
replaceOp(op, fusedOp->getResults());
438 Type srcType = op.getSource().getType();
439 Type dstType = op.getDest().getType();
441 if (srcType == dstType) {
447 if (
Operation *def = op.getSource().getDefiningOp()) {
448 if (def->
hasOneUse() && isa<tensor::ExtractSliceOp>(def)) {
497 auto matched = isRewritablePattern(op, &inst);
498 if (!matched.has_value())
502 auto [c, t, f] = matched.value();
503 assert(t.getType() == f.getType());
504 auto selTp = t.getType();
506 auto binOp = rewriter.
create<sparse_tensor::BinaryOp>(loc, selTp, t, f);
508 rewriter.
createBlock(&binOp.getOverlapRegion(), {}, {selTp, selTp},
509 {t.getLoc(), f.getLoc()});
510 rewriter.
createBlock(&binOp.getRightRegion(), {}, selTp, f.getLoc());
511 rewriter.
createBlock(&binOp.getLeftRegion(), {}, selTp, t.getLoc());
513 for (
auto *r : binOp.getRegions()) {
521 if (
auto *def = c.getDefiningOp())
525 if (r == &binOp.getLeftRegion()) {
528 }
else if (r == &binOp.getRightRegion()) {
536 rewriter.
create<sparse_tensor::YieldOp>(loc, y);
542 semiRings.emplace_back(&inst, binOp);
546 for (
auto [sel, semi] : semiRings)
547 rewriter.
replaceOp(sel, semi->getResults());
549 return success(!semiRings.empty());
553 static std::optional<std::tuple<Value, BlockArgument, BlockArgument>>
554 isRewritablePattern(GenericOp op,
Operation *v) {
555 auto sel = dyn_cast<arith::SelectOp>(v);
559 auto tVal = dyn_cast<BlockArgument>(sel.getTrueValue());
560 auto fVal = dyn_cast<BlockArgument>(sel.getFalseValue());
569 auto isValFromDenseInputOrInvariant = [&op](
Value v) ->
bool {
570 if (
auto bArg = dyn_cast<BlockArgument>(v);
571 bArg && !
isSparseTensor(op.getDpsInputOperand(bArg.getArgNumber())))
574 return v.getDefiningOp() && v.getDefiningOp()->
getBlock() != op.getBody();
579 auto cond = sel.getCondition();
580 if (isValFromDenseInputOrInvariant(cond))
581 return std::make_tuple(cond, tVal, fVal);
590 if (isValFromDenseInputOrInvariant(cmpL) ||
591 isValFromDenseInputOrInvariant(cmpR))
592 return std::make_tuple(cond, tVal, fVal);
623 if (!op.hasPureTensorSemantics() || op.getNumDpsInputs() != 1 ||
626 auto *inp = op.getDpsInputOperand(0);
627 auto *init = op.getDpsInitOperand(0);
634 if (!isa<arith::AndIOp, arith::MulIOp, arith::MulFOp, arith::MinimumFOp,
635 arith::MinSIOp, arith::MinUIOp, arith::MaximumFOp, arith::MaxSIOp,
636 arith::MaxUIOp>(red))
640 if ((red->getOperand(0) != s0 || red->getOperand(1) != s1) &&
641 (red->getOperand(0) != s1 || red->getOperand(1) != s0))
653 auto semiring = rewriter.
create<sparse_tensor::UnaryOp>(loc, rtp, s0);
655 rewriter.
createBlock(&semiring.getPresentRegion(), {}, rtp, loc);
657 rewriter.
create<sparse_tensor::YieldOp>(loc, present->getArgument(0));
658 rewriter.
createBlock(&semiring.getAbsentRegion(), {}, {}, {});
662 rewriter.
create<sparse_tensor::YieldOp>(loc, zero);
667 auto custom = rewriter.
create<sparse_tensor::ReduceOp>(
668 loc, rtp, semiring.getResult(), s1, identity);
670 rewriter.
createBlock(&custom.getRegion(), {}, {rtp, rtp}, {loc, loc});
673 irMap.
map(red->getOperand(0), region->getArgument(0));
674 irMap.
map(red->getOperand(1), region->getArgument(1));
675 auto *cloned = rewriter.
clone(*red, irMap);
676 rewriter.
create<sparse_tensor::YieldOp>(loc, cloned->getResult(0));
678 rewriter.
replaceOp(red, custom.getResult());
692 auto tensor = op.getTensor();
695 auto nse = rewriter.
create<NumberOfEntriesOp>(loc, tensor);
696 rewriter.
create<vector::PrintOp>(
697 loc, rewriter.
getStringAttr(
"---- Sparse Tensor ----\nnse = "));
698 rewriter.
create<vector::PrintOp>(loc, nse);
701 printSizes(rewriter, loc, tensor, stt.getDimRank(),
true);
703 printSizes(rewriter, loc, tensor, stt.getLvlRank(),
false);
711 case SparseTensorFieldKind::StorageSpec: {
714 case SparseTensorFieldKind::PosMemRef: {
717 rewriter.
create<vector::PrintOp>(
718 loc, lvl, vector::PrintPunctuation::NoPunctuation);
720 auto pos = rewriter.
create<ToPositionsOp>(loc, tensor, l);
721 printContents(rewriter, loc, pos);
724 case SparseTensorFieldKind::CrdMemRef: {
727 rewriter.
create<vector::PrintOp>(
728 loc, lvl, vector::PrintPunctuation::NoPunctuation);
734 if (stt.getAoSCOOStart() == l)
735 crd = rewriter.
create<ToCoordinatesBufferOp>(loc, tensor);
737 crd = rewriter.
create<ToCoordinatesOp>(loc, tensor, l);
738 printContents(rewriter, loc, crd);
741 case SparseTensorFieldKind::ValMemRef: {
742 rewriter.
create<vector::PrintOp>(loc,
744 auto val = rewriter.
create<ToValuesOp>(loc, tensor);
745 printContents(rewriter, loc, val);
768 rewriter.
create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
771 auto size = rewriter.
create<memref::DimOp>(loc, vec, zero);
773 auto forOp = rewriter.
create<scf::ForOp>(loc, zero, size, step);
775 auto idx = forOp.getInductionVar();
776 auto val = rewriter.
create<memref::LoadOp>(loc, vec, idx);
777 if (llvm::isa<ComplexType>(val.getType())) {
780 Value real = rewriter.
create<complex::ReOp>(loc, val);
781 Value imag = rewriter.
create<complex::ImOp>(loc, val);
782 rewriter.
create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
783 rewriter.
create<vector::PrintOp>(loc, real,
784 vector::PrintPunctuation::Comma);
785 rewriter.
create<vector::PrintOp>(loc, imag,
786 vector::PrintPunctuation::Close);
787 rewriter.
create<vector::PrintOp>(loc, vector::PrintPunctuation::Comma);
789 rewriter.
create<vector::PrintOp>(loc, val,
790 vector::PrintPunctuation::Comma);
794 rewriter.
create<vector::PrintOp>(loc, vector::PrintPunctuation::Close);
795 rewriter.
create<vector::PrintOp>(loc, vector::PrintPunctuation::NewLine);
800 unsigned size,
bool isDim) {
802 rewriter.
create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
804 for (
unsigned i = 0; i < size; i++) {
808 val = rewriter.
create<tensor::DimOp>(loc, tensor, idx);
810 val = rewriter.
create<LvlOp>(loc, tensor, idx);
811 rewriter.
create<vector::PrintOp>(
813 i != size - 1 ? vector::PrintPunctuation::Comma
814 : vector::PrintPunctuation::NoPunctuation);
817 rewriter.
create<vector::PrintOp>(loc, vector::PrintPunctuation::Close);
818 rewriter.
create<vector::PrintOp>(loc, vector::PrintPunctuation::NewLine);
830 Value srcTensor = op.getSource();
834 if (!srcTp.hasEncoding() || !dstTp.hasEncoding() ||
835 !dstTp.hasStaticDimShape())
844 Value nnz = rewriter.
create<NumberOfEntriesOp>(loc, srcTensor);
848 dstTp.withoutDimToLvl(),
849 !srcTp.isAllOrdered() || !srcTp.isIdentity() || !dstTp.isIdentity());
851 Value buffer = rewriter
852 .
create<AllocTensorOp>(loc, bufferTp, dynSizes,
Value(),
867 const auto encSrc = srcTp.getEncoding();
868 ForeachOp foreachOp = rewriter.
create<ForeachOp>(
869 loc, srcTensor, buffer,
872 const Dimension srcRank = srcTp.getDimRank();
874 srcDcvs.reserve(srcRank);
875 for (
Dimension d = 0; d < srcRank; d++) {
877 srcDcvs.push_back(srcLcvs[lvl]);
883 builder.create<arith::MulIOp>(loc, collapseSize, srcSizes[d]);
888 collapseIdx.push_back(i);
891 reshapeCvs(builder, loc, collapseReass, srcSizes, srcDcvs,
892 collapsedSizes, collapsedDcvs);
895 for (
Dimension i = 0; i < dstTp.getDimRank(); i++)
896 expandIdx.push_back(i);
899 reshapeCvs(builder, loc, expandReass, collapsedSizes, collapsedDcvs,
903 builder.create<tensor::InsertOp>(loc, v, reduc.front(), dstDcvs);
904 builder.create<sparse_tensor::YieldOp>(loc, t);
907 Value t = rewriter.
create<LoadOp>(loc, foreachOp.getResult(0),
true);
908 if (bufferTp != dstTp) {
909 auto dstRTT = dstTp.getRankedTensorType();
910 Value converted = rewriter.
create<ConvertOp>(loc, dstRTT, t).getResult();
911 rewriter.
create<DeallocTensorOp>(loc, t);
920 template <
typename ReshapeOp>
928 Value srcTensor = op.getSrc();
931 if (!srcTp.hasEncoding() || !dstTp.hasEncoding())
940 if (dstTp.hasStaticDimShape()) {
946 op.getReassociationIndices());
948 if (shape == ShapedType::kDynamic)
949 dstDynSizes.push_back(dstSizes[idx]);
952 Value nnz = rewriter.
create<NumberOfEntriesOp>(loc, srcTensor);
956 dstTp.withoutDimToLvl(),
957 !srcTp.isAllOrdered() || !srcTp.isIdentity() || !dstTp.isIdentity());
961 .
create<AllocTensorOp>(loc, bufferTp, dstDynSizes,
Value(),
972 const auto encSrc = srcTp.getEncoding();
973 ForeachOp foreachOp = rewriter.
create<ForeachOp>(
974 loc, srcTensor, buffer,
977 const Dimension dimRank = srcTp.getDimRank();
979 srcDcvs.reserve(dimRank);
980 for (
Dimension d = 0; d < dimRank; d++) {
982 srcDcvs.push_back(srcLcvs[lvl]);
985 reshapeCvs(builder, loc, op.getReassociationIndices(), srcSizes,
986 srcDcvs, dstSizes, dstDcvs);
988 builder.create<tensor::InsertOp>(loc, v, reduc.front(), dstDcvs);
989 builder.create<sparse_tensor::YieldOp>(loc, t);
992 Value t = rewriter.
create<LoadOp>(loc, foreachOp.getResult(0),
true);
993 if (bufferTp != dstTp) {
994 auto dstRTT = dstTp.getRankedTensorType();
995 Value converted = rewriter.
create<ConvertOp>(loc, dstRTT, t).getResult();
996 rewriter.
create<DeallocTensorOp>(loc, t);
1006 template <
typename ReshapeOp>
1020 if (encDst && encSrc) {
1027 auto convert = rewriter.
create<ConvertOp>(loc, denseTp, op.getSrc());
1035 auto reshape = rewriter.
create<ReshapeOp>(loc, denseTp, op.getSrc(),
1036 op.getReassociation());
1037 Value convert = rewriter.
create<ConvertOp>(loc, rtp, reshape);
1053 val = builder.
create<AllocTensorOp>(loc, rtt, dynSzs);
1056 val = builder.
create<linalg::FillOp>(loc, c0, val).getResult(0);
1061 val = builder.
create<tensor::InsertOp>(loc, v, val, crds);
1066 return builder.
create<LoadOp>(loc, val,
true);
1070 bool isSparse()
const {
1077 struct SparseTensorDimOpRewriter :
public OpRewritePattern<tensor::DimOp> {
1081 std::optional<int64_t> dim = op.getConstantIndex();
1083 if (!dim || !stt.hasEncoding())
1086 if (stt.isPermutation()) {
1088 toLvl(stt.getEncoding(), *dim));
1100 for (Level l = 0; l < stt.getLvlRank(); l++) {
1101 Value lvlSz = rewriter.
create<LvlOp>(loc, op.getSource(), l);
1104 maxLvlCrds.push_back(maxLvlCrd);
1107 AffineExpr lvl2DimExp = stt.getLvlToDim().getResult(*dim);
1108 Value maxDimCrd = rewriter.
create<affine::AffineApplyOp>(
1123 if (op.needsExtraSort())
1124 op.
emitError(
"ConcatenateOp not staged");
1128 const Dimension conDim = op.getDimension();
1145 TensorLike dstBuf(rewriter, loc, dstTp.getRankedTensorType(), sizes);
1147 Value iterArg = dstBuf.val;
1149 ForeachOp foreachOp;
1150 for (
Value input : op.getInputs()) {
1153 foreachOp = rewriter.
create<ForeachOp>(
1154 loc, input, iterArg,
1159 builder.create<arith::AddIOp>(loc, offDimCrd[conDim], offset);
1162 dstBuf.val = reduc.front();
1163 if (!dstTp.isAllDense()) {
1165 auto ifOp = builder.create<scf::IfOp>(loc, reduc.getTypes(), cond,
1167 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
1168 builder.create<scf::YieldOp>(loc, dstBuf.val);
1170 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
1171 dstBuf.insert(builder, loc, v, offDimCrd);
1172 builder.create<scf::YieldOp>(loc, dstBuf.val);
1175 builder.setInsertionPointAfter(ifOp);
1176 dstBuf.val = ifOp.getResult(0);
1178 dstBuf.insert(builder, loc, v, offDimCrd);
1180 builder.create<sparse_tensor::YieldOp>(loc, dstBuf.val);
1186 assert(!ShapedType::isDynamic(sz));
1187 offset = rewriter.
create<arith::AddIOp>(loc, offset,
1190 dstBuf.val = iterArg;
1193 dstBuf.val = iterArg;
1194 Value ret = dstBuf.finalize(rewriter, loc, dstTp.getRankedTensorType());
1204 if (op.needsExtraSort())
1205 return op.
emitError(
"ConvertOp not staged.");
1210 if (encDst && encSrc && !encSrc.isSlice() &&
1211 encSrc.withoutBitWidths() == encDst.withoutBitWidths()) {
1218 Value src = op.getSource();
1223 bool fromSparseConst =
false;
1224 if (
auto constOp = op.getSource().getDefiningOp<arith::ConstantOp>())
1225 if (dyn_cast<SparseElementsAttr>(constOp.getValue()))
1226 fromSparseConst =
true;
1228 const AffineMapAttr foreachOrder =
1233 bool skipZeroCheck = srcStt.
hasEncoding() || fromSparseConst;
1240 auto foreachOp = rewriter.
create<ForeachOp>(
1241 loc, src, dstBuf.val, foreachOrder,
1245 dstBuf.val = reduc.front();
1246 if (!skipZeroCheck) {
1248 auto ifOp = builder.create<scf::IfOp>(loc, reduc.getTypes(), cond,
1250 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
1251 builder.create<scf::YieldOp>(loc, dstBuf.val);
1253 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
1254 dstBuf.insert(builder, loc, v, dcvs);
1255 builder.create<scf::YieldOp>(loc, dstBuf.val);
1258 builder.setInsertionPointAfter(ifOp);
1259 dstBuf.val = ifOp.getResult(0);
1261 dstBuf.insert(builder, loc, v, dcvs);
1263 builder.create<sparse_tensor::YieldOp>(loc, dstBuf.val);
1269 dstBuf.val = foreachOp.getResult(0);
1281 AffineMap map = op.getDirection() == CrdTransDirectionKind::dim2lvl
1282 ? op.getEncoder().getDimToLvl()
1283 : op.getEncoder().getLvlToDim();
1290 Value trans = rewriter.
create<affine::AffineApplyOp>(
1293 outCrds.push_back(trans);
1309 Value input = op.getTensor();
1312 const Level lvlRank = stt.getLvlRank();
1316 if (
auto constOp = input.
getDefiningOp<arith::ConstantOp>()) {
1317 if (
auto attr = dyn_cast<SparseElementsAttr>(constOp.getValue())) {
1323 const auto enc = stt.getEncoding();
1330 for (Level l = 0; l < lvlRank; l++) {
1334 loopEmitter.makeTensorLevel(0, l)};
1335 loopEmitter.enterNewLoopSeq(rewriter, loc, tidLvls);
1338 loopEmitter.enterCoIterationOverTensorsAtLvls(rewriter, loc, tidLvls,
1343 if (op.getOrder()) {
1346 "Level order not yet implemented on non-constant input tensors.");
1349 Value vals = loopEmitter.getValBuffer()[0];
1353 Value val = enc ? rewriter.
create<memref::LoadOp>(loc, vals, pos)
1354 : rewriter.
create<memref::LoadOp>(loc, vals, lcvs);
1357 Block *srcBlock = op.getBody();
1361 enc.translateCrds(rewriter, loc, lcvs, CrdTransDirectionKind::lvl2dim);
1364 args.push_back(val);
1373 if (llvm::isa<scf::YieldOp>(last)) {
1383 for (Level l = 0; l < lvlRank; l++) {
1386 loopEmitter.exitCurrentLoop(rewriter, loc, reducValue);
1387 loopEmitter.exitCurrentLoopSeq(rewriter, loc);
1404 if (!stt.hasEncoding() || stt.getAoSCOOStart() == 0)
1411 RankedTensorType dstTp = stt.getRankedTensorType();
1412 RankedTensorType cooTp = stt.getCOOType(
true);
1414 Value convert = cooTensor;
1415 auto enc = stt.getEncoding();
1416 if (!stt.isPermutation()) {
1418 convert = rewriter.
create<ReinterpretMapOp>(loc, coo, convert);
1421 convert = rewriter.
create<ConvertOp>(loc, dstTp, convert);
1422 if (!stt.isPermutation())
1423 convert = rewriter.
create<ReinterpretMapOp>(loc, enc, convert);
1428 rewriter.
create<DeallocTensorOp>(loc, cooTensor);
1441 Value src = op.getTensor();
1442 Value nnz = rewriter.
create<NumberOfEntriesOp>(loc, src);
1446 const Dimension dimRank = srcTp.getDimRank();
1454 for (
Dimension d = 0; d < dimRank; d++) {
1455 rewriter.
create<memref::StoreOp>(loc, dims[d], dimSizes,
1462 createFuncCall(rewriter, loc,
"createSparseTensorWriter", {opaqueTp},
1463 {op.getDest()}, EmitCInterface::Off)
1466 createFuncCall(rewriter, loc,
"outSparseTensorWriterMetaData", {},
1467 {writer, rankValue, nnz, dimSizes}, EmitCInterface::On);
1469 Value dimCoords = dimSizes;
1470 Type eltTp = srcTp.getElementType();
1477 rewriter.
create<ForeachOp>(
1478 loc, src, std::nullopt,
1481 for (
Dimension d = 0; d < dimRank; d++) {
1482 rewriter.
create<memref::StoreOp>(loc, dcvs[d], dimCoords,
1485 rewriter.
create<memref::StoreOp>(loc, v, value);
1488 EmitCInterface::On);
1489 builder.create<func::CallOp>(loc,
TypeRange(), fn, operands);
1490 builder.create<sparse_tensor::YieldOp>(loc);
1494 createFuncCall(rewriter, loc,
"delSparseTensorWriter", {}, {writer},
1495 EmitCInterface::Off);
1509 patterns.
add<FuseExtractSliceWithConcat, FoldInvariantYield,
1510 FuseSparseMultiplyOverAdd, FuseTensorCast, GenSemiRingReduction,
1511 GenSemiRingSelect, PrintRewriter>(patterns.
getContext());
1516 bool enableConvert) {
1517 patterns.
add<ConcatenateRewriter, ReshapeRewriter<tensor::ExpandShapeOp>,
1518 ReshapeRewriter<tensor::CollapseShapeOp>,
1519 Sparse2SparseReshapeRewriter<tensor::ExpandShapeOp>,
1520 Sparse2SparseReshapeRewriter<tensor::CollapseShapeOp>,
1521 SparseTensorDimOpRewriter, TensorReshapeRewriter, OutRewriter>(
1533 patterns.
add<CrdTranslateRewriter, ForeachRewriter>(patterns.
getContext());
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static MLIRContext * getContext(OpFoldResult val)
static bool isMulChain(Value val, Value x)
static bool isSampling(GenericOp op)
static bool isSumOfMul(GenericOp op)
static bool isZeroValue(Value val)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static LogicalResult genForeachOnSparseConstant(ForeachOp op, RewriterBase &rewriter, SparseElementsAttr attr)
static bool isMaterializing(OpOperand *op, bool isZero)
static void concatSizesFromInputs(OpBuilder &builder, SmallVectorImpl< Value > &sizes, Location loc, ShapedType dstTp, ValueRange srcs, unsigned dim)
Populates the given sizes array for concatenation from types (for static sizes) and from the source t...
static bool isSparseTensor(Value v)
static bool isZeroYield(GenericOp op)
static void sizesForTensor(OpBuilder &builder, SmallVectorImpl< Value > &sizes, Location loc, ShapedType stp, Value tensor)
Populates given sizes array from type (for static sizes) and from the tensor (for dynamic sizes).
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Location getLoc() const
Return the location for this argument.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
BlockArgListType getArguments()
IntegerAttr getIndexAttr(int64_t value)
StringAttr getStringAttr(const Twine &bytes)
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
A symbol reference with a reference path containing a single element.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Block * getBlock() const
Returns the current block of the builder.
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
void setOperand(unsigned idx, Value value)
bool hasOneUse()
Returns true if this operation has exactly one use.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
void initializeLoopEmit(OpBuilder &builder, Location loc, OutputUpdater updater=nullptr, SynTensorBoundSetter synSetter=nullptr)
Starts a loop emitting session by generating all the buffers needed for iterating over the tensors.
A wrapper around RankedTensorType, which has three goals:
Size getDynamicDimSize(Dimension d) const
Safely looks up the requested dimension-DynSize.
bool hasEncoding() const
Returns true for tensors which have an encoding, and false for those which do not.
SparseTensorType withEncoding(SparseTensorEncodingAttr newEnc) const
bool isIdentity() const
Returns true if the dimToLvl mapping is the identity.
RankedTensorType getRankedTensorType() const
Explicitly convert to RankedTensorType.
AffineMap getExpandedDimToLvl() const
Returns the dimToLvl mapping, where the identity map is expanded out into a full AffineMap.
RankedTensorType getCOOType(bool ordered) const
Returns [un]ordered COO type for this sparse tensor type.
SparseTensorEncodingAttr getEncoding() const
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
FailureOr< BaseMemRefType > getBufferType(Value value, const BufferizationOptions &options)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
FlatSymbolRefAttr getFunc(ModuleOp module, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)
Returns a function reference (first hit also inserts into module).
Value genAllocaScalar(OpBuilder &builder, Location loc, Type tp)
Generates an uninitialized temporary buffer with room for one value of the given type,...
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
void foreachInSparseConstant(OpBuilder &builder, Location loc, SparseElementsAttr attr, AffineMap order, function_ref< void(ArrayRef< Value >, Value)> callback)
Iterate over a sparse constant, generates constantOp for value and coordinates.
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)
unsigned FieldIndex
The type of field indices.
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
uint64_t Level
The type of level identifiers and level-ranks.
int64_t Size
The type for individual components of a compile-time shape, including the value ShapedType::kDynamic ...
RankedTensorType getRankedTensorType(T &&t)
Convenience method to abbreviate casting getType().
Type getOpaquePointerType(MLIRContext *ctx)
Returns the equivalent of void* for opaque arguments to the execution engine.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
Value genIsNonzero(OpBuilder &builder, Location loc, Value v)
Generates the comparison v != 0 where v is of numeric type.
Level toLvl(SparseTensorEncodingAttr enc, Dimension d)
Convenience method to translate the given dimension to the corresponding level.
Value genAlloca(OpBuilder &builder, Location loc, Value sz, Type tp)
Generates an uninitialized temporary buffer of the given size and type, but returns it as type memref...
void genReshapeDstShape(OpBuilder &builder, Location loc, SmallVectorImpl< Value > &dstShape, ArrayRef< Value > srcShape, ArrayRef< Size > staticDstShape, ArrayRef< ReassociationIndices > reassociation)
Computes the shape of destination tensor of a reshape operator.
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
void reshapeCvs(OpBuilder &builder, Location loc, ArrayRef< ReassociationIndices > reassociation, ValueRange srcSizes, ValueRange srcCvs, ValueRange dstSizes, SmallVectorImpl< Value > &dstCvs)
Reshape coordinates during a reshaping operation.
bool hasAnySparseOperand(Operation *op)
Returns true iff MLIR operand has any sparse operand.
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
func::CallOp createFuncCall(OpBuilder &builder, Location loc, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)
Creates a CallOp to the function reference returned by getFunc() in the builder's module.
StringRef primaryTypeFunctionSuffix(PrimaryType pt)
Convert PrimaryType to its function-name suffix.
void sizesFromSrc(OpBuilder &builder, SmallVectorImpl< Value > &sizes, Location loc, Value src)
Populates given sizes array from dense tensor or sparse tensor constant.
bool isSameTypeWithoutEncoding(Type tp1, Type tp2)
Tests if types are the same when ignoring encoding on ranked tensors.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
void populatePreSparsificationRewriting(RewritePatternSet &patterns)
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
void populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns, bool enableRT, bool enableConvert)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateLowerForeachToSCFPatterns(RewritePatternSet &patterns)
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This enum defines all the sparse representations supportable by the SparseTensor dialect.