50 return enc && !llvm::all_of(enc.getLvlTypes(),
51 [](
auto lt) { return lt == LevelFormat::Dense; });
74 auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
75 if (
auto *def = yieldOp.getOperand(0).getDefiningOp()) {
76 if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def)) {
78 Value s1 = op.getBlock()->getArgument(0);
79 Value s2 = op.getBlock()->getArgument(1);
80 return (def->getOperand(0) == s1 && def->getOperand(1) == s2) ||
81 (def->getOperand(1) == s1 && def->getOperand(0) == s2);
89 if (
auto arg = dyn_cast<BlockArgument>(val))
92 if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def))
101 auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
102 if (
auto *def = yieldOp.getOperand(0).getDefiningOp()) {
103 if (isa<arith::AddFOp>(def) || isa<arith::AddIOp>(def)) {
104 Value x = op.getBlock()->getArguments().back();
105 return (def->getOperand(0) == x &&
isMulChain(def->getOperand(1), x)) ||
106 (def->getOperand(1) == x &&
isMulChain(def->getOperand(0), x));
114 auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
115 if (
auto arg = dyn_cast<BlockArgument>(yieldOp.getOperand(0))) {
116 if (arg.getOwner()->getParentOp() == op) {
117 return isZeroValue(op->getOperand(arg.getArgNumber()));
127 for (
const auto &d :
enumerate(stp.getShape())) {
129 if (d.value() == ShapedType::kDynamic)
130 dim = tensor::DimOp::create(builder, loc, tensor, d.index());
133 sizes.push_back(dim);
148 for (
const auto &d :
enumerate(tp.getShape())) {
149 if (d.value() == ShapedType::kDynamic)
150 dynSizes.push_back(sizes[d.index()]);
156 SparseElementsAttr attr) {
157 auto loc = op.getLoc();
162 rewriter, loc, attr, op.getOrder().value_or(
AffineMap()),
165 args.append(cvs.begin(), cvs.end());
169 auto cloned = cast<ForeachOp>(rewriter.
clone(*op.getOperation()));
170 assert(args.size() == cloned.getBody()->getNumArguments());
171 Operation *yield = cloned.getBody()->getTerminator();
175 reduc = yield->getOperands();
189 auto dstShape = dstTp.getShape();
193 if (dstShape[dim] != ShapedType::kDynamic) {
198 for (
const auto &src : srcs.drop_front()) {
201 sizes[dim] = arith::AddIOp::create(builder, loc, sizes[dim], srcSz);
226 struct FuseExtractSliceWithConcat
230 LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractOp,
232 auto concatOp = extractOp.getSource().getDefiningOp<tensor::ConcatOp>();
237 int64_t dim = concatOp.getDim();
238 int64_t rank = extractOp.getResultType().getRank();
247 for (
auto [idx, input] :
250 partialSums.push_back(sum);
251 offsetStrides.push_back(
254 auto partialSumMap =
AffineMap::get(concatOp.getInputs().size(), 0,
258 rewriter, loc, partialSumMap, offsetStrides);
261 for (
auto [l, r] : llvm::zip(lhs, rhs)) {
266 return lhs.size() == rhs.size();
269 for (
auto [i, input, offset] :
273 srcOffsets[dim] = offset;
279 if (allEqual(srcSizes, dstSizes) && allEqual(srcOffsets, dstOffsets) &&
280 allEqual(srcStrides, dstStrides)) {
281 Value operand = concatOp.getOperand(i);
282 if (operand.
getType() == extractOp.getResultType())
297 LogicalResult matchAndRewrite(ConvertOp op,
299 auto producer = op.getSource().getDefiningOp<GenericOp>();
300 if (!producer || producer.getDpsInits().size() != 1 ||
302 !producer.getResult(0).hasOneUse()) {
307 Operation *init = producer.getDpsInitOperand(0)->get().getDefiningOp();
312 producer.getDpsInitsMutable().assign(cloned->
getResults());
313 producer.getResult(0).setType(op.getResult().getType());
328 LogicalResult matchAndRewrite(GenericOp op,
330 if (!op.hasPureTensorSemantics() || op.getNumResults() != 1 ||
332 !
isZeroYield(op) || !op.getDpsInitOperand(0)->get().hasOneUse())
338 rewriter.
replaceOp(op, op.getDpsInitOperand(0)->get());
342 if (!outputType.hasStaticShape())
344 Operation *def = op.getDpsInitOperand(0)->get().getDefiningOp();
370 LogicalResult matchAndRewrite(GenericOp op,
373 if (!op.hasPureTensorSemantics() || op.getNumDpsInputs() != 2 ||
374 op.getNumResults() != 1 ||
375 op.getNumParallelLoops() != op.getNumLoops() ||
376 !op.getMatchingIndexingMap(op.getDpsInitOperand(0)).isIdentity() ||
377 !op.getMatchingIndexingMap(op.getDpsInputOperand(0)).isIdentity() ||
378 !op.getMatchingIndexingMap(op.getDpsInputOperand(1)).isIdentity())
390 auto prod = dyn_cast_or_null<GenericOp>(
391 op.getDpsInputOperand(other)->get().getDefiningOp());
392 if (!prod || !prod.hasPureTensorSemantics() || prod.getNumResults() != 1 ||
393 !prod.getResult(0).hasOneUse())
405 inputOps.push_back(op.getDpsInputOperand(1 - other)->get());
406 fusedIndexMaps.push_back(fusedIndexMaps.back());
408 auto fusedOp = GenericOp::create(
409 rewriter, loc, op.getResult(0).getType(), inputOps, outputOps,
417 for (
unsigned i = 0; i < num - 1; i++)
418 addArg(mapper, fusedBlock, prodBlock.
getArgument(i));
419 addArg(mapper, fusedBlock, consBlock.
getArgument(1 - other));
420 addArg(mapper, fusedBlock, prodBlock.
getArgument(num - 1));
427 last = op.getResult(0);
428 rewriter.
clone(op, mapper);
433 linalg::YieldOp::create(rewriter, loc, last);
437 Value init = prod.getDpsInitOperand(0)
439 .getDefiningOp<AllocTensorOp>()
442 op.getDpsInitOperand(0)->get().getDefiningOp<AllocTensorOp>();
443 rewriter.
modifyOpInPlace(a, [&]() { a.getCopyMutable().assign(init); });
447 rewriter.
replaceOp(op, fusedOp->getResults());
467 LogicalResult matchAndRewrite(tensor::CastOp op,
469 Type srcType = op.getSource().getType();
470 Type dstType = op.getDest().getType();
472 if (srcType == dstType) {
473 rewriter.
replaceOp(op, op->getResults());
478 if (
Operation *def = op.getSource().getDefiningOp()) {
479 if (def->
hasOneUse() && isa<tensor::ExtractSliceOp>(def)) {
481 def->
getResult(0).setType(op->getResultTypes()[0]);
518 LogicalResult matchAndRewrite(GenericOp op,
528 auto matched = isRewritablePattern(op, &inst);
529 if (!matched.has_value())
533 auto [c, t, f] = matched.value();
534 assert(t.getType() == f.getType());
535 auto selTp = t.getType();
537 auto binOp = sparse_tensor::BinaryOp::create(rewriter, loc, selTp, t, f);
539 rewriter.
createBlock(&binOp.getOverlapRegion(), {}, {selTp, selTp},
540 {t.getLoc(), f.getLoc()});
541 rewriter.
createBlock(&binOp.getRightRegion(), {}, selTp, f.getLoc());
542 rewriter.
createBlock(&binOp.getLeftRegion(), {}, selTp, t.getLoc());
544 for (
auto *r : binOp.getRegions()) {
552 if (
auto *def = c.getDefiningOp())
556 if (r == &binOp.getLeftRegion()) {
559 }
else if (r == &binOp.getRightRegion()) {
567 sparse_tensor::YieldOp::create(rewriter, loc, y);
573 semiRings.emplace_back(&inst, binOp);
577 for (
auto [sel, semi] : semiRings)
578 rewriter.
replaceOp(sel, semi->getResults());
580 return success(!semiRings.empty());
584 static std::optional<std::tuple<Value, BlockArgument, BlockArgument>>
585 isRewritablePattern(GenericOp op,
Operation *v) {
586 auto sel = dyn_cast<arith::SelectOp>(v);
590 auto tVal = dyn_cast<BlockArgument>(sel.getTrueValue());
591 auto fVal = dyn_cast<BlockArgument>(sel.getFalseValue());
600 auto isValFromDenseInputOrInvariant = [&op](
Value v) ->
bool {
601 if (
auto bArg = dyn_cast<BlockArgument>(v);
602 bArg && !
isSparseTensor(op.getDpsInputOperand(bArg.getArgNumber())))
605 return v.getDefiningOp() && v.getDefiningOp()->
getBlock() != op.getBody();
610 auto cond = sel.getCondition();
611 if (isValFromDenseInputOrInvariant(cond))
612 return std::make_tuple(cond, tVal, fVal);
621 if (isValFromDenseInputOrInvariant(cmpL) ||
622 isValFromDenseInputOrInvariant(cmpR))
623 return std::make_tuple(cond, tVal, fVal);
651 LogicalResult matchAndRewrite(GenericOp op,
654 if (!op.hasPureTensorSemantics() || op.getNumDpsInputs() != 1 ||
655 op.getNumReductionLoops() == 0 || op.getNumResults() != 1)
657 auto *inp = op.getDpsInputOperand(0);
658 auto *init = op.getDpsInitOperand(0);
662 auto *red = cast<linalg::YieldOp>(op.getRegion().front().getTerminator())
665 if (!isa<arith::AndIOp, arith::MulIOp, arith::MulFOp, arith::MinimumFOp,
666 arith::MinSIOp, arith::MinUIOp, arith::MaximumFOp, arith::MaxSIOp,
667 arith::MaxUIOp>(red))
669 Value s0 = op.getBlock()->getArgument(0);
670 Value s1 = op.getBlock()->getArgument(1);
671 if ((red->getOperand(0) != s0 || red->getOperand(1) != s1) &&
672 (red->getOperand(0) != s1 || red->getOperand(1) != s0))
677 tensor::ExtractOp::create(rewriter, loc, init->get(),
ValueRange());
684 auto semiring = sparse_tensor::UnaryOp::create(rewriter, loc, rtp, s0);
686 rewriter.
createBlock(&semiring.getPresentRegion(), {}, rtp, loc);
688 sparse_tensor::YieldOp::create(rewriter, loc, present->
getArgument(0));
689 rewriter.
createBlock(&semiring.getAbsentRegion(), {}, {}, {});
692 arith::ConstantOp::create(rewriter, loc, rewriter.
getZeroAttr(rtp));
693 sparse_tensor::YieldOp::create(rewriter, loc, zero);
698 auto custom = sparse_tensor::ReduceOp::create(
699 rewriter, loc, rtp, semiring.getResult(), s1, identity);
701 rewriter.
createBlock(&custom.getRegion(), {}, {rtp, rtp}, {loc, loc});
706 auto *cloned = rewriter.
clone(*red, irMap);
707 sparse_tensor::YieldOp::create(rewriter, loc, cloned->
getResult(0));
709 rewriter.
replaceOp(red, custom.getResult());
720 LogicalResult matchAndRewrite(PrintOp op,
723 auto tensor = op.getTensor();
726 auto nse = NumberOfEntriesOp::create(rewriter, loc, tensor);
727 vector::PrintOp::create(
730 vector::PrintOp::create(rewriter, loc, nse);
732 vector::PrintOp::create(rewriter, loc, rewriter.
getStringAttr(
"dim = "));
733 printSizes(rewriter, loc, tensor, stt.getDimRank(),
true);
734 vector::PrintOp::create(rewriter, loc, rewriter.
getStringAttr(
"lvl = "));
735 printSizes(rewriter, loc, tensor, stt.getLvlRank(),
false);
743 case SparseTensorFieldKind::StorageSpec: {
746 case SparseTensorFieldKind::PosMemRef: {
748 vector::PrintOp::create(rewriter, loc, rewriter.
getStringAttr(
"pos["));
749 vector::PrintOp::create(rewriter, loc, lvl,
750 vector::PrintPunctuation::NoPunctuation);
751 vector::PrintOp::create(rewriter, loc, rewriter.
getStringAttr(
"] : "));
752 auto pos = ToPositionsOp::create(rewriter, loc, tensor, l);
753 printContents(rewriter, loc, pos);
756 case SparseTensorFieldKind::CrdMemRef: {
758 vector::PrintOp::create(rewriter, loc, rewriter.
getStringAttr(
"crd["));
759 vector::PrintOp::create(rewriter, loc, lvl,
760 vector::PrintPunctuation::NoPunctuation);
761 vector::PrintOp::create(rewriter, loc, rewriter.
getStringAttr(
"] : "));
766 if (stt.getAoSCOOStart() == l)
767 crd = ToCoordinatesBufferOp::create(rewriter, loc, tensor);
769 crd = ToCoordinatesOp::create(rewriter, loc, tensor, l);
770 printContents(rewriter, loc, crd);
773 case SparseTensorFieldKind::ValMemRef: {
774 vector::PrintOp::create(rewriter, loc,
776 auto val = ToValuesOp::create(rewriter, loc, tensor);
777 printContents(rewriter, loc, val);
783 vector::PrintOp::create(rewriter, loc, rewriter.
getStringAttr(
"----\n"));
798 auto shape = cast<ShapedType>(vec.
getType()).getShape();
800 printContentsLevel(rewriter, loc, vec, 0, shape, idxs);
801 vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::NewLine);
809 vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::Open);
813 auto size = memref::DimOp::create(rewriter, loc, vec, index);
815 auto forOp = scf::ForOp::create(rewriter, loc, zero, size, step);
816 idxs.push_back(forOp.getInductionVar());
818 if (i < shape.size() - 1) {
820 printContentsLevel(rewriter, loc, vec, i + 1, shape, idxs);
823 auto val = memref::LoadOp::create(rewriter, loc, vec, idxs);
824 if (llvm::isa<ComplexType>(val.getType())) {
827 Value real = complex::ReOp::create(rewriter, loc, val);
828 Value imag = complex::ImOp::create(rewriter, loc, val);
829 vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::Open);
830 vector::PrintOp::create(rewriter, loc, real,
831 vector::PrintPunctuation::Comma);
832 vector::PrintOp::create(rewriter, loc, imag,
833 vector::PrintPunctuation::Close);
835 vector::PrintOp::create(rewriter, loc, val,
836 vector::PrintPunctuation::NoPunctuation);
839 auto bound = arith::AddIOp::create(rewriter, loc, idxs.back(), step);
840 Value cond = arith::CmpIOp::create(rewriter, loc,
841 arith::CmpIPredicate::ne, bound, size);
842 scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, cond,
false);
844 vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::Comma);
849 vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::Close);
854 unsigned size,
bool isDim) {
856 vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::Open);
858 for (
unsigned i = 0; i < size; i++) {
862 val = tensor::DimOp::create(rewriter, loc, tensor, idx);
864 val = LvlOp::create(rewriter, loc, tensor, idx);
865 vector::PrintOp::create(rewriter, loc, val,
867 ? vector::PrintPunctuation::Comma
868 : vector::PrintPunctuation::NoPunctuation);
871 vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::Close);
872 vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::NewLine);
881 LogicalResult matchAndRewrite(tensor::ReshapeOp op,
884 Value srcTensor = op.getSource();
887 if (!srcTp || !dstTp)
890 if (!srcTp->hasEncoding() || !dstTp->hasEncoding() ||
891 !dstTp->hasStaticDimShape())
900 Value nnz = NumberOfEntriesOp::create(rewriter, loc, srcTensor);
904 dstTp->withoutDimToLvl(),
905 !srcTp->isAllOrdered() || !srcTp->isIdentity() || !dstTp->isIdentity());
907 Value buffer = AllocTensorOp::create(rewriter, loc, bufferTp, dynSizes,
922 const auto encSrc = srcTp->getEncoding();
923 ForeachOp foreachOp = ForeachOp::create(
924 rewriter, loc, srcTensor, buffer,
927 const Dimension srcRank = srcTp->getDimRank();
929 srcDcvs.reserve(srcRank);
930 for (
Dimension d = 0; d < srcRank; d++) {
932 srcDcvs.push_back(srcLcvs[lvl]);
938 arith::MulIOp::create(builder, loc, collapseSize, srcSizes[d]);
943 collapseIdx.push_back(i);
946 reshapeCvs(builder, loc, collapseReass, srcSizes, srcDcvs,
947 collapsedSizes, collapsedDcvs);
950 for (
Dimension i = 0; i < dstTp->getDimRank(); i++)
951 expandIdx.push_back(i);
954 reshapeCvs(builder, loc, expandReass, collapsedSizes, collapsedDcvs,
958 tensor::InsertOp::create(builder, loc, v, reduc.front(), dstDcvs);
959 sparse_tensor::YieldOp::create(builder, loc, t);
962 Value t = LoadOp::create(rewriter, loc, foreachOp.getResult(0),
true);
963 if (bufferTp != *dstTp) {
964 auto dstRTT = dstTp->getRankedTensorType();
965 Value converted = ConvertOp::create(rewriter, loc, dstRTT, t).getResult();
966 DeallocTensorOp::create(rewriter, loc, t);
975 template <
typename ReshapeOp>
980 LogicalResult matchAndRewrite(ReshapeOp op,
983 Value srcTensor = op.getSrc();
986 if (!srcTp.hasEncoding() || !dstTp.hasEncoding())
995 if (dstTp.hasStaticDimShape()) {
1001 op.getReassociationIndices());
1003 if (shape == ShapedType::kDynamic)
1004 dstDynSizes.push_back(dstSizes[idx]);
1007 Value nnz = NumberOfEntriesOp::create(rewriter, loc, srcTensor);
1011 dstTp.withoutDimToLvl(),
1012 !srcTp.isAllOrdered() || !srcTp.isIdentity() || !dstTp.isIdentity());
1015 AllocTensorOp::create(rewriter, loc, bufferTp, dstDynSizes,
Value(),
1026 const auto encSrc = srcTp.getEncoding();
1027 ForeachOp foreachOp = ForeachOp::create(
1028 rewriter, loc, srcTensor, buffer,
1031 const Dimension dimRank = srcTp.getDimRank();
1033 srcDcvs.reserve(dimRank);
1034 for (
Dimension d = 0; d < dimRank; d++) {
1036 srcDcvs.push_back(srcLcvs[lvl]);
1039 reshapeCvs(builder, loc, op.getReassociationIndices(), srcSizes,
1040 srcDcvs, dstSizes, dstDcvs);
1042 tensor::InsertOp::create(builder, loc, v, reduc.front(), dstDcvs);
1043 sparse_tensor::YieldOp::create(builder, loc, t);
1046 Value t = LoadOp::create(rewriter, loc, foreachOp.getResult(0),
true);
1047 if (bufferTp != dstTp) {
1048 auto dstRTT = dstTp.getRankedTensorType();
1049 Value converted = ConvertOp::create(rewriter, loc, dstRTT, t).getResult();
1050 DeallocTensorOp::create(rewriter, loc, t);
1060 template <
typename ReshapeOp>
1065 LogicalResult matchAndRewrite(ReshapeOp op,
1074 if (encDst && encSrc) {
1081 auto convert = ConvertOp::create(rewriter, loc, denseTp, op.getSrc());
1090 if constexpr (std::is_same<ReshapeOp, tensor::ExpandShapeOp>::value) {
1091 reshape = ReshapeOp::create(rewriter, loc, denseTp, op.getSrc(),
1092 op.getReassociation(), op.getOutputShape(),
1093 op.getStaticOutputShape());
1095 reshape = ReshapeOp::create(rewriter, loc, denseTp, op.getSrc(),
1096 op.getReassociation());
1098 Value convert = ConvertOp::create(rewriter, loc, rtp, reshape);
1114 val = AllocTensorOp::create(builder, loc, rtt, dynSzs);
1117 val = linalg::FillOp::create(builder, loc, c0, val).getResult(0);
1122 val = tensor::InsertOp::create(builder, loc, v, val, crds);
1127 return LoadOp::create(builder, loc, val,
true);
1131 bool isSparse()
const {
1138 struct SparseTensorDimOpRewriter :
public OpRewritePattern<tensor::DimOp> {
1140 LogicalResult matchAndRewrite(tensor::DimOp op,
1142 std::optional<int64_t> dim = op.getConstantIndex();
1144 if (!dim || !stt || !stt->hasEncoding())
1147 if (stt->isPermutation()) {
1149 toLvl(stt->getEncoding(), *dim));
1161 for (Level l = 0; l < stt->getLvlRank(); l++) {
1162 Value lvlSz = LvlOp::create(rewriter, loc, op.getSource(), l);
1163 Value maxLvlCrd = arith::SubIOp::create(
1164 rewriter, loc, lvlSz,
1166 maxLvlCrds.push_back(maxLvlCrd);
1169 AffineExpr lvl2DimExp = stt->getLvlToDim().getResult(*dim);
1170 Value maxDimCrd = affine::AffineApplyOp::create(
1171 rewriter, op.getLoc(),
AffineMap::get(stt->getLvlRank(), 0, lvl2DimExp),
1174 Value dimSz = arith::AddIOp::create(
1175 rewriter, loc, maxDimCrd,
1184 LogicalResult matchAndRewrite(ConcatenateOp op,
1186 if (op.needsExtraSort())
1187 op.emitError(
"ConcatenateOp not staged");
1191 const Dimension conDim = op.getDimension();
1208 TensorLike dstBuf(rewriter, loc, dstTp.getRankedTensorType(), sizes);
1210 Value iterArg = dstBuf.val;
1212 ForeachOp foreachOp;
1213 for (
Value input : op.getInputs()) {
1216 foreachOp = ForeachOp::create(
1217 rewriter, loc, input, iterArg,
1222 arith::AddIOp::create(builder, loc, offDimCrd[conDim], offset);
1225 dstBuf.val = reduc.front();
1226 if (!dstTp.isAllDense()) {
1229 scf::IfOp::create(builder, loc, reduc.
getTypes(), cond,
1232 scf::YieldOp::create(builder, loc, dstBuf.val);
1235 dstBuf.insert(builder, loc, v, offDimCrd);
1236 scf::YieldOp::create(builder, loc, dstBuf.val);
1240 dstBuf.val = ifOp.getResult(0);
1242 dstBuf.insert(builder, loc, v, offDimCrd);
1244 sparse_tensor::YieldOp::create(builder, loc, dstBuf.val);
1250 assert(ShapedType::isStatic(sz));
1251 offset = arith::AddIOp::create(rewriter, loc, offset,
1253 iterArg = foreachOp.getResult(0);
1254 dstBuf.val = iterArg;
1257 dstBuf.val = iterArg;
1258 Value ret = dstBuf.finalize(rewriter, loc, dstTp.getRankedTensorType());
1266 LogicalResult matchAndRewrite(ConvertOp op,
1268 if (op.needsExtraSort())
1269 return op.emitError(
"ConvertOp not staged.");
1274 if (encDst && encSrc && !encSrc.isSlice() &&
1275 encSrc.withoutBitWidths() == encDst.withoutBitWidths()) {
1282 Value src = op.getSource();
1287 bool fromSparseConst =
false;
1288 if (
auto constOp = op.getSource().getDefiningOp<arith::ConstantOp>())
1289 if (isa<SparseElementsAttr>(constOp.getValue()))
1290 fromSparseConst =
true;
1292 const AffineMapAttr foreachOrder =
1297 bool skipZeroCheck = srcStt.
hasEncoding() || fromSparseConst;
1304 auto foreachOp = ForeachOp::create(
1305 rewriter, loc, src, dstBuf.val, foreachOrder,
1309 dstBuf.val = reduc.front();
1310 if (!skipZeroCheck) {
1311 Value cond = genIsNonzero(builder, loc, v);
1312 auto ifOp = scf::IfOp::create(builder, loc, reduc.getTypes(), cond,
1314 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
1315 scf::YieldOp::create(builder, loc, dstBuf.val);
1317 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
1318 dstBuf.insert(builder, loc, v, dcvs);
1319 scf::YieldOp::create(builder, loc, dstBuf.val);
1322 builder.setInsertionPointAfter(ifOp);
1323 dstBuf.val = ifOp.getResult(0);
1325 dstBuf.insert(builder, loc, v, dcvs);
1327 sparse_tensor::YieldOp::create(builder, loc, dstBuf.val);
1333 dstBuf.val = foreachOp.getResult(0);
1343 LogicalResult matchAndRewrite(CrdTranslateOp op,
1345 AffineMap map = op.getDirection() == CrdTransDirectionKind::dim2lvl
1346 ? op.getEncoder().getDimToLvl()
1347 : op.getEncoder().getLvlToDim();
1354 Value trans = affine::AffineApplyOp::create(
1357 outCrds.push_back(trans);
1369 LogicalResult matchAndRewrite(ForeachOp op,
1372 auto loc = op.getLoc();
1373 Value input = op.getTensor();
1376 const Level lvlRank = stt.getLvlRank();
1380 if (
auto constOp = input.
getDefiningOp<arith::ConstantOp>()) {
1381 if (
auto attr = dyn_cast<SparseElementsAttr>(constOp.getValue())) {
1387 const auto enc = stt.getEncoding();
1394 for (Level l = 0; l < lvlRank; l++) {
1398 loopEmitter.makeTensorLevel(0, l)};
1399 loopEmitter.enterNewLoopSeq(rewriter, loc, tidLvls);
1402 loopEmitter.enterCoIterationOverTensorsAtLvls(rewriter, loc, tidLvls, 1,
1407 if (op.getOrder()) {
1410 "Level order not yet implemented on non-constant input tensors.");
1413 Value vals = loopEmitter.getValBuffer()[0];
1417 Value val = enc ? memref::LoadOp::create(rewriter, loc, vals, pos)
1418 : memref::LoadOp::create(rewriter, loc, vals, lcvs);
1421 Block *srcBlock = op.getBody();
1425 enc.translateCrds(rewriter, loc, lcvs, CrdTransDirectionKind::lvl2dim);
1428 args.push_back(val);
1437 if (llvm::isa<scf::YieldOp>(last)) {
1447 for (Level l = 0; l < lvlRank; l++) {
1450 loopEmitter.exitCurrentLoop(rewriter, loc, reducValue);
1451 loopEmitter.exitCurrentLoopSeq(rewriter, loc);
1464 LogicalResult matchAndRewrite(
NewOp op,
1468 if (!stt.hasEncoding() || stt.getAoSCOOStart() == 0)
1475 RankedTensorType dstTp = stt.getRankedTensorType();
1476 RankedTensorType cooTp = stt.getCOOType(
true);
1477 Value cooTensor = NewOp::create(rewriter, loc, cooTp, op.getSource());
1478 Value convert = cooTensor;
1479 auto enc = stt.getEncoding();
1480 if (!stt.isPermutation()) {
1482 convert = ReinterpretMapOp::create(rewriter, loc, coo, convert);
1485 convert = ConvertOp::create(rewriter, loc, dstTp, convert);
1486 if (!stt.isPermutation())
1487 convert = ReinterpretMapOp::create(rewriter, loc, enc, convert);
1492 DeallocTensorOp::create(rewriter, loc, cooTensor);
1501 LogicalResult matchAndRewrite(OutOp op,
1505 Value src = op.getTensor();
1506 Value nnz = NumberOfEntriesOp::create(rewriter, loc, src);
1510 const Dimension dimRank = srcTp.getDimRank();
1518 for (
Dimension d = 0; d < dimRank; d++) {
1519 memref::StoreOp::create(rewriter, loc, dims[d], dimSizes,
1526 createFuncCall(rewriter, loc,
"createSparseTensorWriter", {opaqueTp},
1527 {op.getDest()}, EmitCInterface::Off)
1530 createFuncCall(rewriter, loc,
"outSparseTensorWriterMetaData", {},
1531 {writer, rankValue, nnz, dimSizes}, EmitCInterface::On);
1533 Value dimCoords = dimSizes;
1534 Type eltTp = srcTp.getElementType();
1538 ModuleOp module = op->getParentOfType<ModuleOp>();
1545 for (
Dimension d = 0; d < dimRank; d++) {
1546 memref::StoreOp::create(rewriter, loc, dcvs[d], dimCoords,
1549 memref::StoreOp::create(rewriter, loc, v, value);
1552 EmitCInterface::On);
1553 func::CallOp::create(builder, loc,
TypeRange(), fn, operands);
1554 sparse_tensor::YieldOp::create(builder, loc);
1558 createFuncCall(rewriter, loc,
"delSparseTensorWriter", {}, {writer},
1559 EmitCInterface::Off);
1573 patterns.add<FuseExtractSliceWithConcat, FoldConvertIntoProducer,
1574 FoldInvariantYield, FuseSparseMultiplyOverAdd, FuseTensorCast,
1575 GenSemiRingReduction, GenSemiRingSelect, PrintRewriter>(
1581 bool enableConvert) {
1582 patterns.add<ConcatenateRewriter, ReshapeRewriter<tensor::ExpandShapeOp>,
1583 ReshapeRewriter<tensor::CollapseShapeOp>,
1584 Sparse2SparseReshapeRewriter<tensor::ExpandShapeOp>,
1585 Sparse2SparseReshapeRewriter<tensor::CollapseShapeOp>,
1586 SparseTensorDimOpRewriter, TensorReshapeRewriter, OutRewriter>(
1598 patterns.add<CrdTranslateRewriter, ForeachRewriter>(
patterns.getContext());
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1242::ArityGroupAndKind::Kind kind
static bool isMulChain(Value val, Value x)
static bool isSampling(GenericOp op)
static bool isSumOfMul(GenericOp op)
static bool isZeroValue(Value val)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static LogicalResult genForeachOnSparseConstant(ForeachOp op, RewriterBase &rewriter, SparseElementsAttr attr)
static bool isMaterializing(OpOperand *op, bool isZero)
static void concatSizesFromInputs(OpBuilder &builder, SmallVectorImpl< Value > &sizes, Location loc, ShapedType dstTp, ValueRange srcs, unsigned dim)
Populates the given sizes array for concatenation from types (for static sizes) and from the source t...
static bool isSparseTensor(Value v)
static bool isZeroYield(GenericOp op)
static void sizesForTensor(OpBuilder &builder, SmallVectorImpl< Value > &sizes, Location loc, ShapedType stp, Value tensor)
Populates given sizes array from type (for static sizes) and from the tensor (for dynamic sizes).
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Location getLoc() const
Return the location for this argument.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
IntegerAttr getIndexAttr(int64_t value)
StringAttr getStringAttr(const Twine &bytes)
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
A symbol reference with a reference path containing a single element.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Block * getBlock() const
Returns the current block of the builder.
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool hasOneUse()
Returns true if this operation has exactly one use.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Block * getBlock()
Returns the operation block that contains this operation.
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void replaceAllOpUsesWith(Operation *from, ValueRange to)
Find uses of from and replace them with to.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
void initializeLoopEmit(OpBuilder &builder, Location loc, OutputUpdater updater=nullptr, SynTensorBoundSetter synSetter=nullptr)
Starts a loop emitting session by generating all the buffers needed for iterating over the tensors.
A wrapper around RankedTensorType, which has three goals:
Size getDynamicDimSize(Dimension d) const
Safely looks up the requested dimension-DynSize.
bool hasEncoding() const
Returns true for tensors which have an encoding, and false for those which do not.
SparseTensorType withEncoding(SparseTensorEncodingAttr newEnc) const
bool isIdentity() const
Returns true if the dimToLvl mapping is the identity.
RankedTensorType getRankedTensorType() const
Explicitly convert to RankedTensorType.
AffineMap getExpandedDimToLvl() const
Returns the dimToLvl mapping, where the identity map is expanded out into a full AffineMap.
RankedTensorType getCOOType(bool ordered) const
Returns [un]ordered COO type for this sparse tensor type.
SparseTensorEncodingAttr getEncoding() const
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
FailureOr< BufferLikeType > getBufferType(Value value, const BufferizationOptions &options, const BufferizationState &state)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
FlatSymbolRefAttr getFunc(ModuleOp module, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)
Returns a function reference (first hit also inserts into module).
Value genAllocaScalar(OpBuilder &builder, Location loc, Type tp)
Generates an uninitialized temporary buffer with room for one value of the given type,...
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
void foreachInSparseConstant(OpBuilder &builder, Location loc, SparseElementsAttr attr, AffineMap order, function_ref< void(ArrayRef< Value >, Value)> callback)
Iterate over a sparse constant, generates constantOp for value and coordinates.
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)
unsigned FieldIndex
The type of field indices.
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
uint64_t Level
The type of level identifiers and level-ranks.
std::optional< SparseTensorType > tryGetSparseTensorType(Value val)
int64_t Size
The type for individual components of a compile-time shape, including the value ShapedType::kDynamic ...
RankedTensorType getRankedTensorType(T &&t)
Convenience method to abbreviate casting getType().
Type getOpaquePointerType(MLIRContext *ctx)
Returns the equivalent of void* for opaque arguments to the execution engine.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
Value genIsNonzero(OpBuilder &builder, Location loc, Value v)
Generates the comparison v != 0 where v is of numeric type.
Level toLvl(SparseTensorEncodingAttr enc, Dimension d)
Convenience method to translate the given dimension to the corresponding level.
Value genAlloca(OpBuilder &builder, Location loc, Value sz, Type tp)
Generates an uninitialized temporary buffer of the given size and type, but returns it as type memref...
void genReshapeDstShape(OpBuilder &builder, Location loc, SmallVectorImpl< Value > &dstShape, ArrayRef< Value > srcShape, ArrayRef< Size > staticDstShape, ArrayRef< ReassociationIndices > reassociation)
Computes the shape of destination tensor of a reshape operator.
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
void reshapeCvs(OpBuilder &builder, Location loc, ArrayRef< ReassociationIndices > reassociation, ValueRange srcSizes, ValueRange srcCvs, ValueRange dstSizes, SmallVectorImpl< Value > &dstCvs)
Reshape coordinates during a reshaping operation.
bool hasAnySparseOperand(Operation *op)
Returns true iff MLIR operand has any sparse operand.
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
func::CallOp createFuncCall(OpBuilder &builder, Location loc, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)
Creates a CallOp to the function reference returned by getFunc() in the builder's module.
StringRef primaryTypeFunctionSuffix(PrimaryType pt)
Convert PrimaryType to its function-name suffix.
void sizesFromSrc(OpBuilder &builder, SmallVectorImpl< Value > &sizes, Location loc, Value src)
Populates given sizes array from dense tensor or sparse tensor constant.
bool isSameTypeWithoutEncoding(Type tp1, Type tp2)
Tests if types are the same when ignoring encoding on ranked tensors.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void populatePreSparsificationRewriting(RewritePatternSet &patterns)
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
const FrozenRewritePatternSet & patterns
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
void populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns, bool enableRT, bool enableConvert)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateLowerForeachToSCFPatterns(RewritePatternSet &patterns)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This enum defines all the sparse representations supportable by the SparseTensor dialect.