MLIR  21.0.0git
ValueBoundsOpInterfaceImpl.cpp
Go to the documentation of this file.
1 //===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
13 
14 using namespace mlir;
15 
16 namespace mlir {
17 namespace memref {
18 namespace {
19 
20 template <typename OpTy>
21 struct AllocOpInterface
22  : public ValueBoundsOpInterface::ExternalModel<AllocOpInterface<OpTy>,
23  OpTy> {
24  void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
25  ValueBoundsConstraintSet &cstr) const {
26  auto allocOp = cast<OpTy>(op);
27  assert(value == allocOp.getResult() && "invalid value");
28 
29  cstr.bound(value)[dim] == allocOp.getMixedSizes()[dim];
30  }
31 };
32 
33 struct CastOpInterface
34  : public ValueBoundsOpInterface::ExternalModel<CastOpInterface, CastOp> {
35  void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
36  ValueBoundsConstraintSet &cstr) const {
37  auto castOp = cast<CastOp>(op);
38  assert(value == castOp.getResult() && "invalid value");
39 
40  if (llvm::isa<MemRefType>(castOp.getResult().getType()) &&
41  llvm::isa<MemRefType>(castOp.getSource().getType())) {
42  cstr.bound(value)[dim] == cstr.getExpr(castOp.getSource(), dim);
43  }
44  }
45 };
46 
47 struct DimOpInterface
48  : public ValueBoundsOpInterface::ExternalModel<DimOpInterface, DimOp> {
49  void populateBoundsForIndexValue(Operation *op, Value value,
50  ValueBoundsConstraintSet &cstr) const {
51  auto dimOp = cast<DimOp>(op);
52  assert(value == dimOp.getResult() && "invalid value");
53 
54  cstr.bound(value) >= 0;
55  auto constIndex = dimOp.getConstantIndex();
56  if (!constIndex.has_value())
57  return;
58  cstr.bound(value) == cstr.getExpr(dimOp.getSource(), *constIndex);
59  }
60 };
61 
62 struct GetGlobalOpInterface
63  : public ValueBoundsOpInterface::ExternalModel<GetGlobalOpInterface,
64  GetGlobalOp> {
65  void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
66  ValueBoundsConstraintSet &cstr) const {
67  auto getGlobalOp = cast<GetGlobalOp>(op);
68  assert(value == getGlobalOp.getResult() && "invalid value");
69 
70  auto type = getGlobalOp.getType();
71  assert(!type.isDynamicDim(dim) && "expected static dim");
72  cstr.bound(value)[dim] == type.getDimSize(dim);
73  }
74 };
75 
76 struct RankOpInterface
77  : public ValueBoundsOpInterface::ExternalModel<RankOpInterface, RankOp> {
78  void populateBoundsForIndexValue(Operation *op, Value value,
79  ValueBoundsConstraintSet &cstr) const {
80  auto rankOp = cast<RankOp>(op);
81  assert(value == rankOp.getResult() && "invalid value");
82 
83  auto memrefType = llvm::dyn_cast<MemRefType>(rankOp.getMemref().getType());
84  if (!memrefType)
85  return;
86  cstr.bound(value) == memrefType.getRank();
87  }
88 };
89 
90 struct SubViewOpInterface
91  : public ValueBoundsOpInterface::ExternalModel<SubViewOpInterface,
92  SubViewOp> {
93  void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
94  ValueBoundsConstraintSet &cstr) const {
95  auto subViewOp = cast<SubViewOp>(op);
96  assert(value == subViewOp.getResult() && "invalid value");
97 
98  llvm::SmallBitVector dropped = subViewOp.getDroppedDims();
99  int64_t ctr = -1;
100  for (int64_t i = 0, e = subViewOp.getMixedSizes().size(); i < e; ++i) {
101  // Skip over rank-reduced dimensions.
102  if (!dropped.test(i))
103  ++ctr;
104  if (ctr == dim) {
105  cstr.bound(value)[dim] == subViewOp.getMixedSizes()[i];
106  return;
107  }
108  }
109  llvm_unreachable("could not find non-rank-reduced dim");
110  }
111 };
112 
113 } // namespace
114 } // namespace memref
115 } // namespace mlir
116 
118  DialectRegistry &registry) {
119  registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
120  memref::AllocOp::attachInterface<memref::AllocOpInterface<memref::AllocOp>>(
121  *ctx);
122  memref::AllocaOp::attachInterface<
123  memref::AllocOpInterface<memref::AllocaOp>>(*ctx);
124  memref::CastOp::attachInterface<memref::CastOpInterface>(*ctx);
125  memref::DimOp::attachInterface<memref::DimOpInterface>(*ctx);
126  memref::GetGlobalOp::attachInterface<memref::GetGlobalOpInterface>(*ctx);
127  memref::RankOp::attachInterface<memref::RankOpInterface>(*ctx);
128  memref::SubViewOp::attachInterface<memref::SubViewOpInterface>(*ctx);
129  });
130 }
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A helper class to be used with ValueBoundsOpInterface.
AffineExpr getExpr(Value value, std::optional< int64_t > dim=std::nullopt)
Return an expression that represents the given index-typed value or shaped value dimension.
BoundBuilder bound(Value value)
Add a bound for the given index-typed value or shaped value.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void registerValueBoundsOpInterfaceExternalModels(DialectRegistry &registry)
Include the generated interface declarations.