MLIR  20.0.0git
ValueBoundsOpInterfaceImpl.cpp
Go to the documentation of this file.
1 //===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
13 
14 using namespace mlir;
15 
16 namespace mlir {
17 namespace scf {
18 namespace {
19 
20 struct ForOpInterface
21  : public ValueBoundsOpInterface::ExternalModel<ForOpInterface, ForOp> {
22 
23  /// Populate bounds of values/dimensions for iter_args/OpResults. If the
24  /// value/dimension size does not change in an iteration, we can deduce that
25  /// it the same as the initial value/dimension.
26  ///
27  /// Example 1:
28  /// %0 = scf.for ... iter_args(%arg0 = %t) -> tensor<?xf32> {
29  /// ...
30  /// %1 = tensor.insert %f into %arg0[...] : tensor<?xf32>
31  /// scf.yield %1 : tensor<?xf32>
32  /// }
33  /// --> bound(%0)[0] == bound(%t)[0]
34  /// --> bound(%arg0)[0] == bound(%t)[0]
35  ///
36  /// Example 2:
37  /// %0 = scf.for ... iter_args(%arg0 = %t) -> tensor<?xf32> {
38  /// %sz = tensor.dim %arg0 : tensor<?xf32>
39  /// %incr = arith.addi %sz, %c1 : index
40  /// %1 = tensor.empty(%incr) : tensor<?xf32>
41  /// scf.yield %1 : tensor<?xf32>
42  /// }
43  /// --> The yielded tensor dimension size changes with each iteration. Such
44  /// loops are not supported and no constraints are added.
45  static void populateIterArgBounds(scf::ForOp forOp, Value value,
46  std::optional<int64_t> dim,
48  // `value` is an iter_arg or an OpResult.
49  int64_t iterArgIdx;
50  if (auto iterArg = llvm::dyn_cast<BlockArgument>(value)) {
51  iterArgIdx = iterArg.getArgNumber() - forOp.getNumInductionVars();
52  } else {
53  iterArgIdx = llvm::cast<OpResult>(value).getResultNumber();
54  }
55 
56  Value yieldedValue = cast<scf::YieldOp>(forOp.getBody()->getTerminator())
57  .getOperand(iterArgIdx);
58  Value iterArg = forOp.getRegionIterArg(iterArgIdx);
59  Value initArg = forOp.getInitArgs()[iterArgIdx];
60 
61  // An EQ constraint can be added if the yielded value (dimension size)
62  // equals the corresponding block argument (dimension size).
63  if (cstr.populateAndCompare(
64  /*lhs=*/{yieldedValue, dim},
65  ValueBoundsConstraintSet::ComparisonOperator::EQ,
66  /*rhs=*/{iterArg, dim})) {
67  if (dim.has_value()) {
68  cstr.bound(value)[*dim] == cstr.getExpr(initArg, dim);
69  } else {
70  cstr.bound(value) == cstr.getExpr(initArg);
71  }
72  }
73  }
74 
75  void populateBoundsForIndexValue(Operation *op, Value value,
76  ValueBoundsConstraintSet &cstr) const {
77  auto forOp = cast<ForOp>(op);
78 
79  if (value == forOp.getInductionVar()) {
80  // TODO: Take into account step size.
81  cstr.bound(value) >= forOp.getLowerBound();
82  cstr.bound(value) < forOp.getUpperBound();
83  return;
84  }
85 
86  // Handle iter_args and OpResults.
87  populateIterArgBounds(forOp, value, std::nullopt, cstr);
88  }
89 
90  void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
91  ValueBoundsConstraintSet &cstr) const {
92  auto forOp = cast<ForOp>(op);
93  // Handle iter_args and OpResults.
94  populateIterArgBounds(forOp, value, dim, cstr);
95  }
96 };
97 
98 struct ForallOpInterface
99  : public ValueBoundsOpInterface::ExternalModel<ForallOpInterface,
100  ForallOp> {
101 
102  void populateBoundsForIndexValue(Operation *op, Value value,
103  ValueBoundsConstraintSet &cstr) const {
104  auto forallOp = cast<ForallOp>(op);
105 
106  // Index values should be induction variables, since the semantics of
107  // tensor::ParallelInsertSliceOp requires forall outputs to be ranked
108  // tensors.
109  auto blockArg = cast<BlockArgument>(value);
110  assert(blockArg.getArgNumber() < forallOp.getInductionVars().size() &&
111  "expected index value to be an induction var");
112  int64_t idx = blockArg.getArgNumber();
113  // TODO: Take into account step size.
114  AffineExpr lb = cstr.getExpr(forallOp.getMixedLowerBound()[idx]);
115  AffineExpr ub = cstr.getExpr(forallOp.getMixedUpperBound()[idx]);
116  cstr.bound(value) >= lb;
117  cstr.bound(value) < ub;
118  }
119 
120  void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
121  ValueBoundsConstraintSet &cstr) const {
122  auto forallOp = cast<ForallOp>(op);
123 
124  // `value` is an iter_arg or an OpResult.
125  int64_t iterArgIdx;
126  if (auto iterArg = llvm::dyn_cast<BlockArgument>(value)) {
127  iterArgIdx = iterArg.getArgNumber() - forallOp.getInductionVars().size();
128  } else {
129  iterArgIdx = llvm::cast<OpResult>(value).getResultNumber();
130  }
131 
132  // The forall results and output arguments have the same sizes as the output
133  // operands.
134  Value outputOperand = forallOp.getOutputs()[iterArgIdx];
135  cstr.bound(value)[dim] == cstr.getExpr(outputOperand, dim);
136  }
137 };
138 
139 struct IfOpInterface
140  : public ValueBoundsOpInterface::ExternalModel<IfOpInterface, IfOp> {
141 
142  static void populateBounds(scf::IfOp ifOp, Value value,
143  std::optional<int64_t> dim,
144  ValueBoundsConstraintSet &cstr) {
145  unsigned int resultNum = cast<OpResult>(value).getResultNumber();
146  Value thenValue = ifOp.thenYield().getResults()[resultNum];
147  Value elseValue = ifOp.elseYield().getResults()[resultNum];
148 
149  auto boundsBuilder = cstr.bound(value);
150  if (dim)
151  boundsBuilder[*dim];
152 
153  // Compare yielded values.
154  // If thenValue <= elseValue:
155  // * result <= elseValue
156  // * result >= thenValue
157  if (cstr.populateAndCompare(
158  /*lhs=*/{thenValue, dim},
159  ValueBoundsConstraintSet::ComparisonOperator::LE,
160  /*rhs=*/{elseValue, dim})) {
161  if (dim) {
162  cstr.bound(value)[*dim] >= cstr.getExpr(thenValue, dim);
163  cstr.bound(value)[*dim] <= cstr.getExpr(elseValue, dim);
164  } else {
165  cstr.bound(value) >= thenValue;
166  cstr.bound(value) <= elseValue;
167  }
168  }
169  // If elseValue <= thenValue:
170  // * result <= thenValue
171  // * result >= elseValue
172  if (cstr.populateAndCompare(
173  /*lhs=*/{elseValue, dim},
174  ValueBoundsConstraintSet::ComparisonOperator::LE,
175  /*rhs=*/{thenValue, dim})) {
176  if (dim) {
177  cstr.bound(value)[*dim] >= cstr.getExpr(elseValue, dim);
178  cstr.bound(value)[*dim] <= cstr.getExpr(thenValue, dim);
179  } else {
180  cstr.bound(value) >= elseValue;
181  cstr.bound(value) <= thenValue;
182  }
183  }
184  }
185 
186  void populateBoundsForIndexValue(Operation *op, Value value,
187  ValueBoundsConstraintSet &cstr) const {
188  populateBounds(cast<IfOp>(op), value, /*dim=*/std::nullopt, cstr);
189  }
190 
191  void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
192  ValueBoundsConstraintSet &cstr) const {
193  populateBounds(cast<IfOp>(op), value, dim, cstr);
194  }
195 };
196 
197 } // namespace
198 } // namespace scf
199 } // namespace mlir
200 
202  DialectRegistry &registry) {
203  registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) {
204  scf::ForOp::attachInterface<scf::ForOpInterface>(*ctx);
205  scf::ForallOp::attachInterface<scf::ForallOpInterface>(*ctx);
206  scf::IfOp::attachInterface<scf::IfOpInterface>(*ctx);
207  });
208 }
Base type for affine expression.
Definition: AffineExpr.h:68
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A helper class to be used with ValueBoundsOpInterface.
AffineExpr getExpr(Value value, std::optional< int64_t > dim=std::nullopt)
Return an expression that represents the given index-typed value or shaped value dimension.
BoundBuilder bound(Value value)
Add a bound for the given index-typed value or shaped value.
bool populateAndCompare(const Variable &lhs, ComparisonOperator cmp, const Variable &rhs)
Populate constraints for lhs/rhs (until the stop condition is met).
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void registerValueBoundsOpInterfaceExternalModels(DialectRegistry &registry)
Include the generated interface declarations.