MLIR  22.0.0git
ValueBoundsOpInterfaceImpl.cpp
Go to the documentation of this file.
1 //===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
13 
14 using namespace mlir;
15 
16 namespace mlir {
17 namespace memref {
18 namespace {
19 
20 template <typename OpTy>
21 struct AllocOpInterface
22  : public ValueBoundsOpInterface::ExternalModel<AllocOpInterface<OpTy>,
23  OpTy> {
24  void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
25  ValueBoundsConstraintSet &cstr) const {
26  auto allocOp = cast<OpTy>(op);
27  assert(value == allocOp.getResult() && "invalid value");
28 
29  cstr.bound(value)[dim] == allocOp.getMixedSizes()[dim];
30  }
31 };
32 
33 struct CastOpInterface
34  : public ValueBoundsOpInterface::ExternalModel<CastOpInterface, CastOp> {
35  void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
36  ValueBoundsConstraintSet &cstr) const {
37  auto castOp = cast<CastOp>(op);
38  assert(value == castOp.getResult() && "invalid value");
39 
40  if (llvm::isa<MemRefType>(castOp.getResult().getType()) &&
41  llvm::isa<MemRefType>(castOp.getSource().getType())) {
42  cstr.bound(value)[dim] == cstr.getExpr(castOp.getSource(), dim);
43  }
44  }
45 };
46 
47 struct DimOpInterface
48  : public ValueBoundsOpInterface::ExternalModel<DimOpInterface, DimOp> {
49  void populateBoundsForIndexValue(Operation *op, Value value,
50  ValueBoundsConstraintSet &cstr) const {
51  auto dimOp = cast<DimOp>(op);
52  assert(value == dimOp.getResult() && "invalid value");
53 
54  cstr.bound(value) >= 0;
55  auto constIndex = dimOp.getConstantIndex();
56  if (!constIndex.has_value())
57  return;
58  cstr.bound(value) == cstr.getExpr(dimOp.getSource(), *constIndex);
59  }
60 };
61 
62 struct ExpandShapeOpInterface
63  : public ValueBoundsOpInterface::ExternalModel<ExpandShapeOpInterface,
64  memref::ExpandShapeOp> {
65  void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
66  ValueBoundsConstraintSet &cstr) const {
67  auto expandOp = cast<memref::ExpandShapeOp>(op);
68  assert(value == expandOp.getResult() && "invalid value");
69  cstr.bound(value)[dim] == expandOp.getMixedOutputShape()[dim];
70  }
71 };
72 
73 struct GetGlobalOpInterface
74  : public ValueBoundsOpInterface::ExternalModel<GetGlobalOpInterface,
75  GetGlobalOp> {
76  void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
77  ValueBoundsConstraintSet &cstr) const {
78  auto getGlobalOp = cast<GetGlobalOp>(op);
79  assert(value == getGlobalOp.getResult() && "invalid value");
80 
81  auto type = getGlobalOp.getType();
82  assert(!type.isDynamicDim(dim) && "expected static dim");
83  cstr.bound(value)[dim] == type.getDimSize(dim);
84  }
85 };
86 
87 struct RankOpInterface
88  : public ValueBoundsOpInterface::ExternalModel<RankOpInterface, RankOp> {
89  void populateBoundsForIndexValue(Operation *op, Value value,
90  ValueBoundsConstraintSet &cstr) const {
91  auto rankOp = cast<RankOp>(op);
92  assert(value == rankOp.getResult() && "invalid value");
93 
94  auto memrefType = llvm::dyn_cast<MemRefType>(rankOp.getMemref().getType());
95  if (!memrefType)
96  return;
97  cstr.bound(value) == memrefType.getRank();
98  }
99 };
100 
101 struct CollapseShapeOpInterface
102  : public ValueBoundsOpInterface::ExternalModel<CollapseShapeOpInterface,
103  memref::CollapseShapeOp> {
104  void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
105  ValueBoundsConstraintSet &cstr) const {
106  auto collapseOp = cast<memref::CollapseShapeOp>(op);
107  assert(value == collapseOp.getResult() && "invalid value");
108 
109  // Multiply the expressions for the dimensions in the reassociation group.
110  const ReassociationIndices reassocIndices =
111  collapseOp.getReassociationIndices()[dim];
112  AffineExpr productExpr =
113  cstr.getExpr(collapseOp.getSrc(), reassocIndices[0]);
114  for (size_t i = 1; i < reassocIndices.size(); ++i) {
115  productExpr =
116  productExpr * cstr.getExpr(collapseOp.getSrc(), reassocIndices[i]);
117  }
118  cstr.bound(value)[dim] == productExpr;
119  }
120 };
121 
122 struct SubViewOpInterface
123  : public ValueBoundsOpInterface::ExternalModel<SubViewOpInterface,
124  SubViewOp> {
125  void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
126  ValueBoundsConstraintSet &cstr) const {
127  auto subViewOp = cast<SubViewOp>(op);
128  assert(value == subViewOp.getResult() && "invalid value");
129 
130  llvm::SmallBitVector dropped = subViewOp.getDroppedDims();
131  int64_t ctr = -1;
132  for (int64_t i = 0, e = subViewOp.getMixedSizes().size(); i < e; ++i) {
133  // Skip over rank-reduced dimensions.
134  if (!dropped.test(i))
135  ++ctr;
136  if (ctr == dim) {
137  cstr.bound(value)[dim] == subViewOp.getMixedSizes()[i];
138  return;
139  }
140  }
141  llvm_unreachable("could not find non-rank-reduced dim");
142  }
143 };
144 
145 } // namespace
146 } // namespace memref
147 } // namespace mlir
148 
150  DialectRegistry &registry) {
151  registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
152  memref::AllocOp::attachInterface<memref::AllocOpInterface<memref::AllocOp>>(
153  *ctx);
154  memref::AllocaOp::attachInterface<
155  memref::AllocOpInterface<memref::AllocaOp>>(*ctx);
156  memref::CastOp::attachInterface<memref::CastOpInterface>(*ctx);
157  memref::DimOp::attachInterface<memref::DimOpInterface>(*ctx);
158  memref::CollapseShapeOp::attachInterface<memref::CollapseShapeOpInterface>(
159  *ctx);
160  memref::ExpandShapeOp::attachInterface<memref::ExpandShapeOpInterface>(
161  *ctx);
162  memref::GetGlobalOp::attachInterface<memref::GetGlobalOpInterface>(*ctx);
163  memref::RankOp::attachInterface<memref::RankOpInterface>(*ctx);
164  memref::SubViewOp::attachInterface<memref::SubViewOpInterface>(*ctx);
165  });
166 }
Base type for affine expression.
Definition: AffineExpr.h:68
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A helper class to be used with ValueBoundsOpInterface.
AffineExpr getExpr(Value value, std::optional< int64_t > dim=std::nullopt)
Return an expression that represents the given index-typed value or shaped value dimension.
BoundBuilder bound(Value value)
Add a bound for the given index-typed value or shaped value.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void registerValueBoundsOpInterfaceExternalModels(DialectRegistry &registry)
Include the generated interface declarations.