1 //===- StageSparseOperations.cpp - stage sparse ops rewriting rules -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
15 using namespace mlir;
16 using namespace mlir::sparse_tensor;
18 namespace {
20 struct GuardSparseAlloc
21  : public OpRewritePattern<bufferization::AllocTensorOp> {
24  LogicalResult matchAndRewrite(bufferization::AllocTensorOp op,
25  PatternRewriter &rewriter) const override {
26  // Only rewrite sparse allocations.
28  return failure();
30  // Only rewrite sparse allocations that escape the method
31  // without any chance of a finalizing operation in between.
32  // Here we assume that sparse tensor setup never crosses
33  // method boundaries. The current rewriting only repairs
34  // the most obvious allocate-call/return cases.
35  if (!llvm::all_of(op->getUses(), [](OpOperand &use) {
36  return isa<func::ReturnOp, func::CallOp, func::CallIndirectOp>(
37  use.getOwner());
38  }))
39  return failure();
41  // Guard escaping empty sparse tensor allocations with a finalizing
42  // operation that leaves the underlying storage in a proper state
43  // before the tensor escapes across the method boundary.
44  rewriter.setInsertionPointAfter(op);
45  auto load = rewriter.create<LoadOp>(op.getLoc(), op.getResult(), true);
46  rewriter.replaceAllUsesExcept(op, load, load);
47  return success();
48  }
49 };
51 template <typename StageWithSortOp>
52 struct StageUnorderedSparseOps : public OpRewritePattern<StageWithSortOp> {
55  LogicalResult matchAndRewrite(StageWithSortOp op,
56  PatternRewriter &rewriter) const override {
57  Location loc = op.getLoc();
58  Value tmpBuf = nullptr;
59  auto itOp = llvm::cast<StageWithSortSparseOp>(op.getOperation());
60  LogicalResult stageResult = itOp.stageWithSort(rewriter, tmpBuf);
61  // Deallocate tmpBuf.
62  // TODO: Delegate to buffer deallocation pass in the future.
63  if (succeeded(stageResult) && tmpBuf)
64  rewriter.create<bufferization::DeallocTensorOp>(loc, tmpBuf);
66  return stageResult;
67  }
68 };
69 } // namespace
72  patterns.add<GuardSparseAlloc, StageUnorderedSparseOps<ConvertOp>,
73  StageUnorderedSparseOps<ConcatenateOp>>(patterns.getContext());
74 }
