MLIR 22.0.0git
ValueBoundsOpInterfaceImpl.cpp
Go to the documentation of this file.
1//===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
13
14using namespace mlir;
15
16namespace mlir {
17namespace memref {
18namespace {
19
20template <typename OpTy>
21struct AllocOpInterface
22 : public ValueBoundsOpInterface::ExternalModel<AllocOpInterface<OpTy>,
23 OpTy> {
24 void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
25 ValueBoundsConstraintSet &cstr) const {
26 auto allocOp = cast<OpTy>(op);
27 assert(value == allocOp.getResult() && "invalid value");
28
29 cstr.bound(value)[dim] == allocOp.getMixedSizes()[dim];
30 }
31};
32
33struct CastOpInterface
34 : public ValueBoundsOpInterface::ExternalModel<CastOpInterface, CastOp> {
35 void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
36 ValueBoundsConstraintSet &cstr) const {
37 auto castOp = cast<CastOp>(op);
38 assert(value == castOp.getResult() && "invalid value");
39
40 if (llvm::isa<MemRefType>(castOp.getResult().getType()) &&
41 llvm::isa<MemRefType>(castOp.getSource().getType())) {
42 cstr.bound(value)[dim] == cstr.getExpr(castOp.getSource(), dim);
43 }
44 }
45};
46
47struct DimOpInterface
48 : public ValueBoundsOpInterface::ExternalModel<DimOpInterface, DimOp> {
49 void populateBoundsForIndexValue(Operation *op, Value value,
50 ValueBoundsConstraintSet &cstr) const {
51 auto dimOp = cast<DimOp>(op);
52 assert(value == dimOp.getResult() && "invalid value");
53
54 cstr.bound(value) >= 0;
55 auto constIndex = dimOp.getConstantIndex();
56 if (!constIndex.has_value())
57 return;
58 cstr.bound(value) == cstr.getExpr(dimOp.getSource(), *constIndex);
59 }
60};
61
62struct ExpandShapeOpInterface
63 : public ValueBoundsOpInterface::ExternalModel<ExpandShapeOpInterface,
64 memref::ExpandShapeOp> {
65 void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
66 ValueBoundsConstraintSet &cstr) const {
67 auto expandOp = cast<memref::ExpandShapeOp>(op);
68 assert(value == expandOp.getResult() && "invalid value");
69 cstr.bound(value)[dim] == expandOp.getMixedOutputShape()[dim];
70 }
71};
72
73struct GetGlobalOpInterface
74 : public ValueBoundsOpInterface::ExternalModel<GetGlobalOpInterface,
75 GetGlobalOp> {
76 void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
77 ValueBoundsConstraintSet &cstr) const {
78 auto getGlobalOp = cast<GetGlobalOp>(op);
79 assert(value == getGlobalOp.getResult() && "invalid value");
80
81 auto type = getGlobalOp.getType();
82 assert(!type.isDynamicDim(dim) && "expected static dim");
83 cstr.bound(value)[dim] == type.getDimSize(dim);
84 }
85};
86
87struct RankOpInterface
88 : public ValueBoundsOpInterface::ExternalModel<RankOpInterface, RankOp> {
89 void populateBoundsForIndexValue(Operation *op, Value value,
90 ValueBoundsConstraintSet &cstr) const {
91 auto rankOp = cast<RankOp>(op);
92 assert(value == rankOp.getResult() && "invalid value");
93
94 auto memrefType = llvm::dyn_cast<MemRefType>(rankOp.getMemref().getType());
95 if (!memrefType)
96 return;
97 cstr.bound(value) == memrefType.getRank();
98 }
99};
100
101struct CollapseShapeOpInterface
102 : public ValueBoundsOpInterface::ExternalModel<CollapseShapeOpInterface,
103 memref::CollapseShapeOp> {
104 void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
105 ValueBoundsConstraintSet &cstr) const {
106 auto collapseOp = cast<memref::CollapseShapeOp>(op);
107 assert(value == collapseOp.getResult() && "invalid value");
108
109 // Multiply the expressions for the dimensions in the reassociation group.
110 const ReassociationIndices reassocIndices =
111 collapseOp.getReassociationIndices()[dim];
112 AffineExpr productExpr =
113 cstr.getExpr(collapseOp.getSrc(), reassocIndices[0]);
114 for (size_t i = 1; i < reassocIndices.size(); ++i) {
115 productExpr =
116 productExpr * cstr.getExpr(collapseOp.getSrc(), reassocIndices[i]);
117 }
118 cstr.bound(value)[dim] == productExpr;
119 }
120};
121
122struct SubViewOpInterface
123 : public ValueBoundsOpInterface::ExternalModel<SubViewOpInterface,
124 SubViewOp> {
125 void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
126 ValueBoundsConstraintSet &cstr) const {
127 auto subViewOp = cast<SubViewOp>(op);
128 assert(value == subViewOp.getResult() && "invalid value");
129
130 llvm::SmallBitVector dropped = subViewOp.getDroppedDims();
131 int64_t ctr = -1;
132 for (int64_t i = 0, e = subViewOp.getMixedSizes().size(); i < e; ++i) {
133 // Skip over rank-reduced dimensions.
134 if (!dropped.test(i))
135 ++ctr;
136 if (ctr == dim) {
137 cstr.bound(value)[dim] == subViewOp.getMixedSizes()[i];
138 return;
139 }
140 }
141 llvm_unreachable("could not find non-rank-reduced dim");
142 }
143};
144
145} // namespace
146} // namespace memref
147} // namespace mlir
148
150 DialectRegistry &registry) {
151 registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
152 memref::AllocOp::attachInterface<memref::AllocOpInterface<memref::AllocOp>>(
153 *ctx);
154 memref::AllocaOp::attachInterface<
155 memref::AllocOpInterface<memref::AllocaOp>>(*ctx);
156 memref::CastOp::attachInterface<memref::CastOpInterface>(*ctx);
157 memref::DimOp::attachInterface<memref::DimOpInterface>(*ctx);
158 memref::CollapseShapeOp::attachInterface<memref::CollapseShapeOpInterface>(
159 *ctx);
160 memref::ExpandShapeOp::attachInterface<memref::ExpandShapeOpInterface>(
161 *ctx);
162 memref::GetGlobalOp::attachInterface<memref::GetGlobalOpInterface>(*ctx);
163 memref::RankOp::attachInterface<memref::RankOpInterface>(*ctx);
164 memref::SubViewOp::attachInterface<memref::SubViewOpInterface>(*ctx);
165 });
166}
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
AffineExpr getExpr(Value value, std::optional< int64_t > dim=std::nullopt)
Return an expression that represents the given index-typed value or shaped value dimension.
BoundBuilder bound(Value value)
Add a bound for the given index-typed value or shaped value.
void registerValueBoundsOpInterfaceExternalModels(DialectRegistry &registry)
Include the generated interface declarations.
SmallVector< int64_t, 2 > ReassociationIndices
Definition Utils.h:27