MLIR  19.0.0git
Go to the documentation of this file.
1 //===- SparseTensorInterfaces.cpp - SparseTensor interfaces impl ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
12 #include "mlir/IR/PatternMatch.h"
14 using namespace mlir;
15 using namespace mlir::sparse_tensor;
17 #include "mlir/Dialect/SparseTensor/IR/"
19 /// Stage the operations into a sequence of simple operations as follow:
20 /// op -> unsorted_coo +
21 /// unsorted_coo -> sorted_coo +
22 /// sorted_coo -> dstTp.
23 ///
24 /// return `tmpBuf` if a intermediate memory is allocated.
26  StageWithSortSparseOp op, PatternRewriter &rewriter, Value &tmpBufs) {
27  if (!op.needsExtraSort())
28  return failure();
30  Location loc = op.getLoc();
31  Type finalTp = op->getOpResult(0).getType();
32  SparseTensorType dstStt(cast<RankedTensorType>(finalTp));
33  Type srcCOOTp = dstStt.getCOOType(/*ordered=*/false);
35  // Clones the original operation but changing the output to an unordered COO.
36  Operation *cloned = rewriter.clone(*op.getOperation());
37  rewriter.modifyOpInPlace(cloned, [cloned, srcCOOTp]() {
38  cloned->getOpResult(0).setType(srcCOOTp);
39  });
40  Value srcCOO = cloned->getOpResult(0);
42  // -> sort
43  Type dstCOOTp = dstStt.getCOOType(/*ordered=*/true);
44  Value dstCOO = rewriter.create<ReorderCOOOp>(
45  loc, dstCOOTp, srcCOO, SparseTensorSortKind::HybridQuickSort);
47  // -> dest.
48  if (dstCOO.getType() == finalTp) {
49  rewriter.replaceOp(op, dstCOO);
50  } else {
51  // Need an extra conversion if the target type is not COO.
52  auto c = rewriter.replaceOpWithNewOp<ConvertOp>(op, finalTp, dstCOO);
53  rewriter.setInsertionPointAfter(c);
54  // Informs the caller about the intermediate buffer we allocated. We can not
55  // create a bufferization::DeallocateTensorOp here because it would
56  // introduce cyclic dependency between the SparseTensorDialect and the
57  // BufferizationDialect. Besides, whether the buffer need to be deallocated
58  // by SparseTensorDialect or by BufferDeallocationPass is still TBD.
59  tmpBufs = dstCOO;
60  }
62  return success();
63 }
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:555
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:414
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getOpResult(unsigned idx)
Definition: Operation.h:416
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
A wrapper around RankedTensorType, which has three goals:
LogicalResult stageWithSortImpl(sparse_tensor::StageWithSortSparseOp op, PatternRewriter &rewriter, Value &tmpBufs)
Include the generated interface declarations.