MLIR 22.0.0git
IndexingUtils.cpp
Go to the documentation of this file.
1//===- IndexingUtils.cpp - Helpers related to index computations ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
11#include "mlir/IR/AffineExpr.h"
12#include "mlir/IR/Builders.h"
14#include "mlir/IR/MLIRContext.h"
15#include "llvm/ADT/STLExtras.h"
16#include <numeric>
17#include <optional>
18
19using namespace mlir;
20
21template <typename ExprType>
23 ExprType unit) {
24 if (sizes.empty())
25 return {};
26 SmallVector<ExprType> strides(sizes.size(), unit);
27 for (int64_t r = static_cast<int64_t>(strides.size()) - 2; r >= 0; --r)
28 strides[r] = strides[r + 1] * sizes[r + 1];
29 return strides;
30}
31
32template <typename ExprType>
35 // Early exit if both are empty, let zip_equal fail if only 1 is empty.
36 if (v1.empty() && v2.empty())
37 return {};
39 for (auto it : llvm::zip_equal(v1, v2))
40 result.push_back(std::get<0>(it) * std::get<1>(it));
41 return result;
42}
43
44template <typename ExprType>
46 ExprType zero) {
47 assert(offsets.size() == basis.size());
48 ExprType linearIndex = zero;
49 for (unsigned idx = 0, e = basis.size(); idx < e; ++idx)
50 linearIndex = linearIndex + offsets[idx] * basis[idx];
51 return linearIndex;
52}
53
54template <typename ExprType, typename DivOpTy>
56 ArrayRef<ExprType> strides,
57 DivOpTy divOp) {
58 int64_t rank = strides.size();
59 SmallVector<ExprType> offsets(rank);
60 for (int64_t r = 0; r < rank; ++r) {
61 offsets[r] = divOp(linearIndex, strides[r]);
62 linearIndex = linearIndex % strides[r];
63 }
64 return offsets;
65}
66
67//===----------------------------------------------------------------------===//
68// Utils that operate on static integer values.
69//===----------------------------------------------------------------------===//
70
72 assert((sizes.empty() ||
73 llvm::all_of(sizes.drop_front(), [](int64_t s) { return s >= 0; })) &&
74 "sizes must be nonnegative");
75 int64_t unit = 1;
76 return ::computeSuffixProductImpl(sizes, unit);
77}
78
83
85 assert(llvm::all_of(basis, [](int64_t s) { return s > 0; }) &&
86 "basis must be nonnegative");
87 return llvm::product_of(basis);
88}
89
91 assert(llvm::all_of(basis, [](int64_t s) { return s > 0; }) &&
92 "basis must be nonnegative");
93 int64_t zero = 0;
94 return linearizeImpl(offsets, basis, zero);
95}
96
98 ArrayRef<int64_t> strides) {
99 assert(llvm::all_of(strides, [](int64_t s) { return s > 0; }) &&
100 "strides must be nonnegative");
101 return delinearizeImpl(linearIndex, strides,
102 [](int64_t e1, int64_t e2) { return e1 / e2; });
103}
104
105std::optional<SmallVector<int64_t>>
107 if (shape.size() < subShape.size())
108 return std::nullopt;
109 assert(llvm::all_of(shape, [](int64_t s) { return s > 0; }) &&
110 "shape must be nonnegative");
111 assert(llvm::all_of(subShape, [](int64_t s) { return s > 0; }) &&
112 "subShape must be nonnegative");
113
114 // Starting from the end, compute the integer divisors.
115 std::vector<int64_t> result;
116 result.reserve(shape.size());
117 for (auto [size, subSize] :
118 llvm::zip(llvm::reverse(shape), llvm::reverse(subShape))) {
119 // If integral division does not occur, return and let the caller decide.
120 if (size % subSize != 0)
121 return std::nullopt;
122 result.push_back(size / subSize);
123 }
124 // At this point we computed the ratio (in reverse) for the common size.
125 // Fill with the remaining entries from the shape (still in reverse).
126 int commonSize = subShape.size();
127 std::copy(shape.rbegin() + commonSize, shape.rend(),
128 std::back_inserter(result));
129 // Reverse again to get it back in the proper order and return.
130 return SmallVector<int64_t>{result.rbegin(), result.rend()};
131}
132
133//===----------------------------------------------------------------------===//
134// Utils that operate on AffineExpr.
135//===----------------------------------------------------------------------===//
136
138 if (sizes.empty())
139 return {};
140 AffineExpr unit = getAffineConstantExpr(1, sizes.front().getContext());
141 return ::computeSuffixProductImpl(sizes, unit);
142}
143
148
150 return llvm::sum_of(basis, getAffineConstantExpr(0, ctx));
151}
152
154 return llvm::product_of(basis, getAffineConstantExpr(1, ctx));
155}
156
158 ArrayRef<AffineExpr> basis) {
159 AffineExpr zero = getAffineConstantExpr(0, ctx);
160 return linearizeImpl(offsets, basis, zero);
161}
162
164 ArrayRef<int64_t> basis) {
165
166 return linearize(ctx, offsets, getAffineConstantExprs(basis, ctx));
167}
168
170 ArrayRef<AffineExpr> strides) {
171 return delinearizeImpl(
172 linearIndex, strides,
173 [](AffineExpr e1, AffineExpr e2) { return e1.floorDiv(e2); });
174}
175
177 ArrayRef<int64_t> strides) {
178 MLIRContext *ctx = linearIndex.getContext();
179 return delinearize(linearIndex, getAffineConstantExprs(strides, ctx));
180}
181
182//===----------------------------------------------------------------------===//
183// Permutation utils.
184//===----------------------------------------------------------------------===//
185
188 assert(llvm::all_of(permutation, [](int64_t s) { return s >= 0; }) &&
189 "permutation must be non-negative");
190 SmallVector<int64_t> inversion(permutation.size());
191 for (const auto &pos : llvm::enumerate(permutation)) {
192 inversion[pos.value()] = pos.index();
193 }
194 return inversion;
195}
196
198 for (auto i : llvm::seq<int64_t>(0, permutation.size()))
199 if (permutation[i] != i)
200 return false;
201 return true;
202}
203
205 llvm::SmallDenseSet<int64_t, 4> seenVals;
206 for (auto val : interchange) {
207 if (val < 0 || static_cast<uint64_t>(val) >= interchange.size())
208 return false;
209 if (seenVals.count(val))
210 return false;
211 seenVals.insert(val);
212 }
213 return seenVals.size() == interchange.size();
214}
215
218 ArrayRef<int64_t> desiredPositions) {
219 SmallVector<int64_t> res(permSize, -1);
221 for (auto [pos, desiredPos] : llvm::zip_equal(positions, desiredPositions)) {
222 res[desiredPos] = pos;
223 seen.insert(pos);
224 }
225 int64_t nextPos = 0;
226 for (int64_t &entry : res) {
227 if (entry != -1)
228 continue;
229 while (seen.contains(nextPos))
230 ++nextPos;
231 entry = nextPos;
232 ++nextPos;
233 }
234 return res;
235}
236
238 ArrayRef<int64_t> dropPositions) {
239 assert(inputPerm.size() >= dropPositions.size() &&
240 "expect inputPerm size large than position to drop");
242 unsigned permSize = inputPerm.size();
243 for (unsigned inputIndex = 0; inputIndex < permSize; ++inputIndex) {
244 int64_t targetIndex = inputPerm[inputIndex];
245 bool shouldDrop = false;
246 unsigned dropSize = dropPositions.size();
247 for (unsigned dropIndex = 0; dropIndex < dropSize; dropIndex++) {
248 if (dropPositions[dropIndex] == inputPerm[inputIndex]) {
249 shouldDrop = true;
250 break;
251 }
252 if (dropPositions[dropIndex] < inputPerm[inputIndex]) {
253 targetIndex--;
254 }
255 }
256 if (!shouldDrop) {
257 res.push_back(targetIndex);
258 }
259 }
260 return res;
261}
262
264 unsigned dropFront,
265 unsigned dropBack) {
266 assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
267 auto range = arrayAttr.getAsRange<IntegerAttr>();
269 res.reserve(arrayAttr.size() - dropFront - dropBack);
270 for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
271 it != eit; ++it)
272 res.push_back((*it).getValue().getSExtValue());
273 return res;
274}
275
276// TODO: do we have any common utily for this?
278 assert(val && "Invalid value");
279 if (auto attr = dyn_cast<Attribute>(val)) {
280 return attr.getContext();
281 }
282 return cast<Value>(val).getContext();
283}
284
285std::pair<AffineExpr, SmallVector<OpFoldResult>>
289 assert(strides.size() == indices.size());
290 auto sourceRank = static_cast<unsigned>(strides.size());
291
292 // Hold the affine symbols and values for the computation of the offset.
293 SmallVector<OpFoldResult> values(2 * sourceRank + 1);
294 SmallVector<AffineExpr> symbols(2 * sourceRank + 1);
295
296 bindSymbolsList(getContext(sourceOffset), MutableArrayRef{symbols});
297 AffineExpr expr = symbols.front();
298 values[0] = sourceOffset;
299
300 for (unsigned i = 0; i < sourceRank; ++i) {
301 // Compute the stride.
302 OpFoldResult origStride = strides[i];
303
304 // Build up the computation of the offset.
305 unsigned baseIdxForDim = 1 + 2 * i;
306 unsigned subOffsetForDim = baseIdxForDim;
307 unsigned origStrideForDim = baseIdxForDim + 1;
308 expr = expr + symbols[subOffsetForDim] * symbols[origStrideForDim];
309 values[subOffsetForDim] = indices[i];
310 values[origStrideForDim] = origStride;
311 }
312
313 return {expr, values};
314}
315
316std::pair<AffineExpr, SmallVector<OpFoldResult>>
319 return computeLinearIndex(
320 sourceOffset, getAsIndexOpFoldResult(sourceOffset.getContext(), strides),
322}
323
324//===----------------------------------------------------------------------===//
325// TileOffsetRange
326//===----------------------------------------------------------------------===//
327
328/// Apply left-padding by 1 to the tile shape if required.
330 unsigned paddedSize) {
331 assert(tileShape.size() <= paddedSize &&
332 "expected tileShape to <= paddedSize");
333 if (tileShape.size() == paddedSize)
334 return to_vector(tileShape);
335 SmallVector<int64_t> result(paddedSize - tileShape.size(), 1);
336 llvm::append_range(result, tileShape);
337 return result;
338}
339
342 ArrayRef<int64_t> loopOrder)
343 : tileShape(padTileShapeToSize(tileShape, shape.size())),
344 inverseLoopOrder(invertPermutationVector(loopOrder)),
345 sliceStrides(shape.size()) {
346 // Divide the shape by the tile shape.
347 std::optional<SmallVector<int64_t>> shapeRatio =
348 mlir::computeShapeRatio(shape, tileShape);
349 assert(shapeRatio && shapeRatio->size() == shape.size() &&
350 "target shape does not evenly divide the original shape");
351 assert(isPermutationVector(loopOrder) && loopOrder.size() == shape.size() &&
352 "expected loop order to be a permutation of rank equal to outer "
353 "shape");
354
355 maxLinearIndex = mlir::computeMaxLinearIndex(*shapeRatio);
356 mlir::applyPermutationToVector(*shapeRatio, loopOrder);
357 sliceStrides = mlir::computeStrides(*shapeRatio);
358}
359
361 int64_t linearIndex) const {
363 delinearize(linearIndex, sliceStrides), inverseLoopOrder);
364 return computeElementwiseMul(tileCoords, tileShape);
365}
366
369 AffineExpr linearIndex) const {
370 MLIRContext *ctx = linearIndex.getContext();
372 delinearize(linearIndex, sliceStrides), inverseLoopOrder);
373 return mlir::computeElementwiseMul(tileCoords,
374 getAffineConstantExprs(tileShape, ctx));
375}
void dropFront(int64_t arr[N], int64_t *res)
SmallVector< ExprType > delinearizeImpl(ExprType linearIndex, ArrayRef< ExprType > strides, DivOpTy divOp)
SmallVector< ExprType > computeElementwiseMulImpl(ArrayRef< ExprType > v1, ArrayRef< ExprType > v2)
SmallVector< ExprType > computeSuffixProductImpl(ArrayRef< ExprType > sizes, ExprType unit)
static SmallVector< int64_t > padTileShapeToSize(ArrayRef< int64_t > tileShape, unsigned paddedSize)
Apply left-padding by 1 to the tile shape if required.
ExprType linearizeImpl(ArrayRef< ExprType > offsets, ArrayRef< ExprType > basis, ExprType zero)
ArrayAttr()
b getContext())
Base type for affine expression.
Definition AffineExpr.h:68
AffineExpr floorDiv(uint64_t v) const
MLIRContext * getContext() const
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class represents a single result from folding an operation.
MLIRContext * getContext() const
SmallVector< int64_t > getStaticTileOffsets(int64_t linearIndex) const
TileOffsetRangeImpl(ArrayRef< int64_t > shape, ArrayRef< int64_t > tileShape, ArrayRef< int64_t > loopOrder)
SmallVector< AffineExpr > getDynamicTileOffsets(AffineExpr linearIndex) const
Include the generated interface declarations.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
SmallVector< int64_t > computeElementwiseMul(ArrayRef< int64_t > v1, ArrayRef< int64_t > v2)
Return a vector containing llvm::zip_equal(v1, v2) multiplied elementwise.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
std::pair< AffineExpr, SmallVector< OpFoldResult > > computeLinearIndex(OpFoldResult sourceOffset, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices)
Compute linear index from provided strides and indices, assuming strided layout.
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:128
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
SmallVector< int64_t > computePermutationVector(int64_t permSize, ArrayRef< int64_t > positions, ArrayRef< int64_t > desiredPositions)
Return a permutation vector of size permSize that would result in moving positions into desiredPositi...
SmallVector< int64_t > getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront=0, unsigned dropBack=0)
Helper to return a subset of arrayAttr as a vector of int64_t.
SmallVector< int64_t > computeSuffixProduct(ArrayRef< int64_t > sizes)
Given a set of sizes, return the suffix product.
int64_t computeMaxLinearIndex(ArrayRef< int64_t > basis)
Return the number of elements of basis (i.e.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
SmallVector< AffineExpr > getAffineConstantExprs(ArrayRef< int64_t > constants, MLIRContext *context)
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
int64_t computeSum(ArrayRef< int64_t > basis)
Self-explicit.
SmallVector< int64_t > dropDims(ArrayRef< int64_t > inputPerm, ArrayRef< int64_t > dropPositions)
Returns a permutation vector that drop the input dims in dropPositions from inputPerm.
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
void bindSymbolsList(MLIRContext *ctx, MutableArrayRef< AffineExprTy > exprs)
Definition AffineExpr.h:330
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.