MLIR  19.0.0git
ValueBoundsOpInterfaceImpl.cpp
Go to the documentation of this file.
1 //===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
13 
14 using namespace mlir;
15 
16 namespace mlir {
17 namespace memref {
18 namespace {
19 
20 template <typename OpTy>
21 struct AllocOpInterface
22  : public ValueBoundsOpInterface::ExternalModel<AllocOpInterface<OpTy>,
23  OpTy> {
24  void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
25  ValueBoundsConstraintSet &cstr) const {
26  auto allocOp = cast<OpTy>(op);
27  assert(value == allocOp.getResult() && "invalid value");
28 
29  cstr.bound(value)[dim] == allocOp.getMixedSizes()[dim];
30  }
31 };
32 
33 struct CastOpInterface
34  : public ValueBoundsOpInterface::ExternalModel<CastOpInterface, CastOp> {
35  void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
36  ValueBoundsConstraintSet &cstr) const {
37  auto castOp = cast<CastOp>(op);
38  assert(value == castOp.getResult() && "invalid value");
39 
40  if (llvm::isa<MemRefType>(castOp.getResult().getType()) &&
41  llvm::isa<MemRefType>(castOp.getSource().getType())) {
42  cstr.bound(value)[dim] == cstr.getExpr(castOp.getSource(), dim);
43  }
44  }
45 };
46 
47 struct DimOpInterface
48  : public ValueBoundsOpInterface::ExternalModel<DimOpInterface, DimOp> {
49  void populateBoundsForIndexValue(Operation *op, Value value,
50  ValueBoundsConstraintSet &cstr) const {
51  auto dimOp = cast<DimOp>(op);
52  assert(value == dimOp.getResult() && "invalid value");
53 
54  auto constIndex = dimOp.getConstantIndex();
55  if (!constIndex.has_value())
56  return;
57  cstr.bound(value) == cstr.getExpr(dimOp.getSource(), *constIndex);
58  }
59 };
60 
61 struct GetGlobalOpInterface
62  : public ValueBoundsOpInterface::ExternalModel<GetGlobalOpInterface,
63  GetGlobalOp> {
64  void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
65  ValueBoundsConstraintSet &cstr) const {
66  auto getGlobalOp = cast<GetGlobalOp>(op);
67  assert(value == getGlobalOp.getResult() && "invalid value");
68 
69  auto type = getGlobalOp.getType();
70  assert(!type.isDynamicDim(dim) && "expected static dim");
71  cstr.bound(value)[dim] == type.getDimSize(dim);
72  }
73 };
74 
75 struct RankOpInterface
76  : public ValueBoundsOpInterface::ExternalModel<RankOpInterface, RankOp> {
77  void populateBoundsForIndexValue(Operation *op, Value value,
78  ValueBoundsConstraintSet &cstr) const {
79  auto rankOp = cast<RankOp>(op);
80  assert(value == rankOp.getResult() && "invalid value");
81 
82  auto memrefType = llvm::dyn_cast<MemRefType>(rankOp.getMemref().getType());
83  if (!memrefType)
84  return;
85  cstr.bound(value) == memrefType.getRank();
86  }
87 };
88 
89 struct SubViewOpInterface
90  : public ValueBoundsOpInterface::ExternalModel<SubViewOpInterface,
91  SubViewOp> {
92  void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
93  ValueBoundsConstraintSet &cstr) const {
94  auto subViewOp = cast<SubViewOp>(op);
95  assert(value == subViewOp.getResult() && "invalid value");
96 
97  llvm::SmallBitVector dropped = subViewOp.getDroppedDims();
98  int64_t ctr = -1;
99  for (int64_t i = 0, e = subViewOp.getMixedSizes().size(); i < e; ++i) {
100  // Skip over rank-reduced dimensions.
101  if (!dropped.test(i))
102  ++ctr;
103  if (ctr == dim) {
104  cstr.bound(value)[dim] == subViewOp.getMixedSizes()[i];
105  return;
106  }
107  }
108  llvm_unreachable("could not find non-rank-reduced dim");
109  }
110 };
111 
112 } // namespace
113 } // namespace memref
114 } // namespace mlir
115 
117  DialectRegistry &registry) {
118  registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
119  memref::AllocOp::attachInterface<memref::AllocOpInterface<memref::AllocOp>>(
120  *ctx);
121  memref::AllocaOp::attachInterface<
122  memref::AllocOpInterface<memref::AllocaOp>>(*ctx);
123  memref::CastOp::attachInterface<memref::CastOpInterface>(*ctx);
124  memref::DimOp::attachInterface<memref::DimOpInterface>(*ctx);
125  memref::GetGlobalOp::attachInterface<memref::GetGlobalOpInterface>(*ctx);
126  memref::RankOp::attachInterface<memref::RankOpInterface>(*ctx);
127  memref::SubViewOp::attachInterface<memref::SubViewOpInterface>(*ctx);
128  });
129 }
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A helper class to be used with ValueBoundsOpInterface.
AffineExpr getExpr(Value value, std::optional< int64_t > dim=std::nullopt)
Return an expression that represents the given index-typed value or shaped value dimension.
BoundBuilder bound(Value value)
Add a bound for the given index-typed value or shaped value.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void registerValueBoundsOpInterfaceExternalModels(DialectRegistry &registry)
Include the generated interface declarations.