MLIR 22.0.0git
ValueBoundsOpInterfaceImpl.cpp
Go to the documentation of this file.
1//===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
13
14using namespace mlir;
15
16namespace mlir {
17namespace tensor {
18namespace {
19
20struct CastOpInterface
21 : public ValueBoundsOpInterface::ExternalModel<CastOpInterface, CastOp> {
22 void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
23 ValueBoundsConstraintSet &cstr) const {
24 auto castOp = cast<CastOp>(op);
25 assert(value == castOp.getResult() && "invalid value");
26
27 if (llvm::isa<RankedTensorType>(castOp.getResult().getType()) &&
28 llvm::isa<RankedTensorType>(castOp.getSource().getType())) {
29 cstr.bound(value)[dim] == cstr.getExpr(castOp.getSource(), dim);
30 }
31 }
32};
33
34struct CollapseShapeOpInterface
35 : public ValueBoundsOpInterface::ExternalModel<CollapseShapeOpInterface,
36 CollapseShapeOp> {
37 void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
38 ValueBoundsConstraintSet &cstr) const {
39 auto collapseOp = cast<CollapseShapeOp>(op);
40 assert(value == collapseOp.getResult() && "invalid value");
41
42 // Multiply the expressions for the dimensions in the reassociation group.
43 const ReassociationIndices reassocIndices =
44 collapseOp.getReassociationIndices()[dim];
45 AffineExpr productExpr =
46 cstr.getExpr(collapseOp.getSrc(), reassocIndices[0]);
47 for (size_t i = 1; i < reassocIndices.size(); ++i) {
48 productExpr =
49 productExpr * cstr.getExpr(collapseOp.getSrc(), reassocIndices[i]);
50 }
51 cstr.bound(value)[dim] == productExpr;
52 }
53};
54
55struct DimOpInterface
56 : public ValueBoundsOpInterface::ExternalModel<DimOpInterface, DimOp> {
57 void populateBoundsForIndexValue(Operation *op, Value value,
58 ValueBoundsConstraintSet &cstr) const {
59 auto dimOp = cast<DimOp>(op);
60 assert(value == dimOp.getResult() && "invalid value");
61
62 cstr.bound(value) >= 0;
63 auto constIndex = dimOp.getConstantIndex();
64 if (!constIndex.has_value())
65 return;
66 cstr.bound(value) == cstr.getExpr(dimOp.getSource(), *constIndex);
67 }
68};
69
70struct EmptyOpInterface
71 : public ValueBoundsOpInterface::ExternalModel<EmptyOpInterface, EmptyOp> {
72 void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
73 ValueBoundsConstraintSet &cstr) const {
74 auto emptyOp = cast<EmptyOp>(op);
75 assert(value == emptyOp.getResult() && "invalid value");
76
77 cstr.bound(value)[dim] == emptyOp.getMixedSizes()[dim];
78 }
79};
80
81struct ExpandShapeOpInterface
82 : public ValueBoundsOpInterface::ExternalModel<ExpandShapeOpInterface,
83 ExpandShapeOp> {
84 void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
85 ValueBoundsConstraintSet &cstr) const {
86 auto expandOp = cast<ExpandShapeOp>(op);
87 assert(value == expandOp.getResult() && "invalid value");
88 cstr.bound(value)[dim] == expandOp.getMixedOutputShape()[dim];
89 }
90};
91
92struct ExtractSliceOpInterface
93 : public ValueBoundsOpInterface::ExternalModel<ExtractSliceOpInterface,
94 ExtractSliceOp> {
95 void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
96 ValueBoundsConstraintSet &cstr) const {
97 auto extractSliceOp = cast<ExtractSliceOp>(op);
98 assert(value == extractSliceOp.getResult() && "invalid value");
99
100 llvm::SmallBitVector dropped = extractSliceOp.getDroppedDims();
101 int64_t ctr = -1;
102 for (int64_t i = 0, e = extractSliceOp.getMixedSizes().size(); i < e; ++i) {
103 // Skip over rank-reduced dimensions.
104 if (!dropped.test(i))
105 ++ctr;
106 if (ctr == dim) {
107 cstr.bound(value)[dim] == extractSliceOp.getMixedSizes()[i];
108 return;
109 }
110 }
111 llvm_unreachable("could not find non-rank-reduced dim");
112 }
113};
114
115struct PadOpInterface
116 : public ValueBoundsOpInterface::ExternalModel<PadOpInterface, PadOp> {
117 void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
118 ValueBoundsConstraintSet &cstr) const {
119 auto padOp = cast<PadOp>(op);
120 assert(value == padOp.getResult() && "invalid value");
121
122 AffineExpr srcSize = cstr.getExpr(padOp.getSource(), dim);
123 AffineExpr lowPad = cstr.getExpr(padOp.getMixedLowPad()[dim]);
124 AffineExpr highPad = cstr.getExpr(padOp.getMixedHighPad()[dim]);
125 cstr.bound(value)[dim] == srcSize + lowPad + highPad;
126 }
127};
128
129struct RankOpInterface
130 : public ValueBoundsOpInterface::ExternalModel<RankOpInterface, RankOp> {
131 void populateBoundsForIndexValue(Operation *op, Value value,
132 ValueBoundsConstraintSet &cstr) const {
133 auto rankOp = cast<RankOp>(op);
134 assert(value == rankOp.getResult() && "invalid value");
135
136 auto tensorType =
137 llvm::dyn_cast<RankedTensorType>(rankOp.getTensor().getType());
138 if (!tensorType)
139 return;
140 cstr.bound(value) == tensorType.getRank();
141 }
142};
143
144} // namespace
145} // namespace tensor
146} // namespace mlir
147
149 DialectRegistry &registry) {
150 registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
151 tensor::CastOp::attachInterface<tensor::CastOpInterface>(*ctx);
152 tensor::CollapseShapeOp::attachInterface<tensor::CollapseShapeOpInterface>(
153 *ctx);
154 tensor::DimOp::attachInterface<tensor::DimOpInterface>(*ctx);
155 tensor::EmptyOp::attachInterface<tensor::EmptyOpInterface>(*ctx);
156 tensor::ExpandShapeOp::attachInterface<tensor::ExpandShapeOpInterface>(
157 *ctx);
158 tensor::ExtractSliceOp::attachInterface<tensor::ExtractSliceOpInterface>(
159 *ctx);
160 tensor::PadOp::attachInterface<tensor::PadOpInterface>(*ctx);
161 tensor::RankOp::attachInterface<tensor::RankOpInterface>(*ctx);
162 // Note: ValueBoundsOpInterface implementation is not required for ops that
163 // implement `DestinationStyleOpInterface` (for querying shaped OpResults).
164 });
165}
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
AffineExpr getExpr(Value value, std::optional< int64_t > dim=std::nullopt)
Return an expression that represents the given index-typed value or shaped value dimension.
BoundBuilder bound(Value value)
Add a bound for the given index-typed value or shaped value.
void registerValueBoundsOpInterfaceExternalModels(DialectRegistry &registry)
Include the generated interface declarations.
SmallVector< int64_t, 2 > ReassociationIndices
Definition Utils.h:27