MLIR  20.0.0git
ValueBoundsOpInterfaceImpl.cpp
Go to the documentation of this file.
1 //===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
13 
14 using namespace mlir;
15 
16 namespace mlir {
17 namespace tensor {
18 namespace {
19 
20 struct CastOpInterface
21  : public ValueBoundsOpInterface::ExternalModel<CastOpInterface, CastOp> {
22  void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
23  ValueBoundsConstraintSet &cstr) const {
24  auto castOp = cast<CastOp>(op);
25  assert(value == castOp.getResult() && "invalid value");
26 
27  if (llvm::isa<RankedTensorType>(castOp.getResult().getType()) &&
28  llvm::isa<RankedTensorType>(castOp.getSource().getType())) {
29  cstr.bound(value)[dim] == cstr.getExpr(castOp.getSource(), dim);
30  }
31  }
32 };
33 
34 struct DimOpInterface
35  : public ValueBoundsOpInterface::ExternalModel<DimOpInterface, DimOp> {
36  void populateBoundsForIndexValue(Operation *op, Value value,
37  ValueBoundsConstraintSet &cstr) const {
38  auto dimOp = cast<DimOp>(op);
39  assert(value == dimOp.getResult() && "invalid value");
40 
41  auto constIndex = dimOp.getConstantIndex();
42  if (!constIndex.has_value())
43  return;
44  cstr.bound(value) == cstr.getExpr(dimOp.getSource(), *constIndex);
45  }
46 };
47 
48 struct EmptyOpInterface
49  : public ValueBoundsOpInterface::ExternalModel<EmptyOpInterface, EmptyOp> {
50  void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
51  ValueBoundsConstraintSet &cstr) const {
52  auto emptyOp = cast<EmptyOp>(op);
53  assert(value == emptyOp.getResult() && "invalid value");
54 
55  cstr.bound(value)[dim] == emptyOp.getMixedSizes()[dim];
56  }
57 };
58 
59 struct ExtractSliceOpInterface
60  : public ValueBoundsOpInterface::ExternalModel<ExtractSliceOpInterface,
61  ExtractSliceOp> {
62  void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
63  ValueBoundsConstraintSet &cstr) const {
64  auto extractSliceOp = cast<ExtractSliceOp>(op);
65  assert(value == extractSliceOp.getResult() && "invalid value");
66 
67  llvm::SmallBitVector dropped = extractSliceOp.getDroppedDims();
68  int64_t ctr = -1;
69  for (int64_t i = 0, e = extractSliceOp.getMixedSizes().size(); i < e; ++i) {
70  // Skip over rank-reduced dimensions.
71  if (!dropped.test(i))
72  ++ctr;
73  if (ctr == dim) {
74  cstr.bound(value)[dim] == extractSliceOp.getMixedSizes()[i];
75  return;
76  }
77  }
78  llvm_unreachable("could not find non-rank-reduced dim");
79  }
80 };
81 
82 struct PadOpInterface
83  : public ValueBoundsOpInterface::ExternalModel<PadOpInterface, PadOp> {
84  void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
85  ValueBoundsConstraintSet &cstr) const {
86  auto padOp = cast<PadOp>(op);
87  assert(value == padOp.getResult() && "invalid value");
88 
89  AffineExpr srcSize = cstr.getExpr(padOp.getSource(), dim);
90  AffineExpr lowPad = cstr.getExpr(padOp.getMixedLowPad()[dim]);
91  AffineExpr highPad = cstr.getExpr(padOp.getMixedHighPad()[dim]);
92  cstr.bound(value)[dim] == srcSize + lowPad + highPad;
93  }
94 };
95 
96 struct RankOpInterface
97  : public ValueBoundsOpInterface::ExternalModel<RankOpInterface, RankOp> {
98  void populateBoundsForIndexValue(Operation *op, Value value,
99  ValueBoundsConstraintSet &cstr) const {
100  auto rankOp = cast<RankOp>(op);
101  assert(value == rankOp.getResult() && "invalid value");
102 
103  auto tensorType =
104  llvm::dyn_cast<RankedTensorType>(rankOp.getTensor().getType());
105  if (!tensorType)
106  return;
107  cstr.bound(value) == tensorType.getRank();
108  }
109 };
110 
111 } // namespace
112 } // namespace tensor
113 } // namespace mlir
114 
116  DialectRegistry &registry) {
117  registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
118  tensor::CastOp::attachInterface<tensor::CastOpInterface>(*ctx);
119  tensor::DimOp::attachInterface<tensor::DimOpInterface>(*ctx);
120  tensor::EmptyOp::attachInterface<tensor::EmptyOpInterface>(*ctx);
121  tensor::ExtractSliceOp::attachInterface<tensor::ExtractSliceOpInterface>(
122  *ctx);
123  tensor::PadOp::attachInterface<tensor::PadOpInterface>(*ctx);
124  tensor::RankOp::attachInterface<tensor::RankOpInterface>(*ctx);
125  // Note: ValueBoundsOpInterface implementation is not required for ops that
126  // implement `DestinationStyleOpInterface` (for querying shaped OpResults).
127  });
128 }
Base type for affine expression.
Definition: AffineExpr.h:68
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A helper class to be used with ValueBoundsOpInterface.
AffineExpr getExpr(Value value, std::optional< int64_t > dim=std::nullopt)
Return an expression that represents the given index-typed value or shaped value dimension.
BoundBuilder bound(Value value)
Add a bound for the given index-typed value or shaped value.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void registerValueBoundsOpInterfaceExternalModels(DialectRegistry &registry)
Include the generated interface declarations.