MLIR 23.0.0git
StaticValueUtils.h
Go to the documentation of this file.
1//===- StaticValueUtils.h - Utilities for static values ---------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This header file defines utilities for dealing with static values, e.g.,
10// converting back and forth between Value and OpFoldResult. Such functionality
11// is used in multiple dialects.
12//
13//===----------------------------------------------------------------------===//
14
15#ifndef MLIR_DIALECT_UTILS_STATICVALUEUTILS_H
16#define MLIR_DIALECT_UTILS_STATICVALUEUTILS_H
17
18#include "mlir/IR/Builders.h"
21#include "mlir/Support/LLVM.h"
22#include "llvm/ADT/SmallVector.h"
23#include "llvm/ADT/SmallVectorExtras.h"
24
25namespace mlir {
26
27/// Return "true" if `v` is an integer value/attribute with constant value `0`.
29
30/// Return "true" if `v` is a float value/attribute with constant value `0.0`.
32
33/// Return "true" if `v` is an integer/float value/attribute with constant
34/// value zero.
36
37/// Return true if `v` is an IntegerAttr with value `1`.
39
40/// Represents a range (offset, size, and stride) where each element of the
41/// triple may be dynamic or static.
47
48/// Given an array of Range values, return a tuple of (offset vector, sizes
49/// vector, and strides vector) formed by separating out the individual
50/// elements of each range.
51std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
54
55/// Helper function to dispatch an OpFoldResult into `staticVec` if:
56/// a) it is an IntegerAttr
57/// In other cases, the OpFoldResult is dispached to the `dynamicVec`.
58/// In such dynamic cases, ShapedType::kDynamic is also pushed to
59/// `staticVec`. This is useful to extract mixed static and dynamic entries
60/// that come from an AttrSizedOperandSegments trait.
62 SmallVectorImpl<Value> &dynamicVec,
63 SmallVectorImpl<int64_t> &staticVec);
64
65/// Helper function to dispatch multiple OpFoldResults according to the
66/// behavior of `dispatchIndexOpFoldResult(OpFoldResult ofr` for a single
67/// OpFoldResult.
69 SmallVectorImpl<Value> &dynamicVec,
70 SmallVectorImpl<int64_t> &staticVec);
71
72/// Given OpFoldResult representing dim size value (*), generates a pair of
73/// sizes:
74/// * 1st result, static value, contains an int64_t dim size that can be used
75/// to build ShapedType (ShapedType::kDynamic is used for truly dynamic dims),
76/// * 2nd result, dynamic value, contains OpFoldResult encapsulating the
77/// actual dim size (either original or updated input value).
78/// For input sizes for which it is possible to extract a constant Attribute,
79/// replaces the original size value with an integer attribute (unless it's
80/// already a constant Attribute). The 1st return value also becomes the actual
81/// integer size (as opposed ShapedType::kDynamic).
82///
83/// (*) This hook is usually used when, given input sizes as OpFoldResult,
84/// it's required to generate two vectors:
85/// * sizes as int64_t to generate a shape,
86/// * sizes as OpFoldResult for sizes-like attribute.
87/// Please update this comment if you identify other use cases.
88std::pair<int64_t, OpFoldResult>
90
91/// Extract integer values from the assumed ArrayAttr of IntegerAttr.
92template <typename IntTy>
94 return llvm::map_to_vector(cast<ArrayAttr>(attr), [](Attribute a) -> IntTy {
95 return cast<IntegerAttr>(a).getInt();
96 });
97}
98
99/// Given a value, try to extract a constant Attribute. If this fails, return
100/// the original value.
101OpFoldResult getAsOpFoldResult(Value val);
102/// Given an array of values, try to extract a constant Attribute from each
103/// value. If this fails, return the original value.
105/// Convert `arrayAttr` to a vector of OpFoldResult.
107
108/// Convert int64_t to integer attributes of index type and return them as
109/// OpFoldResult.
110OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val);
112 ArrayRef<int64_t> values);
113
114/// If ofr is a constant integer or an IntegerAttr, return the integer.
115/// The second return value indicates whether the value is an index type
116/// and thus the bitwidth is not defined (the APInt will be set with 64bits).
117std::optional<std::pair<APInt, bool>> getConstantAPIntValue(OpFoldResult ofr);
118/// If ofr is a constant integer or an IntegerAttr, return the integer.
119std::optional<int64_t> getConstantIntValue(OpFoldResult ofr);
120/// If all ofrs are constant integers or IntegerAttrs, return the integers.
121std::optional<SmallVector<int64_t>>
122getConstantIntValues(ArrayRef<OpFoldResult> ofrs);
123
124/// Return true if `ofr` is constant integer equal to `value`.
125bool isConstantIntValue(OpFoldResult ofr, int64_t value);
126/// Return true if all of `ofrs` are constant integers equal to `value`.
127bool areAllConstantIntValue(ArrayRef<OpFoldResult> ofrs, int64_t value);
128/// Return true if all of `ofrs` are constant integers equal to the
129/// corresponding value in `values`.
130bool areConstantIntValues(ArrayRef<OpFoldResult> ofrs,
131 ArrayRef<int64_t> values);
132
133/// Return true if ofr1 and ofr2 are the same integer constant attribute
134/// values or the same SSA value. Ignore integer bitwitdh and type mismatch
135/// that come from the fact there is no IndexAttr and that IndexType have no
136/// bitwidth.
137bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2);
138bool isEqualConstantIntOrValueArray(ArrayRef<OpFoldResult> ofrs1,
139 ArrayRef<OpFoldResult> ofrs2);
140
141// To convert an OpFoldResult to a Value of index type, see:
142// mlir/include/mlir/Dialect/Arith/Utils/Utils.h
143// TODO: find a better common landing place.
144//
145// Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
146// OpFoldResult ofr);
147
148// To convert an OpFoldResult to a Value of index type, see:
149// mlir/include/mlir/Dialect/Arith/Utils/Utils.h
150// TODO: find a better common landing place.
151//
152// SmallVector<Value>
153// getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
154// ArrayRef<OpFoldResult> valueOrAttrVec);
155
156/// Return a vector of OpFoldResults with the same size a staticValues, but
157/// all elements for which ShapedType::isDynamic is true, will be replaced by
158/// dynamicValues.
159SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
160 ValueRange dynamicValues,
161 MLIRContext *context);
162SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
163 ValueRange dynamicValues, Builder &b);
164
165/// Decompose a vector of mixed static or dynamic values into the
166/// corresponding pair of arrays. This is the inverse function of
167/// `getMixedValues`.
168std::pair<SmallVector<int64_t>, SmallVector<Value>>
169decomposeMixedValues(ArrayRef<OpFoldResult> mixedValues);
170
171/// Helper to sort `values` according to matching `keys`.
173getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<Value> values,
174 llvm::function_ref<bool(Attribute, Attribute)> compare);
176getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<OpFoldResult> values,
177 llvm::function_ref<bool(Attribute, Attribute)> compare);
179getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<int64_t> values,
180 llvm::function_ref<bool(Attribute, Attribute)> compare);
181
182/// Helper function to check whether the passed in `sizes` or `offsets` are
183/// valid. This can be used to re-check whether dimensions are still valid
184/// after constant folding the dynamic dimensions.
185bool hasValidSizesOffsets(SmallVector<int64_t> sizesOrOffsets);
186
187/// Helper function to check whether the passed in `strides` are valid. This
188/// can be used to re-check whether dimensions are still valid after constant
189/// folding the dynamic dimensions.
191
192/// Returns "success" when any of the elements in `ofrs` is a constant value. In
193/// that case the value is replaced by an attribute. Returns "failure" when no
194/// folding happened. If `onlyNonNegative` and `onlyNonZero` are set, only
195/// non-negative and non-zero constant values are folded respectively.
196LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
197 bool onlyNonNegative = false,
198 bool onlyNonZero = false);
199
200/// Returns "success" when any of the elements in `offsetsOrSizes` is a
201/// constant value. In that case the value is replaced by an attribute. Returns
202/// "failure" when no folding happened. Invalid values are not folded to avoid
203/// canonicalization crashes.
204LogicalResult
205foldDynamicOffsetSizeList(SmallVectorImpl<OpFoldResult> &offsetsOrSizes);
206
207/// Returns "success" when any of the elements in `strides` is a constant
208/// value. In that case the value is replaced by an attribute. Returns
209/// "failure" when no folding happened. Invalid values are not folded to avoid
210/// canonicalization crashes.
211LogicalResult foldDynamicStrideList(SmallVectorImpl<OpFoldResult> &strides);
212
213/// Return the number of iterations for a loop with a lower bound `lb`, upper
214/// bound `ub` and step `step`, as an unsigned integer. The `isSigned` flag
215/// indicates whether the loop comparison between lb and ub is signed or
216/// unsigned. (The result of this function must be interpreted as an unsigned
217/// integer.) A lower bound greater than the upper bound is considered invalid
218/// and will yield a zero trip count.
219/// The `computeUbMinusLb` callback is invoked to compute the difference between
220/// the upper and lower bound when not constant. It can be used by the client
221/// to compute a static difference when the bounds are not constant.
222///
223/// For example, the following code:
224///
225/// %ub = arith.addi nsw %lb, %c16_i32 : i32
226/// %1 = scf.for %arg0 = %lb to %ub ...
227///
228/// where %ub is computed as a static offset from %lb.
229/// Note: the matched addition should be nsw/nuw (matching the loop comparison)
230/// to avoid overflow, otherwise an overflow would imply a zero trip count.
231std::optional<APInt> constantTripCount(
232 OpFoldResult lb, OpFoldResult ub, OpFoldResult step, bool isSigned,
233 llvm::function_ref<std::optional<llvm::APSInt>(Value, Value, bool)>
234 computeUbMinusLb);
235
236/// Idiomatic saturated operations on values like offsets, sizes, and strides.
239 return (ShapedType::isDynamic(v)) ? SaturatedInteger{true, 0}
240 : SaturatedInteger{false, v};
241 }
242 int64_t asInteger() { return saturated ? ShapedType::kDynamic : v; }
243 FailureOr<SaturatedInteger> desaturate(SaturatedInteger other) {
244 if (saturated && !other.saturated)
245 return other;
246 if (!saturated && !other.saturated && v != other.v)
247 return failure();
248 return *this;
249 }
251 return (saturated && other.saturated) ||
252 (!saturated && !other.saturated && v == other.v);
253 }
254 bool operator!=(SaturatedInteger other) { return !(*this == other); }
256 if (saturated || other.saturated)
257 return SaturatedInteger{true, 0};
258 return SaturatedInteger{false, other.v + v};
259 }
261 // Multiplication with 0 is always 0.
262 if (!other.saturated && other.v == 0)
263 return SaturatedInteger{false, 0};
264 if (!saturated && v == 0)
265 return SaturatedInteger{false, 0};
266 // Otherwise, if this or the other integer is dynamic, so is the result.
267 if (saturated || other.saturated)
268 return SaturatedInteger{true, 0};
269 return SaturatedInteger{false, other.v * v};
270 }
271 bool saturated = true;
273};
274
275} // namespace mlir
276
277#endif // MLIR_DIALECT_UTILS_STATICVALUEUTILS_H
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
This class represents a single result from folding an operation.
Include the generated interface declarations.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
bool areConstantIntValues(ArrayRef< OpFoldResult > ofrs, ArrayRef< int64_t > values)
Return true if all of ofrs are constant integers equal to the corresponding value in values.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)
Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult foldDynamicStrideList(SmallVectorImpl< OpFoldResult > &strides)
Returns "success" when any of the elements in strides is a constant value.
bool areAllConstantIntValue(ArrayRef< OpFoldResult > ofrs, int64_t value)
Return true if all of ofrs are constant integers equal to value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
bool isZeroIntegerOrFloat(OpFoldResult v)
Return "true" if v is an integer/float value/attribute with constant value zero.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
std::optional< std::pair< APInt, bool > > getConstantAPIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
bool hasValidStrides(SmallVector< int64_t > strides)
Helper function to check whether the passed in strides are valid.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
bool isEqualConstantIntOrValueArray(ArrayRef< OpFoldResult > ofrs1, ArrayRef< OpFoldResult > ofrs2)
void dispatchIndexOpFoldResult(OpFoldResult ofr, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch an OpFoldResult into staticVec if: a) it is an IntegerAttr In other cases...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
std::pair< int64_t, OpFoldResult > getSimplifiedOfrAndStaticSizePair(OpFoldResult ofr, Builder &b)
Given OpFoldResult representing dim size value (*), generates a pair of sizes:
SmallVector< IntTy > extractFromIntegerArrayAttr(Attribute attr)
Extract integer values from the assumed ArrayAttr of IntegerAttr.
std::optional< SmallVector< int64_t > > getConstantIntValues(ArrayRef< OpFoldResult > ofrs)
If all ofrs are constant integers or IntegerAttrs, return the integers.
SmallVector< Value > getValuesSortedByKey(ArrayRef< Attribute > keys, ArrayRef< Value > values, llvm::function_ref< bool(Attribute, Attribute)> compare)
Helper to sort values according to matching keys.
LogicalResult foldDynamicOffsetSizeList(SmallVectorImpl< OpFoldResult > &offsetsOrSizes)
Returns "success" when any of the elements in offsetsOrSizes is a constant value.
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
std::optional< APInt > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step, bool isSigned, llvm::function_ref< std::optional< llvm::APSInt >(Value, Value, bool)> computeUbMinusLb)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step,...
bool isZeroFloat(OpFoldResult v)
Return "true" if v is a float value/attribute with constant value 0.0.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
Idiomatic saturated operations on values like offsets, sizes, and strides.
SaturatedInteger operator+(SaturatedInteger other)
static SaturatedInteger wrap(int64_t v)
FailureOr< SaturatedInteger > desaturate(SaturatedInteger other)
SaturatedInteger operator*(SaturatedInteger other)
bool operator!=(SaturatedInteger other)
bool operator==(SaturatedInteger other)