19struct GuardSparseAlloc
21 using OpRewritePattern<bufferization::AllocTensorOp>::OpRewritePattern;
23 LogicalResult matchAndRewrite(bufferization::AllocTensorOp op,
24 PatternRewriter &rewriter)
const override {
34 if (!llvm::all_of(op->getUses(), [](OpOperand &use) {
35 return isa<func::ReturnOp, func::CallOp, func::CallIndirectOp>(
44 auto load = LoadOp::create(rewriter, op.getLoc(), op.getResult(),
true);
50template <
typename StageWithSortOp>
52 using OpRewritePattern<StageWithSortOp>::OpRewritePattern;
54 LogicalResult matchAndRewrite(StageWithSortOp op,
55 PatternRewriter &rewriter)
const override {
56 Location loc = op.getLoc();
57 Value tmpBuf =
nullptr;
58 auto itOp = llvm::cast<StageWithSortSparseOp>(op.getOperation());
59 LogicalResult stageResult = itOp.stageWithSort(rewriter, tmpBuf);
62 if (succeeded(stageResult) && tmpBuf)
63 bufferization::DeallocTensorOp::create(rewriter, loc, tmpBuf);
71 patterns.add<GuardSparseAlloc, StageUnorderedSparseOps<ConvertOp>,
72 StageUnorderedSparseOps<ConcatenateOp>>(
patterns.getContext());
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void populateStageSparseOperationsPatterns(RewritePatternSet &patterns)
Sets up StageSparseOperation rewriting rules.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...