MLIR 22.0.0git
IntegerRangeAnalysis.cpp
Go to the documentation of this file.
1//===- IntegerRangeAnalysis.cpp - Integer range analysis --------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the dataflow analysis class for integer range inference
10// which is used in transformations over the `arith` dialect such as
11// branch elimination or signed->unsigned rewriting
12//
13//===----------------------------------------------------------------------===//
14
20#include "mlir/IR/Dialect.h"
22#include "mlir/IR/Operation.h"
25#include "mlir/IR/Value.h"
30#include "mlir/Support/LLVM.h"
31#include "llvm/ADT/STLExtras.h"
32#include "llvm/Support/Casting.h"
33#include "llvm/Support/Debug.h"
34#include "llvm/Support/DebugLog.h"
35#include <cassert>
36#include <optional>
37#include <utility>
38
39#define DEBUG_TYPE "int-range-analysis"
40
41using namespace mlir;
42using namespace mlir::dataflow;
43
44namespace mlir::dataflow {
45LogicalResult staticallyNonNegative(DataFlowSolver &solver, Value v) {
47 if (!result || result->getValue().isUninitialized())
48 return failure();
49 const ConstantIntRanges &range = result->getValue().getValue();
50 return success(range.smin().isNonNegative());
51}
52
53LogicalResult staticallyNonNegative(DataFlowSolver &solver, Operation *op) {
54 auto nonNegativePred = [&solver](Value v) -> bool {
55 return succeeded(staticallyNonNegative(solver, v));
56 };
57 return success(llvm::all_of(op->getOperands(), nonNegativePred) &&
58 llvm::all_of(op->getResults(), nonNegativePred));
59}
60} // namespace mlir::dataflow
61
63 Lattice::onUpdate(solver);
64
65 // If the integer range can be narrowed to a constant, update the constant
66 // value of the SSA value.
67 std::optional<APInt> constant = getValue().getValue().getConstantValue();
68 auto value = cast<Value>(anchor);
69 auto *cv = solver->getOrCreateState<Lattice<ConstantValue>>(value);
70 if (!constant)
71 return solver->propagateIfChanged(
72 cv, cv->join(ConstantValue::getUnknownConstant()));
73
74 Dialect *dialect;
75 if (auto *parent = value.getDefiningOp())
76 dialect = parent->getDialect();
77 else
78 dialect = value.getParentBlock()->getParentOp()->getDialect();
79
80 Attribute cstAttr;
81 if (isa<IntegerType, IndexType>(value.getType())) {
82 cstAttr = IntegerAttr::get(value.getType(), *constant);
83 } else if (auto shapedTy = dyn_cast<ShapedType>(value.getType())) {
84 cstAttr = SplatElementsAttr::get(shapedTy, *constant);
85 } else {
86 llvm::report_fatal_error(
87 Twine("FIXME: Don't know how to create a constant for this type: ") +
88 mlir::debugString(value.getType()));
89 }
90 solver->propagateIfChanged(cv, cv->join(ConstantValue(cstAttr, dialect)));
91}
92
96 auto inferrable = dyn_cast<InferIntRangeInterface>(op);
97 if (!inferrable) {
98 setAllToEntryStates(results);
99 return success();
100 }
101
102 LDBG() << "Inferring ranges for "
103 << OpWithFlags(op, OpPrintingFlags().skipRegions());
104 auto argRanges = llvm::map_to_vector(
105 operands, [](const IntegerValueRangeLattice *lattice) {
106 return lattice->getValue();
107 });
108
109 auto joinCallback = [&](Value v, const IntegerValueRange &attrs) {
110 auto result = dyn_cast<OpResult>(v);
111 if (!result)
112 return;
113 assert(llvm::is_contained(op->getResults(), result));
114
115 LDBG() << "Inferred range " << attrs;
116 IntegerValueRangeLattice *lattice = results[result.getResultNumber()];
117 IntegerValueRange oldRange = lattice->getValue();
118
119 ChangeResult changed = lattice->join(attrs);
120
121 // Catch loop results with loop variant bounds and conservatively make
122 // them [-inf, inf] so we don't circle around infinitely often (because
123 // the dataflow analysis in MLIR doesn't attempt to work out trip counts
124 // and often can't).
125 bool isYieldedResult = llvm::any_of(v.getUsers(), [](Operation *op) {
126 return op->hasTrait<OpTrait::IsTerminator>();
127 });
128 if (isYieldedResult && !oldRange.isUninitialized() &&
129 !(lattice->getValue() == oldRange)) {
130 LDBG() << "Loop variant loop result detected";
132 }
133 propagateIfChanged(lattice, changed);
134 };
135
136 inferrable.inferResultRangesFromOptional(argRanges, joinCallback);
137 return success();
138}
139
141 Operation *op, const RegionSuccessor &successor,
142 ArrayRef<IntegerValueRangeLattice *> argLattices, unsigned firstIndex) {
143 if (auto inferrable = dyn_cast<InferIntRangeInterface>(op)) {
144 LDBG() << "Inferring ranges for "
145 << OpWithFlags(op, OpPrintingFlags().skipRegions());
146
147 auto argRanges = llvm::map_to_vector(op->getOperands(), [&](Value value) {
148 return getLatticeElementFor(getProgramPointAfter(op), value)->getValue();
149 });
150
151 auto joinCallback = [&](Value v, const IntegerValueRange &attrs) {
152 auto arg = dyn_cast<BlockArgument>(v);
153 if (!arg)
154 return;
155 if (!llvm::is_contained(successor.getSuccessor()->getArguments(), arg))
156 return;
157
158 LDBG() << "Inferred range " << attrs;
159 IntegerValueRangeLattice *lattice = argLattices[arg.getArgNumber()];
160 IntegerValueRange oldRange = lattice->getValue();
161
162 ChangeResult changed = lattice->join(attrs);
163
164 // Catch loop results with loop variant bounds and conservatively make
165 // them [-inf, inf] so we don't circle around infinitely often (because
166 // the dataflow analysis in MLIR doesn't attempt to work out trip counts
167 // and often can't).
168 bool isYieldedValue = llvm::any_of(v.getUsers(), [](Operation *op) {
169 return op->hasTrait<OpTrait::IsTerminator>();
170 });
171 if (isYieldedValue && !oldRange.isUninitialized() &&
172 !(lattice->getValue() == oldRange)) {
173 LDBG() << "Loop variant loop result detected";
175 }
176 propagateIfChanged(lattice, changed);
177 };
178
179 inferrable.inferResultRangesFromOptional(argRanges, joinCallback);
180 return;
181 }
182
183 /// Given a lower bound, upper bound, or step from a LoopLikeInterface return
184 /// the lower/upper bound for that result if possible.
185 auto getLoopBoundFromFold = [&](OpFoldResult loopBound, Type boundType,
186 Block *block, bool getUpper) {
187 unsigned int width = ConstantIntRanges::getStorageBitwidth(boundType);
188 if (auto attr = dyn_cast<Attribute>(loopBound)) {
189 if (auto bound = dyn_cast<IntegerAttr>(attr))
190 return bound.getValue();
191 } else if (auto value = llvm::dyn_cast<Value>(loopBound)) {
192 const IntegerValueRangeLattice *lattice =
194 if (lattice != nullptr && !lattice->getValue().isUninitialized())
195 return getUpper ? lattice->getValue().getValue().smax()
196 : lattice->getValue().getValue().smin();
197 }
198 // Given the results of getConstant{Lower,Upper}Bound()
199 // or getConstantStep() on a LoopLikeInterface return the lower/upper
200 // bound
201 return getUpper ? APInt::getSignedMaxValue(width)
202 : APInt::getSignedMinValue(width);
203 };
204
205 // Infer bounds for loop arguments that have static bounds
206 if (auto loop = dyn_cast<LoopLikeOpInterface>(op)) {
207 std::optional<llvm::SmallVector<Value>> maybeIvs =
208 loop.getLoopInductionVars();
209 if (!maybeIvs) {
210 return SparseForwardDataFlowAnalysis ::visitNonControlFlowArguments(
211 op, successor, argLattices, firstIndex);
212 }
213 // This shouldn't be returning nullopt if there are indunction variables.
214 SmallVector<OpFoldResult> lowerBounds = *loop.getLoopLowerBounds();
215 SmallVector<OpFoldResult> upperBounds = *loop.getLoopUpperBounds();
216 SmallVector<OpFoldResult> steps = *loop.getLoopSteps();
217 for (auto [iv, lowerBound, upperBound, step] :
218 llvm::zip_equal(*maybeIvs, lowerBounds, upperBounds, steps)) {
219 Block *block = iv.getParentBlock();
220 APInt min = getLoopBoundFromFold(lowerBound, iv.getType(), block,
221 /*getUpper=*/false);
222 APInt max = getLoopBoundFromFold(upperBound, iv.getType(), block,
223 /*getUpper=*/true);
224 // Assume positivity for uniscoverable steps by way of getUpper = true.
225 APInt stepVal =
226 getLoopBoundFromFold(step, iv.getType(), block, /*getUpper=*/true);
227
228 if (stepVal.isNegative()) {
229 std::swap(min, max);
230 } else {
231 // Correct the upper bound by subtracting 1 so that it becomes a <=
232 // bound, because loops do not generally include their upper bound.
233 max -= 1;
234 }
235
236 // If we infer the lower bound to be larger than the upper bound, the
237 // resulting range is meaningless and should not be used in further
238 // inferences.
239 if (max.sge(min)) {
241 auto ivRange = ConstantIntRanges::fromSigned(min, max);
242 propagateIfChanged(ivEntry, ivEntry->join(IntegerValueRange{ivRange}));
243 }
244 }
245 return;
246 }
247
249 op, successor, argLattices, firstIndex);
250}
return success()
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
LatticeAnchor anchor
The lattice anchor to which the state belongs.
friend class DataFlowSolver
Allow the framework to access the dependents.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
A set of arbitrary-precision integers representing bounds on a given integer value.
const APInt & smax() const
The maximum value of an integer when it is interpreted as signed.
const APInt & smin() const
The minimum value of an integer when it is interpreted as signed.
static ConstantIntRanges fromSigned(const APInt &smin, const APInt &smax)
Create an ConstantIntRanges with the signed minimum and maximum equal to smin and smax,...
static unsigned getStorageBitwidth(Type type)
Return the bitwidth that should be used for integer ranges describing type.
std::optional< APInt > getConstantValue() const
If either the signed or unsigned interpretations of the range indicate that the value it bounds is a ...
ProgramPoint * getProgramPointBefore(Operation *op)
Get a uniqued program point instance.
void propagateIfChanged(AnalysisState *state, ChangeResult changed)
Propagate an update to a state if it changed.
The general data-flow analysis solver.
void propagateIfChanged(AnalysisState *state, ChangeResult changed)
Propagate an update to an analysis state if it changed by pushing dependent work items to the back of...
const StateT * lookupState(AnchorT anchor) const
Lookup an analysis state for the given lattice anchor.
StateT * getOrCreateState(AnchorT anchor)
Get the state associated with the given lattice anchor.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition Dialect.h:38
This lattice value represents the integer range of an SSA value.
const ConstantIntRanges & getValue() const
Get the known integer value range.
bool isUninitialized() const
Whether the range is uninitialized.
static IntegerValueRange getMaxRange(Value value)
Create a maximal range ([0, uint_max(t)] / [int_min(t), int_max(t)]) range that is used to mark the v...
This class represents a single result from folding an operation.
Set of flags used to control the behavior of the various IR print methods (e.g.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
Definition Operation.h:1111
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
result_range getResults()
Definition Operation.h:415
This class represents a successor of a region.
Region * getSuccessor() const
Return the given region successor.
BlockArgListType getArguments()
Definition Region.h:81
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
user_range getUsers() const
Definition Value.h:218
void onUpdate(DataFlowSolver *solver) const override
When the lattice gets updated, propagate an update to users of the value using its use-def chain to s...
This lattice value represents a known constant value of a lattice.
static ConstantValue getUnknownConstant()
The state where the constant value is unknown.
LogicalResult visitOperation(Operation *op, ArrayRef< const IntegerValueRangeLattice * > operands, ArrayRef< IntegerValueRangeLattice * > results) override
Visit an operation.
void visitNonControlFlowArguments(Operation *op, const RegionSuccessor &successor, ArrayRef< IntegerValueRangeLattice * > argLattices, unsigned firstIndex) override
Visit block arguments or operation results of an operation with region control-flow for which values ...
This lattice element represents the integer value range of an SSA value.
void onUpdate(DataFlowSolver *solver) const override
If the range can be narrowed to an integer constant, update the constant value of the SSA value.
This class represents a lattice holding a specific value of type ValueT.
ChangeResult join(const AbstractSparseLattice &rhs) override
Join the information contained in the 'rhs' lattice into this lattice.
IntegerValueRangeLattice * getLatticeElement(Value value) override
void setAllToEntryStates(ArrayRef< IntegerValueRangeLattice * > lattices)
const IntegerValueRangeLattice * getLatticeElementFor(ProgramPoint *point, Value value)
virtual void visitNonControlFlowArguments(Operation *op, const RegionSuccessor &successor, ArrayRef< StateT * > argLattices, unsigned firstIndex)
Given an operation with possible region control-flow, the lattices of the operands,...
LogicalResult staticallyNonNegative(DataFlowSolver &solver, Operation *op)
Succeeds if an op can be converted to its unsigned equivalent without changing its semantics.
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
ChangeResult
A result type used to indicate if a change happened.
static std::string debugString(T &&op)