MLIR  21.0.0git
ValueBoundsOpInterfaceImpl.cpp
Go to the documentation of this file.
1 //===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
13 
14 using namespace mlir;
15 
16 namespace mlir {
17 namespace tensor {
18 namespace {
19 
20 struct CastOpInterface
21  : public ValueBoundsOpInterface::ExternalModel<CastOpInterface, CastOp> {
22  void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
23  ValueBoundsConstraintSet &cstr) const {
24  auto castOp = cast<CastOp>(op);
25  assert(value == castOp.getResult() && "invalid value");
26 
27  if (llvm::isa<RankedTensorType>(castOp.getResult().getType()) &&
28  llvm::isa<RankedTensorType>(castOp.getSource().getType())) {
29  cstr.bound(value)[dim] == cstr.getExpr(castOp.getSource(), dim);
30  }
31  }
32 };
33 
34 struct DimOpInterface
35  : public ValueBoundsOpInterface::ExternalModel<DimOpInterface, DimOp> {
36  void populateBoundsForIndexValue(Operation *op, Value value,
37  ValueBoundsConstraintSet &cstr) const {
38  auto dimOp = cast<DimOp>(op);
39  assert(value == dimOp.getResult() && "invalid value");
40 
41  cstr.bound(value) >= 0;
42  auto constIndex = dimOp.getConstantIndex();
43  if (!constIndex.has_value())
44  return;
45  cstr.bound(value) == cstr.getExpr(dimOp.getSource(), *constIndex);
46  }
47 };
48 
49 struct EmptyOpInterface
50  : public ValueBoundsOpInterface::ExternalModel<EmptyOpInterface, EmptyOp> {
51  void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
52  ValueBoundsConstraintSet &cstr) const {
53  auto emptyOp = cast<EmptyOp>(op);
54  assert(value == emptyOp.getResult() && "invalid value");
55 
56  cstr.bound(value)[dim] == emptyOp.getMixedSizes()[dim];
57  }
58 };
59 
60 struct ExtractSliceOpInterface
61  : public ValueBoundsOpInterface::ExternalModel<ExtractSliceOpInterface,
62  ExtractSliceOp> {
63  void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
64  ValueBoundsConstraintSet &cstr) const {
65  auto extractSliceOp = cast<ExtractSliceOp>(op);
66  assert(value == extractSliceOp.getResult() && "invalid value");
67 
68  llvm::SmallBitVector dropped = extractSliceOp.getDroppedDims();
69  int64_t ctr = -1;
70  for (int64_t i = 0, e = extractSliceOp.getMixedSizes().size(); i < e; ++i) {
71  // Skip over rank-reduced dimensions.
72  if (!dropped.test(i))
73  ++ctr;
74  if (ctr == dim) {
75  cstr.bound(value)[dim] == extractSliceOp.getMixedSizes()[i];
76  return;
77  }
78  }
79  llvm_unreachable("could not find non-rank-reduced dim");
80  }
81 };
82 
83 struct PadOpInterface
84  : public ValueBoundsOpInterface::ExternalModel<PadOpInterface, PadOp> {
85  void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
86  ValueBoundsConstraintSet &cstr) const {
87  auto padOp = cast<PadOp>(op);
88  assert(value == padOp.getResult() && "invalid value");
89 
90  AffineExpr srcSize = cstr.getExpr(padOp.getSource(), dim);
91  AffineExpr lowPad = cstr.getExpr(padOp.getMixedLowPad()[dim]);
92  AffineExpr highPad = cstr.getExpr(padOp.getMixedHighPad()[dim]);
93  cstr.bound(value)[dim] == srcSize + lowPad + highPad;
94  }
95 };
96 
97 struct RankOpInterface
98  : public ValueBoundsOpInterface::ExternalModel<RankOpInterface, RankOp> {
99  void populateBoundsForIndexValue(Operation *op, Value value,
100  ValueBoundsConstraintSet &cstr) const {
101  auto rankOp = cast<RankOp>(op);
102  assert(value == rankOp.getResult() && "invalid value");
103 
104  auto tensorType =
105  llvm::dyn_cast<RankedTensorType>(rankOp.getTensor().getType());
106  if (!tensorType)
107  return;
108  cstr.bound(value) == tensorType.getRank();
109  }
110 };
111 
112 } // namespace
113 } // namespace tensor
114 } // namespace mlir
115 
117  DialectRegistry &registry) {
118  registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
119  tensor::CastOp::attachInterface<tensor::CastOpInterface>(*ctx);
120  tensor::DimOp::attachInterface<tensor::DimOpInterface>(*ctx);
121  tensor::EmptyOp::attachInterface<tensor::EmptyOpInterface>(*ctx);
122  tensor::ExtractSliceOp::attachInterface<tensor::ExtractSliceOpInterface>(
123  *ctx);
124  tensor::PadOp::attachInterface<tensor::PadOpInterface>(*ctx);
125  tensor::RankOp::attachInterface<tensor::RankOpInterface>(*ctx);
126  // Note: ValueBoundsOpInterface implementation is not required for ops that
127  // implement `DestinationStyleOpInterface` (for querying shaped OpResults).
128  });
129 }
Base type for affine expression.
Definition: AffineExpr.h:68
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A helper class to be used with ValueBoundsOpInterface.
AffineExpr getExpr(Value value, std::optional< int64_t > dim=std::nullopt)
Return an expression that represents the given index-typed value or shaped value dimension.
BoundBuilder bound(Value value)
Add a bound for the given index-typed value or shaped value.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void registerValueBoundsOpInterfaceExternalModels(DialectRegistry &registry)
Include the generated interface declarations.