MLIR  20.0.0git
ValueBoundsOpInterfaceImpl.cpp
Go to the documentation of this file.
1 //===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
13 
14 using namespace mlir;
15 
16 namespace mlir {
17 namespace scf {
18 namespace {
19 
20 struct ForOpInterface
21  : public ValueBoundsOpInterface::ExternalModel<ForOpInterface, ForOp> {
22 
23  /// Populate bounds of values/dimensions for iter_args/OpResults. If the
24  /// value/dimension size does not change in an iteration, we can deduce that
25  /// it the same as the initial value/dimension.
26  ///
27  /// Example 1:
28  /// %0 = scf.for ... iter_args(%arg0 = %t) -> tensor<?xf32> {
29  /// ...
30  /// %1 = tensor.insert %f into %arg0[...] : tensor<?xf32>
31  /// scf.yield %1 : tensor<?xf32>
32  /// }
33  /// --> bound(%0)[0] == bound(%t)[0]
34  /// --> bound(%arg0)[0] == bound(%t)[0]
35  ///
36  /// Example 2:
37  /// %0 = scf.for ... iter_args(%arg0 = %t) -> tensor<?xf32> {
38  /// %sz = tensor.dim %arg0 : tensor<?xf32>
39  /// %incr = arith.addi %sz, %c1 : index
40  /// %1 = tensor.empty(%incr) : tensor<?xf32>
41  /// scf.yield %1 : tensor<?xf32>
42  /// }
43  /// --> The yielded tensor dimension size changes with each iteration. Such
44  /// loops are not supported and no constraints are added.
45  static void populateIterArgBounds(scf::ForOp forOp, Value value,
46  std::optional<int64_t> dim,
48  // `value` is an iter_arg or an OpResult.
49  int64_t iterArgIdx;
50  if (auto iterArg = llvm::dyn_cast<BlockArgument>(value)) {
51  iterArgIdx = iterArg.getArgNumber() - forOp.getNumInductionVars();
52  } else {
53  iterArgIdx = llvm::cast<OpResult>(value).getResultNumber();
54  }
55 
56  Value yieldedValue = cast<scf::YieldOp>(forOp.getBody()->getTerminator())
57  .getOperand(iterArgIdx);
58  Value iterArg = forOp.getRegionIterArg(iterArgIdx);
59  Value initArg = forOp.getInitArgs()[iterArgIdx];
60 
61  // An EQ constraint can be added if the yielded value (dimension size)
62  // equals the corresponding block argument (dimension size).
63  if (cstr.populateAndCompare(
64  /*lhs=*/{yieldedValue, dim},
65  ValueBoundsConstraintSet::ComparisonOperator::EQ,
66  /*rhs=*/{iterArg, dim})) {
67  if (dim.has_value()) {
68  cstr.bound(value)[*dim] == cstr.getExpr(initArg, dim);
69  } else {
70  cstr.bound(value) == cstr.getExpr(initArg);
71  }
72  }
73  }
74 
75  void populateBoundsForIndexValue(Operation *op, Value value,
76  ValueBoundsConstraintSet &cstr) const {
77  auto forOp = cast<ForOp>(op);
78 
79  if (value == forOp.getInductionVar()) {
80  // TODO: Take into account step size.
81  cstr.bound(value) >= forOp.getLowerBound();
82  cstr.bound(value) < forOp.getUpperBound();
83  return;
84  }
85 
86  // Handle iter_args and OpResults.
87  populateIterArgBounds(forOp, value, std::nullopt, cstr);
88  }
89 
90  void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
91  ValueBoundsConstraintSet &cstr) const {
92  auto forOp = cast<ForOp>(op);
93  // Handle iter_args and OpResults.
94  populateIterArgBounds(forOp, value, dim, cstr);
95  }
96 };
97 
98 struct IfOpInterface
99  : public ValueBoundsOpInterface::ExternalModel<IfOpInterface, IfOp> {
100 
101  static void populateBounds(scf::IfOp ifOp, Value value,
102  std::optional<int64_t> dim,
103  ValueBoundsConstraintSet &cstr) {
104  unsigned int resultNum = cast<OpResult>(value).getResultNumber();
105  Value thenValue = ifOp.thenYield().getResults()[resultNum];
106  Value elseValue = ifOp.elseYield().getResults()[resultNum];
107 
108  auto boundsBuilder = cstr.bound(value);
109  if (dim)
110  boundsBuilder[*dim];
111 
112  // Compare yielded values.
113  // If thenValue <= elseValue:
114  // * result <= elseValue
115  // * result >= thenValue
116  if (cstr.populateAndCompare(
117  /*lhs=*/{thenValue, dim},
118  ValueBoundsConstraintSet::ComparisonOperator::LE,
119  /*rhs=*/{elseValue, dim})) {
120  if (dim) {
121  cstr.bound(value)[*dim] >= cstr.getExpr(thenValue, dim);
122  cstr.bound(value)[*dim] <= cstr.getExpr(elseValue, dim);
123  } else {
124  cstr.bound(value) >= thenValue;
125  cstr.bound(value) <= elseValue;
126  }
127  }
128  // If elseValue <= thenValue:
129  // * result <= thenValue
130  // * result >= elseValue
131  if (cstr.populateAndCompare(
132  /*lhs=*/{elseValue, dim},
133  ValueBoundsConstraintSet::ComparisonOperator::LE,
134  /*rhs=*/{thenValue, dim})) {
135  if (dim) {
136  cstr.bound(value)[*dim] >= cstr.getExpr(elseValue, dim);
137  cstr.bound(value)[*dim] <= cstr.getExpr(thenValue, dim);
138  } else {
139  cstr.bound(value) >= elseValue;
140  cstr.bound(value) <= thenValue;
141  }
142  }
143  }
144 
145  void populateBoundsForIndexValue(Operation *op, Value value,
146  ValueBoundsConstraintSet &cstr) const {
147  populateBounds(cast<IfOp>(op), value, /*dim=*/std::nullopt, cstr);
148  }
149 
150  void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
151  ValueBoundsConstraintSet &cstr) const {
152  populateBounds(cast<IfOp>(op), value, dim, cstr);
153  }
154 };
155 
156 } // namespace
157 } // namespace scf
158 } // namespace mlir
159 
161  DialectRegistry &registry) {
162  registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) {
163  scf::ForOp::attachInterface<scf::ForOpInterface>(*ctx);
164  scf::IfOp::attachInterface<scf::IfOpInterface>(*ctx);
165  });
166 }
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A helper class to be used with ValueBoundsOpInterface.
AffineExpr getExpr(Value value, std::optional< int64_t > dim=std::nullopt)
Return an expression that represents the given index-typed value or shaped value dimension.
BoundBuilder bound(Value value)
Add a bound for the given index-typed value or shaped value.
bool populateAndCompare(const Variable &lhs, ComparisonOperator cmp, const Variable &rhs)
Populate constraints for lhs/rhs (until the stop condition is met).
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void registerValueBoundsOpInterfaceExternalModels(DialectRegistry &registry)
Include the generated interface declarations.