MLIR  22.0.0git
StageSparseOperations.cpp
Go to the documentation of this file.
1 //===- StageSparseOperations.cpp - stage sparse ops rewriting rules -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 
14 using namespace mlir;
15 using namespace mlir::sparse_tensor;
16 
17 namespace {
18 
19 struct GuardSparseAlloc
20  : public OpRewritePattern<bufferization::AllocTensorOp> {
22 
23  LogicalResult matchAndRewrite(bufferization::AllocTensorOp op,
24  PatternRewriter &rewriter) const override {
25  // Only rewrite sparse allocations.
26  if (!getSparseTensorEncoding(op.getResult().getType()))
27  return failure();
28 
29  // Only rewrite sparse allocations that escape the method
30  // without any chance of a finalizing operation in between.
31  // Here we assume that sparse tensor setup never crosses
32  // method boundaries. The current rewriting only repairs
33  // the most obvious allocate-call/return cases.
34  if (!llvm::all_of(op->getUses(), [](OpOperand &use) {
35  return isa<func::ReturnOp, func::CallOp, func::CallIndirectOp>(
36  use.getOwner());
37  }))
38  return failure();
39 
40  // Guard escaping empty sparse tensor allocations with a finalizing
41  // operation that leaves the underlying storage in a proper state
42  // before the tensor escapes across the method boundary.
43  rewriter.setInsertionPointAfter(op);
44  auto load = LoadOp::create(rewriter, op.getLoc(), op.getResult(), true);
45  rewriter.replaceAllUsesExcept(op, load, load);
46  return success();
47  }
48 };
49 
50 template <typename StageWithSortOp>
51 struct StageUnorderedSparseOps : public OpRewritePattern<StageWithSortOp> {
53 
54  LogicalResult matchAndRewrite(StageWithSortOp op,
55  PatternRewriter &rewriter) const override {
56  Location loc = op.getLoc();
57  Value tmpBuf = nullptr;
58  auto itOp = llvm::cast<StageWithSortSparseOp>(op.getOperation());
59  LogicalResult stageResult = itOp.stageWithSort(rewriter, tmpBuf);
60  // Deallocate tmpBuf.
61  // TODO: Delegate to buffer deallocation pass in the future.
62  if (succeeded(stageResult) && tmpBuf)
63  bufferization::DeallocTensorOp::create(rewriter, loc, tmpBuf);
64 
65  return stageResult;
66  }
67 };
68 } // namespace
69 
71  patterns.add<GuardSparseAlloc, StageUnorderedSparseOps<ConvertOp>,
72  StageUnorderedSparseOps<ConcatenateOp>>(patterns.getContext());
73 }
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
This class represents an operand of an operation.
Definition: Value.h:257
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:700
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void populateStageSparseOperationsPatterns(RewritePatternSet &patterns)
Sets up StageSparseOperation rewriting rules.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314