MLIR  19.0.0git
StageSparseOperations.cpp
Go to the documentation of this file.
1 //===- StageSparseOperations.cpp - stage sparse ops rewriting rules -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
14 
15 using namespace mlir;
16 using namespace mlir::sparse_tensor;
17 
18 namespace {
19 
20 struct GuardSparseAlloc
21  : public OpRewritePattern<bufferization::AllocTensorOp> {
23 
24  LogicalResult matchAndRewrite(bufferization::AllocTensorOp op,
25  PatternRewriter &rewriter) const override {
26  // Only rewrite sparse allocations.
28  return failure();
29 
30  // Only rewrite sparse allocations that escape the method
31  // without any chance of a finalizing operation in between.
32  // Here we assume that sparse tensor setup never crosses
33  // method boundaries. The current rewriting only repairs
34  // the most obvious allocate-call/return cases.
35  if (!llvm::all_of(op->getUses(), [](OpOperand &use) {
36  return isa<func::ReturnOp, func::CallOp, func::CallIndirectOp>(
37  use.getOwner());
38  }))
39  return failure();
40 
41  // Guard escaping empty sparse tensor allocations with a finalizing
42  // operation that leaves the underlying storage in a proper state
43  // before the tensor escapes across the method boundary.
44  rewriter.setInsertionPointAfter(op);
45  auto load = rewriter.create<LoadOp>(op.getLoc(), op.getResult(), true);
46  rewriter.replaceAllUsesExcept(op, load, load);
47  return success();
48  }
49 };
50 
51 template <typename StageWithSortOp>
52 struct StageUnorderedSparseOps : public OpRewritePattern<StageWithSortOp> {
54 
55  LogicalResult matchAndRewrite(StageWithSortOp op,
56  PatternRewriter &rewriter) const override {
57  Location loc = op.getLoc();
58  Value tmpBuf = nullptr;
59  auto itOp = llvm::cast<StageWithSortSparseOp>(op.getOperation());
60  LogicalResult stageResult = itOp.stageWithSort(rewriter, tmpBuf);
61  // Deallocate tmpBuf.
62  // TODO: Delegate to buffer deallocation pass in the future.
63  if (succeeded(stageResult) && tmpBuf)
64  rewriter.create<bufferization::DeallocTensorOp>(loc, tmpBuf);
65 
66  return stageResult;
67  }
68 };
69 } // namespace
70 
72  patterns.add<GuardSparseAlloc, StageUnorderedSparseOps<ConvertOp>,
73  StageUnorderedSparseOps<ConcatenateOp>>(patterns.getContext());
74 }
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:414
This class represents an operand of an operation.
Definition: Value.h:267
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Operation.h:842
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:702
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
void populateStageSparseOperationsPatterns(RewritePatternSet &patterns)
Sets up StageSparseOperation rewriting rules.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358