MLIR  19.0.0git
SubsetInsertionOpInterfaceImpl.cpp
Go to the documentation of this file.
1 //===- SubsetInsertionOpInterfaceImpl.cpp - Tensor subsets ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
14 
15 using namespace mlir;
16 using namespace mlir::tensor;
17 
18 namespace {
19 
20 struct ExtractSliceOpSubsetOpInterface
21  : public SubsetOpInterface::ExternalModel<ExtractSliceOpSubsetOpInterface,
22  tensor::ExtractSliceOp> {
24  getAccessedHyperrectangularSlice(Operation *op) const {
25  return HyperrectangularSlice(cast<OffsetSizeAndStrideOpInterface>(op));
26  }
27 };
28 
29 struct ExtractSliceOpSubsetExtractionOpInterface
30  : public SubsetExtractionOpInterface::ExternalModel<
31  ExtractSliceOpSubsetExtractionOpInterface, tensor::ExtractSliceOp> {
32  OpOperand &getSourceOperand(Operation *op) const {
33  return cast<tensor::ExtractSliceOp>(op).getSourceMutable();
34  }
35 };
36 
37 template <typename OpTy>
38 struct InsertSliceLikeOpSubsetOpInterface
39  : public SubsetOpInterface::ExternalModel<
40  InsertSliceLikeOpSubsetOpInterface<OpTy>, OpTy> {
42  getAccessedHyperrectangularSlice(Operation *op) const {
43  return HyperrectangularSlice(cast<OffsetSizeAndStrideOpInterface>(op));
44  }
45 };
46 
47 template <typename OpTy>
48 struct InsertSliceLikeOpSubsetInsertionOpInterface
49  : public SubsetInsertionOpInterface::ExternalModel<
50  InsertSliceLikeOpSubsetInsertionOpInterface<OpTy>, OpTy> {
51  OpOperand &getSourceOperand(Operation *op) const {
52  return cast<OpTy>(op).getSourceMutable();
53  }
54 
55  OpOperand &getDestinationOperand(Operation *op) const {
56  return cast<OpTy>(op).getDestMutable();
57  }
58 
59  Value buildSubsetExtraction(Operation *op, OpBuilder &builder,
60  Location loc) const {
61  auto insertSliceOp = cast<OpTy>(op);
62  auto extractOp = builder.create<tensor::ExtractSliceOp>(
63  loc, insertSliceOp.getSourceType(), insertSliceOp.getDest(),
64  insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
65  insertSliceOp.getMixedStrides());
66  return extractOp.getResult();
67  }
68 
70  getValuesNeededToBuildSubsetExtraction(Operation *op) const {
71  auto insertSliceOp = cast<OpTy>(op);
72  SmallVector<Value> neededValues;
73  // Collect all values that are needed to construct the replacement op.
74  neededValues.append(insertSliceOp.getOffsets().begin(),
75  insertSliceOp.getOffsets().end());
76  neededValues.append(insertSliceOp.getSizes().begin(),
77  insertSliceOp.getSizes().end());
78  neededValues.append(insertSliceOp.getStrides().begin(),
79  insertSliceOp.getStrides().end());
80  neededValues.push_back(insertSliceOp.getDest());
81  return neededValues;
82  }
83 };
84 
85 } // namespace
86 
88  DialectRegistry &registry) {
89  registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
90  // Note: `SubsetExtractionOpInterface` and `SubsetInsertionOpInterface`
91  // require `SubsetOpInterface`.
92  ExtractSliceOp::attachInterface<ExtractSliceOpSubsetOpInterface>(*ctx);
93  ExtractSliceOp::attachInterface<ExtractSliceOpSubsetExtractionOpInterface>(
94  *ctx);
95  InsertSliceOp::attachInterface<
96  InsertSliceLikeOpSubsetOpInterface<InsertSliceOp>>(*ctx);
97  InsertSliceOp::attachInterface<
98  InsertSliceLikeOpSubsetInsertionOpInterface<InsertSliceOp>>(*ctx);
99  ParallelInsertSliceOp::attachInterface<
100  InsertSliceLikeOpSubsetOpInterface<ParallelInsertSliceOp>>(*ctx);
101  ParallelInsertSliceOp::attachInterface<
102  InsertSliceLikeOpSubsetInsertionOpInterface<ParallelInsertSliceOp>>(
103  *ctx);
104  });
105 }
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
A hyperrectangular slice, represented as a list of offsets, sizes and strides.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:209
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
This class represents an operand of an operation.
Definition: Value.h:267
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void registerSubsetOpInterfaceExternalModels(DialectRegistry &registry)
Include the generated interface declarations.