MLIR  20.0.0git
IntegerRangeAnalysis.cpp
Go to the documentation of this file.
1 //===- IntegerRangeAnalysis.cpp - Integer range analysis --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the dataflow analysis class for integer range inference
10 // which is used in transformations over the `arith` dialect such as
11 // branch elimination or signed->unsigned rewriting
12 //
13 //===----------------------------------------------------------------------===//
14 
20 #include "mlir/IR/Dialect.h"
21 #include "mlir/IR/OpDefinition.h"
22 #include "mlir/IR/TypeUtilities.h"
23 #include "mlir/IR/Value.h"
27 #include "mlir/Support/LLVM.h"
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/Support/Casting.h"
30 #include "llvm/Support/Debug.h"
31 #include <cassert>
32 #include <optional>
33 #include <utility>
34 
35 #define DEBUG_TYPE "int-range-analysis"
36 
37 using namespace mlir;
38 using namespace mlir::dataflow;
39 
41  Lattice::onUpdate(solver);
42 
43  // If the integer range can be narrowed to a constant, update the constant
44  // value of the SSA value.
45  std::optional<APInt> constant = getValue().getValue().getConstantValue();
46  auto value = cast<Value>(anchor);
47  auto *cv = solver->getOrCreateState<Lattice<ConstantValue>>(value);
48  if (!constant)
49  return solver->propagateIfChanged(
50  cv, cv->join(ConstantValue::getUnknownConstant()));
51 
52  Dialect *dialect;
53  if (auto *parent = value.getDefiningOp())
54  dialect = parent->getDialect();
55  else
56  dialect = value.getParentBlock()->getParentOp()->getDialect();
57 
58  Type type = getElementTypeOrSelf(value);
59  solver->propagateIfChanged(
60  cv, cv->join(ConstantValue(IntegerAttr::get(type, *constant), dialect)));
61 }
62 
66  auto inferrable = dyn_cast<InferIntRangeInterface>(op);
67  if (!inferrable) {
68  setAllToEntryStates(results);
69  return success();
70  }
71 
72  LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n");
73  auto argRanges = llvm::map_to_vector(
74  operands, [](const IntegerValueRangeLattice *lattice) {
75  return lattice->getValue();
76  });
77 
78  auto joinCallback = [&](Value v, const IntegerValueRange &attrs) {
79  auto result = dyn_cast<OpResult>(v);
80  if (!result)
81  return;
82  assert(llvm::is_contained(op->getResults(), result));
83 
84  LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n");
85  IntegerValueRangeLattice *lattice = results[result.getResultNumber()];
86  IntegerValueRange oldRange = lattice->getValue();
87 
88  ChangeResult changed = lattice->join(attrs);
89 
90  // Catch loop results with loop variant bounds and conservatively make
91  // them [-inf, inf] so we don't circle around infinitely often (because
92  // the dataflow analysis in MLIR doesn't attempt to work out trip counts
93  // and often can't).
94  bool isYieldedResult = llvm::any_of(v.getUsers(), [](Operation *op) {
95  return op->hasTrait<OpTrait::IsTerminator>();
96  });
97  if (isYieldedResult && !oldRange.isUninitialized() &&
98  !(lattice->getValue() == oldRange)) {
99  LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
101  }
102  propagateIfChanged(lattice, changed);
103  };
104 
105  inferrable.inferResultRangesFromOptional(argRanges, joinCallback);
106  return success();
107 }
108 
110  Operation *op, const RegionSuccessor &successor,
111  ArrayRef<IntegerValueRangeLattice *> argLattices, unsigned firstIndex) {
112  if (auto inferrable = dyn_cast<InferIntRangeInterface>(op)) {
113  LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n");
114 
115  auto argRanges = llvm::map_to_vector(op->getOperands(), [&](Value value) {
116  return getLatticeElementFor(getProgramPointAfter(op), value)->getValue();
117  });
118 
119  auto joinCallback = [&](Value v, const IntegerValueRange &attrs) {
120  auto arg = dyn_cast<BlockArgument>(v);
121  if (!arg)
122  return;
123  if (!llvm::is_contained(successor.getSuccessor()->getArguments(), arg))
124  return;
125 
126  LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n");
127  IntegerValueRangeLattice *lattice = argLattices[arg.getArgNumber()];
128  IntegerValueRange oldRange = lattice->getValue();
129 
130  ChangeResult changed = lattice->join(attrs);
131 
132  // Catch loop results with loop variant bounds and conservatively make
133  // them [-inf, inf] so we don't circle around infinitely often (because
134  // the dataflow analysis in MLIR doesn't attempt to work out trip counts
135  // and often can't).
136  bool isYieldedValue = llvm::any_of(v.getUsers(), [](Operation *op) {
137  return op->hasTrait<OpTrait::IsTerminator>();
138  });
139  if (isYieldedValue && !oldRange.isUninitialized() &&
140  !(lattice->getValue() == oldRange)) {
141  LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
143  }
144  propagateIfChanged(lattice, changed);
145  };
146 
147  inferrable.inferResultRangesFromOptional(argRanges, joinCallback);
148  return;
149  }
150 
151  /// Given the results of getConstant{Lower,Upper}Bound() or getConstantStep()
152  /// on a LoopLikeInterface return the lower/upper bound for that result if
153  /// possible.
154  auto getLoopBoundFromFold = [&](std::optional<OpFoldResult> loopBound,
155  Type boundType, bool getUpper) {
156  unsigned int width = ConstantIntRanges::getStorageBitwidth(boundType);
157  if (loopBound.has_value()) {
158  if (auto attr = dyn_cast<Attribute>(*loopBound)) {
159  if (auto bound = dyn_cast_or_null<IntegerAttr>(attr))
160  return bound.getValue();
161  } else if (auto value = llvm::dyn_cast_if_present<Value>(*loopBound)) {
162  const IntegerValueRangeLattice *lattice =
164  if (lattice != nullptr && !lattice->getValue().isUninitialized())
165  return getUpper ? lattice->getValue().getValue().smax()
166  : lattice->getValue().getValue().smin();
167  }
168  }
169  // Given the results of getConstant{Lower,Upper}Bound()
170  // or getConstantStep() on a LoopLikeInterface return the lower/upper
171  // bound
172  return getUpper ? APInt::getSignedMaxValue(width)
173  : APInt::getSignedMinValue(width);
174  };
175 
176  // Infer bounds for loop arguments that have static bounds
177  if (auto loop = dyn_cast<LoopLikeOpInterface>(op)) {
178  std::optional<Value> iv = loop.getSingleInductionVar();
179  if (!iv) {
181  op, successor, argLattices, firstIndex);
182  }
183  std::optional<OpFoldResult> lowerBound = loop.getSingleLowerBound();
184  std::optional<OpFoldResult> upperBound = loop.getSingleUpperBound();
185  std::optional<OpFoldResult> step = loop.getSingleStep();
186  APInt min = getLoopBoundFromFold(lowerBound, iv->getType(),
187  /*getUpper=*/false);
188  APInt max = getLoopBoundFromFold(upperBound, iv->getType(),
189  /*getUpper=*/true);
190  // Assume positivity for uniscoverable steps by way of getUpper = true.
191  APInt stepVal =
192  getLoopBoundFromFold(step, iv->getType(), /*getUpper=*/true);
193 
194  if (stepVal.isNegative()) {
195  std::swap(min, max);
196  } else {
197  // Correct the upper bound by subtracting 1 so that it becomes a <=
198  // bound, because loops do not generally include their upper bound.
199  max -= 1;
200  }
201 
202  // If we infer the lower bound to be larger than the upper bound, the
203  // resulting range is meaningless and should not be used in further
204  // inferences.
205  if (max.sge(min)) {
207  auto ivRange = ConstantIntRanges::fromSigned(min, max);
208  propagateIfChanged(ivEntry, ivEntry->join(IntegerValueRange{ivRange}));
209  }
210  return;
211  }
212 
214  op, successor, argLattices, firstIndex);
215 }
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
LatticeAnchor anchor
The lattice anchor to which the state belongs.
const APInt & smax() const
The maximum value of an integer when it is interpreted as signed.
const APInt & smin() const
The minimum value of an integer when it is interpreted as signed.
static ConstantIntRanges fromSigned(const APInt &smin, const APInt &smax)
Create an ConstantIntRanges with the signed minimum and maximum equal to smin and smax,...
static unsigned getStorageBitwidth(Type type)
Return the bitwidth that should be used for integer ranges describing type.
std::optional< APInt > getConstantValue() const
If either the signed or unsigned interpretations of the range indicate that the value it bounds is a ...
void propagateIfChanged(AnalysisState *state, ChangeResult changed)
Propagate an update to a state if it changed.
ProgramPoint * getProgramPointAfter(Operation *op)
The general data-flow analysis solver.
void propagateIfChanged(AnalysisState *state, ChangeResult changed)
Propagate an update to an analysis state if it changed by pushing dependent work items to the back of...
StateT * getOrCreateState(AnchorT anchor)
Get the state associated with the given lattice anchor.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
This lattice value represents the integer range of an SSA value.
const ConstantIntRanges & getValue() const
Get the known integer value range.
bool isUninitialized() const
Whether the range is uninitialized.
static IntegerValueRange getMaxRange(Value value)
Create a maximal range ([0, uint_max(t)] / [int_min(t), int_max(t)]) range that is used to mark the v...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
This class represents a successor of a region.
Region * getSuccessor() const
Return the given region successor.
BlockArgListType getArguments()
Definition: Region.h:81
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
user_range getUsers() const
Definition: Value.h:228
void onUpdate(DataFlowSolver *solver) const override
When the lattice gets updated, propagate an update to users of the value using its use-def chain to s...
This lattice value represents a known constant value of a lattice.
static ConstantValue getUnknownConstant()
The state where the constant value is unknown.
LogicalResult visitOperation(Operation *op, ArrayRef< const IntegerValueRangeLattice * > operands, ArrayRef< IntegerValueRangeLattice * > results) override
Visit an operation.
void visitNonControlFlowArguments(Operation *op, const RegionSuccessor &successor, ArrayRef< IntegerValueRangeLattice * > argLattices, unsigned firstIndex) override
Visit block arguments or operation results of an operation with region control-flow for which values ...
This lattice element represents the integer value range of an SSA value.
void onUpdate(DataFlowSolver *solver) const override
If the range can be narrowed to an integer constant, update the constant value of the SSA value.
This class represents a lattice holding a specific value of type ValueT.
IntegerValueRange & getValue()
Return the value held by this lattice.
ChangeResult join(const AbstractSparseLattice &rhs) override
Join the information contained in the 'rhs' lattice into this lattice.
const IntegerValueRangeLattice * getLatticeElementFor(ProgramPoint *point, Value value)
Get the lattice element for a value and create a dependency on the provided program point.
IntegerValueRangeLattice * getLatticeElement(Value value) override
Get the lattice element for a value.
void setAllToEntryStates(ArrayRef< IntegerValueRangeLattice * > lattices)
virtual void visitNonControlFlowArguments(Operation *op, const RegionSuccessor &successor, ArrayRef< StateT * > argLattices, unsigned firstIndex)
Given an operation with possible region control-flow, the lattices of the operands,...
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
ChangeResult
A result type used to indicate if a change happened.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...