20 struct GuardSparseAlloc
24 LogicalResult matchAndRewrite(bufferization::AllocTensorOp op,
35 if (!llvm::all_of(op->getUses(), [](
OpOperand &use) {
36 return isa<func::ReturnOp, func::CallOp, func::CallIndirectOp>(
45 auto load = rewriter.
create<LoadOp>(op.getLoc(), op.getResult(),
true);
51 template <
typename StageWithSortOp>
55 LogicalResult matchAndRewrite(StageWithSortOp op,
58 Value tmpBuf =
nullptr;
59 auto itOp = llvm::cast<StageWithSortSparseOp>(op.getOperation());
60 LogicalResult stageResult = itOp.stageWithSort(rewriter, tmpBuf);
63 if (succeeded(stageResult) && tmpBuf)
64 rewriter.
create<bufferization::DeallocTensorOp>(loc, tmpBuf);
72 patterns.add<GuardSparseAlloc, StageUnorderedSparseOps<ConvertOp>,
73 StageUnorderedSparseOps<ConcatenateOp>>(
patterns.getContext());
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void populateStageSparseOperationsPatterns(RewritePatternSet &patterns)
Sets up StageSparseOperation rewriting rules.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...