19 struct GuardSparseAlloc
23 LogicalResult matchAndRewrite(bufferization::AllocTensorOp op,
34 if (!llvm::all_of(op->getUses(), [](
OpOperand &use) {
35 return isa<func::ReturnOp, func::CallOp, func::CallIndirectOp>(
44 auto load = LoadOp::create(rewriter, op.getLoc(), op.getResult(),
true);
50 template <
typename StageWithSortOp>
54 LogicalResult matchAndRewrite(StageWithSortOp op,
57 Value tmpBuf =
nullptr;
58 auto itOp = llvm::cast<StageWithSortSparseOp>(op.getOperation());
59 LogicalResult stageResult = itOp.stageWithSort(rewriter, tmpBuf);
62 if (succeeded(stageResult) && tmpBuf)
63 bufferization::DeallocTensorOp::create(rewriter, loc, tmpBuf);
71 patterns.add<GuardSparseAlloc, StageUnorderedSparseOps<ConvertOp>,
72 StageUnorderedSparseOps<ConcatenateOp>>(
patterns.getContext());
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void populateStageSparseOperationsPatterns(RewritePatternSet &patterns)
Sets up StageSparseOperation rewriting rules.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...