45 return enc && !llvm::all_of(enc.getLvlTypes(),
46 [](
auto lt) { return lt == LevelFormat::Dense; });
69 auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
70 if (
auto *def = yieldOp.getOperand(0).getDefiningOp()) {
71 if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def)) {
73 Value s1 = op.getBlock()->getArgument(0);
74 Value s2 = op.getBlock()->getArgument(1);
75 return (def->getOperand(0) == s1 && def->getOperand(1) == s2) ||
76 (def->getOperand(1) == s1 && def->getOperand(0) == s2);
84 if (
auto arg = dyn_cast<BlockArgument>(val))
87 if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def))
96 auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
97 if (
auto *def = yieldOp.getOperand(0).getDefiningOp()) {
98 if (isa<arith::AddFOp>(def) || isa<arith::AddIOp>(def)) {
99 Value x = op.getBlock()->getArguments().back();
100 return (def->getOperand(0) == x &&
isMulChain(def->getOperand(1), x)) ||
101 (def->getOperand(1) == x &&
isMulChain(def->getOperand(0), x));
109 auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
110 if (
auto arg = dyn_cast<BlockArgument>(yieldOp.getOperand(0))) {
111 if (arg.getOwner()->getParentOp() == op) {
122 for (
const auto &d : enumerate(stp.getShape())) {
124 if (d.value() == ShapedType::kDynamic)
125 dim = tensor::DimOp::create(builder, loc,
tensor, d.index());
128 sizes.push_back(dim);
143 for (
const auto &d : enumerate(tp.getShape())) {
144 if (d.value() == ShapedType::kDynamic)
145 dynSizes.push_back(sizes[d.index()]);
151 SparseElementsAttr attr) {
152 auto loc = op.getLoc();
157 rewriter, loc, attr, op.getOrder().value_or(
AffineMap()),
160 args.append(cvs.begin(), cvs.end());
164 auto cloned = cast<ForeachOp>(rewriter.
clone(*op.getOperation()));
165 assert(args.size() == cloned.getBody()->getNumArguments());
166 Operation *yield = cloned.getBody()->getTerminator();
184 auto dstShape = dstTp.getShape();
188 if (dstShape[dim] != ShapedType::kDynamic) {
193 for (
const auto &src : srcs.drop_front()) {
196 sizes[dim] = arith::AddIOp::create(builder, loc, sizes[dim], srcSz);
221struct FuseExtractSliceWithConcat
223 using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
225 LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractOp,
226 PatternRewriter &rewriter)
const override {
227 auto concatOp = extractOp.getSource().getDefiningOp<tensor::ConcatOp>();
231 Location loc = extractOp.getLoc();
232 int64_t dim = concatOp.getDim();
233 int64_t rank = extractOp.getResultType().getRank();
235 SmallVector<OpFoldResult> srcStrides(rank, rewriter.
getIndexAttr(1));
236 SmallVector<OpFoldResult> srcOffsets(rank, rewriter.
getIndexAttr(0));
240 SmallVector<AffineExpr> partialSums = {sum};
241 SmallVector<OpFoldResult> offsetStrides = {rewriter.
getIndexAttr(0)};
242 for (
auto [idx, input] :
243 llvm::enumerate(concatOp.getInputs().drop_back())) {
245 partialSums.push_back(sum);
246 offsetStrides.push_back(
249 auto partialSumMap =
AffineMap::get(concatOp.getInputs().size(), 0,
251 SmallVector<OpFoldResult> dimOffsets =
253 rewriter, loc, partialSumMap, offsetStrides);
255 auto allEqual = [](ArrayRef<OpFoldResult>
lhs, ArrayRef<OpFoldResult>
rhs) {
256 for (
auto [l, r] : llvm::zip(
lhs,
rhs)) {
261 return lhs.size() ==
rhs.size();
264 for (
auto [i, input, offset] :
265 llvm::enumerate(concatOp.getInputs(), dimOffsets)) {
266 SmallVector<OpFoldResult> srcSizes =
268 srcOffsets[dim] = offset;
270 SmallVector<OpFoldResult> dstSizes = extractOp.getMixedSizes();
271 SmallVector<OpFoldResult> dstOffsets = extractOp.getMixedOffsets();
272 SmallVector<OpFoldResult> dstStrides = extractOp.getMixedStrides();
274 if (allEqual(srcSizes, dstSizes) && allEqual(srcOffsets, dstOffsets) &&
275 allEqual(srcStrides, dstStrides)) {
276 Value operand = concatOp.getOperand(i);
277 if (operand.
getType() == extractOp.getResultType())
292 LogicalResult matchAndRewrite(ConvertOp op,
293 PatternRewriter &rewriter)
const override {
294 auto producer = op.getSource().getDefiningOp<GenericOp>();
295 if (!producer || producer.getDpsInits().size() != 1 ||
297 !producer.getResult(0).hasOneUse()) {
302 Operation *init = producer.getDpsInitOperand(0)->get().getDefiningOp();
303 Operation *cloned = rewriter.
clone(*init);
307 producer.getDpsInitsMutable().assign(cloned->
getResults());
308 producer.getResult(0).setType(op.getResult().getType());
321 using OpRewritePattern<GenericOp>::OpRewritePattern;
323 LogicalResult matchAndRewrite(GenericOp op,
324 PatternRewriter &rewriter)
const override {
325 if (!op.hasPureTensorSemantics() || op.getNumResults() != 1 ||
327 !
isZeroYield(op) || !op.getDpsInitOperand(0)->get().hasOneUse())
333 rewriter.
replaceOp(op, op.getDpsInitOperand(0)->get());
337 if (!outputType.hasStaticShape())
339 Operation *def = op.getDpsInitOperand(0)->get().getDefiningOp();
363 using OpRewritePattern<GenericOp>::OpRewritePattern;
365 LogicalResult matchAndRewrite(GenericOp op,
366 PatternRewriter &rewriter)
const override {
368 if (!op.hasPureTensorSemantics() || op.getNumDpsInputs() != 2 ||
369 op.getNumResults() != 1 ||
370 op.getNumParallelLoops() != op.getNumLoops() ||
371 !op.getMatchingIndexingMap(op.getDpsInitOperand(0)).isIdentity() ||
372 !op.getMatchingIndexingMap(op.getDpsInputOperand(0)).isIdentity() ||
373 !op.getMatchingIndexingMap(op.getDpsInputOperand(1)).isIdentity())
385 auto prod = dyn_cast_or_null<GenericOp>(
386 op.getDpsInputOperand(other)->get().getDefiningOp());
387 if (!prod || !prod.hasPureTensorSemantics() || prod.getNumResults() != 1 ||
388 !prod.getResult(0).hasOneUse())
396 Location loc = prod.getLoc();
397 SmallVector<Value> inputOps = prod.getInputs();
398 SmallVector<Value> outputOps = op.getOutputs();
399 SmallVector<AffineMap> fusedIndexMaps = prod.getIndexingMapsArray();
400 inputOps.push_back(op.getDpsInputOperand(1 - other)->get());
401 fusedIndexMaps.push_back(fusedIndexMaps.back());
403 auto fusedOp = GenericOp::create(
404 rewriter, loc, op.getResult(0).getType(), inputOps, outputOps,
407 Block &prodBlock = prod.getRegion().front();
408 Block &consBlock = op.getRegion().front();
412 for (
unsigned i = 0; i < num - 1; i++)
413 addArg(mapper, fusedBlock, prodBlock.
getArgument(i));
414 addArg(mapper, fusedBlock, consBlock.
getArgument(1 - other));
415 addArg(mapper, fusedBlock, prodBlock.
getArgument(num - 1));
422 last = op.getResult(0);
423 rewriter.
clone(op, mapper);
428 linalg::YieldOp::create(rewriter, loc, last);
432 Value init = prod.getDpsInitOperand(0)
434 .getDefiningOp<AllocTensorOp>()
437 op.getDpsInitOperand(0)->get().getDefiningOp<AllocTensorOp>();
438 rewriter.
modifyOpInPlace(a, [&]() { a.getCopyMutable().assign(init); });
442 rewriter.
replaceOp(op, fusedOp->getResults());
448 static void addArg(IRMapping &mapper,
Block *
b, BlockArgument a) {
460 using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
462 LogicalResult matchAndRewrite(tensor::CastOp op,
463 PatternRewriter &rewriter)
const override {
464 Type srcType = op.getSource().getType();
465 Type dstType = op.getDest().getType();
467 if (srcType == dstType) {
468 rewriter.
replaceOp(op, op->getResults());
473 if (Operation *def = op.getSource().getDefiningOp()) {
474 if (def->
hasOneUse() && isa<tensor::ExtractSliceOp>(def)) {
512 using OpRewritePattern<GenericOp>::OpRewritePattern;
513 LogicalResult matchAndRewrite(GenericOp op,
514 PatternRewriter &rewriter)
const override {
519 Location loc = op.getLoc();
520 SmallVector<std::pair<Operation *, sparse_tensor::BinaryOp>> semiRings;
521 for (Operation &inst : *op.getBody()) {
523 auto matched = isRewritablePattern(op, &inst);
524 if (!matched.has_value())
528 auto [c, t, f] = matched.value();
529 assert(t.getType() == f.getType());
530 auto selTp = t.getType();
532 auto binOp = sparse_tensor::BinaryOp::create(rewriter, loc, selTp, t, f);
534 rewriter.
createBlock(&binOp.getOverlapRegion(), {}, {selTp, selTp},
535 {t.getLoc(), f.getLoc()});
536 rewriter.
createBlock(&binOp.getRightRegion(), {}, selTp, f.getLoc());
537 rewriter.
createBlock(&binOp.getLeftRegion(), {}, selTp, t.getLoc());
539 for (
auto *r : binOp.getRegions()) {
547 if (
auto *def = c.getDefiningOp())
551 if (r == &binOp.getLeftRegion()) {
552 irMap.
map(t,
b->getArgument(0));
554 }
else if (r == &binOp.getRightRegion()) {
556 irMap.
map(f,
b->getArgument(0));
558 irMap.
map(t,
b->getArgument(0));
559 irMap.
map(f,
b->getArgument(1));
562 sparse_tensor::YieldOp::create(rewriter, loc, y);
568 semiRings.emplace_back(&inst, binOp);
572 for (
auto [sel, semi] : semiRings)
573 rewriter.
replaceOp(sel, semi->getResults());
575 return success(!semiRings.empty());
579 static std::optional<std::tuple<Value, BlockArgument, BlockArgument>>
580 isRewritablePattern(GenericOp op, Operation *v) {
581 auto sel = dyn_cast<arith::SelectOp>(v);
585 auto tVal = dyn_cast<BlockArgument>(sel.getTrueValue());
586 auto fVal = dyn_cast<BlockArgument>(sel.getFalseValue());
595 auto isValFromDenseInputOrInvariant = [&op](Value v) ->
bool {
596 if (
auto bArg = dyn_cast<BlockArgument>(v);
597 bArg && !
isSparseTensor(op.getDpsInputOperand(bArg.getArgNumber())))
600 return v.getDefiningOp() && v.getDefiningOp()->
getBlock() != op.getBody();
605 auto cond = sel.getCondition();
606 if (isValFromDenseInputOrInvariant(cond))
607 return std::make_tuple(cond, tVal, fVal);
616 if (isValFromDenseInputOrInvariant(cmpL) ||
617 isValFromDenseInputOrInvariant(cmpR))
618 return std::make_tuple(cond, tVal, fVal);
644 using OpRewritePattern<GenericOp>::OpRewritePattern;
646 LogicalResult matchAndRewrite(GenericOp op,
647 PatternRewriter &rewriter)
const override {
649 if (!op.hasPureTensorSemantics() || op.getNumDpsInputs() != 1 ||
650 op.getNumReductionLoops() == 0 || op.getNumResults() != 1)
652 auto *inp = op.getDpsInputOperand(0);
653 auto *init = op.getDpsInitOperand(0);
657 auto *red = cast<linalg::YieldOp>(op.getRegion().front().getTerminator())
660 if (!isa<arith::AndIOp, arith::MulIOp, arith::MulFOp, arith::MinimumFOp,
661 arith::MinSIOp, arith::MinUIOp, arith::MaximumFOp, arith::MaxSIOp,
662 arith::MaxUIOp>(red))
664 Value s0 = op.getBlock()->getArgument(0);
665 Value s1 = op.getBlock()->getArgument(1);
666 if ((red->getOperand(0) != s0 || red->getOperand(1) != s1) &&
667 (red->getOperand(0) != s1 || red->getOperand(1) != s0))
670 Location loc = op.getLoc();
672 tensor::ExtractOp::create(rewriter, loc, init->get(),
ValueRange());
679 auto semiring = sparse_tensor::UnaryOp::create(rewriter, loc, rtp, s0);
681 rewriter.
createBlock(&semiring.getPresentRegion(), {}, rtp, loc);
683 sparse_tensor::YieldOp::create(rewriter, loc, present->
getArgument(0));
684 rewriter.
createBlock(&semiring.getAbsentRegion(), {}, {}, {});
687 arith::ConstantOp::create(rewriter, loc, rewriter.
getZeroAttr(rtp));
688 sparse_tensor::YieldOp::create(rewriter, loc, zero);
693 auto custom = sparse_tensor::ReduceOp::create(
694 rewriter, loc, rtp, semiring.getResult(), s1, identity);
696 rewriter.
createBlock(&custom.getRegion(), {}, {rtp, rtp}, {loc, loc});
701 auto *cloned = rewriter.
clone(*red, irMap);
702 sparse_tensor::YieldOp::create(rewriter, loc, cloned->
getResult(0));
704 rewriter.
replaceOp(red, custom.getResult());
715 LogicalResult matchAndRewrite(PrintOp op,
716 PatternRewriter &rewriter)
const override {
717 Location loc = op.getLoc();
718 auto tensor = op.getTensor();
721 auto nse = NumberOfEntriesOp::create(rewriter, loc, tensor);
722 vector::PrintOp::create(
725 vector::PrintOp::create(rewriter, loc, nse);
727 vector::PrintOp::create(rewriter, loc, rewriter.
getStringAttr(
"dim = "));
728 printSizes(rewriter, loc, tensor, stt.getDimRank(),
true);
729 vector::PrintOp::create(rewriter, loc, rewriter.
getStringAttr(
"lvl = "));
730 printSizes(rewriter, loc, tensor, stt.getLvlRank(),
false);
736 Level l, LevelType) {
738 case SparseTensorFieldKind::StorageSpec: {
741 case SparseTensorFieldKind::PosMemRef: {
743 vector::PrintOp::create(rewriter, loc, rewriter.
getStringAttr(
"pos["));
744 vector::PrintOp::create(rewriter, loc, lvl,
745 vector::PrintPunctuation::NoPunctuation);
746 vector::PrintOp::create(rewriter, loc, rewriter.
getStringAttr(
"] : "));
747 auto pos = ToPositionsOp::create(rewriter, loc, tensor, l);
748 printContents(rewriter, loc, pos);
751 case SparseTensorFieldKind::CrdMemRef: {
753 vector::PrintOp::create(rewriter, loc, rewriter.
getStringAttr(
"crd["));
754 vector::PrintOp::create(rewriter, loc, lvl,
755 vector::PrintPunctuation::NoPunctuation);
756 vector::PrintOp::create(rewriter, loc, rewriter.
getStringAttr(
"] : "));
761 if (stt.getAoSCOOStart() == l)
762 crd = ToCoordinatesBufferOp::create(rewriter, loc, tensor);
764 crd = ToCoordinatesOp::create(rewriter, loc, tensor, l);
765 printContents(rewriter, loc, crd);
768 case SparseTensorFieldKind::ValMemRef: {
769 vector::PrintOp::create(rewriter, loc,
771 auto val = ToValuesOp::create(rewriter, loc, tensor);
772 printContents(rewriter, loc, val);
778 vector::PrintOp::create(rewriter, loc, rewriter.
getStringAttr(
"----\n"));
791 static void printContents(PatternRewriter &rewriter, Location loc,
793 auto shape = cast<ShapedType>(vec.
getType()).getShape();
794 SmallVector<Value> idxs;
795 printContentsLevel(rewriter, loc, vec, 0, shape, idxs);
796 vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::NewLine);
800 static void printContentsLevel(PatternRewriter &rewriter, Location loc,
801 Value vec,
unsigned i, ArrayRef<int64_t> shape,
802 SmallVectorImpl<Value> &idxs) {
804 vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::Open);
808 auto size = memref::DimOp::create(rewriter, loc, vec, index);
810 auto forOp = scf::ForOp::create(rewriter, loc, zero, size, step);
811 idxs.push_back(forOp.getInductionVar());
813 if (i < shape.size() - 1) {
815 printContentsLevel(rewriter, loc, vec, i + 1, shape, idxs);
818 auto val = memref::LoadOp::create(rewriter, loc, vec, idxs);
819 if (llvm::isa<ComplexType>(val.getType())) {
822 Value real = complex::ReOp::create(rewriter, loc, val);
823 Value imag = complex::ImOp::create(rewriter, loc, val);
824 vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::Open);
825 vector::PrintOp::create(rewriter, loc, real,
826 vector::PrintPunctuation::Comma);
827 vector::PrintOp::create(rewriter, loc, imag,
828 vector::PrintPunctuation::Close);
830 vector::PrintOp::create(rewriter, loc, val,
831 vector::PrintPunctuation::NoPunctuation);
834 auto bound = arith::AddIOp::create(rewriter, loc, idxs.back(), step);
835 Value cond = arith::CmpIOp::create(rewriter, loc,
836 arith::CmpIPredicate::ne, bound, size);
837 scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, cond,
false);
839 vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::Comma);
844 vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::Close);
848 static void printSizes(PatternRewriter &rewriter, Location loc, Value tensor,
849 unsigned size,
bool isDim) {
851 vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::Open);
853 for (
unsigned i = 0; i < size; i++) {
857 val = tensor::DimOp::create(rewriter, loc, tensor, idx);
859 val = LvlOp::create(rewriter, loc, tensor, idx);
860 vector::PrintOp::create(rewriter, loc, val,
862 ? vector::PrintPunctuation::Comma
863 : vector::PrintPunctuation::NoPunctuation);
866 vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::Close);
867 vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::NewLine);
874 using OpRewritePattern<tensor::ReshapeOp>::OpRewritePattern;
876 LogicalResult matchAndRewrite(tensor::ReshapeOp op,
877 PatternRewriter &rewriter)
const override {
878 Location loc = op.getLoc();
879 Value srcTensor = op.getSource();
882 if (!srcTp || !dstTp)
885 if (!srcTp->hasEncoding() || !dstTp->hasEncoding() ||
886 !dstTp->hasStaticDimShape())
889 SmallVector<Value> srcSizes;
891 SmallVector<Value> dstSizes;
895 Value nnz = NumberOfEntriesOp::create(rewriter, loc, srcTensor);
899 dstTp->withoutDimToLvl(),
900 !srcTp->isAllOrdered() || !srcTp->isIdentity() || !dstTp->isIdentity());
901 SmallVector<Value> dynSizes;
902 Value buffer = AllocTensorOp::create(rewriter, loc, bufferTp, dynSizes,
903 Value(), nnz, Attribute())
917 const auto encSrc = srcTp->getEncoding();
918 ForeachOp foreachOp = ForeachOp::create(
919 rewriter, loc, srcTensor, buffer,
920 [&](OpBuilder &builder, Location loc,
ValueRange srcLcvs, Value v,
922 const Dimension srcRank = srcTp->getDimRank();
923 SmallVector<Value> srcDcvs;
924 srcDcvs.reserve(srcRank);
925 for (
Dimension d = 0; d < srcRank; d++) {
927 srcDcvs.push_back(srcLcvs[lvl]);
933 arith::MulIOp::create(builder, loc, collapseSize, srcSizes[d]);
934 SmallVector<Value, 1> collapsedSizes = {collapseSize};
938 collapseIdx.push_back(i);
939 SmallVector<ReassociationIndices, 1> collapseReass = {collapseIdx};
940 SmallVector<Value, 1> collapsedDcvs;
941 reshapeCvs(builder, loc, collapseReass, srcSizes, srcDcvs,
942 collapsedSizes, collapsedDcvs);
945 for (
Dimension i = 0; i < dstTp->getDimRank(); i++)
946 expandIdx.push_back(i);
947 SmallVector<ReassociationIndices, 1> expandReass = {expandIdx};
948 SmallVector<Value> dstDcvs;
949 reshapeCvs(builder, loc, expandReass, collapsedSizes, collapsedDcvs,
953 tensor::InsertOp::create(builder, loc, v, reduc.front(), dstDcvs);
954 sparse_tensor::YieldOp::create(builder, loc, t);
957 Value t = LoadOp::create(rewriter, loc, foreachOp.getResult(0),
true);
958 if (bufferTp != *dstTp) {
959 auto dstRTT = dstTp->getRankedTensorType();
960 Value converted = ConvertOp::create(rewriter, loc, dstRTT, t).getResult();
961 DeallocTensorOp::create(rewriter, loc, t);
970template <
typename ReshapeOp>
973 using OpRewritePattern<ReshapeOp>::OpRewritePattern;
975 LogicalResult matchAndRewrite(ReshapeOp op,
976 PatternRewriter &rewriter)
const override {
977 Location loc = op.getLoc();
978 Value srcTensor = op.getSrc();
981 if (!srcTp.hasEncoding() || !dstTp.hasEncoding())
986 SmallVector<Value> srcSizes;
988 SmallVector<Value> dstSizes;
989 SmallVector<Value> dstDynSizes;
990 if (dstTp.hasStaticDimShape()) {
994 ArrayRef<Size> dstShape = dstTp.getDimShape();
996 op.getReassociationIndices());
997 for (
auto [idx, shape] : llvm::enumerate(dstShape)) {
998 if (shape == ShapedType::kDynamic)
999 dstDynSizes.push_back(dstSizes[idx]);
1002 Value nnz = NumberOfEntriesOp::create(rewriter, loc, srcTensor);
1006 dstTp.withoutDimToLvl(),
1007 !srcTp.isAllOrdered() || !srcTp.isIdentity() || !dstTp.isIdentity());
1010 AllocTensorOp::create(rewriter, loc, bufferTp, dstDynSizes, Value(),
1021 const auto encSrc = srcTp.getEncoding();
1022 ForeachOp foreachOp = ForeachOp::create(
1023 rewriter, loc, srcTensor, buffer,
1024 [&](OpBuilder &builder, Location loc,
ValueRange srcLcvs, Value v,
1026 const Dimension dimRank = srcTp.getDimRank();
1027 SmallVector<Value> srcDcvs;
1028 srcDcvs.reserve(dimRank);
1029 for (
Dimension d = 0; d < dimRank; d++) {
1031 srcDcvs.push_back(srcLcvs[lvl]);
1033 SmallVector<Value> dstDcvs;
1034 reshapeCvs(builder, loc, op.getReassociationIndices(), srcSizes,
1035 srcDcvs, dstSizes, dstDcvs);
1037 tensor::InsertOp::create(builder, loc, v, reduc.front(), dstDcvs);
1038 sparse_tensor::YieldOp::create(builder, loc, t);
1041 Value t = LoadOp::create(rewriter, loc, foreachOp.getResult(0),
true);
1042 if (bufferTp != dstTp) {
1043 auto dstRTT = dstTp.getRankedTensorType();
1044 Value converted = ConvertOp::create(rewriter, loc, dstRTT, t).getResult();
1045 DeallocTensorOp::create(rewriter, loc, t);
1055template <
typename ReshapeOp>
1058 using OpRewritePattern<ReshapeOp>::OpRewritePattern;
1060 LogicalResult matchAndRewrite(ReshapeOp op,
1061 PatternRewriter &rewriter)
const override {
1062 Location loc = op->getLoc();
1069 if (encDst && encSrc) {
1075 RankedTensorType::get(rtp.getShape(), rtp.getElementType());
1076 auto convert = ConvertOp::create(rewriter, loc, denseTp, op.getSrc());
1083 RankedTensorType::get(rtp.getShape(), rtp.getElementType());
1085 if constexpr (std::is_same<ReshapeOp, tensor::ExpandShapeOp>::value) {
1086 reshape = ReshapeOp::create(rewriter, loc, denseTp, op.getSrc(),
1087 op.getReassociation(), op.getOutputShape(),
1088 op.getStaticOutputShape());
1090 reshape = ReshapeOp::create(rewriter, loc, denseTp, op.getSrc(),
1091 op.getReassociation());
1093 Value convert = ConvertOp::create(rewriter, loc, rtp, reshape);
1104 TensorLike(OpBuilder &builder, Location loc, RankedTensorType rtt,
1106 SmallVector<Value> dynSzs;
1109 val = AllocTensorOp::create(builder, loc, rtt, dynSzs);
1111 Value c0 =
constantZero(builder, loc, rtt.getElementType());
1112 val = linalg::FillOp::create(builder, loc, c0, val).getResult(0);
1116 void insert(OpBuilder &builder, Location loc, Value v,
ValueRange crds) {
1117 val = tensor::InsertOp::create(builder, loc, v, val, crds);
1120 Value finalize(OpBuilder &builder, Location loc, RankedTensorType rtp)
const {
1122 return LoadOp::create(builder, loc, val,
true);
1126 bool isSparse()
const {
1135 LogicalResult matchAndRewrite(tensor::DimOp op,
1136 PatternRewriter &rewriter)
const override {
1137 std::optional<int64_t> dim = op.getConstantIndex();
1139 if (!dim || !stt || !stt->hasEncoding())
1142 if (stt->isPermutation()) {
1144 toLvl(stt->getEncoding(), *dim));
1154 Location loc = op.getLoc();
1155 SmallVector<Value> maxLvlCrds;
1156 for (
Level l = 0; l < stt->getLvlRank(); l++) {
1157 Value lvlSz = LvlOp::create(rewriter, loc, op.getSource(), l);
1158 Value maxLvlCrd = arith::SubIOp::create(
1159 rewriter, loc, lvlSz,
1161 maxLvlCrds.push_back(maxLvlCrd);
1164 AffineExpr lvl2DimExp = stt->getLvlToDim().getResult(*dim);
1165 Value maxDimCrd = affine::AffineApplyOp::create(
1166 rewriter, op.getLoc(),
AffineMap::get(stt->getLvlRank(), 0, lvl2DimExp),
1169 Value dimSz = arith::AddIOp::create(
1170 rewriter, loc, maxDimCrd,
1179 LogicalResult matchAndRewrite(ConcatenateOp op,
1180 PatternRewriter &rewriter)
const override {
1181 if (op.needsExtraSort())
1182 op.emitError(
"ConcatenateOp not staged");
1184 const Location loc = op.getLoc();
1186 const Dimension conDim = op.getDimension();
1187 SmallVector<Value> sizes;
1203 TensorLike dstBuf(rewriter, loc, dstTp.getRankedTensorType(), sizes);
1205 Value iterArg = dstBuf.val;
1207 ForeachOp foreachOp;
1208 for (Value input : op.getInputs()) {
1211 foreachOp = ForeachOp::create(
1212 rewriter, loc, input, iterArg,
1213 [&](OpBuilder &builder, Location loc,
ValueRange dcvs, Value v,
1215 SmallVector<Value> offDimCrd(dcvs);
1217 arith::AddIOp::create(builder, loc, offDimCrd[conDim], offset);
1220 dstBuf.val = reduc.front();
1221 if (!dstTp.isAllDense()) {
1224 scf::IfOp::create(builder, loc, reduc.
getTypes(), cond,
1227 scf::YieldOp::create(builder, loc, dstBuf.val);
1230 dstBuf.insert(builder, loc, v, offDimCrd);
1231 scf::YieldOp::create(builder, loc, dstBuf.val);
1235 dstBuf.val = ifOp.getResult(0);
1237 dstBuf.insert(builder, loc, v, offDimCrd);
1239 sparse_tensor::YieldOp::create(builder, loc, dstBuf.val);
1245 assert(ShapedType::isStatic(sz));
1246 offset = arith::AddIOp::create(rewriter, loc, offset,
1248 iterArg = foreachOp.getResult(0);
1249 dstBuf.val = iterArg;
1252 dstBuf.val = iterArg;
1253 Value ret = dstBuf.finalize(rewriter, loc, dstTp.getRankedTensorType());
1261 LogicalResult matchAndRewrite(ConvertOp op,
1262 PatternRewriter &rewriter)
const override {
1263 if (op.needsExtraSort())
1264 return op.emitError(
"ConvertOp not staged.");
1269 if (encDst && encSrc && !encSrc.isSlice() &&
1270 encSrc.withoutBitWidths() == encDst.withoutBitWidths()) {
1276 Location loc = op.getLoc();
1277 Value src = op.getSource();
1282 bool fromSparseConst =
false;
1283 if (
auto constOp = op.getSource().getDefiningOp<arith::ConstantOp>())
1284 if (isa<SparseElementsAttr>(constOp.getValue()))
1285 fromSparseConst =
true;
1287 const AffineMapAttr foreachOrder =
1292 bool skipZeroCheck = srcStt.
hasEncoding() || fromSparseConst;
1294 SmallVector<Value> sizes;
1299 auto foreachOp = ForeachOp::create(
1300 rewriter, loc, src, dstBuf.val, foreachOrder,
1301 [&](OpBuilder &builder, Location loc,
ValueRange dcvs, Value v,
1304 dstBuf.val = reduc.front();
1305 if (!skipZeroCheck) {
1306 Value cond = genIsNonzero(builder, loc, v);
1307 auto ifOp = scf::IfOp::create(builder, loc, reduc.getTypes(), cond,
1309 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
1310 scf::YieldOp::create(builder, loc, dstBuf.val);
1312 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
1313 dstBuf.insert(builder, loc, v, dcvs);
1314 scf::YieldOp::create(builder, loc, dstBuf.val);
1317 builder.setInsertionPointAfter(ifOp);
1318 dstBuf.val = ifOp.getResult(0);
1320 dstBuf.insert(builder, loc, v, dcvs);
1322 sparse_tensor::YieldOp::create(builder, loc, dstBuf.val);
1328 dstBuf.val = foreachOp.getResult(0);
1338 LogicalResult matchAndRewrite(CrdTranslateOp op,
1339 PatternRewriter &rewriter)
const override {
1340 AffineMap map = op.getDirection() == CrdTransDirectionKind::dim2lvl
1341 ? op.getEncoder().getDimToLvl()
1342 : op.getEncoder().getLvlToDim();
1344 SmallVector<Value> outCrds;
1349 Value trans = affine::AffineApplyOp::create(
1352 outCrds.push_back(trans);
1364 LogicalResult matchAndRewrite(ForeachOp op,
1365 PatternRewriter &rewriter)
const override {
1367 auto loc = op.getLoc();
1368 Value input = op.getTensor();
1369 SmallVector<Value> reduc = op.getInitArgs();
1371 const Level lvlRank = stt.getLvlRank();
1375 if (
auto constOp = input.
getDefiningOp<arith::ConstantOp>()) {
1376 if (
auto attr = dyn_cast<SparseElementsAttr>(constOp.getValue())) {
1382 const auto enc = stt.getEncoding();
1385 LoopEmitter loopEmitter(
1387 StringAttr::get(
getContext(), ForeachOp::getOperationName()));
1388 loopEmitter.initializeLoopEmit(rewriter, loc);
1389 for (
Level l = 0; l < lvlRank; l++) {
1392 const SmallVector<TensorLevel, 1> tidLvls{
1393 loopEmitter.makeTensorLevel(0, l)};
1394 loopEmitter.enterNewLoopSeq(rewriter, loc, tidLvls);
1397 loopEmitter.enterCoIterationOverTensorsAtLvls(rewriter, loc, tidLvls, 1,
1401 SmallVector<Value> lcvs = loopEmitter.getLoopIVs();
1402 if (op.getOrder()) {
1405 "Level order not yet implemented on non-constant input tensors.");
1408 Value vals = loopEmitter.getValBuffer()[0];
1409 SmallVector<Value> pos = loopEmitter.getValPosits(0);
1412 Value val = enc ? memref::LoadOp::create(rewriter, loc, vals, pos)
1413 : memref::LoadOp::create(rewriter, loc, vals, lcvs);
1416 Block *srcBlock = op.getBody();
1419 SmallVector<Value> args =
1420 enc.translateCrds(rewriter, loc, lcvs, CrdTransDirectionKind::lvl2dim);
1423 args.push_back(val);
1428 SmallVector<Value> reducValue = srcBlock->
getTerminator()->getOperands();
1432 if (llvm::isa<scf::YieldOp>(last)) {
1442 for (
Level l = 0; l < lvlRank; l++) {
1445 loopEmitter.exitCurrentLoop(rewriter, loc, reducValue);
1446 loopEmitter.exitCurrentLoopSeq(rewriter, loc);
1459 LogicalResult matchAndRewrite(
NewOp op,
1460 PatternRewriter &rewriter)
const override {
1461 Location loc = op.getLoc();
1463 if (!stt.hasEncoding() || stt.getAoSCOOStart() == 0)
1470 RankedTensorType dstTp = stt.getRankedTensorType();
1471 RankedTensorType cooTp = stt.getCOOType(
true);
1472 Value cooTensor = NewOp::create(rewriter, loc, cooTp, op.getSource());
1473 Value convert = cooTensor;
1474 auto enc = stt.getEncoding();
1475 if (!stt.isPermutation()) {
1477 convert = ReinterpretMapOp::create(rewriter, loc, coo, convert);
1480 convert = ConvertOp::create(rewriter, loc, dstTp, convert);
1481 if (!stt.isPermutation())
1482 convert = ReinterpretMapOp::create(rewriter, loc, enc, convert);
1487 DeallocTensorOp::create(rewriter, loc, cooTensor);
1496 LogicalResult matchAndRewrite(OutOp op,
1497 PatternRewriter &rewriter)
const override {
1498 Location loc = op.getLoc();
1500 Value src = op.getTensor();
1501 Value nnz = NumberOfEntriesOp::create(rewriter, loc, src);
1505 const Dimension dimRank = srcTp.getDimRank();
1507 Value dimSizes =
genAlloca(rewriter, loc, dimRank, indexTp);
1511 SmallVector<Value> dims;
1513 for (
Dimension d = 0; d < dimRank; d++) {
1514 memref::StoreOp::create(rewriter, loc, dims[d], dimSizes,
1521 createFuncCall(rewriter, loc,
"createSparseTensorWriter", {opaqueTp},
1522 {op.getDest()}, EmitCInterface::Off)
1525 createFuncCall(rewriter, loc,
"outSparseTensorWriterMetaData", {},
1526 {writer, rankValue, nnz, dimSizes}, EmitCInterface::On);
1528 Value dimCoords = dimSizes;
1529 Type eltTp = srcTp.getElementType();
1530 SmallString<29> outNextFuncName{
"outSparseTensorWriterNext",
1533 ModuleOp module = op->getParentOfType<ModuleOp>();
1538 [&](OpBuilder &builder, Location loc,
ValueRange dcvs, Value v,
1540 for (
Dimension d = 0; d < dimRank; d++) {
1541 memref::StoreOp::create(rewriter, loc, dcvs[d], dimCoords,
1544 memref::StoreOp::create(rewriter, loc, v, value);
1545 SmallVector<Value> operands{writer, rankValue, dimCoords, value};
1546 FlatSymbolRefAttr fn =
getFunc(module, outNextFuncName, {}, operands,
1547 EmitCInterface::On);
1548 func::CallOp::create(builder, loc,
TypeRange(), fn, operands);
1549 sparse_tensor::YieldOp::create(builder, loc);
1553 createFuncCall(rewriter, loc,
"delSparseTensorWriter", {}, {writer},
1554 EmitCInterface::Off);
1568 patterns.add<FuseExtractSliceWithConcat, FoldConvertIntoProducer,
1569 FoldInvariantYield, FuseSparseMultiplyOverAdd, FuseTensorCast,
1570 GenSemiRingReduction, GenSemiRingSelect, PrintRewriter>(
1576 bool enableConvert) {
1577 patterns.add<ConcatenateRewriter, ReshapeRewriter<tensor::ExpandShapeOp>,
1578 ReshapeRewriter<tensor::CollapseShapeOp>,
1579 Sparse2SparseReshapeRewriter<tensor::ExpandShapeOp>,
1580 Sparse2SparseReshapeRewriter<tensor::CollapseShapeOp>,
1581 SparseTensorDimOpRewriter, TensorReshapeRewriter, OutRewriter>(
1593 patterns.add<CrdTranslateRewriter, ForeachRewriter>(
patterns.getContext());
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static bool isMulChain(Value val, Value x)
static bool isSampling(GenericOp op)
static bool isSumOfMul(GenericOp op)
static RankedTensorType getBufferType(const SparseTensorType &stt, bool needTmpCOO)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static LogicalResult genForeachOnSparseConstant(ForeachOp op, RewriterBase &rewriter, SparseElementsAttr attr)
static bool isMaterializing(OpOperand *op, bool isZero)
static void concatSizesFromInputs(OpBuilder &builder, SmallVectorImpl< Value > &sizes, Location loc, ShapedType dstTp, ValueRange srcs, unsigned dim)
Populates the given sizes array for concatenation from types (for static sizes) and from the source t...
static bool isSparseTensor(Value v)
static bool isZeroYield(GenericOp op)
static void sizesForTensor(OpBuilder &builder, SmallVectorImpl< Value > &sizes, Location loc, ShapedType stp, Value tensor)
Populates given sizes array from type (for static sizes) and from the tensor (for dynamic sizes).
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
Location getLoc() const
Return the location for this argument.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
IntegerAttr getIndexAttr(int64_t value)
StringAttr getStringAttr(const Twine &bytes)
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Block * getBlock() const
Returns the current block of the builder.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool hasOneUse()
Returns true if this operation has exactly one use.
Block * getBlock()
Returns the operation block that contains this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void replaceAllOpUsesWith(Operation *from, ValueRange to)
Find uses of from and replace them with to.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A wrapper around RankedTensorType, which has three goals:
Size getDynamicDimSize(Dimension d) const
Safely looks up the requested dimension-DynSize.
bool hasEncoding() const
Returns true for tensors which have an encoding, and false for those which do not.
SparseTensorType withEncoding(SparseTensorEncodingAttr newEnc) const
bool isIdentity() const
Returns true if the dimToLvl mapping is the identity.
RankedTensorType getRankedTensorType() const
Explicitly convert to RankedTensorType.
AffineMap getExpandedDimToLvl() const
Returns the dimToLvl mapping, where the identity map is expanded out into a full AffineMap.
RankedTensorType getCOOType(bool ordered) const
Returns [un]ordered COO type for this sparse tensor type.
SparseTensorEncodingAttr getEncoding() const
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
FlatSymbolRefAttr getFunc(ModuleOp module, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)
Returns a function reference (first hit also inserts into module).
Value genAllocaScalar(OpBuilder &builder, Location loc, Type tp)
Generates an uninitialized temporary buffer with room for one value of the given type,...
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
void foreachInSparseConstant(OpBuilder &builder, Location loc, SparseElementsAttr attr, AffineMap order, function_ref< void(ArrayRef< Value >, Value)> callback)
Iterate over a sparse constant, generates constantOp for value and coordinates.
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)
unsigned FieldIndex
The type of field indices.
RankedTensorType getRankedTensorType(T &&t)
Convenience method to abbreviate casting getType().
uint64_t Level
The type of level identifiers and level-ranks.
Type getOpaquePointerType(MLIRContext *ctx)
Returns the equivalent of void* for opaque arguments to the execution engine.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
Value genIsNonzero(OpBuilder &builder, Location loc, Value v)
Generates the comparison v != 0 where v is of numeric type.
Level toLvl(SparseTensorEncodingAttr enc, Dimension d)
Convenience method to translate the given dimension to the corresponding level.
Value genAlloca(OpBuilder &builder, Location loc, Value sz, Type tp)
Generates an uninitialized temporary buffer of the given size and type, but returns it as type memref...
int64_t Size
The type for individual components of a compile-time shape, including the value ShapedType::kDynamic ...
void genReshapeDstShape(OpBuilder &builder, Location loc, SmallVectorImpl< Value > &dstShape, ArrayRef< Value > srcShape, ArrayRef< Size > staticDstShape, ArrayRef< ReassociationIndices > reassociation)
Computes the shape of destination tensor of a reshape operator.
std::optional< SparseTensorType > tryGetSparseTensorType(Value val)
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
void reshapeCvs(OpBuilder &builder, Location loc, ArrayRef< ReassociationIndices > reassociation, ValueRange srcSizes, ValueRange srcCvs, ValueRange dstSizes, SmallVectorImpl< Value > &dstCvs)
Reshape coordinates during a reshaping operation.
bool hasAnySparseOperand(Operation *op)
Returns true iff MLIR operand has any sparse operand.
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
func::CallOp createFuncCall(OpBuilder &builder, Location loc, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)
Creates a CallOp to the function reference returned by getFunc() in the builder's module.
StringRef primaryTypeFunctionSuffix(PrimaryType pt)
Convert PrimaryType to its function-name suffix.
void sizesFromSrc(OpBuilder &builder, SmallVectorImpl< Value > &sizes, Location loc, Value src)
Populates given sizes array from dense tensor or sparse tensor constant.
bool isSameTypeWithoutEncoding(Type tp1, Type tp2)
Tests if types are the same when ignoring encoding on ranked tensors.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
bool isZeroIntegerOrFloat(OpFoldResult v)
Return "true" if v is an integer/float value/attribute with constant value zero.
void populatePreSparsificationRewriting(RewritePatternSet &patterns)
const FrozenRewritePatternSet & patterns
void populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns, bool enableRT, bool enableConvert)
detail::op_matcher< OpClass > m_Op()
Matches the given OpClass.
SmallVector< int64_t, 2 > ReassociationIndices
void populateLowerForeachToSCFPatterns(RewritePatternSet &patterns)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...