MLIR  22.0.0git
StaticValueUtils.h
Go to the documentation of this file.
1 //===- StaticValueUtils.h - Utilities for static values ---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This header file defines utilities for dealing with static values, e.g.,
10 // converting back and forth between Value and OpFoldResult. Such functionality
11 // is used in multiple dialects.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #ifndef MLIR_DIALECT_UTILS_STATICVALUEUTILS_H
16 #define MLIR_DIALECT_UTILS_STATICVALUEUTILS_H
17 
18 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/OpDefinition.h"
21 #include "mlir/Support/LLVM.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/ADT/SmallVectorExtras.h"
24 
25 namespace mlir {
26 
27 /// Return true if `v` is an IntegerAttr with value `0`.
28 bool isZeroInteger(OpFoldResult v);
29 
30 /// Return true if `v` is an IntegerAttr with value `1`.
31 bool isOneInteger(OpFoldResult v);
32 
33 /// Represents a range (offset, size, and stride) where each element of the
34 /// triple may be dynamic or static.
35 struct Range {
39 };
40 
41 /// Given an array of Range values, return a tuple of (offset vector, sizes
42 /// vector, and strides vector) formed by separating out the individual
43 /// elements of each range.
44 std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
47 
48 /// Helper function to dispatch an OpFoldResult into `staticVec` if:
49 /// a) it is an IntegerAttr
50 /// In other cases, the OpFoldResult is dispached to the `dynamicVec`.
51 /// In such dynamic cases, ShapedType::kDynamic is also pushed to
52 /// `staticVec`. This is useful to extract mixed static and dynamic entries
53 /// that come from an AttrSizedOperandSegments trait.
55  SmallVectorImpl<Value> &dynamicVec,
56  SmallVectorImpl<int64_t> &staticVec);
57 
58 /// Helper function to dispatch multiple OpFoldResults according to the
59 /// behavior of `dispatchIndexOpFoldResult(OpFoldResult ofr` for a single
60 /// OpFoldResult.
62  SmallVectorImpl<Value> &dynamicVec,
63  SmallVectorImpl<int64_t> &staticVec);
64 
65 /// Given OpFoldResult representing dim size value (*), generates a pair of
66 /// sizes:
67 /// * 1st result, static value, contains an int64_t dim size that can be used
68 /// to build ShapedType (ShapedType::kDynamic is used for truly dynamic dims),
69 /// * 2nd result, dynamic value, contains OpFoldResult encapsulating the
70 /// actual dim size (either original or updated input value).
71 /// For input sizes for which it is possible to extract a constant Attribute,
72 /// replaces the original size value with an integer attribute (unless it's
73 /// already a constant Attribute). The 1st return value also becomes the actual
74 /// integer size (as opposed ShapedType::kDynamic).
75 ///
76 /// (*) This hook is usually used when, given input sizes as OpFoldResult,
77 /// it's required to generate two vectors:
78 /// * sizes as int64_t to generate a shape,
79 /// * sizes as OpFoldResult for sizes-like attribute.
80 /// Please update this comment if you identify other use cases.
81 std::pair<int64_t, OpFoldResult>
83 
84 /// Extract integer values from the assumed ArrayAttr of IntegerAttr.
85 template <typename IntTy>
87  return llvm::to_vector(
88  llvm::map_range(cast<ArrayAttr>(attr), [](Attribute a) -> IntTy {
89  return cast<IntegerAttr>(a).getInt();
90  }));
91 }
92 
93 /// Given a value, try to extract a constant Attribute. If this fails, return
94 /// the original value.
95 OpFoldResult getAsOpFoldResult(Value val);
96 /// Given an array of values, try to extract a constant Attribute from each
97 /// value. If this fails, return the original value.
98 SmallVector<OpFoldResult> getAsOpFoldResult(ValueRange values);
99 /// Convert `arrayAttr` to a vector of OpFoldResult.
100 SmallVector<OpFoldResult> getAsOpFoldResult(ArrayAttr arrayAttr);
101 
102 /// Convert int64_t to integer attributes of index type and return them as
103 /// OpFoldResult.
104 OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val);
105 SmallVector<OpFoldResult> getAsIndexOpFoldResult(MLIRContext *ctx,
106  ArrayRef<int64_t> values);
107 
108 /// If ofr is a constant integer or an IntegerAttr, return the integer.
109 /// The second return value indicates whether the value is an index type
110 /// and thus the bitwidth is not defined (the APInt will be set with 64bits).
111 std::optional<std::pair<APInt, bool>> getConstantAPIntValue(OpFoldResult ofr);
112 /// If ofr is a constant integer or an IntegerAttr, return the integer.
113 std::optional<int64_t> getConstantIntValue(OpFoldResult ofr);
114 /// If all ofrs are constant integers or IntegerAttrs, return the integers.
115 std::optional<SmallVector<int64_t>>
116 getConstantIntValues(ArrayRef<OpFoldResult> ofrs);
117 
118 /// Return true if `ofr` is constant integer equal to `value`.
119 bool isConstantIntValue(OpFoldResult ofr, int64_t value);
120 /// Return true if all of `ofrs` are constant integers equal to `value`.
121 bool areAllConstantIntValue(ArrayRef<OpFoldResult> ofrs, int64_t value);
122 /// Return true if all of `ofrs` are constant integers equal to the
123 /// corresponding value in `values`.
124 bool areConstantIntValues(ArrayRef<OpFoldResult> ofrs,
125  ArrayRef<int64_t> values);
126 
127 /// Return true if ofr1 and ofr2 are the same integer constant attribute
128 /// values or the same SSA value. Ignore integer bitwitdh and type mismatch
129 /// that come from the fact there is no IndexAttr and that IndexType have no
130 /// bitwidth.
131 bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2);
132 bool isEqualConstantIntOrValueArray(ArrayRef<OpFoldResult> ofrs1,
133  ArrayRef<OpFoldResult> ofrs2);
134 
135 // To convert an OpFoldResult to a Value of index type, see:
136 // mlir/include/mlir/Dialect/Arith/Utils/Utils.h
137 // TODO: find a better common landing place.
138 //
139 // Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
140 // OpFoldResult ofr);
141 
142 // To convert an OpFoldResult to a Value of index type, see:
143 // mlir/include/mlir/Dialect/Arith/Utils/Utils.h
144 // TODO: find a better common landing place.
145 //
146 // SmallVector<Value>
147 // getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
148 // ArrayRef<OpFoldResult> valueOrAttrVec);
149 
150 /// Return a vector of OpFoldResults with the same size a staticValues, but
151 /// all elements for which ShapedType::isDynamic is true, will be replaced by
152 /// dynamicValues.
153 SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
154  ValueRange dynamicValues,
155  MLIRContext *context);
156 SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
157  ValueRange dynamicValues, Builder &b);
158 
159 /// Decompose a vector of mixed static or dynamic values into the
160 /// corresponding pair of arrays. This is the inverse function of
161 /// `getMixedValues`.
162 std::pair<SmallVector<int64_t>, SmallVector<Value>>
163 decomposeMixedValues(ArrayRef<OpFoldResult> mixedValues);
164 
165 /// Helper to sort `values` according to matching `keys`.
166 SmallVector<Value>
167 getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<Value> values,
168  llvm::function_ref<bool(Attribute, Attribute)> compare);
169 SmallVector<OpFoldResult>
170 getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<OpFoldResult> values,
171  llvm::function_ref<bool(Attribute, Attribute)> compare);
172 SmallVector<int64_t>
173 getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<int64_t> values,
174  llvm::function_ref<bool(Attribute, Attribute)> compare);
175 
176 /// Helper function to check whether the passed in `sizes` or `offsets` are
177 /// valid. This can be used to re-check whether dimensions are still valid
178 /// after constant folding the dynamic dimensions.
179 bool hasValidSizesOffsets(SmallVector<int64_t> sizesOrOffsets);
180 
181 /// Helper function to check whether the passed in `strides` are valid. This
182 /// can be used to re-check whether dimensions are still valid after constant
183 /// folding the dynamic dimensions.
184 bool hasValidStrides(SmallVector<int64_t> strides);
185 
186 /// Returns "success" when any of the elements in `ofrs` is a constant value. In
187 /// that case the value is replaced by an attribute. Returns "failure" when no
188 /// folding happened. If `onlyNonNegative` and `onlyNonZero` are set, only
189 /// non-negative and non-zero constant values are folded respectively.
190 LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
191  bool onlyNonNegative = false,
192  bool onlyNonZero = false);
193 
194 /// Returns "success" when any of the elements in `offsetsOrSizes` is a
195 /// constant value. In that case the value is replaced by an attribute. Returns
196 /// "failure" when no folding happened. Invalid values are not folded to avoid
197 /// canonicalization crashes.
198 LogicalResult
199 foldDynamicOffsetSizeList(SmallVectorImpl<OpFoldResult> &offsetsOrSizes);
200 
201 /// Returns "success" when any of the elements in `strides` is a constant
202 /// value. In that case the value is replaced by an attribute. Returns
203 /// "failure" when no folding happened. Invalid values are not folded to avoid
204 /// canonicalization crashes.
205 LogicalResult foldDynamicStrideList(SmallVectorImpl<OpFoldResult> &strides);
206 
207 /// Return the number of iterations for a loop with a lower bound `lb`, upper
208 /// bound `ub` and step `step`. The `isSigned` flag indicates whether the loop
209 /// comparison between lb and ub is signed or unsigned. A negative step or a
210 /// lower bound greater than the upper bound are considered invalid and will
211 /// yield a zero trip count.
212 /// The `computeUbMinusLb` callback is invoked to compute the difference between
213 /// the upper and lower bound when not constant. It can be used by the client
214 /// to compute a static difference when the bounds are not constant.
215 ///
216 /// For example, the following code:
217 ///
218 /// %ub = arith.addi nsw %lb, %c16_i32 : i32
219 /// %1 = scf.for %arg0 = %lb to %ub ...
220 ///
221 /// where %ub is computed as a static offset from %lb.
222 /// Note: the matched addition should be nsw/nuw (matching the loop comparison)
223 /// to avoid overflow, otherwise an overflow would imply a zero trip count.
224 std::optional<APInt> constantTripCount(
225  OpFoldResult lb, OpFoldResult ub, OpFoldResult step, bool isSigned,
226  llvm::function_ref<std::optional<llvm::APSInt>(Value, Value, bool)>
228 
229 /// Idiomatic saturated operations on values like offsets, sizes, and strides.
231  static SaturatedInteger wrap(int64_t v) {
232  return (ShapedType::isDynamic(v)) ? SaturatedInteger{true, 0}
233  : SaturatedInteger{false, v};
234  }
235  int64_t asInteger() { return saturated ? ShapedType::kDynamic : v; }
236  FailureOr<SaturatedInteger> desaturate(SaturatedInteger other) {
237  if (saturated && !other.saturated)
238  return other;
239  if (!saturated && !other.saturated && v != other.v)
240  return failure();
241  return *this;
242  }
244  return (saturated && other.saturated) ||
245  (!saturated && !other.saturated && v == other.v);
246  }
247  bool operator!=(SaturatedInteger other) { return !(*this == other); }
249  if (saturated || other.saturated)
250  return SaturatedInteger{true, 0};
251  return SaturatedInteger{false, other.v + v};
252  }
254  // Multiplication with 0 is always 0.
255  if (!other.saturated && other.v == 0)
256  return SaturatedInteger{false, 0};
257  if (!saturated && v == 0)
258  return SaturatedInteger{false, 0};
259  // Otherwise, if this or the other integer is dynamic, so is the result.
260  if (saturated || other.saturated)
261  return SaturatedInteger{true, 0};
262  return SaturatedInteger{false, other.v * v};
263  }
264  bool saturated = true;
265  int64_t v = 0;
266 };
267 
268 } // namespace mlir
269 
270 #endif // MLIR_DIALECT_UTILS_STATICVALUEUTILS_H
static std::optional< llvm::APSInt > computeUbMinusLb(Value lb, Value ub, bool isSigned)
Helper function to compute the difference between two values.
Definition: SCF.cpp:115
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
int compare(const Fraction &x, const Fraction &y)
Three-way comparison between two fractions.
Definition: Fraction.h:68
Include the generated interface declarations.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
bool areConstantIntValues(ArrayRef< OpFoldResult > ofrs, ArrayRef< int64_t > values)
Return true if all of ofrs are constant integers equal to the corresponding value in values.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)
Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult foldDynamicStrideList(SmallVectorImpl< OpFoldResult > &strides)
Returns "success" when any of the elements in strides is a constant value.
bool areAllConstantIntValue(ArrayRef< OpFoldResult > ofrs, int64_t value)
Return true if all of ofrs are constant integers equal to value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
SmallVector< IntTy > extractFromIntegerArrayAttr(Attribute attr)
Extract integer values from the assumed ArrayAttr of IntegerAttr.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
std::optional< std::pair< APInt, bool > > getConstantAPIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
bool hasValidStrides(SmallVector< int64_t > strides)
Helper function to check whether the passed in strides are valid.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
bool isEqualConstantIntOrValueArray(ArrayRef< OpFoldResult > ofrs1, ArrayRef< OpFoldResult > ofrs2)
void dispatchIndexOpFoldResult(OpFoldResult ofr, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch an OpFoldResult into staticVec if: a) it is an IntegerAttr In other cases...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
std::pair< int64_t, OpFoldResult > getSimplifiedOfrAndStaticSizePair(OpFoldResult ofr, Builder &b)
Given OpFoldResult representing dim size value (*), generates a pair of sizes:
std::optional< SmallVector< int64_t > > getConstantIntValues(ArrayRef< OpFoldResult > ofrs)
If all ofrs are constant integers or IntegerAttrs, return the integers.
SmallVector< Value > getValuesSortedByKey(ArrayRef< Attribute > keys, ArrayRef< Value > values, llvm::function_ref< bool(Attribute, Attribute)> compare)
Helper to sort values according to matching keys.
LogicalResult foldDynamicOffsetSizeList(SmallVectorImpl< OpFoldResult > &offsetsOrSizes)
Returns "success" when any of the elements in offsetsOrSizes is a constant value.
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
std::optional< APInt > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step, bool isSigned, llvm::function_ref< std::optional< llvm::APSInt >(Value, Value, bool)> computeUbMinusLb)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
Idiomatic saturated operations on values like offsets, sizes, and strides.
SaturatedInteger operator+(SaturatedInteger other)
static SaturatedInteger wrap(int64_t v)
SaturatedInteger operator*(SaturatedInteger other)
FailureOr< SaturatedInteger > desaturate(SaturatedInteger other)
bool operator!=(SaturatedInteger other)
bool operator==(SaturatedInteger other)