MLIR 23.0.0git
XeGPUContiguityAnalysis.cpp
Go to the documentation of this file.
1//===- XeGPUContiguityAnalysis.cpp - Offset contiguity analysis ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the contiguity analysis. It computes, for a memory
10// operation (i.e., `xegpu.load` / `xegpu.store`), how many elements are
11// contiguous along the innermost offsets dimension, and stamps that count as an
12// attribute on the op. The analysis is a small XeGPU-local
13// axis-info dataflow tracking per-axis `contiguity`, `constancy`, and
14// `divisibility`; the stamped value is the inner-dim `contiguity`.
15//
16// Contiguity is a target-independent property of the offsets.
17//
18// The analysis tracks per-axis information for vectors of integer / index
19// type at any rank, against the innermost dimension.
20//
21// The analysis gets it's inspiration from the Triton Axis info analysis.
22//===----------------------------------------------------------------------===//
23
36#include "llvm/ADT/APInt.h"
37#include "llvm/Support/MathExtras.h"
38#include <numeric>
39#include <optional>
40
41#define DEBUG_TYPE "xegpu-contiguity-analysis"
42
43using namespace mlir;
44
45// AxisInfo and AxisInfoAnalysis are intentionally placed in a named namespace
46// (not anonymous) so the `dataflow::Lattice<AxisInfo>` template instantiation
47// gets a stable, externally-visible name. The TypeID machinery requires that
48// for `MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID` to apply.
50
51//===----------------------------------------------------------------------===//
52// AxisInfo: per-axis contiguity / constancy / divisibility lattice.
53//===----------------------------------------------------------------------===//
54
55/// Sentinel "very large" value for unconstrained dimensions. Any real shape is
56/// far smaller, so component-wise `min` collapses this to the truth. It is not
57/// `numeric_limits<int64_t>::max()` because divisibility values are multiplied
58/// (e.g. in `visitMul`); `1 << 30` keeps those products well within `int64_t`.
59static constexpr int64_t kAxisInfoTop = 1LL << 30;
60
61/// Per-dimension axis information for an SSA vector value of integer / index
62/// type. The fields describe, for each dimension `d`, the pattern of the values
63/// along `d` (examples use a 1-D vector, so `d` is the only / innermost dim):
64/// - `contiguity[d]`: longest run that increases by exactly 1.
65/// `[0, 1, 2, 3]` -> 4; `[0, 1, 0, 1]` -> 2.
66/// - `constancy[d]`: longest run of equal values.
67/// `[5, 5, 5, 5]` -> 4; `[5, 5, 6, 6]` -> 2.
68/// - `divisibility[d]`: a power-of-two divisor of every element.
69/// `[8, 16, 24]` -> 8; `[3, 6, 9]` -> 1.
70/// - `knownConstant`: the value, if the whole vector is one constant.
71/// `dense<7>` -> 7; `[0, 1, 2]` -> nullopt.
72/// - `innerStride`: if set, consecutive values along the innermost dim differ
73/// by this constant step. `1` is the contiguous case (`[0,1,2,3]`), `0` the
74/// all-equal case (`[5,5,5,5]`); any other value is a strided progression
75/// (`[0,4,8,12]` -> 4) that `contiguity`/`constancy` can't represent (both
76/// read 1). For a multi-dim vector each inner-dim slice is its own
77/// progression; their bases may differ, with the shared inner alignment in
78/// `divisibility[innerDim]`.
79///
80/// Only `contiguity[innerDim]` is consumed when stamping, but all dimensions
81/// are tracked because `vector.transpose` / `vector.shape_cast` permute or move
82/// per-dim info between axes, so an intermediate value's outer dims can become
83/// the inner dim of a later value.
84///
85/// Pessimistic / entry value: contiguity=1, constancy=1, divisibility=1,
86/// innerStride absent.
87struct AxisInfo {
91 std::optional<int64_t> knownConstant;
92 std::optional<int64_t> innerStride;
93
94 AxisInfo() = default;
95
96 static AxisInfo getPessimistic(unsigned rank) {
97 AxisInfo v;
98 v.contiguity.assign(rank, 1);
99 v.constancy.assign(rank, 1);
100 v.divisibility.assign(rank, 1);
101 return v;
102 }
103
104 unsigned getRank() const { return contiguity.size(); }
105 bool isInitialized() const { return getRank() > 0; }
106
107 bool operator==(const AxisInfo &rhs) const {
108 return contiguity == rhs.contiguity && constancy == rhs.constancy &&
109 divisibility == rhs.divisibility &&
110 knownConstant == rhs.knownConstant && innerStride == rhs.innerStride;
111 }
112
113 /// Conservative join (lattice meet). When a value can arrive from several
114 /// paths (e.g. a block argument, or `arith.select`), only what holds on
115 /// *every* path is safe to assume. So we keep the weaker fact per field:
116 /// `min` of each run length (a run is only guaranteed as long as the shortest
117 /// incoming one), `gcd` of divisibility, and a value/stride only when both
118 /// sides agree. This is what makes the contiguity we later stamp sound rather
119 /// than "undecidable" — it is the largest run guaranteed on all paths.
120 static AxisInfo join(const AxisInfo &lhs, const AxisInfo &rhs) {
121 if (!lhs.isInitialized())
122 return rhs;
123 if (!rhs.isInitialized())
124 return lhs;
125 assert(lhs.getRank() == rhs.getRank());
126 AxisInfo out;
127 unsigned r = lhs.getRank();
128 out.contiguity.resize(r);
129 out.constancy.resize(r);
130 out.divisibility.resize(r);
131 for (unsigned d = 0; d < r; ++d) {
132 out.contiguity[d] = std::min(lhs.contiguity[d], rhs.contiguity[d]);
133 out.constancy[d] = std::min(lhs.constancy[d], rhs.constancy[d]);
134 out.divisibility[d] = std::gcd(lhs.divisibility[d], rhs.divisibility[d]);
135 }
136 if (lhs.knownConstant && rhs.knownConstant &&
137 *lhs.knownConstant == *rhs.knownConstant)
138 out.knownConstant = lhs.knownConstant;
139 if (lhs.innerStride && rhs.innerStride &&
140 *lhs.innerStride == *rhs.innerStride)
141 out.innerStride = lhs.innerStride;
142 return out;
143 }
144
145 void print(raw_ostream &os) const {
146 os << "contiguity=[";
147 llvm::interleaveComma(contiguity, os);
148 os << "] constancy=[";
149 llvm::interleaveComma(constancy, os);
150 os << "] divisibility=[";
151 llvm::interleaveComma(divisibility, os);
152 os << "]";
153 if (knownConstant)
154 os << " const=" << *knownConstant;
155 if (innerStride)
156 os << " innerStride=" << *innerStride;
157 }
158};
159
161
162/// Power-of-two divisor of `v`. Returns `kAxisInfoTop` when `v == 0`.
164 if (v == 0)
165 return kAxisInfoTop;
166 uint64_t u = static_cast<uint64_t>(std::abs(v));
167 return static_cast<int64_t>(u & (~u + 1));
168}
169
170/// Initial lattice value for an SSA value when no transfer function applies.
172 if (auto vt = dyn_cast<VectorType>(v.getType()))
173 return AxisInfo::getPessimistic(vt.getRank());
174 return AxisInfo::getPessimistic(1);
175}
176
177/// AxisInfo for a tensor that is constant `c` everywhere.
179 AxisInfo v;
180 unsigned r = shape.size();
181 v.contiguity.assign(r, 1);
182 v.constancy.assign(shape.begin(), shape.end());
183 v.divisibility.assign(r, highestPow2Divisor(c));
184 v.knownConstant = c;
185 v.innerStride = 0;
186 return v;
187}
188
189/// Sparse forward dataflow analysis that computes `AxisInfo` for vector
190/// values reachable from the entry of the analyzed op.
192 : public dataflow::SparseForwardDataFlowAnalysis<AxisInfoLattice> {
193public:
195 using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis;
196
197 LogicalResult visitOperation(Operation *op,
199 ArrayRef<AxisInfoLattice *> results) override {
200 if (auto step = dyn_cast<vector::StepOp>(op))
201 return visitStep(step, results);
202 if (auto cst = dyn_cast<arith::ConstantOp>(op))
203 return visitConstant(cst, results);
204 if (auto bcast = dyn_cast<vector::BroadcastOp>(op))
205 return visitBroadcast(bcast, operands, results);
206 if (auto sc = dyn_cast<vector::ShapeCastOp>(op))
207 return visitShapeCast(sc, operands, results);
208 if (auto tp = dyn_cast<vector::TransposeOp>(op))
209 return visitTranspose(tp, operands, results);
210 if (auto add = dyn_cast<arith::AddIOp>(op))
211 return visitAddSub</*IsSub=*/false>(add, operands, results);
212 if (auto sub = dyn_cast<arith::SubIOp>(op))
213 return visitAddSub</*IsSub=*/true>(sub, operands, results);
214 if (auto mul = dyn_cast<arith::MulIOp>(op))
215 return visitMul(mul, operands, results);
216 if (auto div = dyn_cast<arith::DivUIOp>(op))
217 return visitDivRem</*IsSigned=*/false, /*IsRem=*/false>(div, operands,
218 results);
219 if (auto div = dyn_cast<arith::DivSIOp>(op))
220 return visitDivRem</*IsSigned=*/true, /*IsRem=*/false>(div, operands,
221 results);
222 if (auto rem = dyn_cast<arith::RemUIOp>(op))
223 return visitDivRem</*IsSigned=*/false, /*IsRem=*/true>(rem, operands,
224 results);
225 if (auto rem = dyn_cast<arith::RemSIOp>(op))
226 return visitDivRem</*IsSigned=*/true, /*IsRem=*/true>(rem, operands,
227 results);
228 if (auto andi = dyn_cast<arith::AndIOp>(op))
229 return visitAndI(andi, operands, results);
230 if (auto shl = dyn_cast<arith::ShLIOp>(op))
231 return visitShift</*IsLeft=*/true>(shl, operands, results);
232 if (auto shr = dyn_cast<arith::ShRUIOp>(op))
233 return visitShift</*IsLeft=*/false>(shr, operands, results);
234 if (auto sel = dyn_cast<arith::SelectOp>(op))
235 return visitSelect(sel, operands, results);
236 if (auto cast = dyn_cast<arith::IndexCastOp>(op))
237 return visitPassThrough(cast, operands, results);
238 if (auto cast = dyn_cast<arith::IndexCastUIOp>(op))
239 return visitPassThrough(cast, operands, results);
240 setAllPessimistic(op, results);
241 return success();
242 }
243
244 void setToEntryState(AxisInfoLattice *lattice) override {
245 propagateIfChanged(lattice,
246 lattice->join(entryStateFor(lattice->getAnchor())));
247 }
248
249private:
250 void setAllPessimistic(Operation *op, ArrayRef<AxisInfoLattice *> results) {
251 for (auto [r, lat] : llvm::zip(op->getResults(), results)) {
252 AxisInfo state = entryStateFor(r);
253 propagateIfChanged(lat, lat->join(state));
254 }
255 }
256
257 // vector.step is always 1-D and produces [0, 1, ..., n-1].
258 LogicalResult visitStep(vector::StepOp op,
260 auto vt = cast<VectorType>(op.getType());
261 int64_t n = vt.getNumElements();
262 AxisInfo v;
263 v.contiguity = {n};
264 v.constancy = {1};
265 v.divisibility = {kAxisInfoTop};
266 v.innerStride = 1;
267 propagateIfChanged(results[0], results[0]->join(v));
268 return success();
269 }
270
271 // arith.constant. The four cases below, by example:
272 // - scalar int `arith.constant 8 : index`
273 // - non-int scalar `arith.constant 1.0 : f32` (pessimistic)
274 // - splat vector `arith.constant dense<5> : vector<16xindex>`
275 // - dense vector `arith.constant dense<[0,1,2,3]> : vector<4xindex>`
276 LogicalResult visitConstant(arith::ConstantOp op,
277 ArrayRef<AxisInfoLattice *> results) {
278 auto vt = dyn_cast<VectorType>(op.getType());
279 if (!vt) {
280 // Scalar integer, e.g. `arith.constant 8 : index`: a single known value,
281 // contiguity/constancy 1, divisibility from the value (8 -> 8).
282 if (auto intAttr = dyn_cast<IntegerAttr>(op.getValue())) {
283 int64_t c = intAttr.getValue().getSExtValue();
284 AxisInfo v;
285 v.contiguity = {1};
286 v.constancy = {1};
287 v.divisibility = {highestPow2Divisor(c)};
288 v.knownConstant = c;
289 propagateIfChanged(results[0], results[0]->join(v));
290 return success();
291 }
292 // Non-integer scalar, e.g. `arith.constant 1.0 : f32`: nothing to track.
293 setAllPessimistic(op, results);
294 return success();
295 }
296 auto dense = dyn_cast<DenseIntElementsAttr>(op.getValue());
297 if (!dense) {
298 setAllPessimistic(op, results);
299 return success();
300 }
301 auto shape = vt.getShape();
302 // Splat, e.g. `arith.constant dense<5> : vector<16xindex>`: every element
303 // equal, so constancy = full extent, innerStride 0.
304 if (dense.isSplat()) {
305 int64_t c = dense.getSplatValue<APInt>().getSExtValue();
306 AxisInfo v = splatAxisInfo(shape, c);
307 propagateIfChanged(results[0], results[0]->join(v));
308 return success();
309 }
310
311 // General dense vector, e.g. `arith.constant dense<[0,1,2,3]> :
312 // vector<4xindex>`. Compute innermost-dim contiguity / constancy /
313 // base-divisibility by iterating the dense values along the inner stride
314 // (here stride 1 -> contiguity 4). Outer dims report pessimistic (1)
315 // unless they collapse trivially below.
316 unsigned r = shape.size();
317 int64_t inner = shape.back();
318 int64_t outer = vt.getNumElements() / inner;
319 if (inner < 2 || outer < 1) {
320 // Can't meaningfully analyze a 0/1-element inner dim; fall back to
321 // splat handling already covered, otherwise pessimistic.
322 AxisInfo v = AxisInfo::getPessimistic(r);
323 // For a 1-element inner dim the inner-dim contiguity/constancy is
324 // trivially 1 (already pessimistic).
325 propagateIfChanged(results[0], results[0]->join(v));
326 return success();
327 }
328 auto values = llvm::to_vector(dense.getValues<APInt>());
329 int64_t innerCont = inner;
330 int64_t innerConst = inner;
331 int64_t innerStride = values[1].getSExtValue() - values[0].getSExtValue();
332 int64_t base = values[0].getSExtValue();
333 int64_t baseDiv = highestPow2Divisor(base);
334 for (int64_t o = 0; o < outer; ++o) {
335 int64_t origin = values[o * inner].getSExtValue();
336 baseDiv = std::gcd(baseDiv, highestPow2Divisor(origin));
337 for (int64_t i = 1; i < inner; ++i) {
338 int64_t cur = values[o * inner + i].getSExtValue();
339 int64_t prev = values[o * inner + i - 1].getSExtValue();
340 int64_t diff = cur - prev;
341 if (diff != innerStride)
342 innerStride = std::numeric_limits<int64_t>::min(); // not AP
343 if (diff != 1)
344 innerCont = std::min<int64_t>(innerCont, i);
345 if (diff != 0)
346 innerConst = std::min<int64_t>(innerConst, i);
347 }
348 }
349 AxisInfo v = AxisInfo::getPessimistic(r);
350 if (innerStride == 1)
351 v.contiguity[r - 1] = innerCont;
352 else if (innerStride == 0)
353 v.constancy[r - 1] = innerConst;
354 // For a non-AP inner dim, leave at pessimistic.
355 v.divisibility[r - 1] = baseDiv;
356 if (innerStride != std::numeric_limits<int64_t>::min())
357 v.innerStride = innerStride;
358 propagateIfChanged(results[0], results[0]->join(v));
359 return success();
360 }
361
362 // vector.broadcast: source lattice extends to the broadcast dims with
363 // constancy = full extent on those dims. The trailing dims of the source
364 // (if any) align with the trailing dims of the result.
365 LogicalResult visitBroadcast(vector::BroadcastOp op,
366 ArrayRef<const AxisInfoLattice *> operands,
367 ArrayRef<AxisInfoLattice *> results) {
368 auto resTy = dyn_cast<VectorType>(op.getType());
369 if (!resTy) {
370 setAllPessimistic(op, results);
371 return success();
372 }
373 unsigned rRank = resTy.getRank();
374 AxisInfo src = operands[0]->getValue();
375 AxisInfo v = AxisInfo::getPessimistic(rRank);
376 auto resShape = resTy.getShape();
377 auto srcVt = dyn_cast<VectorType>(op.getSource().getType());
378 unsigned sRank = srcVt ? srcVt.getRank() : 0;
379 // Broadcast aligns trailing dims of the source with trailing dims of
380 // the result. Leading dims that are 1 in source (or absent) are filled
381 // with constancy = result extent.
382 for (unsigned d = 0; d < rRank; ++d) {
383 int64_t resExt = resShape[d];
384 // Index in source aligned with result dim d, or -1 if d is a
385 // broadcast (front-padded) dim.
386 int sIdx = static_cast<int>(d) - static_cast<int>(rRank - sRank);
387 if (sIdx < 0) {
388 v.constancy[d] = resExt;
389 v.contiguity[d] = 1;
390 v.divisibility[d] = src.isInitialized() ? src.divisibility.front() : 1;
391 continue;
392 }
393 int64_t srcExt = srcVt.getShape()[sIdx];
394 if (srcExt == 1 && resExt > 1) {
395 v.constancy[d] = resExt;
396 v.contiguity[d] = 1;
397 v.divisibility[d] = src.isInitialized() ? src.divisibility[sIdx] : 1;
398 } else if (src.isInitialized()) {
399 v.contiguity[d] = src.contiguity[sIdx];
400 v.constancy[d] = src.constancy[sIdx];
401 v.divisibility[d] = src.divisibility[sIdx];
402 }
403 }
404 if (src.knownConstant)
405 v.knownConstant = src.knownConstant;
406 // A broadcast that fans out a scalar / leading-1 source has the broadcast
407 // dim repeating its value -> inner stride 0. Otherwise, the trailing
408 // source dim's stride is preserved when its extent matches the result.
409 auto resShapeArr = resTy.getShape();
410 int64_t innerExt = resShapeArr.back();
411 int sIdxInner =
412 static_cast<int>(rRank - 1) - static_cast<int>(rRank - sRank);
413 if (sIdxInner < 0) {
414 v.innerStride = 0;
415 } else if (srcVt) {
416 int64_t srcInner = srcVt.getShape()[sIdxInner];
417 if (srcInner == 1 && innerExt > 1)
418 v.innerStride = 0;
419 else if (srcInner == innerExt && src.innerStride)
420 v.innerStride = src.innerStride;
421 }
422 propagateIfChanged(results[0], results[0]->join(v));
423 return success();
424 }
425
426 // vector.shape_cast: handle two cases.
427 // (a) Identity-like: shapes match after stripping leading-1 dims —
428 // rebind per-dim info to the new dim positions.
429 // (b) General reshape with the same total element count and row-major
430 // linearization — propagate the source's innermost-dim info (inner
431 // contiguity / constancy) to the destination's innermost dim,
432 // capped by the inner extent. Outer dims stay pessimistic.
433 LogicalResult visitShapeCast(vector::ShapeCastOp op,
434 ArrayRef<const AxisInfoLattice *> operands,
435 ArrayRef<AxisInfoLattice *> results) {
436 auto srcTy = dyn_cast<VectorType>(op.getSource().getType());
437 auto dstTy = dyn_cast<VectorType>(op.getType());
438 if (!srcTy || !dstTy) {
439 setAllPessimistic(op, results);
440 return success();
441 }
442 AxisInfo src = operands[0]->getValue();
443 if (!src.isInitialized()) {
444 setAllPessimistic(op, results);
445 return success();
446 }
447 // Strip leading 1-dims on both sides; if remaining shapes match, this
448 // is an identity-like reshape.
449 auto stripLeading = [](ArrayRef<int64_t> s) {
450 unsigned i = 0;
451 while (i < s.size() && s[i] == 1)
452 ++i;
453 return s.drop_front(i);
454 };
455 auto sCore = stripLeading(srcTy.getShape());
456 auto dCore = stripLeading(dstTy.getShape());
457 unsigned dRank = dstTy.getRank();
458 AxisInfo v = AxisInfo::getPessimistic(dRank);
459 if (sCore == dCore) {
460 unsigned sLead = srcTy.getRank() - sCore.size();
461 unsigned dLead = dRank - dCore.size();
462 for (unsigned d = dLead; d < dRank; ++d) {
463 unsigned sIdx = sLead + (d - dLead);
464 v.contiguity[d] = src.contiguity[sIdx];
465 v.constancy[d] = src.constancy[sIdx];
466 v.divisibility[d] = src.divisibility[sIdx];
467 }
468 } else {
469 // General linear reshape. Propagate source's inner-dim contiguity /
470 // constancy to dst's inner dim, capped by inner extent. Treat inner
471 // info conservatively as the min across all source dims (so a 1-D
472 // source with full contig => inner-dim contig on dst; an N-D source
473 // collapsed to 1-D inherits the inner-dim info).
474 int64_t innerExt = dstTy.getShape().back();
475 int64_t srcContig = std::numeric_limits<int64_t>::max();
476 int64_t srcConst = std::numeric_limits<int64_t>::max();
477 int64_t srcDiv = src.divisibility[src.getRank() - 1];
478 for (unsigned d = 0; d < src.getRank(); ++d) {
479 srcContig = std::min(srcContig, src.contiguity[d]);
480 srcConst = std::min(srcConst, src.constancy[d]);
481 }
482 v.contiguity[dRank - 1] = std::min<int64_t>(srcContig, innerExt);
483 v.constancy[dRank - 1] = std::min<int64_t>(srcConst, innerExt);
484 v.divisibility[dRank - 1] = srcDiv;
485 }
486 if (src.knownConstant)
487 v.knownConstant = src.knownConstant;
488 // Identity-like and general row-major reshape both preserve the source
489 // inner-stride property when the source has a single AP characterization.
490 if (src.innerStride)
491 v.innerStride = src.innerStride;
492 propagateIfChanged(results[0], results[0]->join(v));
493 return success();
494 }
495
496 // vector.transpose: permute per-dim contiguity / constancy / divisibility
497 // according to the transpose permutation. permutation[i] is the source
498 // dim that ends up at result dim i.
499 LogicalResult visitTranspose(vector::TransposeOp op,
500 ArrayRef<const AxisInfoLattice *> operands,
501 ArrayRef<AxisInfoLattice *> results) {
502 auto resTy = dyn_cast<VectorType>(op.getType());
503 if (!resTy) {
504 setAllPessimistic(op, results);
505 return success();
506 }
507 AxisInfo src = operands[0]->getValue();
508 if (!src.isInitialized()) {
509 setAllPessimistic(op, results);
510 return success();
511 }
512 ArrayRef<int64_t> perm = op.getPermutation();
513 unsigned r = resTy.getRank();
514 AxisInfo v = AxisInfo::getPessimistic(r);
515 for (unsigned d = 0; d < r; ++d) {
516 unsigned s = static_cast<unsigned>(perm[d]);
517 v.contiguity[d] = src.contiguity[s];
518 v.constancy[d] = src.constancy[s];
519 v.divisibility[d] = src.divisibility[s];
520 }
521 if (src.knownConstant)
522 v.knownConstant = src.knownConstant;
523 // innerStride only survives when the new inner dim came from the old
524 // inner dim (otherwise a different axis is now the contiguous one).
525 if (src.innerStride && perm.back() == src.getRank() - 1)
526 v.innerStride = src.innerStride;
527 propagateIfChanged(results[0], results[0]->join(v));
528 return success();
529 }
530
531 template <bool IsSub, typename OpTy>
532 LogicalResult visitAddSub(OpTy op, ArrayRef<const AxisInfoLattice *> operands,
533 ArrayRef<AxisInfoLattice *> results) {
534 auto vt = dyn_cast<VectorType>(op.getType());
535 if (!vt) {
536 setAllPessimistic(op, results);
537 return success();
538 }
539 AxisInfo lhs = operands[0]->getValue();
540 AxisInfo rhs = operands[1]->getValue();
541 if (!lhs.isInitialized() || !rhs.isInitialized()) {
542 setAllPessimistic(op, results);
543 return success();
544 }
545 unsigned r = vt.getRank();
546 AxisInfo v = AxisInfo::getPessimistic(r);
547 for (unsigned d = 0; d < r; ++d) {
548 int64_t lhsCont = lhs.contiguity[d];
549 int64_t rhsCont = rhs.contiguity[d];
550 int64_t lhsConst = lhs.constancy[d];
551 int64_t rhsConst = rhs.constancy[d];
552 // contiguity propagates through add when one side is constant on the
553 // run, and through sub only when the rhs is constant on the run.
554 int64_t cont = IsSub ? std::min(lhsCont, rhsConst)
555 : std::max(std::min(lhsCont, rhsConst),
556 std::min(rhsCont, lhsConst));
557 v.contiguity[d] = std::max<int64_t>(1, cont);
558 v.constancy[d] = std::min(lhsConst, rhsConst);
559 v.divisibility[d] = std::gcd(lhs.divisibility[d], rhs.divisibility[d]);
560 }
561 // x + uniform-c: stride preserved. x - uniform-c: same. uniform-c - x:
562 // stride flips sign (only useful for the "stride 0" case, which it
563 // preserves trivially).
564 auto isUniform = [&](const AxisInfo &a) {
565 unsigned inner = vt.getRank() - 1;
566 return a.constancy[inner] >= vt.getShape()[inner];
567 };
568 if (lhs.innerStride && isUniform(rhs)) {
569 v.innerStride = *lhs.innerStride;
570 } else if (rhs.innerStride && isUniform(lhs)) {
571 v.innerStride = IsSub ? -*rhs.innerStride : *rhs.innerStride;
572 }
573 propagateIfChanged(results[0], results[0]->join(v));
574 return success();
575 }
576
577 LogicalResult visitMul(arith::MulIOp op,
578 ArrayRef<const AxisInfoLattice *> operands,
579 ArrayRef<AxisInfoLattice *> results) {
580 auto vt = dyn_cast<VectorType>(op.getType());
581 if (!vt) {
582 setAllPessimistic(op, results);
583 return success();
584 }
585 AxisInfo lhs = operands[0]->getValue();
586 AxisInfo rhs = operands[1]->getValue();
587 if (!lhs.isInitialized() || !rhs.isInitialized()) {
588 setAllPessimistic(op, results);
589 return success();
590 }
591 unsigned r = vt.getRank();
592 auto shape = vt.getShape();
593 AxisInfo v = AxisInfo::getPessimistic(r);
594 auto unitConstant = [](const AxisInfo &a, unsigned d, int64_t extent) {
595 return a.knownConstant && *a.knownConstant == 1 &&
596 a.constancy[d] >= extent;
597 };
598 for (unsigned d = 0; d < r; ++d) {
599 v.constancy[d] = std::min({shape[d], lhs.constancy[d], rhs.constancy[d]});
600 v.divisibility[d] = std::min<int64_t>(
601 kAxisInfoTop, lhs.divisibility[d] * rhs.divisibility[d]);
602 // Multiplying by uniform `s` only keeps contiguity when `s == 1`.
603 if (unitConstant(lhs, d, shape[d]))
604 v.contiguity[d] = std::min(rhs.contiguity[d], shape[d]);
605 else if (unitConstant(rhs, d, shape[d]))
606 v.contiguity[d] = std::min(lhs.contiguity[d], shape[d]);
607 else
608 v.contiguity[d] = 1;
609 }
610 // x * uniform-c: stride scales by c. (Both operands uniform => 0.)
611 unsigned inner = vt.getRank() - 1;
612 auto isUniformInner = [&](const AxisInfo &a) {
613 return a.constancy[inner] >= shape[inner];
614 };
615 if (lhs.innerStride && isUniformInner(rhs) && rhs.knownConstant) {
616 v.innerStride = *lhs.innerStride * *rhs.knownConstant;
617 } else if (rhs.innerStride && isUniformInner(lhs) && lhs.knownConstant) {
618 v.innerStride = *rhs.innerStride * *lhs.knownConstant;
619 }
620 propagateIfChanged(results[0], results[0]->join(v));
621 return success();
622 }
623
624 // arith.divui / arith.divsi / arith.remui / arith.remsi by a uniform
625 // positive constant `c`. The lhs must be an arithmetic progression (AP)
626 // along the inner dim, i.e. its values step by a constant stride `s`, and
627 // `c` must divide `s`.
628 //
629 // Take the inner row `[0, 2, 4, 6, 8, 10, 12, 14]` (stride s = 2) and c = 2:
630 // - Division `/ 2` gives `[0, 1, 2, 3, 4, 5, 6, 7]`: a new AP with stride
631 // `s / c = 1`. A resulting stride of 1 is contiguous, 0 is constant.
632 // - Remainder `% 2` gives `[0, 0, 0, 0, 0, 0, 0, 0]`: every element folds
633 // to the same residue, so the row is constant (stride 0).
634 //
635 // We require positive `c`, so signed and unsigned behave the same.
636 template <bool IsSigned, bool IsRem, typename OpTy>
637 LogicalResult visitDivRem(OpTy op, ArrayRef<const AxisInfoLattice *> operands,
638 ArrayRef<AxisInfoLattice *> results) {
639 auto vt = dyn_cast<VectorType>(op.getType());
640 if (!vt) {
641 setAllPessimistic(op, results);
642 return success();
643 }
644 AxisInfo lhs = operands[0]->getValue();
645 AxisInfo rhs = operands[1]->getValue();
646 if (!lhs.isInitialized() || !rhs.isInitialized()) {
647 setAllPessimistic(op, results);
648 return success();
649 }
650 unsigned r = vt.getRank();
651 unsigned inner = r - 1;
652 auto shape = vt.getShape();
653 AxisInfo v = AxisInfo::getPessimistic(r);
654
655 bool rhsUniform = rhs.constancy[inner] >= shape[inner] && rhs.knownConstant;
656 if (!rhsUniform || *rhs.knownConstant <= 0) {
657 propagateIfChanged(results[0], results[0]->join(v));
658 return success();
659 }
660 int64_t c = *rhs.knownConstant;
661
662 if (!lhs.innerStride) {
663 propagateIfChanged(results[0], results[0]->join(v));
664 return success();
665 }
666 int64_t s = *lhs.innerStride;
667 int64_t baseDivLhs = lhs.divisibility[inner];
668 if (s % c != 0) {
669 propagateIfChanged(results[0], results[0]->join(v));
670 return success();
671 }
672
673 if (IsRem) {
674 // (base + i*s) mod c, with c | s, is the constant base mod c.
675 v.innerStride = 0;
676 v.constancy[inner] = shape[inner];
677 // The remainder is in [0, c-1], so any power-of-two divisor of c is a
678 // lower bound on alignment. Use lhs's existing divisibility too.
679 v.divisibility[inner] = std::gcd(baseDivLhs, highestPow2Divisor(c));
680 } else {
681 if (baseDivLhs % c != 0) {
682 propagateIfChanged(results[0], results[0]->join(v));
683 return success();
684 }
685 int64_t newStride = s / c;
686 v.innerStride = newStride;
687 if (newStride == 1)
688 v.contiguity[inner] = shape[inner];
689 else if (newStride == 0)
690 v.constancy[inner] = shape[inner];
691 v.divisibility[inner] = baseDivLhs / c;
692 }
693 propagateIfChanged(results[0], results[0]->join(v));
694 return success();
695 }
696
697 // arith.andi: `x & m` with a uniform positive constant mask `m`. The
698 // interesting case is `m = P - 1` for a power of 2 `P`, which is the same
699 // as `x % P` (see visitDivRem): masking an inner row whose stride is a
700 // multiple of `P` folds it to a constant.
701 //
702 // Take the row `[0, 2, 4, 6, 8, 10, 12, 14]` (stride 2) and m = 1 (P = 2):
703 // `x & 1` gives `[0, 0, 0, 0, 0, 0, 0, 0]`: constant along the inner dim.
704 //
705 // Also handles the trivial masks `m == 0` (always zero) and all-ones
706 // (identity).
707 LogicalResult visitAndI(arith::AndIOp op,
708 ArrayRef<const AxisInfoLattice *> operands,
709 ArrayRef<AxisInfoLattice *> results) {
710 auto vt = dyn_cast<VectorType>(op.getType());
711 if (!vt) {
712 setAllPessimistic(op, results);
713 return success();
714 }
715 AxisInfo lhs = operands[0]->getValue();
716 AxisInfo rhs = operands[1]->getValue();
717 if (!lhs.isInitialized() || !rhs.isInitialized()) {
718 setAllPessimistic(op, results);
719 return success();
720 }
721 unsigned r = vt.getRank();
722 unsigned inner = r - 1;
723 auto shape = vt.getShape();
724 AxisInfo v = AxisInfo::getPessimistic(r);
725
726 // Look for a uniform constant mask on either side.
727 auto getUniformMask = [&](const AxisInfo &a) -> std::optional<int64_t> {
728 if (a.constancy[inner] >= shape[inner] && a.knownConstant)
729 return a.knownConstant;
730 return std::nullopt;
731 };
732 std::optional<int64_t> mLhs = getUniformMask(lhs);
733 std::optional<int64_t> mRhs = getUniformMask(rhs);
734 if (!mLhs && !mRhs) {
735 propagateIfChanged(results[0], results[0]->join(v));
736 return success();
737 }
738 const AxisInfo &x = mLhs ? rhs : lhs;
739 int64_t m = mLhs ? *mLhs : *mRhs;
740
741 if (m == 0) {
742 v.knownConstant = 0;
743 v.innerStride = 0;
744 v.constancy[inner] = shape[inner];
745 v.divisibility[inner] = kAxisInfoTop;
746 propagateIfChanged(results[0], results[0]->join(v));
747 return success();
748 }
749
750 // `m == P - 1` with P a power of 2 -> equivalent to `x mod P`.
751 if (m > 0 && llvm::isPowerOf2_64(static_cast<uint64_t>(m + 1))) {
752 int64_t P = m + 1;
753 if (x.innerStride && *x.innerStride % P == 0) {
754 v.innerStride = 0;
755 v.constancy[inner] = shape[inner];
756 v.divisibility[inner] =
757 std::gcd(x.divisibility[inner], highestPow2Divisor(P));
758 propagateIfChanged(results[0], results[0]->join(v));
759 return success();
760 }
761 }
762 // Conservative fallback for unrecognized masks.
763 propagateIfChanged(results[0], results[0]->join(v));
764 return success();
765 }
766
767 // arith.shli (left shift) / arith.shrui (logical right shift) by a
768 // uniform constant `k`. These are `* (1 << k)` and `/ (1 << k)`
769 // (truncating, but for non-negative values the trunc is exact when
770 // `(1 << k)` divides the value). We model them by reducing to mul/divui.
771 template <bool IsLeft, typename OpTy>
772 LogicalResult visitShift(OpTy op, ArrayRef<const AxisInfoLattice *> operands,
773 ArrayRef<AxisInfoLattice *> results) {
774 auto vt = dyn_cast<VectorType>(op.getType());
775 if (!vt) {
776 setAllPessimistic(op, results);
777 return success();
778 }
779 AxisInfo lhs = operands[0]->getValue();
780 AxisInfo rhs = operands[1]->getValue();
781 if (!lhs.isInitialized() || !rhs.isInitialized()) {
782 setAllPessimistic(op, results);
783 return success();
784 }
785 unsigned r = vt.getRank();
786 unsigned inner = r - 1;
787 auto shape = vt.getShape();
788 AxisInfo v = AxisInfo::getPessimistic(r);
789
790 if (rhs.constancy[inner] < shape[inner] || !rhs.knownConstant) {
791 propagateIfChanged(results[0], results[0]->join(v));
792 return success();
793 }
794 int64_t k = *rhs.knownConstant;
795 if (k < 0 || k >= 63) {
796 propagateIfChanged(results[0], results[0]->join(v));
797 return success();
798 }
799 int64_t factor = 1LL << k;
800
801 if (IsLeft) {
802 // x << k == x * factor.
803 if (lhs.innerStride) {
804 v.innerStride = *lhs.innerStride * factor;
805 if (*v.innerStride == 1)
806 v.contiguity[inner] = shape[inner];
807 else if (*v.innerStride == 0)
808 v.constancy[inner] = shape[inner];
809 }
810 v.divisibility[inner] =
811 std::min<int64_t>(kAxisInfoTop, lhs.divisibility[inner] * factor);
812 } else {
813 // x >> k == x / factor (for non-negative x); same conditions as divui.
814 if (lhs.innerStride && *lhs.innerStride % factor == 0 &&
815 lhs.divisibility[inner] % factor == 0) {
816 int64_t newStride = *lhs.innerStride / factor;
817 v.innerStride = newStride;
818 if (newStride == 1)
819 v.contiguity[inner] = shape[inner];
820 else if (newStride == 0)
821 v.constancy[inner] = shape[inner];
822 v.divisibility[inner] = lhs.divisibility[inner] / factor;
823 }
824 }
825 propagateIfChanged(results[0], results[0]->join(v));
826 return success();
827 }
828
829 // arith.select: result is at least as constrained as the meet of the two
830 // arms. We propagate fields where both arms agree.
831 LogicalResult visitSelect(arith::SelectOp op,
832 ArrayRef<const AxisInfoLattice *> operands,
833 ArrayRef<AxisInfoLattice *> results) {
834 auto vt = dyn_cast<VectorType>(op.getType());
835 if (!vt) {
836 setAllPessimistic(op, results);
837 return success();
838 }
839 // operands: [cond, true, false]
840 AxisInfo t = operands[1]->getValue();
841 AxisInfo f = operands[2]->getValue();
842 if (!t.isInitialized() || !f.isInitialized()) {
843 setAllPessimistic(op, results);
844 return success();
845 }
846 AxisInfo v = AxisInfo::join(t, f);
847 propagateIfChanged(results[0], results[0]->join(v));
848 return success();
849 }
850
851 template <typename OpTy>
852 LogicalResult visitPassThrough(OpTy op,
853 ArrayRef<const AxisInfoLattice *> operands,
854 ArrayRef<AxisInfoLattice *> results) {
855 if (!isa<VectorType>(op.getType())) {
856 setAllPessimistic(op, results);
857 return success();
858 }
859 propagateIfChanged(results[0], results[0]->join(operands[0]->getValue()));
860 return success();
861 }
862};
863
864} // namespace mlir::xegpu::detail::axis_dataflow
865
866namespace {
867
870
871//===----------------------------------------------------------------------===//
872// Analysis driver.
873//===----------------------------------------------------------------------===//
874
875/// Stamp a `contiguity` attribute on `op` recording the inner-dim contiguity
876/// computed by the analysis. The contiguity is a target-independent property
877/// of the offsets.
878template <typename OpTy>
879static void analyzeAndStampContiguity(OpTy op, DataFlowSolver &solver) {
880 auto offsetsTy = dyn_cast<VectorType>(op.getOffsets().getType());
881 if (!offsetsTy || offsetsTy.getNumElements() <= 1)
882 return;
883 // A pre-existing `contiguity` (user-authored, or stamped by an earlier run)
884 // takes precedence; leave it untouched so the analysis is idempotent.
885 if (op.getContiguity())
886 return;
887 const auto *lat = solver.lookupState<AxisInfoLattice>(op.getOffsets());
888 if (!lat || !lat->getValue().isInitialized())
889 return;
890 const AxisInfo &info = lat->getValue();
891 unsigned innerDim = offsetsTy.getRank() - 1;
892 int64_t inner = offsetsTy.getShape()[innerDim];
893 // The attribute records a contiguity that tiles the inner dim, so it must
894 // divide it (verified on the op). Round the measured run length down to the
895 // largest divisor of `inner` that does not exceed it.
896 int64_t contiguity = std::min<int64_t>(info.contiguity[innerDim], inner);
897 while (contiguity >= 2 && inner % contiguity != 0)
898 --contiguity;
899 if (contiguity < 2)
900 return;
901 op.setContiguity(contiguity);
902}
903
904} // namespace
905
906//===----------------------------------------------------------------------===//
907// Public API.
908//===----------------------------------------------------------------------===//
909
911 DataFlowSolver solver;
914 if (failed(solver.initializeAndRun(root)))
915 return;
916
917 // The solver computed AxisInfo for the whole region in the single
918 // `initializeAndRun` above; offsets shared by several gather/scatter ops are
919 // analyzed only once. This walk is just per-op point lookups into that
920 // result (no re-analysis), turning each cached fact into an attribute.
921 root->walk([&](Operation *op) {
922 if (auto load = dyn_cast<xegpu::LoadGatherOp>(op))
923 analyzeAndStampContiguity(load, solver);
924 else if (auto store = dyn_cast<xegpu::StoreScatterOp>(op))
925 analyzeAndStampContiguity(store, solver);
926 });
927}
return success()
lhs
auto load
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition TypeID.h:331
#define mul(a, b)
#define add(a, b)
#define div(a, b)
#define rem(a, b)
void propagateIfChanged(AnalysisState *state, ChangeResult changed)
Propagate an update to a state if it changed.
The general data-flow analysis solver.
LogicalResult initializeAndRun(Operation *top, llvm::function_ref< bool(DataFlowAnalysis &)> analysisFilter=nullptr)
Initialize analyses starting from the provided top-level operation and run the analysis until fixpoin...
const StateT * lookupState(AnchorT anchor) const
Lookup an analysis state for the given lattice anchor.
AnalysisT * load(Args &&...args)
Load an analysis into the solver. Return the analysis instance.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:822
result_range getResults()
Definition Operation.h:440
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
void join(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs)
Join the lattice element and propagate and update if it changed.
Dead code analysis analyzes control-flow, as understood by RegionBranchOpInterface and BranchOpInterf...
This class represents a lattice holding a specific value of type ValueT.
Value getAnchor() const
Return the value this lattice is located at.
ChangeResult join(const AbstractSparseLattice &rhs) override
Join the information contained in the 'rhs' lattice into this lattice.
A sparse forward data-flow analysis for propagating SSA value lattices across the IR by implementing ...
Sparse forward dataflow analysis that computes AxisInfo for vector values reachable from the entry of...
LogicalResult visitOperation(Operation *op, ArrayRef< const AxisInfoLattice * > operands, ArrayRef< AxisInfoLattice * > results) override
static constexpr int64_t kAxisInfoTop
Sentinel "very large" value for unconstrained dimensions.
static AxisInfo entryStateFor(Value v)
Initial lattice value for an SSA value when no transfer function applies.
static int64_t highestPow2Divisor(int64_t v)
Power-of-two divisor of v. Returns kAxisInfoTop when v == 0.
dataflow::Lattice< AxisInfo > AxisInfoLattice
static AxisInfo splatAxisInfo(ArrayRef< int64_t > shape, int64_t c)
AxisInfo for a tensor that is constant c everywhere.
void runContiguityAnalysis(Operation *root)
Run the AxisInfo-based contiguity analysis over root and stamp a contiguity attribute on every xegpu....
Include the generated interface declarations.
Per-dimension axis information for an SSA vector value of integer / index type.
static AxisInfo join(const AxisInfo &lhs, const AxisInfo &rhs)
Conservative join (lattice meet).