MLIR  19.0.0git
ScalableValueBoundsConstraintSet.cpp
Go to the documentation of this file.
1 //===- ScalableValueBoundsConstraintSet.cpp - Scalable Value Bounds -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 namespace mlir::vector {
12 
13 FailureOr<ConstantOrScalableBound::BoundSize>
15  if (map.isSingleConstant())
16  return BoundSize{map.getSingleConstantResult(), /*scalable=*/false};
17  if (map.getNumResults() != 1 || map.getNumInputs() != 1)
18  return failure();
19  auto binop = dyn_cast<AffineBinaryOpExpr>(map.getResult(0));
20  if (!binop || binop.getKind() != AffineExprKind::Mul)
21  return failure();
22  auto matchConstant = [&](AffineExpr expr, int64_t &constant) -> bool {
23  if (auto cst = dyn_cast<AffineConstantExpr>(expr)) {
24  constant = cst.getValue();
25  return true;
26  }
27  return false;
28  };
29  // Match `s0 * cst` or `cst * s0`:
30  int64_t cst = 0;
31  auto lhs = binop.getLHS();
32  auto rhs = binop.getRHS();
33  if ((matchConstant(lhs, cst) && isa<AffineSymbolExpr>(rhs)) ||
34  (matchConstant(rhs, cst) && isa<AffineSymbolExpr>(lhs))) {
35  return BoundSize{cst, /*scalable=*/true};
36  }
37  return failure();
38 }
39 
41 
44  Value value, std::optional<int64_t> dim, unsigned vscaleMin,
45  unsigned vscaleMax, presburger::BoundType boundType, bool closedUB,
46  StopConditionFn stopCondition) {
47  using namespace presburger;
48  assert(vscaleMin <= vscaleMax);
49 
50  // No stop condition specified: Keep adding constraints until the worklist
51  // is empty.
52  auto defaultStopCondition = [&](Value v, std::optional<int64_t> dim,
54  return false;
55  };
56 
58  value.getContext(), stopCondition ? stopCondition : defaultStopCondition,
59  vscaleMin, vscaleMax);
60  int64_t pos = scalableCstr.insert(value, dim, /*isSymbol=*/false);
61  scalableCstr.processWorklist();
62 
63  // Check the resulting constraints set is valid.
64  if (scalableCstr.cstr.isEmpty()) {
65  return failure();
66  }
67 
68  // Project out all columns apart from vscale and the starting point
69  // (value/dim). This should result in constraints in terms of vscale only.
70  auto projectOutFn = [&](ValueDim p) {
71  bool isStartingPoint =
72  p.first == value &&
73  p.second == dim.value_or(ValueBoundsConstraintSet::kIndexValue);
74  return p.first != scalableCstr.getVscaleValue() && !isStartingPoint;
75  };
76  scalableCstr.projectOut(projectOutFn);
77  // Also project out local variables (these are not tracked by the
78  // ValueBoundsConstraintSet).
79  for (unsigned i = 0, e = scalableCstr.cstr.getNumLocalVars(); i < e; ++i) {
80  scalableCstr.cstr.projectOut(scalableCstr.cstr.getNumDimAndSymbolVars());
81  }
82 
83  assert(scalableCstr.cstr.getNumDimAndSymbolVars() ==
84  scalableCstr.positionToValueDim.size() &&
85  "inconsistent mapping state");
86 
87  // Check that the only columns left are vscale and the starting point.
88  for (int64_t i = 0; i < scalableCstr.cstr.getNumDimAndSymbolVars(); ++i) {
89  if (i == pos)
90  continue;
91  if (scalableCstr.positionToValueDim[i] !=
92  ValueDim(scalableCstr.getVscaleValue(),
94  return failure();
95  }
96  }
97 
98  SmallVector<AffineMap, 1> lowerBound(1), upperBound(1);
99  scalableCstr.cstr.getSliceBounds(pos, 1, value.getContext(), &lowerBound,
100  &upperBound, closedUB);
101 
102  auto invalidBound = [](auto &bound) {
103  return !bound[0] || bound[0].getNumResults() != 1;
104  };
105 
106  AffineMap bound = [&] {
107  if (boundType == BoundType::EQ && !invalidBound(lowerBound) &&
108  lowerBound[0] == lowerBound[0]) {
109  return lowerBound[0];
110  } else if (boundType == BoundType::LB && !invalidBound(lowerBound)) {
111  return lowerBound[0];
112  } else if (boundType == BoundType::UB && !invalidBound(upperBound)) {
113  return upperBound[0];
114  }
115  return AffineMap{};
116  }();
117 
118  if (!bound)
119  return failure();
120 
121  return ConstantOrScalableBound{bound};
122 }
123 
124 } // namespace mlir::vector
function_ref< bool(Region *, ArrayRef< bool > visited)> StopConditionFn
Stop condition for traverseRegionGraph.
Base type for affine expression.
Definition: AffineExpr.h:69
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
int64_t getSingleConstantResult() const
Returns the constant result of this map.
Definition: AffineMap.cpp:369
bool isSingleConstant() const
Returns true if this affine map is a single result constant function.
Definition: AffineMap.cpp:361
unsigned getNumResults() const
Definition: AffineMap.cpp:390
unsigned getNumInputs() const
Definition: AffineMap.cpp:391
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:399
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
A helper class to be used with ValueBoundsOpInterface.
static constexpr int64_t kIndexValue
Dimension identifier to indicate a value is index-typed.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition: Value.h:132
BoundType
The type of bound: equal, lower bound or upper bound.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
@ Mul
RHS of mul is always a constant or a symbolic expression.
A thin wrapper over an AffineMap which can represent a constant bound, or a scalable bound (in terms ...
FailureOr< BoundSize > getSize() const
Get the (possibly) scalable size of the bound, returns failure if the bound cannot be represented as ...
A version of ValueBoundsConstraintSet that can solve for scalable bounds.
static FailureOr< ConstantOrScalableBound > computeScalableBound(Value value, std::optional< int64_t > dim, unsigned vscaleMin, unsigned vscaleMax, presburger::BoundType boundType, bool closedUB=true, StopConditionFn stopCondition=nullptr)
Computes a (possibly) scalable bound for a given value.
Value getVscaleValue() const
Get the value of vscale. Returns nullptr vscale as not been encountered.