MLIR  19.0.0git
IntegerRangeAnalysis.cpp
Go to the documentation of this file.
1 //===- IntegerRangeAnalysis.cpp - Integer range analysis --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the dataflow analysis class for integer range inference
10 // which is used in transformations over the `arith` dialect such as
11 // branch elimination or signed->unsigned rewriting
12 //
13 //===----------------------------------------------------------------------===//
14 
20 #include "mlir/IR/Dialect.h"
21 #include "mlir/IR/OpDefinition.h"
22 #include "mlir/IR/Value.h"
26 #include "mlir/Support/LLVM.h"
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/Support/Casting.h"
29 #include "llvm/Support/Debug.h"
30 #include <cassert>
31 #include <optional>
32 #include <utility>
33 
34 #define DEBUG_TYPE "int-range-analysis"
35 
36 using namespace mlir;
37 using namespace mlir::dataflow;
38 
40  unsigned width = ConstantIntRanges::getStorageBitwidth(value.getType());
41  if (width == 0)
42  return {};
43  APInt umin = APInt::getMinValue(width);
44  APInt umax = APInt::getMaxValue(width);
45  APInt smin = width != 0 ? APInt::getSignedMinValue(width) : umin;
46  APInt smax = width != 0 ? APInt::getSignedMaxValue(width) : umax;
47  return IntegerValueRange{ConstantIntRanges{umin, umax, smin, smax}};
48 }
49 
51  Lattice::onUpdate(solver);
52 
53  // If the integer range can be narrowed to a constant, update the constant
54  // value of the SSA value.
55  std::optional<APInt> constant = getValue().getValue().getConstantValue();
56  auto value = point.get<Value>();
57  auto *cv = solver->getOrCreateState<Lattice<ConstantValue>>(value);
58  if (!constant)
59  return solver->propagateIfChanged(
60  cv, cv->join(ConstantValue::getUnknownConstant()));
61 
62  Dialect *dialect;
63  if (auto *parent = value.getDefiningOp())
64  dialect = parent->getDialect();
65  else
66  dialect = value.getParentBlock()->getParentOp()->getDialect();
67  solver->propagateIfChanged(
68  cv, cv->join(ConstantValue(IntegerAttr::get(value.getType(), *constant),
69  dialect)));
70 }
71 
75  // If the lattice on any operand is unitialized, bail out.
76  if (llvm::any_of(operands, [](const IntegerValueRangeLattice *lattice) {
77  return lattice->getValue().isUninitialized();
78  })) {
79  return;
80  }
81 
82  auto inferrable = dyn_cast<InferIntRangeInterface>(op);
83  if (!inferrable)
84  return setAllToEntryStates(results);
85 
86  LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n");
88  llvm::map_range(operands, [](const IntegerValueRangeLattice *val) {
89  return val->getValue().getValue();
90  }));
91 
92  auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) {
93  auto result = dyn_cast<OpResult>(v);
94  if (!result)
95  return;
96  assert(llvm::is_contained(op->getResults(), result));
97 
98  LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n");
99  IntegerValueRangeLattice *lattice = results[result.getResultNumber()];
100  IntegerValueRange oldRange = lattice->getValue();
101 
102  ChangeResult changed = lattice->join(IntegerValueRange{attrs});
103 
104  // Catch loop results with loop variant bounds and conservatively make
105  // them [-inf, inf] so we don't circle around infinitely often (because
106  // the dataflow analysis in MLIR doesn't attempt to work out trip counts
107  // and often can't).
108  bool isYieldedResult = llvm::any_of(v.getUsers(), [](Operation *op) {
109  return op->hasTrait<OpTrait::IsTerminator>();
110  });
111  if (isYieldedResult && !oldRange.isUninitialized() &&
112  !(lattice->getValue() == oldRange)) {
113  LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
114  changed |= lattice->join(IntegerValueRange::getMaxRange(v));
115  }
116  propagateIfChanged(lattice, changed);
117  };
118 
119  inferrable.inferResultRanges(argRanges, joinCallback);
120 }
121 
123  Operation *op, const RegionSuccessor &successor,
124  ArrayRef<IntegerValueRangeLattice *> argLattices, unsigned firstIndex) {
125  if (auto inferrable = dyn_cast<InferIntRangeInterface>(op)) {
126  LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n");
127  // If the lattice on any operand is unitialized, bail out.
128  if (llvm::any_of(op->getOperands(), [&](Value value) {
129  return getLatticeElementFor(op, value)->getValue().isUninitialized();
130  }))
131  return;
133  llvm::map_range(op->getOperands(), [&](Value value) {
134  return getLatticeElementFor(op, value)->getValue().getValue();
135  }));
136 
137  auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) {
138  auto arg = dyn_cast<BlockArgument>(v);
139  if (!arg)
140  return;
141  if (!llvm::is_contained(successor.getSuccessor()->getArguments(), arg))
142  return;
143 
144  LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n");
145  IntegerValueRangeLattice *lattice = argLattices[arg.getArgNumber()];
146  IntegerValueRange oldRange = lattice->getValue();
147 
148  ChangeResult changed = lattice->join(IntegerValueRange{attrs});
149 
150  // Catch loop results with loop variant bounds and conservatively make
151  // them [-inf, inf] so we don't circle around infinitely often (because
152  // the dataflow analysis in MLIR doesn't attempt to work out trip counts
153  // and often can't).
154  bool isYieldedValue = llvm::any_of(v.getUsers(), [](Operation *op) {
155  return op->hasTrait<OpTrait::IsTerminator>();
156  });
157  if (isYieldedValue && !oldRange.isUninitialized() &&
158  !(lattice->getValue() == oldRange)) {
159  LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
160  changed |= lattice->join(IntegerValueRange::getMaxRange(v));
161  }
162  propagateIfChanged(lattice, changed);
163  };
164 
165  inferrable.inferResultRanges(argRanges, joinCallback);
166  return;
167  }
168 
169  /// Given the results of getConstant{Lower,Upper}Bound() or getConstantStep()
170  /// on a LoopLikeInterface return the lower/upper bound for that result if
171  /// possible.
172  auto getLoopBoundFromFold = [&](std::optional<OpFoldResult> loopBound,
173  Type boundType, bool getUpper) {
174  unsigned int width = ConstantIntRanges::getStorageBitwidth(boundType);
175  if (loopBound.has_value()) {
176  if (loopBound->is<Attribute>()) {
177  if (auto bound =
178  dyn_cast_or_null<IntegerAttr>(loopBound->get<Attribute>()))
179  return bound.getValue();
180  } else if (auto value = llvm::dyn_cast_if_present<Value>(*loopBound)) {
181  const IntegerValueRangeLattice *lattice =
182  getLatticeElementFor(op, value);
183  if (lattice != nullptr && !lattice->getValue().isUninitialized())
184  return getUpper ? lattice->getValue().getValue().smax()
185  : lattice->getValue().getValue().smin();
186  }
187  }
188  // Given the results of getConstant{Lower,Upper}Bound()
189  // or getConstantStep() on a LoopLikeInterface return the lower/upper
190  // bound
191  return getUpper ? APInt::getSignedMaxValue(width)
192  : APInt::getSignedMinValue(width);
193  };
194 
195  // Infer bounds for loop arguments that have static bounds
196  if (auto loop = dyn_cast<LoopLikeOpInterface>(op)) {
197  std::optional<Value> iv = loop.getSingleInductionVar();
198  if (!iv) {
200  op, successor, argLattices, firstIndex);
201  }
202  std::optional<OpFoldResult> lowerBound = loop.getSingleLowerBound();
203  std::optional<OpFoldResult> upperBound = loop.getSingleUpperBound();
204  std::optional<OpFoldResult> step = loop.getSingleStep();
205  APInt min = getLoopBoundFromFold(lowerBound, iv->getType(),
206  /*getUpper=*/false);
207  APInt max = getLoopBoundFromFold(upperBound, iv->getType(),
208  /*getUpper=*/true);
209  // Assume positivity for uniscoverable steps by way of getUpper = true.
210  APInt stepVal =
211  getLoopBoundFromFold(step, iv->getType(), /*getUpper=*/true);
212 
213  if (stepVal.isNegative()) {
214  std::swap(min, max);
215  } else {
216  // Correct the upper bound by subtracting 1 so that it becomes a <=
217  // bound, because loops do not generally include their upper bound.
218  max -= 1;
219  }
220 
222  auto ivRange = ConstantIntRanges::fromSigned(min, max);
223  propagateIfChanged(ivEntry, ivEntry->join(IntegerValueRange{ivRange}));
224  return;
225  }
226 
228  op, successor, argLattices, firstIndex);
229 }
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
ProgramPoint point
The program point to which the state belongs.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
A set of arbitrary-precision integers representing bounds on a given integer value.
const APInt & smax() const
The maximum value of an integer when it is interpreted as signed.
const APInt & smin() const
The minimum value of an integer when it is interpreted as signed.
static ConstantIntRanges fromSigned(const APInt &smin, const APInt &smax)
Create an ConstantIntRanges with the signed minimum and maximum equal to smin and smax,...
static unsigned getStorageBitwidth(Type type)
Return the bitwidth that should be used for integer ranges describing type.
std::optional< APInt > getConstantValue() const
If either the signed or unsigned interpretations of the range indicate that the value it bounds is a ...
void propagateIfChanged(AnalysisState *state, ChangeResult changed)
Propagate an update to a state if it changed.
The general data-flow analysis solver.
StateT * getOrCreateState(PointT point)
Get the state associated with the given program point.
void propagateIfChanged(AnalysisState *state, ChangeResult changed)
Propagate an update to an analysis state if it changed by pushing dependent work items to the back of...
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:41
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
This class represents a successor of a region.
Region * getSuccessor() const
Return the given region successor.
BlockArgListType getArguments()
Definition: Region.h:81
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
user_range getUsers() const
Definition: Value.h:228
void onUpdate(DataFlowSolver *solver) const override
When the lattice gets updated, propagate an update to users of the value using its use-def chain to s...
This lattice value represents a known constant value of a lattice.
static ConstantValue getUnknownConstant()
The state where the constant value is unknown.
void visitOperation(Operation *op, ArrayRef< const IntegerValueRangeLattice * > operands, ArrayRef< IntegerValueRangeLattice * > results) override
Visit an operation.
void visitNonControlFlowArguments(Operation *op, const RegionSuccessor &successor, ArrayRef< IntegerValueRangeLattice * > argLattices, unsigned firstIndex) override
Visit block arguments or operation results of an operation with region control-flow for which values ...
This lattice element represents the integer value range of an SSA value.
void onUpdate(DataFlowSolver *solver) const override
If the range can be narrowed to an integer constant, update the constant value of the SSA value.
This lattice value represents the integer range of an SSA value.
const ConstantIntRanges & getValue() const
Get the known integer value range.
bool isUninitialized() const
Whether the range is uninitialized.
static IntegerValueRange getMaxRange(Value value)
Create a maximal range ([0, uint_max(t)] / [int_min(t), int_max(t)]) range that is used to mark the v...
This class represents a lattice holding a specific value of type ValueT.
IntegerValueRange & getValue()
Return the value held by this lattice.
ChangeResult join(const AbstractSparseLattice &rhs) override
Join the information contained in the 'rhs' lattice into this lattice.
const IntegerValueRangeLattice * getLatticeElementFor(ProgramPoint point, Value value)
Get the lattice element for a value and create a dependency on the provided program point.
IntegerValueRangeLattice * getLatticeElement(Value value) override
Get the lattice element for a value.
void setAllToEntryStates(ArrayRef< IntegerValueRangeLattice * > lattices)
virtual void visitNonControlFlowArguments(Operation *op, const RegionSuccessor &successor, ArrayRef< StateT * > argLattices, unsigned firstIndex)
Given an operation with possible region control-flow, the lattices of the operands,...
Include the generated interface declarations.
ChangeResult
A result type used to indicate if a change happened.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...