MLIR  16.0.0git
IntegerRangeAnalysis.cpp
Go to the documentation of this file.
1 //===- IntegerRangeAnalysis.cpp - Integer range analysis --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the dataflow analysis class for integer range inference
10 // which is used in transformations over the `arith` dialect such as
11 // branch elimination or signed->unsigned rewriting
12 //
13 //===----------------------------------------------------------------------===//
14 
19 #include "llvm/Support/Debug.h"
20 
21 #define DEBUG_TYPE "int-range-analysis"
22 
23 using namespace mlir;
24 using namespace mlir::dataflow;
25 
27  unsigned width = ConstantIntRanges::getStorageBitwidth(value.getType());
28  APInt umin = APInt::getMinValue(width);
29  APInt umax = APInt::getMaxValue(width);
30  APInt smin = width != 0 ? APInt::getSignedMinValue(width) : umin;
31  APInt smax = width != 0 ? APInt::getSignedMaxValue(width) : umax;
32  return IntegerValueRange{ConstantIntRanges{umin, umax, smin, smax}};
33 }
34 
36  Lattice::onUpdate(solver);
37 
38  // If the integer range can be narrowed to a constant, update the constant
39  // value of the SSA value.
41  auto value = point.get<Value>();
42  auto *cv = solver->getOrCreateState<Lattice<ConstantValue>>(value);
43  if (!constant)
44  return solver->propagateIfChanged(
45  cv, cv->join(ConstantValue::getUnknownConstant()));
46 
47  Dialect *dialect;
48  if (auto *parent = value.getDefiningOp())
49  dialect = parent->getDialect();
50  else
51  dialect = value.getParentBlock()->getParentOp()->getDialect();
52  solver->propagateIfChanged(
53  cv, cv->join(ConstantValue(IntegerAttr::get(value.getType(), *constant),
54  dialect)));
55 }
56 
60  // If the lattice on any operand is unitialized, bail out.
61  if (llvm::any_of(operands, [](const IntegerValueRangeLattice *lattice) {
62  return lattice->getValue().isUninitialized();
63  })) {
64  return;
65  }
66 
67  // Ignore non-integer outputs - return early if the op has no scalar
68  // integer results
69  bool hasIntegerResult = false;
70  for (auto it : llvm::zip(results, op->getResults())) {
71  Value value = std::get<1>(it);
72  if (value.getType().isIntOrIndex()) {
73  hasIntegerResult = true;
74  } else {
75  IntegerValueRangeLattice *lattice = std::get<0>(it);
76  propagateIfChanged(lattice,
78  }
79  }
80  if (!hasIntegerResult)
81  return;
82 
83  auto inferrable = dyn_cast<InferIntRangeInterface>(op);
84  if (!inferrable)
85  return setAllToEntryStates(results);
86 
87  LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n");
89  llvm::map_range(operands, [](const IntegerValueRangeLattice *val) {
90  return val->getValue().getValue();
91  }));
92 
93  auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) {
94  auto result = v.dyn_cast<OpResult>();
95  if (!result)
96  return;
97  assert(llvm::is_contained(op->getResults(), result));
98 
99  LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n");
100  IntegerValueRangeLattice *lattice = results[result.getResultNumber()];
101  IntegerValueRange oldRange = lattice->getValue();
102 
103  ChangeResult changed = lattice->join(IntegerValueRange{attrs});
104 
105  // Catch loop results with loop variant bounds and conservatively make
106  // them [-inf, inf] so we don't circle around infinitely often (because
107  // the dataflow analysis in MLIR doesn't attempt to work out trip counts
108  // and often can't).
109  bool isYieldedResult = llvm::any_of(v.getUsers(), [](Operation *op) {
110  return op->hasTrait<OpTrait::IsTerminator>();
111  });
112  if (isYieldedResult && !oldRange.isUninitialized() &&
113  !(lattice->getValue() == oldRange)) {
114  LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
115  changed |= lattice->join(IntegerValueRange::getMaxRange(v));
116  }
117  propagateIfChanged(lattice, changed);
118  };
119 
120  inferrable.inferResultRanges(argRanges, joinCallback);
121 }
122 
124  Operation *op, const RegionSuccessor &successor,
125  ArrayRef<IntegerValueRangeLattice *> argLattices, unsigned firstIndex) {
126  if (auto inferrable = dyn_cast<InferIntRangeInterface>(op)) {
127  LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n");
129  llvm::map_range(op->getOperands(), [&](Value value) {
130  return getLatticeElementFor(op, value)->getValue().getValue();
131  }));
132 
133  auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) {
134  auto arg = v.dyn_cast<BlockArgument>();
135  if (!arg)
136  return;
137  if (!llvm::is_contained(successor.getSuccessor()->getArguments(), arg))
138  return;
139 
140  LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n");
141  IntegerValueRangeLattice *lattice = argLattices[arg.getArgNumber()];
142  IntegerValueRange oldRange = lattice->getValue();
143 
144  ChangeResult changed = lattice->join(IntegerValueRange{attrs});
145 
146  // Catch loop results with loop variant bounds and conservatively make
147  // them [-inf, inf] so we don't circle around infinitely often (because
148  // the dataflow analysis in MLIR doesn't attempt to work out trip counts
149  // and often can't).
150  bool isYieldedValue = llvm::any_of(v.getUsers(), [](Operation *op) {
151  return op->hasTrait<OpTrait::IsTerminator>();
152  });
153  if (isYieldedValue && !oldRange.isUninitialized() &&
154  !(lattice->getValue() == oldRange)) {
155  LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
156  changed |= lattice->join(IntegerValueRange::getMaxRange(v));
157  }
158  propagateIfChanged(lattice, changed);
159  };
160 
161  inferrable.inferResultRanges(argRanges, joinCallback);
162  return;
163  }
164 
165  /// Given the results of getConstant{Lower,Upper}Bound() or getConstantStep()
166  /// on a LoopLikeInterface return the lower/upper bound for that result if
167  /// possible.
168  auto getLoopBoundFromFold = [&](Optional<OpFoldResult> loopBound,
169  Type boundType, bool getUpper) {
170  unsigned int width = ConstantIntRanges::getStorageBitwidth(boundType);
171  if (loopBound.has_value()) {
172  if (loopBound->is<Attribute>()) {
173  if (auto bound =
174  loopBound->get<Attribute>().dyn_cast_or_null<IntegerAttr>())
175  return bound.getValue();
176  } else if (auto value = loopBound->dyn_cast<Value>()) {
177  const IntegerValueRangeLattice *lattice =
179  if (lattice != nullptr)
180  return getUpper ? lattice->getValue().getValue().smax()
181  : lattice->getValue().getValue().smin();
182  }
183  }
184  // Given the results of getConstant{Lower,Upper}Bound()
185  // or getConstantStep() on a LoopLikeInterface return the lower/upper
186  // bound
187  return getUpper ? APInt::getSignedMaxValue(width)
188  : APInt::getSignedMinValue(width);
189  };
190 
191  // Infer bounds for loop arguments that have static bounds
192  if (auto loop = dyn_cast<LoopLikeOpInterface>(op)) {
193  Optional<Value> iv = loop.getSingleInductionVar();
194  if (!iv) {
196  op, successor, argLattices, firstIndex);
197  }
198  Optional<OpFoldResult> lowerBound = loop.getSingleLowerBound();
199  Optional<OpFoldResult> upperBound = loop.getSingleUpperBound();
200  Optional<OpFoldResult> step = loop.getSingleStep();
201  APInt min = getLoopBoundFromFold(lowerBound, iv->getType(),
202  /*getUpper=*/false);
203  APInt max = getLoopBoundFromFold(upperBound, iv->getType(),
204  /*getUpper=*/true);
205  // Assume positivity for uniscoverable steps by way of getUpper = true.
206  APInt stepVal =
207  getLoopBoundFromFold(step, iv->getType(), /*getUpper=*/true);
208 
209  if (stepVal.isNegative()) {
210  std::swap(min, max);
211  } else {
212  // Correct the upper bound by subtracting 1 so that it becomes a <=
213  // bound, because loops do not generally include their upper bound.
214  max -= 1;
215  }
216 
218  auto ivRange = ConstantIntRanges::fromSigned(min, max);
219  propagateIfChanged(ivEntry, ivEntry->join(IntegerValueRange{ivRange}));
220  return;
221  }
222 
224  op, successor, argLattices, firstIndex);
225 }
static constexpr const bool value
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
ProgramPoint point
The program point to which the state belongs.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
U dyn_cast_or_null() const
Definition: Attributes.h:132
This class represents an argument of a Block.
Definition: Value.h:296
A set of arbitrary-precision integers representing bounds on a given integer value.
const APInt & smax() const
The maximum value of an integer when it is interpreted as signed.
const APInt & smin() const
The minimum value of an integer when it is interpreted as signed.
static ConstantIntRanges fromSigned(const APInt &smin, const APInt &smax)
Create an ConstantIntRanges with the signed minimum and maximum equal to smin and smax,...
static unsigned getStorageBitwidth(Type type)
Return the bitwidth that should be used for integer ranges describing type.
Optional< APInt > getConstantValue() const
If either the signed or unsigned interpretations of the range indicate that the value it bounds is a ...
void propagateIfChanged(AnalysisState *state, ChangeResult changed)
Propagate an update to a state if it changed.
The general data-flow analysis solver.
StateT * getOrCreateState(PointT point)
Get the state associated with the given program point.
void propagateIfChanged(AnalysisState *state, ChangeResult changed)
Propagate an update to an analysis state if it changed by pushing dependent work items to the back of...
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:41
This is a value defined by a result of an operation.
Definition: Value.h:442
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:31
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:295
result_range getResults()
Definition: Operation.h:332
This class represents a successor of a region.
Region * getSuccessor() const
Return the given region successor.
BlockArgListType getArguments()
Definition: Region.h:81
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
Type getType() const
Return the type of this value.
Definition: Value.h:114
U dyn_cast() const
Definition: Value.h:95
user_range getUsers() const
Definition: Value.h:209
void onUpdate(DataFlowSolver *solver) const override
When the lattice gets updated, propagate an update to users of the value using its use-def chain to s...
This lattice value represents a known constant value of a lattice.
static ConstantValue getUnknownConstant()
The state where the constant value is unknown.
void visitOperation(Operation *op, ArrayRef< const IntegerValueRangeLattice * > operands, ArrayRef< IntegerValueRangeLattice * > results) override
Visit an operation.
void visitNonControlFlowArguments(Operation *op, const RegionSuccessor &successor, ArrayRef< IntegerValueRangeLattice * > argLattices, unsigned firstIndex) override
Visit block arguments or operation results of an operation with region control-flow for which values ...
This lattice element represents the integer value range of an SSA value.
void onUpdate(DataFlowSolver *solver) const override
If the range can be narrowed to an integer constant, update the constant value of the SSA value.
This lattice value represents the integer range of an SSA value.
const ConstantIntRanges & getValue() const
Get the known integer value range.
bool isUninitialized() const
Whether the range is uninitialized.
static IntegerValueRange getMaxRange(Value value)
Create a maximal range ([0, uint_max(t)] / [int_min(t), int_max(t)]) range that is used to mark the v...
This class represents a lattice holding a specific value of type ValueT.
IntegerValueRange & getValue()
Return the value held by this lattice.
ChangeResult join(const AbstractSparseLattice &rhs) override
Join the information contained in the 'rhs' lattice into this lattice.
void setAllToEntryStates(ArrayRef< IntegerValueRangeLattice * > lattices)
virtual void visitNonControlFlowArguments(Operation *op, const RegionSuccessor &successor, ArrayRef< StateT * > argLattices, unsigned firstIndex)
Given an operation with possible region control-flow, the lattices of the operands,...
const IntegerValueRangeLattice * getLatticeElementFor(ProgramPoint point, Value value)
Get the lattice element for a value and create a dependency on the provided program point.
IntegerValueRangeLattice * getLatticeElement(Value value) override
Get the lattice element for a value.
Include the generated interface declarations.
ChangeResult
A result type used to indicate if a change happened.