MLIR 22.0.0git
StageSparseOperations.cpp
Go to the documentation of this file.
1//===- StageSparseOperations.cpp - stage sparse ops rewriting rules -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
13
14using namespace mlir;
15using namespace mlir::sparse_tensor;
16
17namespace {
18
19struct GuardSparseAlloc
20 : public OpRewritePattern<bufferization::AllocTensorOp> {
21 using OpRewritePattern<bufferization::AllocTensorOp>::OpRewritePattern;
22
23 LogicalResult matchAndRewrite(bufferization::AllocTensorOp op,
24 PatternRewriter &rewriter) const override {
25 // Only rewrite sparse allocations.
26 if (!getSparseTensorEncoding(op.getResult().getType()))
27 return failure();
28
29 // Only rewrite sparse allocations that escape the method
30 // without any chance of a finalizing operation in between.
31 // Here we assume that sparse tensor setup never crosses
32 // method boundaries. The current rewriting only repairs
33 // the most obvious allocate-call/return cases.
34 if (!llvm::all_of(op->getUses(), [](OpOperand &use) {
35 return isa<func::ReturnOp, func::CallOp, func::CallIndirectOp>(
36 use.getOwner());
37 }))
38 return failure();
39
40 // Guard escaping empty sparse tensor allocations with a finalizing
41 // operation that leaves the underlying storage in a proper state
42 // before the tensor escapes across the method boundary.
43 rewriter.setInsertionPointAfter(op);
44 auto load = LoadOp::create(rewriter, op.getLoc(), op.getResult(), true);
45 rewriter.replaceAllUsesExcept(op, load, load);
46 return success();
47 }
48};
49
50template <typename StageWithSortOp>
51struct StageUnorderedSparseOps : public OpRewritePattern<StageWithSortOp> {
52 using OpRewritePattern<StageWithSortOp>::OpRewritePattern;
53
54 LogicalResult matchAndRewrite(StageWithSortOp op,
55 PatternRewriter &rewriter) const override {
56 Location loc = op.getLoc();
57 Value tmpBuf = nullptr;
58 auto itOp = llvm::cast<StageWithSortSparseOp>(op.getOperation());
59 LogicalResult stageResult = itOp.stageWithSort(rewriter, tmpBuf);
60 // Deallocate tmpBuf.
61 // TODO: Delegate to buffer deallocation pass in the future.
62 if (succeeded(stageResult) && tmpBuf)
63 bufferization::DeallocTensorOp::create(rewriter, loc, tmpBuf);
64
65 return stageResult;
66 }
67};
68} // namespace
69
71 patterns.add<GuardSparseAlloc, StageUnorderedSparseOps<ConvertOp>,
72 StageUnorderedSparseOps<ConcatenateOp>>(patterns.getContext());
73}
return success()
auto load
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void populateStageSparseOperationsPatterns(RewritePatternSet &patterns)
Sets up StageSparseOperation rewriting rules.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...