MLIR  22.0.0git
IntegerRangeAnalysis.cpp
Go to the documentation of this file.
1 //===- IntegerRangeAnalysis.cpp - Integer range analysis --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the dataflow analysis class for integer range inference
10 // which is used in transformations over the `arith` dialect such as
11 // branch elimination or signed->unsigned rewriting
12 //
13 //===----------------------------------------------------------------------===//
14 
20 #include "mlir/IR/Dialect.h"
21 #include "mlir/IR/OpDefinition.h"
22 #include "mlir/IR/Operation.h"
24 #include "mlir/IR/TypeUtilities.h"
25 #include "mlir/IR/Value.h"
30 #include "mlir/Support/LLVM.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/Support/Casting.h"
33 #include "llvm/Support/Debug.h"
34 #include "llvm/Support/DebugLog.h"
35 #include <cassert>
36 #include <optional>
37 #include <utility>
38 
39 #define DEBUG_TYPE "int-range-analysis"
40 
41 using namespace mlir;
42 using namespace mlir::dataflow;
43 
44 namespace mlir::dataflow {
45 LogicalResult staticallyNonNegative(DataFlowSolver &solver, Value v) {
46  auto *result = solver.lookupState<IntegerValueRangeLattice>(v);
47  if (!result || result->getValue().isUninitialized())
48  return failure();
49  const ConstantIntRanges &range = result->getValue().getValue();
50  return success(range.smin().isNonNegative());
51 }
52 
53 LogicalResult staticallyNonNegative(DataFlowSolver &solver, Operation *op) {
54  auto nonNegativePred = [&solver](Value v) -> bool {
55  return succeeded(staticallyNonNegative(solver, v));
56  };
57  return success(llvm::all_of(op->getOperands(), nonNegativePred) &&
58  llvm::all_of(op->getResults(), nonNegativePred));
59 }
60 } // namespace mlir::dataflow
61 
63  Lattice::onUpdate(solver);
64 
65  // If the integer range can be narrowed to a constant, update the constant
66  // value of the SSA value.
67  std::optional<APInt> constant = getValue().getValue().getConstantValue();
68  auto value = cast<Value>(anchor);
69  auto *cv = solver->getOrCreateState<Lattice<ConstantValue>>(value);
70  if (!constant)
71  return solver->propagateIfChanged(
72  cv, cv->join(ConstantValue::getUnknownConstant()));
73 
74  Dialect *dialect;
75  if (auto *parent = value.getDefiningOp())
76  dialect = parent->getDialect();
77  else
78  dialect = value.getParentBlock()->getParentOp()->getDialect();
79 
80  Attribute cstAttr;
81  if (isa<IntegerType, IndexType>(value.getType())) {
82  cstAttr = IntegerAttr::get(value.getType(), *constant);
83  } else if (auto shapedTy = dyn_cast<ShapedType>(value.getType())) {
84  cstAttr = SplatElementsAttr::get(shapedTy, *constant);
85  } else {
86  llvm::report_fatal_error(
87  Twine("FIXME: Don't know how to create a constant for this type: ") +
88  mlir::debugString(value.getType()));
89  }
90  solver->propagateIfChanged(cv, cv->join(ConstantValue(cstAttr, dialect)));
91 }
92 
96  auto inferrable = dyn_cast<InferIntRangeInterface>(op);
97  if (!inferrable) {
98  setAllToEntryStates(results);
99  return success();
100  }
101 
102  LDBG() << "Inferring ranges for "
103  << OpWithFlags(op, OpPrintingFlags().skipRegions());
104  auto argRanges = llvm::map_to_vector(
105  operands, [](const IntegerValueRangeLattice *lattice) {
106  return lattice->getValue();
107  });
108 
109  auto joinCallback = [&](Value v, const IntegerValueRange &attrs) {
110  auto result = dyn_cast<OpResult>(v);
111  if (!result)
112  return;
113  assert(llvm::is_contained(op->getResults(), result));
114 
115  LDBG() << "Inferred range " << attrs;
116  IntegerValueRangeLattice *lattice = results[result.getResultNumber()];
117  IntegerValueRange oldRange = lattice->getValue();
118 
119  ChangeResult changed = lattice->join(attrs);
120 
121  // Catch loop results with loop variant bounds and conservatively make
122  // them [-inf, inf] so we don't circle around infinitely often (because
123  // the dataflow analysis in MLIR doesn't attempt to work out trip counts
124  // and often can't).
125  bool isYieldedResult = llvm::any_of(v.getUsers(), [](Operation *op) {
126  return op->hasTrait<OpTrait::IsTerminator>();
127  });
128  if (isYieldedResult && !oldRange.isUninitialized() &&
129  !(lattice->getValue() == oldRange)) {
130  LDBG() << "Loop variant loop result detected";
132  }
133  propagateIfChanged(lattice, changed);
134  };
135 
136  inferrable.inferResultRangesFromOptional(argRanges, joinCallback);
137  return success();
138 }
139 
141  Operation *op, const RegionSuccessor &successor,
142  ArrayRef<IntegerValueRangeLattice *> argLattices, unsigned firstIndex) {
143  if (auto inferrable = dyn_cast<InferIntRangeInterface>(op)) {
144  LDBG() << "Inferring ranges for "
145  << OpWithFlags(op, OpPrintingFlags().skipRegions());
146 
147  auto argRanges = llvm::map_to_vector(op->getOperands(), [&](Value value) {
148  return getLatticeElementFor(getProgramPointAfter(op), value)->getValue();
149  });
150 
151  auto joinCallback = [&](Value v, const IntegerValueRange &attrs) {
152  auto arg = dyn_cast<BlockArgument>(v);
153  if (!arg)
154  return;
155  if (!llvm::is_contained(successor.getSuccessor()->getArguments(), arg))
156  return;
157 
158  LDBG() << "Inferred range " << attrs;
159  IntegerValueRangeLattice *lattice = argLattices[arg.getArgNumber()];
160  IntegerValueRange oldRange = lattice->getValue();
161 
162  ChangeResult changed = lattice->join(attrs);
163 
164  // Catch loop results with loop variant bounds and conservatively make
165  // them [-inf, inf] so we don't circle around infinitely often (because
166  // the dataflow analysis in MLIR doesn't attempt to work out trip counts
167  // and often can't).
168  bool isYieldedValue = llvm::any_of(v.getUsers(), [](Operation *op) {
169  return op->hasTrait<OpTrait::IsTerminator>();
170  });
171  if (isYieldedValue && !oldRange.isUninitialized() &&
172  !(lattice->getValue() == oldRange)) {
173  LDBG() << "Loop variant loop result detected";
175  }
176  propagateIfChanged(lattice, changed);
177  };
178 
179  inferrable.inferResultRangesFromOptional(argRanges, joinCallback);
180  return;
181  }
182 
183  /// Given the results of getConstant{Lower,Upper}Bound() or getConstantStep()
184  /// on a LoopLikeInterface return the lower/upper bound for that result if
185  /// possible.
186  auto getLoopBoundFromFold = [&](std::optional<OpFoldResult> loopBound,
187  Type boundType, Block *block, bool getUpper) {
188  unsigned int width = ConstantIntRanges::getStorageBitwidth(boundType);
189  if (loopBound.has_value()) {
190  if (auto attr = dyn_cast<Attribute>(*loopBound)) {
191  if (auto bound = dyn_cast_or_null<IntegerAttr>(attr))
192  return bound.getValue();
193  } else if (auto value = llvm::dyn_cast_if_present<Value>(*loopBound)) {
194  const IntegerValueRangeLattice *lattice =
196  if (lattice != nullptr && !lattice->getValue().isUninitialized())
197  return getUpper ? lattice->getValue().getValue().smax()
198  : lattice->getValue().getValue().smin();
199  }
200  }
201  // Given the results of getConstant{Lower,Upper}Bound()
202  // or getConstantStep() on a LoopLikeInterface return the lower/upper
203  // bound
204  return getUpper ? APInt::getSignedMaxValue(width)
205  : APInt::getSignedMinValue(width);
206  };
207 
208  // Infer bounds for loop arguments that have static bounds
209  if (auto loop = dyn_cast<LoopLikeOpInterface>(op)) {
210  std::optional<Value> iv = loop.getSingleInductionVar();
211  if (!iv) {
213  op, successor, argLattices, firstIndex);
214  }
215  Block *block = iv->getParentBlock();
216  std::optional<OpFoldResult> lowerBound = loop.getSingleLowerBound();
217  std::optional<OpFoldResult> upperBound = loop.getSingleUpperBound();
218  std::optional<OpFoldResult> step = loop.getSingleStep();
219  APInt min = getLoopBoundFromFold(lowerBound, iv->getType(), block,
220  /*getUpper=*/false);
221  APInt max = getLoopBoundFromFold(upperBound, iv->getType(), block,
222  /*getUpper=*/true);
223  // Assume positivity for uniscoverable steps by way of getUpper = true.
224  APInt stepVal =
225  getLoopBoundFromFold(step, iv->getType(), block, /*getUpper=*/true);
226 
227  if (stepVal.isNegative()) {
228  std::swap(min, max);
229  } else {
230  // Correct the upper bound by subtracting 1 so that it becomes a <=
231  // bound, because loops do not generally include their upper bound.
232  max -= 1;
233  }
234 
235  // If we infer the lower bound to be larger than the upper bound, the
236  // resulting range is meaningless and should not be used in further
237  // inferences.
238  if (max.sge(min)) {
240  auto ivRange = ConstantIntRanges::fromSigned(min, max);
241  propagateIfChanged(ivEntry, ivEntry->join(IntegerValueRange{ivRange}));
242  }
243  return;
244  }
245 
247  op, successor, argLattices, firstIndex);
248 }
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
LatticeAnchor anchor
The lattice anchor to which the state belongs.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
A set of arbitrary-precision integers representing bounds on a given integer value.
const APInt & smax() const
The maximum value of an integer when it is interpreted as signed.
const APInt & smin() const
The minimum value of an integer when it is interpreted as signed.
static ConstantIntRanges fromSigned(const APInt &smin, const APInt &smax)
Create an ConstantIntRanges with the signed minimum and maximum equal to smin and smax,...
static unsigned getStorageBitwidth(Type type)
Return the bitwidth that should be used for integer ranges describing type.
std::optional< APInt > getConstantValue() const
If either the signed or unsigned interpretations of the range indicate that the value it bounds is a ...
void propagateIfChanged(AnalysisState *state, ChangeResult changed)
Propagate an update to a state if it changed.
ProgramPoint * getProgramPointBefore(Operation *op)
Get a uniqued program point instance.
The general data-flow analysis solver.
void propagateIfChanged(AnalysisState *state, ChangeResult changed)
Propagate an update to an analysis state if it changed by pushing dependent work items to the back of...
const StateT * lookupState(AnchorT anchor) const
Lookup an analysis state for the given lattice anchor.
StateT * getOrCreateState(AnchorT anchor)
Get the state associated with the given lattice anchor.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
This lattice value represents the integer range of an SSA value.
const ConstantIntRanges & getValue() const
Get the known integer value range.
bool isUninitialized() const
Whether the range is uninitialized.
static IntegerValueRange getMaxRange(Value value)
Create a maximal range ([0, uint_max(t)] / [int_min(t), int_max(t)]) range that is used to mark the v...
Set of flags used to control the behavior of the various IR print methods (e.g.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
Definition: Operation.h:1111
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
This class represents a successor of a region.
Region * getSuccessor() const
Return the given region successor.
BlockArgListType getArguments()
Definition: Region.h:81
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
user_range getUsers() const
Definition: Value.h:218
void onUpdate(DataFlowSolver *solver) const override
When the lattice gets updated, propagate an update to users of the value using its use-def chain to s...
This lattice value represents a known constant value of a lattice.
static ConstantValue getUnknownConstant()
The state where the constant value is unknown.
LogicalResult visitOperation(Operation *op, ArrayRef< const IntegerValueRangeLattice * > operands, ArrayRef< IntegerValueRangeLattice * > results) override
Visit an operation.
void visitNonControlFlowArguments(Operation *op, const RegionSuccessor &successor, ArrayRef< IntegerValueRangeLattice * > argLattices, unsigned firstIndex) override
Visit block arguments or operation results of an operation with region control-flow for which values ...
This lattice element represents the integer value range of an SSA value.
void onUpdate(DataFlowSolver *solver) const override
If the range can be narrowed to an integer constant, update the constant value of the SSA value.
This class represents a lattice holding a specific value of type ValueT.
IntegerValueRange & getValue()
Return the value held by this lattice.
ChangeResult join(const AbstractSparseLattice &rhs) override
Join the information contained in the 'rhs' lattice into this lattice.
const IntegerValueRangeLattice * getLatticeElementFor(ProgramPoint *point, Value value)
Get the lattice element for a value and create a dependency on the provided program point.
IntegerValueRangeLattice * getLatticeElement(Value value) override
Get the lattice element for a value.
void setAllToEntryStates(ArrayRef< IntegerValueRangeLattice * > lattices)
virtual void visitNonControlFlowArguments(Operation *op, const RegionSuccessor &successor, ArrayRef< StateT * > argLattices, unsigned firstIndex)
Given an operation with possible region control-flow, the lattices of the operands,...
LogicalResult staticallyNonNegative(DataFlowSolver &solver, Operation *op)
Succeeds if an op can be converted to its unsigned equivalent without changing its semantics.
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
ChangeResult
A result type used to indicate if a change happened.
static std::string debugString(T &&op)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...