MLIR 22.0.0git
SparseTensorInterfaces.cpp
Go to the documentation of this file.
1//===- SparseTensorInterfaces.cpp - SparseTensor interfaces impl ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
13
14using namespace mlir;
15using namespace mlir::sparse_tensor;
16
17#include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp.inc"
18
19/// Stage the operations into a sequence of simple operations as follow:
20/// op -> unsorted_coo +
21/// unsorted_coo -> sorted_coo +
22/// sorted_coo -> dstTp.
23///
24/// return `tmpBuf` if a intermediate memory is allocated.
26 StageWithSortSparseOp op, PatternRewriter &rewriter, Value &tmpBufs) {
27 if (!op.needsExtraSort())
28 return failure();
29
30 Location loc = op.getLoc();
31 Type finalTp = op->getOpResult(0).getType();
32 SparseTensorType dstStt(cast<RankedTensorType>(finalTp));
33 Type srcCOOTp = dstStt.getCOOType(/*ordered=*/false);
34
35 // Clones the original operation but changing the output to an unordered COO.
36 Operation *cloned = rewriter.clone(*op.getOperation());
37 rewriter.modifyOpInPlace(cloned, [cloned, srcCOOTp]() {
38 cloned->getOpResult(0).setType(srcCOOTp);
39 });
40 Value srcCOO = cloned->getOpResult(0);
41
42 // -> sort
43 Type dstCOOTp = dstStt.getCOOType(/*ordered=*/true);
44 Value dstCOO = ReorderCOOOp::create(rewriter, loc, dstCOOTp, srcCOO,
45 SparseTensorSortKind::HybridQuickSort);
46
47 // -> dest.
48 if (dstCOO.getType() == finalTp) {
49 rewriter.replaceOp(op, dstCOO);
50 } else {
51 // Need an extra conversion if the target type is not COO.
52 auto c = rewriter.replaceOpWithNewOp<ConvertOp>(op, finalTp, dstCOO);
53 rewriter.setInsertionPointAfter(c);
54 // Informs the caller about the intermediate buffer we allocated. We can not
55 // create a bufferization::DeallocateTensorOp here because it would
56 // introduce cyclic dependency between the SparseTensorDialect and the
57 // BufferizationDialect. Besides, whether the buffer need to be deallocated
58 // by SparseTensorDialect or by BufferDeallocationPass is still TBD.
59 tmpBufs = dstCOO;
60 }
61
62 return success();
63}
return success()
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:562
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OpResult getOpResult(unsigned idx)
Definition Operation.h:421
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Definition Value.h:116
Type getType() const
Return the type of this value.
Definition Value.h:105
A wrapper around RankedTensorType, which has three goals:
LogicalResult stageWithSortImpl(sparse_tensor::StageWithSortSparseOp op, PatternRewriter &rewriter, Value &tmpBufs)
Include the generated interface declarations.