MLIR  20.0.0git
IntegerRangeAnalysis.cpp
Go to the documentation of this file.
1 //===- IntegerRangeAnalysis.cpp - Integer range analysis --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the dataflow analysis class for integer range inference
10 // which is used in transformations over the `arith` dialect such as
11 // branch elimination or signed->unsigned rewriting
12 //
13 //===----------------------------------------------------------------------===//
14 
20 #include "mlir/IR/Dialect.h"
21 #include "mlir/IR/OpDefinition.h"
22 #include "mlir/IR/Value.h"
26 #include "mlir/Support/LLVM.h"
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/Support/Casting.h"
29 #include "llvm/Support/Debug.h"
30 #include <cassert>
31 #include <optional>
32 #include <utility>
33 
34 #define DEBUG_TYPE "int-range-analysis"
35 
36 using namespace mlir;
37 using namespace mlir::dataflow;
38 
40  Lattice::onUpdate(solver);
41 
42  // If the integer range can be narrowed to a constant, update the constant
43  // value of the SSA value.
44  std::optional<APInt> constant = getValue().getValue().getConstantValue();
45  auto value = anchor.get<Value>();
46  auto *cv = solver->getOrCreateState<Lattice<ConstantValue>>(value);
47  if (!constant)
48  return solver->propagateIfChanged(
49  cv, cv->join(ConstantValue::getUnknownConstant()));
50 
51  Dialect *dialect;
52  if (auto *parent = value.getDefiningOp())
53  dialect = parent->getDialect();
54  else
55  dialect = value.getParentBlock()->getParentOp()->getDialect();
56  solver->propagateIfChanged(
57  cv, cv->join(ConstantValue(IntegerAttr::get(value.getType(), *constant),
58  dialect)));
59 }
60 
64  auto inferrable = dyn_cast<InferIntRangeInterface>(op);
65  if (!inferrable) {
66  setAllToEntryStates(results);
67  return success();
68  }
69 
70  LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n");
71  auto argRanges = llvm::map_to_vector(
72  operands, [](const IntegerValueRangeLattice *lattice) {
73  return lattice->getValue();
74  });
75 
76  auto joinCallback = [&](Value v, const IntegerValueRange &attrs) {
77  auto result = dyn_cast<OpResult>(v);
78  if (!result)
79  return;
80  assert(llvm::is_contained(op->getResults(), result));
81 
82  LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n");
83  IntegerValueRangeLattice *lattice = results[result.getResultNumber()];
84  IntegerValueRange oldRange = lattice->getValue();
85 
86  ChangeResult changed = lattice->join(attrs);
87 
88  // Catch loop results with loop variant bounds and conservatively make
89  // them [-inf, inf] so we don't circle around infinitely often (because
90  // the dataflow analysis in MLIR doesn't attempt to work out trip counts
91  // and often can't).
92  bool isYieldedResult = llvm::any_of(v.getUsers(), [](Operation *op) {
93  return op->hasTrait<OpTrait::IsTerminator>();
94  });
95  if (isYieldedResult && !oldRange.isUninitialized() &&
96  !(lattice->getValue() == oldRange)) {
97  LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
98  changed |= lattice->join(IntegerValueRange::getMaxRange(v));
99  }
100  propagateIfChanged(lattice, changed);
101  };
102 
103  inferrable.inferResultRangesFromOptional(argRanges, joinCallback);
104  return success();
105 }
106 
108  Operation *op, const RegionSuccessor &successor,
109  ArrayRef<IntegerValueRangeLattice *> argLattices, unsigned firstIndex) {
110  if (auto inferrable = dyn_cast<InferIntRangeInterface>(op)) {
111  LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n");
112 
113  auto argRanges = llvm::map_to_vector(op->getOperands(), [&](Value value) {
114  return getLatticeElementFor(op, value)->getValue();
115  });
116 
117  auto joinCallback = [&](Value v, const IntegerValueRange &attrs) {
118  auto arg = dyn_cast<BlockArgument>(v);
119  if (!arg)
120  return;
121  if (!llvm::is_contained(successor.getSuccessor()->getArguments(), arg))
122  return;
123 
124  LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n");
125  IntegerValueRangeLattice *lattice = argLattices[arg.getArgNumber()];
126  IntegerValueRange oldRange = lattice->getValue();
127 
128  ChangeResult changed = lattice->join(attrs);
129 
130  // Catch loop results with loop variant bounds and conservatively make
131  // them [-inf, inf] so we don't circle around infinitely often (because
132  // the dataflow analysis in MLIR doesn't attempt to work out trip counts
133  // and often can't).
134  bool isYieldedValue = llvm::any_of(v.getUsers(), [](Operation *op) {
135  return op->hasTrait<OpTrait::IsTerminator>();
136  });
137  if (isYieldedValue && !oldRange.isUninitialized() &&
138  !(lattice->getValue() == oldRange)) {
139  LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
140  changed |= lattice->join(IntegerValueRange::getMaxRange(v));
141  }
142  propagateIfChanged(lattice, changed);
143  };
144 
145  inferrable.inferResultRangesFromOptional(argRanges, joinCallback);
146  return;
147  }
148 
149  /// Given the results of getConstant{Lower,Upper}Bound() or getConstantStep()
150  /// on a LoopLikeInterface return the lower/upper bound for that result if
151  /// possible.
152  auto getLoopBoundFromFold = [&](std::optional<OpFoldResult> loopBound,
153  Type boundType, bool getUpper) {
154  unsigned int width = ConstantIntRanges::getStorageBitwidth(boundType);
155  if (loopBound.has_value()) {
156  if (loopBound->is<Attribute>()) {
157  if (auto bound =
158  dyn_cast_or_null<IntegerAttr>(loopBound->get<Attribute>()))
159  return bound.getValue();
160  } else if (auto value = llvm::dyn_cast_if_present<Value>(*loopBound)) {
161  const IntegerValueRangeLattice *lattice =
162  getLatticeElementFor(op, value);
163  if (lattice != nullptr && !lattice->getValue().isUninitialized())
164  return getUpper ? lattice->getValue().getValue().smax()
165  : lattice->getValue().getValue().smin();
166  }
167  }
168  // Given the results of getConstant{Lower,Upper}Bound()
169  // or getConstantStep() on a LoopLikeInterface return the lower/upper
170  // bound
171  return getUpper ? APInt::getSignedMaxValue(width)
172  : APInt::getSignedMinValue(width);
173  };
174 
175  // Infer bounds for loop arguments that have static bounds
176  if (auto loop = dyn_cast<LoopLikeOpInterface>(op)) {
177  std::optional<Value> iv = loop.getSingleInductionVar();
178  if (!iv) {
180  op, successor, argLattices, firstIndex);
181  }
182  std::optional<OpFoldResult> lowerBound = loop.getSingleLowerBound();
183  std::optional<OpFoldResult> upperBound = loop.getSingleUpperBound();
184  std::optional<OpFoldResult> step = loop.getSingleStep();
185  APInt min = getLoopBoundFromFold(lowerBound, iv->getType(),
186  /*getUpper=*/false);
187  APInt max = getLoopBoundFromFold(upperBound, iv->getType(),
188  /*getUpper=*/true);
189  // Assume positivity for uniscoverable steps by way of getUpper = true.
190  APInt stepVal =
191  getLoopBoundFromFold(step, iv->getType(), /*getUpper=*/true);
192 
193  if (stepVal.isNegative()) {
194  std::swap(min, max);
195  } else {
196  // Correct the upper bound by subtracting 1 so that it becomes a <=
197  // bound, because loops do not generally include their upper bound.
198  max -= 1;
199  }
200 
201  // If we infer the lower bound to be larger than the upper bound, the
202  // resulting range is meaningless and should not be used in further
203  // inferences.
204  if (max.sge(min)) {
206  auto ivRange = ConstantIntRanges::fromSigned(min, max);
207  propagateIfChanged(ivEntry, ivEntry->join(IntegerValueRange{ivRange}));
208  }
209  return;
210  }
211 
213  op, successor, argLattices, firstIndex);
214 }
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
LatticeAnchor anchor
The lattice anchor to which the state belongs.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
const APInt & smax() const
The maximum value of an integer when it is interpreted as signed.
const APInt & smin() const
The minimum value of an integer when it is interpreted as signed.
static ConstantIntRanges fromSigned(const APInt &smin, const APInt &smax)
Create an ConstantIntRanges with the signed minimum and maximum equal to smin and smax,...
static unsigned getStorageBitwidth(Type type)
Return the bitwidth that should be used for integer ranges describing type.
std::optional< APInt > getConstantValue() const
If either the signed or unsigned interpretations of the range indicate that the value it bounds is a ...
void propagateIfChanged(AnalysisState *state, ChangeResult changed)
Propagate an update to a state if it changed.
The general data-flow analysis solver.
void propagateIfChanged(AnalysisState *state, ChangeResult changed)
Propagate an update to an analysis state if it changed by pushing dependent work items to the back of...
StateT * getOrCreateState(AnchorT anchor)
Get the state associated with the given lattice anchor.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
This lattice value represents the integer range of an SSA value.
const ConstantIntRanges & getValue() const
Get the known integer value range.
bool isUninitialized() const
Whether the range is uninitialized.
static IntegerValueRange getMaxRange(Value value)
Create a maximal range ([0, uint_max(t)] / [int_min(t), int_max(t)]) range that is used to mark the v...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
This class represents a successor of a region.
Region * getSuccessor() const
Return the given region successor.
BlockArgListType getArguments()
Definition: Region.h:81
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
user_range getUsers() const
Definition: Value.h:228
void onUpdate(DataFlowSolver *solver) const override
When the lattice gets updated, propagate an update to users of the value using its use-def chain to s...
This lattice value represents a known constant value of a lattice.
static ConstantValue getUnknownConstant()
The state where the constant value is unknown.
LogicalResult visitOperation(Operation *op, ArrayRef< const IntegerValueRangeLattice * > operands, ArrayRef< IntegerValueRangeLattice * > results) override
Visit an operation.
void visitNonControlFlowArguments(Operation *op, const RegionSuccessor &successor, ArrayRef< IntegerValueRangeLattice * > argLattices, unsigned firstIndex) override
Visit block arguments or operation results of an operation with region control-flow for which values ...
This lattice element represents the integer value range of an SSA value.
void onUpdate(DataFlowSolver *solver) const override
If the range can be narrowed to an integer constant, update the constant value of the SSA value.
This class represents a lattice holding a specific value of type ValueT.
IntegerValueRange & getValue()
Return the value held by this lattice.
ChangeResult join(const AbstractSparseLattice &rhs) override
Join the information contained in the 'rhs' lattice into this lattice.
const IntegerValueRangeLattice * getLatticeElementFor(ProgramPoint point, Value value)
Get the lattice element for a value and create a dependency on the provided program point.
IntegerValueRangeLattice * getLatticeElement(Value value) override
Get the lattice element for a value.
void setAllToEntryStates(ArrayRef< IntegerValueRangeLattice * > lattices)
virtual void visitNonControlFlowArguments(Operation *op, const RegionSuccessor &successor, ArrayRef< StateT * > argLattices, unsigned firstIndex)
Given an operation with possible region control-flow, the lattices of the operands,...
Include the generated interface declarations.
ChangeResult
A result type used to indicate if a change happened.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...