MLIR  18.0.0git
ValueBoundsOpInterfaceImpl.cpp
Go to the documentation of this file.
1 //===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
13 
14 using namespace mlir;
16 
17 namespace mlir {
18 namespace scf {
19 namespace {
20 
21 struct ForOpInterface
22  : public ValueBoundsOpInterface::ExternalModel<ForOpInterface, ForOp> {
23 
24  /// Populate bounds of values/dimensions for iter_args/OpResults.
25  static void populateIterArgBounds(scf::ForOp forOp, Value value,
26  std::optional<int64_t> dim,
28  // `value` is an iter_arg or an OpResult.
29  int64_t iterArgIdx;
30  if (auto iterArg = llvm::dyn_cast<BlockArgument>(value)) {
31  iterArgIdx = iterArg.getArgNumber() - forOp.getNumInductionVars();
32  } else {
33  iterArgIdx = llvm::cast<OpResult>(value).getResultNumber();
34  }
35 
36  // An EQ constraint can be added if the yielded value (dimension size)
37  // equals the corresponding block argument (dimension size).
38  Value yieldedValue = cast<scf::YieldOp>(forOp.getBody()->getTerminator())
39  .getOperand(iterArgIdx);
40  Value iterArg = forOp.getRegionIterArg(iterArgIdx);
41  Value initArg = forOp.getInitArgs()[iterArgIdx];
42 
43  auto addEqBound = [&]() {
44  if (dim.has_value()) {
45  cstr.bound(value)[*dim] == cstr.getExpr(initArg, dim);
46  } else {
47  cstr.bound(value) == initArg;
48  }
49  };
50 
51  if (yieldedValue == iterArg) {
52  addEqBound();
53  return;
54  }
55 
56  // Compute EQ bound for yielded value.
57  AffineMap bound;
58  ValueDimList boundOperands;
60  bound, boundOperands, BoundType::EQ, yieldedValue, dim,
61  [&](Value v, std::optional<int64_t> d) {
62  // Stop when reaching a block argument of the loop body.
63  if (auto bbArg = llvm::dyn_cast<BlockArgument>(v))
64  return bbArg.getOwner()->getParentOp() == forOp;
65  // Stop when reaching a value that is defined outside of the loop. It
66  // is impossible to reach an iter_arg from there.
67  Operation *op = v.getDefiningOp();
68  return forOp.getRegion().findAncestorOpInRegion(*op) == nullptr;
69  });
70  if (failed(status))
71  return;
72  if (bound.getNumResults() != 1)
73  return;
74 
75  // Check if computed bound equals the corresponding iter_arg.
76  Value singleValue = nullptr;
77  std::optional<int64_t> singleDim;
78  if (auto dimExpr = dyn_cast<AffineDimExpr>(bound.getResult(0))) {
79  int64_t idx = dimExpr.getPosition();
80  singleValue = boundOperands[idx].first;
81  singleDim = boundOperands[idx].second;
82  } else if (auto symExpr = dyn_cast<AffineSymbolExpr>(bound.getResult(0))) {
83  int64_t idx = symExpr.getPosition() + bound.getNumDims();
84  singleValue = boundOperands[idx].first;
85  singleDim = boundOperands[idx].second;
86  }
87  if (singleValue == iterArg && singleDim == dim)
88  addEqBound();
89  }
90 
91  void populateBoundsForIndexValue(Operation *op, Value value,
92  ValueBoundsConstraintSet &cstr) const {
93  auto forOp = cast<ForOp>(op);
94 
95  if (value == forOp.getInductionVar()) {
96  // TODO: Take into account step size.
97  cstr.bound(value) >= forOp.getLowerBound();
98  cstr.bound(value) < forOp.getUpperBound();
99  return;
100  }
101 
102  // Handle iter_args and OpResults.
103  populateIterArgBounds(forOp, value, std::nullopt, cstr);
104  }
105 
106  void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
107  ValueBoundsConstraintSet &cstr) const {
108  auto forOp = cast<ForOp>(op);
109  // Handle iter_args and OpResults.
110  populateIterArgBounds(forOp, value, dim, cstr);
111  }
112 };
113 
114 } // namespace
115 } // namespace scf
116 } // namespace mlir
117 
119  DialectRegistry &registry) {
120  registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) {
121  scf::ForOp::attachInterface<scf::ForOpInterface>(*ctx);
122  });
123 }
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
unsigned getNumDims() const
Definition: AffineMap.cpp:374
unsigned getNumResults() const
Definition: AffineMap.cpp:382
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:391
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:665
Operation * findAncestorOpInRegion(Operation &op)
Returns 'op' if 'op' lies in this region, or otherwise finds the ancestor of 'op' that lies in this r...
Definition: Region.cpp:168
A helper class to be used with ValueBoundsOpInterface.
AffineExpr getExpr(Value value, std::optional< int64_t > dim=std::nullopt)
Return an expression that represents the given index-typed value or shaped value dimension.
static LogicalResult computeBound(AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type, Value value, std::optional< int64_t > dim, StopConditionFn stopCondition, bool closedUB=false)
Compute a bound for the given index-typed value or shape dimension size.
BoundBuilder bound(Value value)
Add a bound for the given index-typed value or shaped value.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
BoundType
The type of bound: equal, lower bound or upper bound.
void registerValueBoundsOpInterfaceExternalModels(DialectRegistry &registry)
Include the generated interface declarations.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26