50 return enc && !llvm::all_of(enc.getLvlTypes(),
51 [](
auto lt) { return lt == LevelFormat::Dense; });
74 auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
75 if (
auto *def = yieldOp.getOperand(0).getDefiningOp()) {
76 if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def)) {
78 Value s1 = op.getBlock()->getArgument(0);
79 Value s2 = op.getBlock()->getArgument(1);
80 return (def->getOperand(0) == s1 && def->getOperand(1) == s2) ||
81 (def->getOperand(1) == s1 && def->getOperand(0) == s2);
89 if (
auto arg = dyn_cast<BlockArgument>(val))
92 if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def))
101 auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
102 if (
auto *def = yieldOp.getOperand(0).getDefiningOp()) {
103 if (isa<arith::AddFOp>(def) || isa<arith::AddIOp>(def)) {
104 Value x = op.getBlock()->getArguments().back();
105 return (def->getOperand(0) == x &&
isMulChain(def->getOperand(1), x)) ||
106 (def->getOperand(1) == x &&
isMulChain(def->getOperand(0), x));
114 auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
115 if (
auto arg = dyn_cast<BlockArgument>(yieldOp.getOperand(0))) {
116 if (arg.getOwner()->getParentOp() == op) {
117 return isZeroValue(op->getOperand(arg.getArgNumber()));
127 for (
const auto &d :
enumerate(stp.getShape())) {
129 if (d.value() == ShapedType::kDynamic)
130 dim = builder.
create<tensor::DimOp>(loc, tensor, d.index());
133 sizes.push_back(dim);
148 for (
const auto &d :
enumerate(tp.getShape())) {
149 if (d.value() == ShapedType::kDynamic)
150 dynSizes.push_back(sizes[d.index()]);
156 SparseElementsAttr attr) {
157 auto loc = op.getLoc();
162 rewriter, loc, attr, op.getOrder().value_or(
AffineMap()),
165 args.append(cvs.begin(), cvs.end());
169 auto cloned = cast<ForeachOp>(rewriter.
clone(*op.getOperation()));
170 assert(args.size() == cloned.getBody()->getNumArguments());
171 Operation *yield = cloned.getBody()->getTerminator();
175 reduc = yield->getOperands();
189 auto dstShape = dstTp.getShape();
193 if (dstShape[dim] != ShapedType::kDynamic) {
198 for (
const auto &src : srcs.drop_front()) {
201 sizes[dim] = builder.
create<arith::AddIOp>(loc, sizes[dim], srcSz);
226 struct FuseExtractSliceWithConcat
230 LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractOp,
232 auto concatOp = extractOp.getSource().getDefiningOp<tensor::ConcatOp>();
237 int64_t dim = concatOp.getDim();
238 int64_t rank = extractOp.getResultType().getRank();
247 for (
auto [idx, input] :
250 partialSums.push_back(sum);
251 offsetStrides.push_back(
254 auto partialSumMap =
AffineMap::get(concatOp.getInputs().size(), 0,
258 rewriter, loc, partialSumMap, offsetStrides);
261 for (
auto [l, r] : llvm::zip(lhs, rhs)) {
266 return lhs.size() == rhs.size();
269 for (
auto [i, input, offset] :
273 srcOffsets[dim] = offset;
279 if (allEqual(srcSizes, dstSizes) && allEqual(srcOffsets, dstOffsets) &&
280 allEqual(srcStrides, dstStrides)) {
281 Value operand = concatOp.getOperand(i);
282 if (operand.
getType() == extractOp.getResultType())
297 LogicalResult matchAndRewrite(ConvertOp op,
299 auto producer = op.getSource().getDefiningOp<GenericOp>();
300 if (!producer || producer.getDpsInits().size() != 1 ||
302 !producer.getResult(0).hasOneUse()) {
307 Operation *init = producer.getDpsInitOperand(0)->get().getDefiningOp();
312 producer.getDpsInitsMutable().assign(cloned->
getResults());
313 producer.getResult(0).setType(op.getResult().getType());
328 LogicalResult matchAndRewrite(GenericOp op,
330 if (!op.hasPureTensorSemantics() || op.getNumResults() != 1 ||
332 !
isZeroYield(op) || !op.getDpsInitOperand(0)->get().hasOneUse())
338 rewriter.
replaceOp(op, op.getDpsInitOperand(0)->get());
342 if (!outputType.hasStaticShape())
344 Operation *def = op.getDpsInitOperand(0)->get().getDefiningOp();
370 LogicalResult matchAndRewrite(GenericOp op,
373 if (!op.hasPureTensorSemantics() || op.getNumDpsInputs() != 2 ||
374 op.getNumResults() != 1 ||
375 op.getNumParallelLoops() != op.getNumLoops() ||
376 !op.getMatchingIndexingMap(op.getDpsInitOperand(0)).isIdentity() ||
377 !op.getMatchingIndexingMap(op.getDpsInputOperand(0)).isIdentity() ||
378 !op.getMatchingIndexingMap(op.getDpsInputOperand(1)).isIdentity())
390 auto prod = dyn_cast_or_null<GenericOp>(
391 op.getDpsInputOperand(other)->get().getDefiningOp());
392 if (!prod || !prod.hasPureTensorSemantics() || prod.getNumResults() != 1 ||
393 !prod.getResult(0).hasOneUse())
405 inputOps.push_back(op.getDpsInputOperand(1 - other)->get());
406 fusedIndexMaps.push_back(fusedIndexMaps.back());
408 auto fusedOp = rewriter.
create<GenericOp>(
409 loc, op.getResult(0).getType(), inputOps, outputOps,
417 for (
unsigned i = 0; i < num - 1; i++)
418 addArg(mapper, fusedBlock, prodBlock.getArgument(i));
419 addArg(mapper, fusedBlock, consBlock.
getArgument(1 - other));
420 addArg(mapper, fusedBlock, prodBlock.getArgument(num - 1));
422 auto *acc = prodBlock.getTerminator()->getOperand(0).getDefiningOp();
425 for (
auto &op : prodBlock.without_terminator())
427 last = op.getResult(0);
428 rewriter.
clone(op, mapper);
433 rewriter.
create<linalg::YieldOp>(loc, last);
437 Value init = prod.getDpsInitOperand(0)
439 .getDefiningOp<AllocTensorOp>()
442 op.getDpsInitOperand(0)->get().getDefiningOp<AllocTensorOp>();
443 rewriter.
modifyOpInPlace(a, [&]() { a.getCopyMutable().assign(init); });
447 rewriter.
replaceOp(op, fusedOp->getResults());
467 LogicalResult matchAndRewrite(tensor::CastOp op,
469 Type srcType = op.getSource().getType();
470 Type dstType = op.getDest().getType();
472 if (srcType == dstType) {
473 rewriter.
replaceOp(op, op->getResults());
478 if (
Operation *def = op.getSource().getDefiningOp()) {
479 if (def->
hasOneUse() && isa<tensor::ExtractSliceOp>(def)) {
481 def->
getResult(0).setType(op->getResultTypes()[0]);
518 LogicalResult matchAndRewrite(GenericOp op,
528 auto matched = isRewritablePattern(op, &inst);
529 if (!matched.has_value())
533 auto [c, t, f] = matched.value();
534 assert(t.getType() == f.getType());
535 auto selTp = t.getType();
537 auto binOp = rewriter.
create<sparse_tensor::BinaryOp>(loc, selTp, t, f);
539 rewriter.
createBlock(&binOp.getOverlapRegion(), {}, {selTp, selTp},
540 {t.getLoc(), f.getLoc()});
541 rewriter.
createBlock(&binOp.getRightRegion(), {}, selTp, f.getLoc());
542 rewriter.
createBlock(&binOp.getLeftRegion(), {}, selTp, t.getLoc());
544 for (
auto *r : binOp.getRegions()) {
545 Block *b = &r->front();
552 if (
auto *def = c.getDefiningOp())
556 if (r == &binOp.getLeftRegion()) {
559 }
else if (r == &binOp.getRightRegion()) {
567 rewriter.
create<sparse_tensor::YieldOp>(loc, y);
573 semiRings.emplace_back(&inst, binOp);
577 for (
auto [sel, semi] : semiRings)
578 rewriter.
replaceOp(sel, semi->getResults());
580 return success(!semiRings.empty());
584 static std::optional<std::tuple<Value, BlockArgument, BlockArgument>>
585 isRewritablePattern(GenericOp op,
Operation *v) {
586 auto sel = dyn_cast<arith::SelectOp>(v);
590 auto tVal = dyn_cast<BlockArgument>(sel.getTrueValue());
591 auto fVal = dyn_cast<BlockArgument>(sel.getFalseValue());
600 auto isValFromDenseInputOrInvariant = [&op](
Value v) ->
bool {
601 if (
auto bArg = dyn_cast<BlockArgument>(v);
602 bArg && !
isSparseTensor(op.getDpsInputOperand(bArg.getArgNumber())))
605 return v.getDefiningOp() && v.getDefiningOp()->
getBlock() != op.getBody();
610 auto cond = sel.getCondition();
611 if (isValFromDenseInputOrInvariant(cond))
612 return std::make_tuple(cond, tVal, fVal);
621 if (isValFromDenseInputOrInvariant(cmpL) ||
622 isValFromDenseInputOrInvariant(cmpR))
623 return std::make_tuple(cond, tVal, fVal);
651 LogicalResult matchAndRewrite(GenericOp op,
654 if (!op.hasPureTensorSemantics() || op.getNumDpsInputs() != 1 ||
655 op.getNumReductionLoops() == 0 || op.getNumResults() != 1)
657 auto *inp = op.getDpsInputOperand(0);
658 auto *init = op.getDpsInitOperand(0);
662 auto *red = cast<linalg::YieldOp>(op.getRegion().front().getTerminator())
665 if (!isa<arith::AndIOp, arith::MulIOp, arith::MulFOp, arith::MinimumFOp,
666 arith::MinSIOp, arith::MinUIOp, arith::MaximumFOp, arith::MaxSIOp,
667 arith::MaxUIOp>(red))
669 Value s0 = op.getBlock()->getArgument(0);
670 Value s1 = op.getBlock()->getArgument(1);
671 if ((red->getOperand(0) != s0 || red->getOperand(1) != s1) &&
672 (red->getOperand(0) != s1 || red->getOperand(1) != s0))
684 auto semiring = rewriter.
create<sparse_tensor::UnaryOp>(loc, rtp, s0);
686 rewriter.
createBlock(&semiring.getPresentRegion(), {}, rtp, loc);
688 rewriter.
create<sparse_tensor::YieldOp>(loc, present->getArgument(0));
689 rewriter.
createBlock(&semiring.getAbsentRegion(), {}, {}, {});
693 rewriter.
create<sparse_tensor::YieldOp>(loc, zero);
698 auto custom = rewriter.
create<sparse_tensor::ReduceOp>(
699 loc, rtp, semiring.getResult(), s1, identity);
701 rewriter.
createBlock(&custom.getRegion(), {}, {rtp, rtp}, {loc, loc});
704 irMap.
map(red->getOperand(0), region->getArgument(0));
705 irMap.
map(red->getOperand(1), region->getArgument(1));
706 auto *cloned = rewriter.
clone(*red, irMap);
709 rewriter.
replaceOp(red, custom.getResult());
720 LogicalResult matchAndRewrite(PrintOp op,
723 auto tensor = op.getTensor();
726 auto nse = rewriter.
create<NumberOfEntriesOp>(loc, tensor);
727 rewriter.
create<vector::PrintOp>(
728 loc, rewriter.
getStringAttr(
"---- Sparse Tensor ----\nnse = "));
729 rewriter.
create<vector::PrintOp>(loc, nse);
732 printSizes(rewriter, loc, tensor, stt.getDimRank(),
true);
734 printSizes(rewriter, loc, tensor, stt.getLvlRank(),
false);
742 case SparseTensorFieldKind::StorageSpec: {
745 case SparseTensorFieldKind::PosMemRef: {
748 rewriter.
create<vector::PrintOp>(
749 loc, lvl, vector::PrintPunctuation::NoPunctuation);
751 auto pos = rewriter.
create<ToPositionsOp>(loc, tensor, l);
752 printContents(rewriter, loc, pos);
755 case SparseTensorFieldKind::CrdMemRef: {
758 rewriter.
create<vector::PrintOp>(
759 loc, lvl, vector::PrintPunctuation::NoPunctuation);
765 if (stt.getAoSCOOStart() == l)
766 crd = rewriter.
create<ToCoordinatesBufferOp>(loc, tensor);
768 crd = rewriter.
create<ToCoordinatesOp>(loc, tensor, l);
769 printContents(rewriter, loc, crd);
772 case SparseTensorFieldKind::ValMemRef: {
773 rewriter.
create<vector::PrintOp>(loc,
775 auto val = rewriter.
create<ToValuesOp>(loc, tensor);
776 printContents(rewriter, loc, val);
797 auto shape = cast<ShapedType>(vec.
getType()).getShape();
799 printContentsLevel(rewriter, loc, vec, 0, shape, idxs);
800 rewriter.
create<vector::PrintOp>(loc, vector::PrintPunctuation::NewLine);
808 rewriter.
create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
812 auto size = rewriter.
create<memref::DimOp>(loc, vec, index);
814 auto forOp = rewriter.
create<scf::ForOp>(loc, zero, size, step);
815 idxs.push_back(forOp.getInductionVar());
817 if (i < shape.size() - 1) {
819 printContentsLevel(rewriter, loc, vec, i + 1, shape, idxs);
822 auto val = rewriter.
create<memref::LoadOp>(loc, vec, idxs);
823 if (llvm::isa<ComplexType>(val.getType())) {
826 Value real = rewriter.
create<complex::ReOp>(loc, val);
827 Value imag = rewriter.
create<complex::ImOp>(loc, val);
828 rewriter.
create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
829 rewriter.
create<vector::PrintOp>(loc, real,
830 vector::PrintPunctuation::Comma);
831 rewriter.
create<vector::PrintOp>(loc, imag,
832 vector::PrintPunctuation::Close);
834 rewriter.
create<vector::PrintOp>(
835 loc, val, vector::PrintPunctuation::NoPunctuation);
838 auto bound = rewriter.
create<arith::AddIOp>(loc, idxs.back(), step);
839 Value cond = rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne,
841 scf::IfOp ifOp = rewriter.
create<scf::IfOp>(loc, cond,
false);
843 rewriter.
create<vector::PrintOp>(loc, vector::PrintPunctuation::Comma);
848 rewriter.
create<vector::PrintOp>(loc, vector::PrintPunctuation::Close);
853 unsigned size,
bool isDim) {
855 rewriter.
create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
857 for (
unsigned i = 0; i < size; i++) {
861 val = rewriter.
create<tensor::DimOp>(loc, tensor, idx);
863 val = rewriter.
create<LvlOp>(loc, tensor, idx);
864 rewriter.
create<vector::PrintOp>(
866 i != size - 1 ? vector::PrintPunctuation::Comma
867 : vector::PrintPunctuation::NoPunctuation);
870 rewriter.
create<vector::PrintOp>(loc, vector::PrintPunctuation::Close);
871 rewriter.
create<vector::PrintOp>(loc, vector::PrintPunctuation::NewLine);
880 LogicalResult matchAndRewrite(tensor::ReshapeOp op,
883 Value srcTensor = op.getSource();
886 if (!srcTp || !dstTp)
889 if (!srcTp->hasEncoding() || !dstTp->hasEncoding() ||
890 !dstTp->hasStaticDimShape())
899 Value nnz = rewriter.
create<NumberOfEntriesOp>(loc, srcTensor);
903 dstTp->withoutDimToLvl(),
904 !srcTp->isAllOrdered() || !srcTp->isIdentity() || !dstTp->isIdentity());
906 Value buffer = rewriter
907 .
create<AllocTensorOp>(loc, bufferTp, dynSizes,
Value(),
922 const auto encSrc = srcTp->getEncoding();
923 ForeachOp foreachOp = rewriter.
create<ForeachOp>(
924 loc, srcTensor, buffer,
927 const Dimension srcRank = srcTp->getDimRank();
929 srcDcvs.reserve(srcRank);
930 for (
Dimension d = 0; d < srcRank; d++) {
932 srcDcvs.push_back(srcLcvs[lvl]);
938 builder.create<arith::MulIOp>(loc, collapseSize, srcSizes[d]);
943 collapseIdx.push_back(i);
946 reshapeCvs(builder, loc, collapseReass, srcSizes, srcDcvs,
947 collapsedSizes, collapsedDcvs);
950 for (
Dimension i = 0; i < dstTp->getDimRank(); i++)
951 expandIdx.push_back(i);
954 reshapeCvs(builder, loc, expandReass, collapsedSizes, collapsedDcvs,
958 builder.create<tensor::InsertOp>(loc, v, reduc.front(), dstDcvs);
959 builder.create<sparse_tensor::YieldOp>(loc, t);
962 Value t = rewriter.
create<LoadOp>(loc, foreachOp.getResult(0),
true);
963 if (bufferTp != *dstTp) {
964 auto dstRTT = dstTp->getRankedTensorType();
965 Value converted = rewriter.
create<ConvertOp>(loc, dstRTT, t).getResult();
966 rewriter.
create<DeallocTensorOp>(loc, t);
975 template <
typename ReshapeOp>
980 LogicalResult matchAndRewrite(ReshapeOp op,
983 Value srcTensor = op.getSrc();
986 if (!srcTp.hasEncoding() || !dstTp.hasEncoding())
995 if (dstTp.hasStaticDimShape()) {
1001 op.getReassociationIndices());
1003 if (shape == ShapedType::kDynamic)
1004 dstDynSizes.push_back(dstSizes[idx]);
1007 Value nnz = rewriter.
create<NumberOfEntriesOp>(loc, srcTensor);
1011 dstTp.withoutDimToLvl(),
1012 !srcTp.isAllOrdered() || !srcTp.isIdentity() || !dstTp.isIdentity());
1016 .
create<AllocTensorOp>(loc, bufferTp, dstDynSizes,
Value(),
1027 const auto encSrc = srcTp.getEncoding();
1028 ForeachOp foreachOp = rewriter.
create<ForeachOp>(
1029 loc, srcTensor, buffer,
1032 const Dimension dimRank = srcTp.getDimRank();
1034 srcDcvs.reserve(dimRank);
1035 for (
Dimension d = 0; d < dimRank; d++) {
1037 srcDcvs.push_back(srcLcvs[lvl]);
1040 reshapeCvs(builder, loc, op.getReassociationIndices(), srcSizes,
1041 srcDcvs, dstSizes, dstDcvs);
1043 builder.create<tensor::InsertOp>(loc, v, reduc.front(), dstDcvs);
1044 builder.create<sparse_tensor::YieldOp>(loc, t);
1047 Value t = rewriter.
create<LoadOp>(loc, foreachOp.getResult(0),
true);
1048 if (bufferTp != dstTp) {
1049 auto dstRTT = dstTp.getRankedTensorType();
1050 Value converted = rewriter.
create<ConvertOp>(loc, dstRTT, t).getResult();
1051 rewriter.
create<DeallocTensorOp>(loc, t);
1061 template <
typename ReshapeOp>
1066 LogicalResult matchAndRewrite(ReshapeOp op,
1075 if (encDst && encSrc) {
1082 auto convert = rewriter.
create<ConvertOp>(loc, denseTp, op.getSrc());
1091 if constexpr (std::is_same<ReshapeOp, tensor::ExpandShapeOp>::value) {
1092 reshape = rewriter.
create<ReshapeOp>(
1093 loc, denseTp, op.getSrc(), op.getReassociation(),
1094 op.getOutputShape(), op.getStaticOutputShape());
1096 reshape = rewriter.
create<ReshapeOp>(loc, denseTp, op.getSrc(),
1097 op.getReassociation());
1099 Value convert = rewriter.
create<ConvertOp>(loc, rtp, reshape);
1115 val = builder.
create<AllocTensorOp>(loc, rtt, dynSzs);
1118 val = builder.
create<linalg::FillOp>(loc, c0, val).getResult(0);
1123 val = builder.
create<tensor::InsertOp>(loc, v, val, crds);
1128 return builder.
create<LoadOp>(loc, val,
true);
1132 bool isSparse()
const {
1139 struct SparseTensorDimOpRewriter :
public OpRewritePattern<tensor::DimOp> {
1141 LogicalResult matchAndRewrite(tensor::DimOp op,
1143 std::optional<int64_t> dim = op.getConstantIndex();
1145 if (!dim || !stt || !stt->hasEncoding())
1148 if (stt->isPermutation()) {
1150 toLvl(stt->getEncoding(), *dim));
1162 for (Level l = 0; l < stt->getLvlRank(); l++) {
1163 Value lvlSz = rewriter.
create<LvlOp>(loc, op.getSource(), l);
1166 maxLvlCrds.push_back(maxLvlCrd);
1169 AffineExpr lvl2DimExp = stt->getLvlToDim().getResult(*dim);
1170 Value maxDimCrd = rewriter.
create<affine::AffineApplyOp>(
1183 LogicalResult matchAndRewrite(ConcatenateOp op,
1185 if (op.needsExtraSort())
1186 op.emitError(
"ConcatenateOp not staged");
1190 const Dimension conDim = op.getDimension();
1207 TensorLike dstBuf(rewriter, loc, dstTp.getRankedTensorType(), sizes);
1209 Value iterArg = dstBuf.val;
1211 ForeachOp foreachOp;
1212 for (
Value input : op.getInputs()) {
1215 foreachOp = rewriter.
create<ForeachOp>(
1216 loc, input, iterArg,
1221 builder.create<arith::AddIOp>(loc, offDimCrd[conDim], offset);
1224 dstBuf.val = reduc.front();
1225 if (!dstTp.isAllDense()) {
1227 auto ifOp = builder.create<scf::IfOp>(loc, reduc.getTypes(), cond,
1229 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
1230 builder.create<scf::YieldOp>(loc, dstBuf.val);
1232 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
1233 dstBuf.insert(builder, loc, v, offDimCrd);
1234 builder.create<scf::YieldOp>(loc, dstBuf.val);
1237 builder.setInsertionPointAfter(ifOp);
1238 dstBuf.val = ifOp.getResult(0);
1240 dstBuf.insert(builder, loc, v, offDimCrd);
1242 builder.create<sparse_tensor::YieldOp>(loc, dstBuf.val);
1248 assert(!ShapedType::isDynamic(sz));
1249 offset = rewriter.
create<arith::AddIOp>(loc, offset,
1252 dstBuf.val = iterArg;
1255 dstBuf.val = iterArg;
1256 Value ret = dstBuf.finalize(rewriter, loc, dstTp.getRankedTensorType());
1264 LogicalResult matchAndRewrite(ConvertOp op,
1266 if (op.needsExtraSort())
1267 return op.emitError(
"ConvertOp not staged.");
1272 if (encDst && encSrc && !encSrc.isSlice() &&
1273 encSrc.withoutBitWidths() == encDst.withoutBitWidths()) {
1280 Value src = op.getSource();
1285 bool fromSparseConst =
false;
1286 if (
auto constOp = op.getSource().getDefiningOp<arith::ConstantOp>())
1287 if (dyn_cast<SparseElementsAttr>(constOp.getValue()))
1288 fromSparseConst =
true;
1290 const AffineMapAttr foreachOrder =
1295 bool skipZeroCheck = srcStt.
hasEncoding() || fromSparseConst;
1302 auto foreachOp = rewriter.
create<ForeachOp>(
1303 loc, src, dstBuf.val, foreachOrder,
1307 dstBuf.val = reduc.front();
1308 if (!skipZeroCheck) {
1310 auto ifOp = builder.create<scf::IfOp>(loc, reduc.getTypes(), cond,
1312 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
1313 builder.create<scf::YieldOp>(loc, dstBuf.val);
1315 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
1316 dstBuf.insert(builder, loc, v, dcvs);
1317 builder.create<scf::YieldOp>(loc, dstBuf.val);
1320 builder.setInsertionPointAfter(ifOp);
1321 dstBuf.val = ifOp.getResult(0);
1323 dstBuf.insert(builder, loc, v, dcvs);
1325 builder.create<sparse_tensor::YieldOp>(loc, dstBuf.val);
1331 dstBuf.val = foreachOp.getResult(0);
1341 LogicalResult matchAndRewrite(CrdTranslateOp op,
1343 AffineMap map = op.getDirection() == CrdTransDirectionKind::dim2lvl
1344 ? op.getEncoder().getDimToLvl()
1345 : op.getEncoder().getLvlToDim();
1352 Value trans = rewriter.
create<affine::AffineApplyOp>(
1355 outCrds.push_back(trans);
1367 LogicalResult matchAndRewrite(ForeachOp op,
1370 auto loc = op.getLoc();
1371 Value input = op.getTensor();
1374 const Level lvlRank = stt.getLvlRank();
1378 if (
auto constOp = input.
getDefiningOp<arith::ConstantOp>()) {
1379 if (
auto attr = dyn_cast<SparseElementsAttr>(constOp.getValue())) {
1385 const auto enc = stt.getEncoding();
1392 for (Level l = 0; l < lvlRank; l++) {
1396 loopEmitter.makeTensorLevel(0, l)};
1397 loopEmitter.enterNewLoopSeq(rewriter, loc, tidLvls);
1400 loopEmitter.enterCoIterationOverTensorsAtLvls(rewriter, loc, tidLvls, 1,
1405 if (op.getOrder()) {
1408 "Level order not yet implemented on non-constant input tensors.");
1411 Value vals = loopEmitter.getValBuffer()[0];
1415 Value val = enc ? rewriter.
create<memref::LoadOp>(loc, vals, pos)
1416 : rewriter.
create<memref::LoadOp>(loc, vals, lcvs);
1419 Block *srcBlock = op.getBody();
1423 enc.translateCrds(rewriter, loc, lcvs, CrdTransDirectionKind::lvl2dim);
1426 args.push_back(val);
1435 if (llvm::isa<scf::YieldOp>(last)) {
1445 for (Level l = 0; l < lvlRank; l++) {
1448 loopEmitter.exitCurrentLoop(rewriter, loc, reducValue);
1449 loopEmitter.exitCurrentLoopSeq(rewriter, loc);
1462 LogicalResult matchAndRewrite(
NewOp op,
1466 if (!stt.hasEncoding() || stt.getAoSCOOStart() == 0)
1473 RankedTensorType dstTp = stt.getRankedTensorType();
1474 RankedTensorType cooTp = stt.getCOOType(
true);
1476 Value convert = cooTensor;
1477 auto enc = stt.getEncoding();
1478 if (!stt.isPermutation()) {
1480 convert = rewriter.
create<ReinterpretMapOp>(loc, coo, convert);
1483 convert = rewriter.
create<ConvertOp>(loc, dstTp, convert);
1484 if (!stt.isPermutation())
1485 convert = rewriter.
create<ReinterpretMapOp>(loc, enc, convert);
1490 rewriter.
create<DeallocTensorOp>(loc, cooTensor);
1499 LogicalResult matchAndRewrite(OutOp op,
1503 Value src = op.getTensor();
1504 Value nnz = rewriter.
create<NumberOfEntriesOp>(loc, src);
1508 const Dimension dimRank = srcTp.getDimRank();
1516 for (
Dimension d = 0; d < dimRank; d++) {
1517 rewriter.
create<memref::StoreOp>(loc, dims[d], dimSizes,
1524 createFuncCall(rewriter, loc,
"createSparseTensorWriter", {opaqueTp},
1525 {op.getDest()}, EmitCInterface::Off)
1528 createFuncCall(rewriter, loc,
"outSparseTensorWriterMetaData", {},
1529 {writer, rankValue, nnz, dimSizes}, EmitCInterface::On);
1531 Value dimCoords = dimSizes;
1532 Type eltTp = srcTp.getElementType();
1536 ModuleOp module = op->getParentOfType<ModuleOp>();
1539 rewriter.
create<ForeachOp>(
1540 loc, src, std::nullopt,
1543 for (
Dimension d = 0; d < dimRank; d++) {
1544 rewriter.
create<memref::StoreOp>(loc, dcvs[d], dimCoords,
1547 rewriter.
create<memref::StoreOp>(loc, v, value);
1550 EmitCInterface::On);
1551 builder.create<func::CallOp>(loc,
TypeRange(), fn, operands);
1552 builder.create<sparse_tensor::YieldOp>(loc);
1556 createFuncCall(rewriter, loc,
"delSparseTensorWriter", {}, {writer},
1557 EmitCInterface::Off);
1571 patterns.
add<FuseExtractSliceWithConcat, FoldConvertIntoProducer,
1572 FoldInvariantYield, FuseSparseMultiplyOverAdd, FuseTensorCast,
1573 GenSemiRingReduction, GenSemiRingSelect, PrintRewriter>(
1579 bool enableConvert) {
1580 patterns.
add<ConcatenateRewriter, ReshapeRewriter<tensor::ExpandShapeOp>,
1581 ReshapeRewriter<tensor::CollapseShapeOp>,
1582 Sparse2SparseReshapeRewriter<tensor::ExpandShapeOp>,
1583 Sparse2SparseReshapeRewriter<tensor::CollapseShapeOp>,
1584 SparseTensorDimOpRewriter, TensorReshapeRewriter, OutRewriter>(
1596 patterns.
add<CrdTranslateRewriter, ForeachRewriter>(patterns.
getContext());
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static MLIRContext * getContext(OpFoldResult val)
static bool isMulChain(Value val, Value x)
static bool isSampling(GenericOp op)
static bool isSumOfMul(GenericOp op)
static bool isZeroValue(Value val)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static LogicalResult genForeachOnSparseConstant(ForeachOp op, RewriterBase &rewriter, SparseElementsAttr attr)
static bool isMaterializing(OpOperand *op, bool isZero)
static void concatSizesFromInputs(OpBuilder &builder, SmallVectorImpl< Value > &sizes, Location loc, ShapedType dstTp, ValueRange srcs, unsigned dim)
Populates the given sizes array for concatenation from types (for static sizes) and from the source t...
static bool isSparseTensor(Value v)
static bool isZeroYield(GenericOp op)
static void sizesForTensor(OpBuilder &builder, SmallVectorImpl< Value > &sizes, Location loc, ShapedType stp, Value tensor)
Populates given sizes array from type (for static sizes) and from the tensor (for dynamic sizes).
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Location getLoc() const
Return the location for this argument.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
IntegerAttr getIndexAttr(int64_t value)
StringAttr getStringAttr(const Twine &bytes)
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
A symbol reference with a reference path containing a single element.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Block * getBlock() const
Returns the current block of the builder.
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool hasOneUse()
Returns true if this operation has exactly one use.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Block * getBlock()
Returns the operation block that contains this operation.
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
void replaceAllOpUsesWith(Operation *from, ValueRange to)
Find uses of from and replace them with to.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
void initializeLoopEmit(OpBuilder &builder, Location loc, OutputUpdater updater=nullptr, SynTensorBoundSetter synSetter=nullptr)
Starts a loop emitting session by generating all the buffers needed for iterating over the tensors.
A wrapper around RankedTensorType, which has three goals:
Size getDynamicDimSize(Dimension d) const
Safely looks up the requested dimension-DynSize.
bool hasEncoding() const
Returns true for tensors which have an encoding, and false for those which do not.
SparseTensorType withEncoding(SparseTensorEncodingAttr newEnc) const
bool isIdentity() const
Returns true if the dimToLvl mapping is the identity.
RankedTensorType getRankedTensorType() const
Explicitly convert to RankedTensorType.
AffineMap getExpandedDimToLvl() const
Returns the dimToLvl mapping, where the identity map is expanded out into a full AffineMap.
RankedTensorType getCOOType(bool ordered) const
Returns [un]ordered COO type for this sparse tensor type.
SparseTensorEncodingAttr getEncoding() const
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
FailureOr< BaseMemRefType > getBufferType(Value value, const BufferizationOptions &options)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
FlatSymbolRefAttr getFunc(ModuleOp module, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)
Returns a function reference (first hit also inserts into module).
Value genAllocaScalar(OpBuilder &builder, Location loc, Type tp)
Generates an uninitialized temporary buffer with room for one value of the given type,...
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
void foreachInSparseConstant(OpBuilder &builder, Location loc, SparseElementsAttr attr, AffineMap order, function_ref< void(ArrayRef< Value >, Value)> callback)
Iterate over a sparse constant, generates constantOp for value and coordinates.
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)
unsigned FieldIndex
The type of field indices.
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
uint64_t Level
The type of level identifiers and level-ranks.
std::optional< SparseTensorType > tryGetSparseTensorType(Value val)
int64_t Size
The type for individual components of a compile-time shape, including the value ShapedType::kDynamic ...
RankedTensorType getRankedTensorType(T &&t)
Convenience method to abbreviate casting getType().
Type getOpaquePointerType(MLIRContext *ctx)
Returns the equivalent of void* for opaque arguments to the execution engine.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
Value genIsNonzero(OpBuilder &builder, Location loc, Value v)
Generates the comparison v != 0 where v is of numeric type.
Level toLvl(SparseTensorEncodingAttr enc, Dimension d)
Convenience method to translate the given dimension to the corresponding level.
Value genAlloca(OpBuilder &builder, Location loc, Value sz, Type tp)
Generates an uninitialized temporary buffer of the given size and type, but returns it as type memref...
void genReshapeDstShape(OpBuilder &builder, Location loc, SmallVectorImpl< Value > &dstShape, ArrayRef< Value > srcShape, ArrayRef< Size > staticDstShape, ArrayRef< ReassociationIndices > reassociation)
Computes the shape of destination tensor of a reshape operator.
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
void reshapeCvs(OpBuilder &builder, Location loc, ArrayRef< ReassociationIndices > reassociation, ValueRange srcSizes, ValueRange srcCvs, ValueRange dstSizes, SmallVectorImpl< Value > &dstCvs)
Reshape coordinates during a reshaping operation.
bool hasAnySparseOperand(Operation *op)
Returns true iff MLIR operand has any sparse operand.
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
func::CallOp createFuncCall(OpBuilder &builder, Location loc, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)
Creates a CallOp to the function reference returned by getFunc() in the builder's module.
StringRef primaryTypeFunctionSuffix(PrimaryType pt)
Convert PrimaryType to its function-name suffix.
void sizesFromSrc(OpBuilder &builder, SmallVectorImpl< Value > &sizes, Location loc, Value src)
Populates given sizes array from dense tensor or sparse tensor constant.
bool isSameTypeWithoutEncoding(Type tp1, Type tp2)
Tests if types are the same when ignoring encoding on ranked tensors.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void populatePreSparsificationRewriting(RewritePatternSet &patterns)
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
void populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns, bool enableRT, bool enableConvert)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateLowerForeachToSCFPatterns(RewritePatternSet &patterns)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This enum defines all the sparse representations supportable by the SparseTensor dialect.