MLIR 23.0.0git
XeGPUDialect.cpp
Go to the documentation of this file.
1//===- XeGPUDialect.cpp - MLIR XeGPU dialect implementation -----*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
16#include "mlir/IR/Builders.h"
18#include "llvm/ADT/SmallVectorExtras.h"
19#include "llvm/ADT/TypeSwitch.h"
20#include "llvm/Support/Debug.h"
21
22using std::optional;
23
24namespace mlir {
25namespace xegpu {
26
27void XeGPUDialect::initialize() {
28 addTypes<
29#define GET_TYPEDEF_LIST
30#include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc>
31 >();
32 addOperations<
33#define GET_OP_LIST
34#include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
35 >();
36 addAttributes<
37#define GET_ATTRDEF_LIST
38#include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc>
39 >();
40}
41#define GET_OP_INTERFACE_CLASSES
42#include "mlir/Dialect/XeGPU/IR/XeGPUOpInterface.cpp.inc"
43
44// A `srcShape` consists of N distribution units, each being `subShapesLayout` x
45// `subShape`. A `delinearizedId` is used to identify a particular `subShape`
46// within each distribution unit.
47// Example:
48// WG data is 128x256. SG data is 16x32, in 4x2 layout, this gives a
49// distribution unit of shape 64x64, we have 2x4 such distribution units.
50// `delinearizedId` is used to identify a 16x32 of a subgroup in each
51// distribution unit.
54 SmallVector<Value> delinearizedId,
55 ArrayRef<int64_t> subShapesLayout, ArrayRef<int64_t> subShape,
56 ArrayRef<int64_t> srcShape) {
58
59 // A distribution unit must be less than or equal to `srcShape`
60 SmallVector<int64_t> distUnitShape = llvm::map_to_vector(
61 llvm::zip_equal(srcShape,
62 computeElementwiseMul(subShapesLayout, subShape)),
63 [](const auto &t) { return std::min(std::get<0>(t), std::get<1>(t)); });
64
65 // Get the offset of `subShape` within a distribution unit.
66 SmallVector<Value> distUnitLocalOffset = llvm::map_to_vector(
67 llvm::zip(delinearizedId, subShape), [&](const auto &t) -> Value {
68 return builder.createOrFold<arith::MulIOp>(
69 loc, std::get<0>(t),
70 builder.createOrFold<arith::ConstantIndexOp>(loc, std::get<1>(t)));
71 });
72
73 // For each dist unit
74 for (SmallVector<int64_t> unitOffs :
75 StaticTileOffsetRange(srcShape, distUnitShape)) {
76 // Get dist unit offset within `srcShape`.
78 llvm::map_to_vector(unitOffs, [&](int64_t d) -> Value {
79 return arith::ConstantIndexOp::create(builder, loc, d);
80 });
81 // Calculate `subShape` offset within `srcShape`.
83 llvm::map_to_vector(llvm::zip_equal(base, distUnitLocalOffset),
84 [&](const auto &t) -> Value {
85 return builder.createOrFold<arith::AddIOp>(
86 loc, std::get<0>(t), std::get<1>(t));
87 });
88 // Do not go beyond `srcShape` bounds.
89 SmallVector<Value> mods = llvm::map_to_vector(
90 llvm::zip_equal(adds, srcShape), [&](const auto &t) -> Value {
91 return builder.createOrFold<arith::RemUIOp>(
92 loc, std::get<0>(t),
93 arith::ConstantIndexOp::create(builder, loc, std::get<1>(t)));
94 });
95
96 coordinates.push_back(mods);
97 }
98 return coordinates;
99}
100
104 // Compute distribution unit shape (clamped to srcShape).
105 SmallVector<int64_t> distUnitShape(shape.size());
106 for (size_t i = 0; i < shape.size(); ++i)
107 distUnitShape[i] = std::min(shape[i], layout[i] * subShape[i]);
108
109 // Compute local offset of this ID within a distribution unit.
110 SmallVector<int64_t> localOffset(shape.size());
111 for (size_t i = 0; i < shape.size(); ++i)
112 localOffset[i] = canonicalIds[i] * subShape[i];
113
114 // Enumerate all distribution units and compute coordinates.
116 for (SmallVector<int64_t> unitOffs :
117 StaticTileOffsetRange(shape, distUnitShape)) {
118 SmallVector<int64_t> coord(shape.size());
119 for (size_t i = 0; i < shape.size(); ++i)
120 coord[i] = (unitOffs[i] + localOffset[i]) % shape[i];
121 coordinates.push_back(coord);
122 }
123 return coordinates;
124}
125
126/// Expands per-distribution-unit block-start coordinates into the full list of
127/// element coordinates each block covers: every element of the `subShape`-sized
128/// region (row-major) offset by the block start. Comparing these instead of the
129/// bare block starts lets layouts that differ only in `lane_data` blocking, but
130/// own the same elements in the same order, be recognized as equivalent.
133 ArrayRef<int64_t> subShape) {
134 SmallVector<int64_t> unitTile(subShape.size(), 1);
136 for (const SmallVector<int64_t> &start : blockStarts) {
137 for (SmallVector<int64_t> off : StaticTileOffsetRange(subShape, unitTile)) {
138 SmallVector<int64_t> coord(start.size());
139 for (size_t i = 0; i < start.size(); ++i)
140 coord[i] = start[i] + off[i];
141 expanded.push_back(std::move(coord));
142 }
143 }
144 return expanded;
145}
146
147/// Returns true if `self` and `other` distribute `shape` identically at
148/// `level`: every id in `[0, size)` owns the same coordinates under both.
149///
150/// At the Lane level, layouts that pack `lane_data` differently can still own
151/// the same per-lane elements in the same order; their block starts differ but
152/// the expanded per-element coordinates match. So block starts are expanded
153/// (via `expandBlockCoords`) before comparing, but only when it can change the
154/// result (Lane level with differing `lane_data`) - otherwise comparing the
155/// cheaper block starts is already exact.
156///
157/// TODO: Extend the same handling to the Subgroup level (sg_data repacks).
158static bool compareDistributedCoords(xegpu::DistributeLayoutAttr self,
159 const xegpu::DistributeLayoutAttr &other,
161 xegpu::LayoutKind level, int64_t size) {
162 bool expandCoords =
163 level == xegpu::LayoutKind::Lane &&
164 self.getEffectiveLaneDataAsInt() != other.getEffectiveLaneDataAsInt();
165 SmallVector<int64_t> selfSubShape, otherSubShape;
166 if (expandCoords) {
167 selfSubShape = self.getEffectiveLaneDataAsInt();
168 otherSubShape = other.getEffectiveLaneDataAsInt();
169 }
170 for (int64_t id : llvm::seq<int64_t>(0, size)) {
171 auto coords = self.computeStaticDistributedCoords(id, shape);
172 auto otherCoords = other.computeStaticDistributedCoords(id, shape);
173 if (expandCoords) {
174 coords = expandBlockCoords(coords, selfSubShape);
175 otherCoords = expandBlockCoords(otherCoords, otherSubShape);
176 }
177 if (coords != otherCoords)
178 return false;
179 }
180 return true;
181}
182
183// Checks if the given memref type represents shared local memory (SLM).
184bool XeGPUDialect::isSharedMemory(const MemRefType &memrefTy) {
185 Attribute attr = memrefTy.getMemorySpace();
186 if (!attr)
187 return false; // Default memory space is not shared local memory
188 if (auto intAttr = llvm::dyn_cast_if_present<IntegerAttr>(attr))
189 return intAttr.getInt() == 3;
190 if (auto memrefSpace = llvm::dyn_cast_if_present<MemorySpaceAttr>(attr))
191 return memrefSpace.getValue() == MemorySpace::SLM;
192 if (auto xevmSpace = llvm::dyn_cast_if_present<xevm::AddrSpaceAttr>(attr))
193 return xevmSpace.getValue() == xevm::AddrSpace::SHARED;
194 return gpu::GPUDialect::isWorkgroupMemoryAddressSpace(attr);
195}
196
197//===----------------------------------------------------------------------===//
198// XeGPU_BlockTensorDescAttr
199//===----------------------------------------------------------------------===//
200BlockTensorDescAttr BlockTensorDescAttr::get(mlir::MLIRContext *context,
201 xegpu::MemorySpace memory_space,
202 int array_length,
203 bool boundary_check) {
204 auto scopeAttr = MemorySpaceAttr::get(context, memory_space);
205 auto lengthAttr =
206 IntegerAttr::get(IntegerType::get(context, 64), array_length);
207 auto boundaryAttr = BoolAttr::get(context, boundary_check);
208 return Base::get(context, scopeAttr, lengthAttr, boundaryAttr);
209}
210
211bool BlockTensorDescAttr::hasDefaultsOnly() {
212 return getMemorySpace().getValue() == xegpu::MemorySpace::Global &&
213 getArrayLength().getInt() == 1 && getBoundaryCheck().getValue();
214}
215
216//===----------------------------------------------------------------------===//
217// XeGPU_LayoutAttr
218//===----------------------------------------------------------------------===//
219LogicalResult
220LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
221 DenseI32ArrayAttr sg_layout, DenseI32ArrayAttr sg_data,
222 DenseI32ArrayAttr inst_data, DenseI32ArrayAttr lane_layout,
223 DenseI32ArrayAttr lane_data, DenseI32ArrayAttr order) {
224
225 // Special case for store_matrix
226 if (!sg_layout && !inst_data && !lane_layout)
227 return success();
228
229 // generate code to check sg_laout, inst_data and lane_layout having the same
230 // rank if they are not null.
231
232 if (sg_layout && inst_data && sg_layout.size() != inst_data.size()) {
233 return emitError()
234 << "expected sg_layout and inst_data to have the same rank";
235 }
236
237 if (sg_layout && lane_layout && sg_layout.size() != lane_layout.size()) {
238 return emitError()
239 << "expected sg_layout and lane_layout to have the same rank";
240 }
241
242 if (inst_data && lane_layout && inst_data.size() != lane_layout.size()) {
243 return emitError() << "expected inst_data and lane_layout to have the same "
244 "rank, got inst_data "
245 << inst_data.size() << ", lane_layout "
246 << lane_layout.size();
247 }
248
249 if ((sg_layout && !sg_data) || (!sg_layout && sg_data))
250 return emitError() << "sg_layout and sg_data must be used together";
251 if (sg_layout && sg_data && sg_layout.size() != sg_data.size())
252 return emitError()
253 << "expected sg_data and sg_layout to have the same rank";
254
255 if ((lane_layout && !lane_data) || (!lane_layout && lane_data))
256 return emitError() << "lane_layout and lane_data must be used together";
257 if (lane_layout && lane_data && lane_layout.size() != lane_data.size())
258 return emitError()
259 << "expected lane_data and lane_layout to have the same rank";
260
261 if (order) {
262 if (!sg_layout && !lane_layout)
263 return emitError()
264 << "expected sg_layout/lane_layout being used with order";
265
266 if (sg_layout && order.size() != sg_layout.size())
267 return emitError()
268 << "expected order and sg_layout to have the same rank";
269
270 if (lane_layout && order.size() != lane_layout.size())
271 return emitError()
272 << "expected order and lane_layout to have the same rank";
273 }
274
275 return success();
276}
277
278FailureOr<SmallVector<Value>>
279LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) {
280
281 SmallVector<int64_t> sgLayoutInt;
282 if (isForWorkgroup()) {
283 sgLayoutInt = getEffectiveSgLayoutAsInt();
284 } else if (isForSubgroup()) {
285 sgLayoutInt = getEffectiveLaneLayoutAsInt();
286 } else {
287 return failure();
288 }
289
290 DenseI32ArrayAttr orderAttr = getOrder();
291
292 // Handle order attribute
293 SmallVector<int64_t> order;
294 if (orderAttr && !orderAttr.empty()) {
295 order = llvm::map_to_vector(orderAttr.asArrayRef(), [](int32_t idx) {
296 return static_cast<int64_t>(idx);
297 });
298 } else {
299 // Default order: [1, 0] for 2D (row-major), [2, 1, 0] for 3D, etc.
300 order = llvm::to_vector(
301 llvm::reverse(llvm::seq<int64_t>(0, sgLayoutInt.size())));
302 }
303
304 if (order.size() != sgLayoutInt.size()) {
305 return failure();
306 }
307
308 SmallVector<Value> result(sgLayoutInt.size());
309 Value remaining = linearId;
310
311 /// Process dimensions in the order they appear in the order array
312 /// The first dimension in order is the fastest-changing
313 ///
314 /// Example walkthrough for linearId=22, sgLayout=[2,4,4], order=[2,1,0]:
315 ///
316 /// Initial: remaining=22, dimIdx = order[i], dimSize = sgLayout[dimIdx],
317 /// result=[?,?,?]
318 ///
319 /// i=0 (process columns, dimIdx=2, dimSize=4):
320 /// result[2] = 22 % 4 = 2 (column coordinate)
321 /// remaining = 22 / 4 = 5 (5 complete groups of 4 columns processed)
322 ///
323 /// i=1 (process rows, dimIdx=1, dimSize=4):
324 /// result[1] = 5 % 4 = 1 (row coordinate)
325 /// remaining = 5 / 4 = 1 (1 complete group of 4 rows processed)
326 ///
327 /// i=2 (process layers, dimIdx=0, dimSize=2):
328 /// result[0] = 1 % 2 = 1 (layer coordinate)
329 /// (no remaining update - last iteration)
330 ///
331 /// Final result: [1,1,2] = Layer 1, Row 1, Column 2
332 for (size_t i = 0; i < order.size(); ++i) {
333 int64_t dimIdx = order[i];
334 int64_t dimSize = sgLayoutInt[dimIdx];
335
336 Value dimSizeVal =
337 builder.createOrFold<arith::ConstantIndexOp>(loc, dimSize);
338
339 /// Extract the coordinate for this dimension using modulo operation
340 /// This gives us "how far within this dimension" we are
341 /// e.g., linearId=22, dimSize=4: 22 % 4 = 2 (we're at position 2 within
342 /// this dimension)
343 result[dimIdx] =
344 builder.createOrFold<arith::RemUIOp>(loc, remaining, dimSizeVal);
345
346 /// Update remaining for the next dimension by removing what we've already
347 /// processed. Division tells us "how many complete groups of this dimension
348 /// we've gone through" e.g., linearId=22, dimSize=4: 22 / 4 = 5 (we've
349 /// completed 5 groups of 4) Skip this for the last iteration since there's
350 /// no next dimension to process
351 if (i < order.size() - 1) {
352 remaining =
353 builder.createOrFold<arith::DivUIOp>(loc, remaining, dimSizeVal);
354 }
355 }
356 return result;
357}
358
359/// Implements DistributeLayoutAttr::computeDistributedCoords to generate
360/// instructions for computing multi-dimensional offsets when distributed by
361/// LayoutAttr.
362FailureOr<SmallVector<SmallVector<Value>>>
363LayoutAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
364 Value linearId, ArrayRef<int64_t> shape) {
365 SmallVector<int64_t> layout;
366 SmallVector<int64_t> subShape;
367 if (isForWorkgroup()) {
368 layout = getEffectiveSgLayoutAsInt();
369 subShape = getEffectiveSgDataAsInt();
370 } else if (isForSubgroup()) {
371 layout = getEffectiveLaneLayoutAsInt();
372 subShape = getEffectiveLaneDataAsInt();
373 } else {
374 return failure();
375 }
376 assert(!subShape.empty() && "sgdata or lanedata cannot be empty for "
377 "distributed coordinates computation");
378
379 // delinearize Ids
380 auto maybeIds = delinearizeId(builder, loc, linearId);
381 if (failed(maybeIds))
382 return failure();
383 SmallVector<Value> ids = *maybeIds;
384
385 return genCoordinates(builder, loc, ids, layout, subShape, shape);
386}
387
388bool LayoutAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) {
389 if (dyn_cast<xegpu::SliceAttr>(other))
390 return false;
391
392 return *this == dyn_cast<xegpu::LayoutAttr>(other);
393}
394
395/// Implements DistributeLayoutAttr::computeStaticDistributedCoords to
396/// compute multi-dimensional offsets for a given linear ID when distributed by
397/// LayoutAttr.
398SmallVector<SmallVector<int64_t>>
399LayoutAttr::computeStaticDistributedCoords(int64_t linearId,
400 ArrayRef<int64_t> shape) {
401 SmallVector<int64_t> layoutVec;
402 SmallVector<int64_t> subShape;
403 SmallVector<int64_t> instData;
404 if (isForWorkgroup()) {
405 layoutVec = getEffectiveSgLayoutAsInt();
406 subShape = getEffectiveSgDataAsInt();
407 } else if (isForSubgroup()) {
408 instData = getEffectiveInstDataAsInt();
409 layoutVec = getEffectiveLaneLayoutAsInt();
410 subShape = getEffectiveLaneDataAsInt();
411 }
412 if (!instData.empty()) {
413 linearId = 0;
414 subShape = instData;
415 }
416 assert(!subShape.empty() && "sgdata or lanedata cannot be empty");
417
418 // Delinearize the linear ID using the order attribute.
419 SmallVector<int64_t> order = getEffectiveOrderAsInt();
420 SmallVector<int64_t> delinearizedId(layoutVec.size());
421 int64_t remaining = linearId;
422 for (size_t i = 0; i < order.size(); ++i) {
423 int64_t dimIdx = order[i];
424 delinearizedId[dimIdx] = remaining % layoutVec[dimIdx];
425 remaining = remaining / layoutVec[dimIdx];
426 }
427
428 return genStaticCoordinates(delinearizedId, layoutVec, subShape, shape);
429}
430
431// set the layout for unit dims: sg_data, inst_data and lane_data to 1
432DistributeLayoutAttr
433LayoutAttr::setUnitDimData(SmallVector<int64_t> unitDims) const {
434 auto sgDataOpt = getSgData();
435 auto instDataOpt = getInstData();
436 auto laneDataOpt = getLaneData();
437
438 SmallVector<int32_t> sgData;
439 SmallVector<int32_t> instData;
440 SmallVector<int32_t> laneData;
441
442 if (sgDataOpt)
443 sgData = llvm::to_vector(sgDataOpt.asArrayRef());
444
445 if (instDataOpt)
446 instData = llvm::to_vector(instDataOpt.asArrayRef());
447
448 if (laneDataOpt)
449 laneData = llvm::to_vector(laneDataOpt.asArrayRef());
450
451 for (auto dim : unitDims) {
452 if (dim < static_cast<int64_t>(sgData.size()))
453 sgData[dim] = 1;
454 if (dim < static_cast<int64_t>(instData.size()))
455 instData[dim] = 1;
456 if (dim < static_cast<int64_t>(laneData.size()))
457 laneData[dim] = 1;
458 }
459
460 return LayoutAttr::get(
461 getContext(), getSgLayout(),
462 sgData.empty() ? DenseI32ArrayAttr()
464 instData.empty() ? DenseI32ArrayAttr()
465 : DenseI32ArrayAttr::get(getContext(), instData),
466 getLaneLayout(),
467 laneData.empty() ? DenseI32ArrayAttr()
468 : DenseI32ArrayAttr::get(getContext(), laneData),
469 getOrder());
470}
471
472// set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
473DistributeLayoutAttr
474LayoutAttr::setUnitDimLayout(SmallVector<int64_t> unitDims) const {
475 auto sgLayoutOpt = getSgLayout();
476 auto laneLayoutOpt = getLaneLayout();
477
478 SmallVector<int32_t> sgLayout;
479 SmallVector<int32_t> laneLayout;
480
481 if (sgLayoutOpt)
482 sgLayout = llvm::to_vector(sgLayoutOpt.asArrayRef());
483 if (laneLayoutOpt)
484 laneLayout = llvm::to_vector(laneLayoutOpt.asArrayRef());
485
486 for (auto dim : unitDims) {
487 if (dim < static_cast<int64_t>(sgLayout.size()))
488 sgLayout[dim] = 1;
489 if (dim < static_cast<int64_t>(laneLayout.size()))
490 laneLayout[dim] = 1;
491 }
492
493 return LayoutAttr::get(
494 getContext(),
495 sgLayout.empty() ? DenseI32ArrayAttr()
496 : DenseI32ArrayAttr::get(getContext(), sgLayout),
497 getSgData(), getInstData(),
498 laneLayout.empty() ? DenseI32ArrayAttr()
499 : DenseI32ArrayAttr::get(getContext(), laneLayout),
500 getLaneData(), getOrder());
501}
502
503// Derive a new layout with sg_data, inst_data and lane_data set to the
504// specified values for the given dimension
505DistributeLayoutAttr LayoutAttr::setDimData(int64_t dim, int64_t sgData,
506 int64_t instData,
507 int64_t laneData) {
508
509 SmallVector<int64_t> sgDataVec = getEffectiveSgDataAsInt();
510 SmallVector<int64_t> instDataVec = getEffectiveInstDataAsInt();
511 SmallVector<int64_t> laneDataVec = getEffectiveLaneDataAsInt();
512
513 if (dim < static_cast<int64_t>(sgDataVec.size()) && sgData != -1)
514 sgDataVec[dim] = sgData;
515 if (dim < static_cast<int64_t>(instDataVec.size()) && instData != -1)
516 instDataVec[dim] = instData;
517 if (dim < static_cast<int64_t>(laneDataVec.size()) && laneData != -1)
518 laneDataVec[dim] = laneData;
519
520 SmallVector<int32_t> sgDataVec32(sgDataVec.begin(), sgDataVec.end());
521 SmallVector<int32_t> instDataVec32(instDataVec.begin(), instDataVec.end());
522 SmallVector<int32_t> laneDataVec32(laneDataVec.begin(), laneDataVec.end());
523
524 return LayoutAttr::get(
525 getContext(), getSgLayout(),
526 sgDataVec.empty() ? DenseI32ArrayAttr()
527 : DenseI32ArrayAttr::get(getContext(), sgDataVec32),
528 instDataVec.empty() ? DenseI32ArrayAttr()
529 : DenseI32ArrayAttr::get(getContext(), instDataVec32),
530 getLaneLayout(),
531 laneDataVec.empty() ? DenseI32ArrayAttr()
532 : DenseI32ArrayAttr::get(getContext(), laneDataVec32),
533 getOrder());
534}
535
536// Derive a new layout by removing dimensions.
537// `dimGroup` specifies a group of dimensions to be removed in the derived
538// layout.
539DistributeLayoutAttr LayoutAttr::dropDims(SmallVector<int64_t> dimGroup) {
540
541 SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt();
542 SmallVector<int64_t> sgData = getEffectiveSgDataAsInt();
543 SmallVector<int64_t> instData = getEffectiveInstDataAsInt();
544 SmallVector<int64_t> laneLayout = getEffectiveLaneLayoutAsInt();
545 SmallVector<int64_t> laneData = getEffectiveLaneDataAsInt();
546 DenseI32ArrayAttr origOrderAttr = getOrder();
547
548 SmallVector<int64_t> sortedDimGroup = dimGroup;
549 llvm::sort(sortedDimGroup);
550
551 for (auto dimIdx : llvm::reverse(sortedDimGroup)) {
552 if (!sgLayout.empty()) {
553 sgLayout.erase(sgLayout.begin() + dimIdx);
554 sgData.erase(sgData.begin() + dimIdx);
555 }
556 if (!instData.empty())
557 instData.erase(instData.begin() + dimIdx);
558 if (!laneLayout.empty()) {
559 laneLayout.erase(laneLayout.begin() + dimIdx);
560 laneData.erase(laneData.begin() + dimIdx);
561 }
562 }
563
564 // Only emit a new order attribute when the input had one, so that "no
565 // order" inputs do not gain a synthetic default that would later trip
566 // adjacency checks in collapseDims.
567 SmallVector<int64_t> newOrder;
568 if (origOrderAttr && !origOrderAttr.empty()) {
569 SmallVector<int64_t> origOrder = getEffectiveOrderAsInt();
570 for (int64_t d : origOrder) {
571 if (llvm::is_contained(dimGroup, d))
572 continue;
573 int64_t offset =
574 llvm::count_if(dimGroup, [&](int64_t s) { return s < d; });
575 newOrder.push_back(d - offset);
576 }
577 if ((sgLayout.empty() && laneLayout.empty()) || newOrder.size() == 1)
578 newOrder.clear();
579 }
580
581 auto toAttr = [&](ArrayRef<int64_t> v) -> DenseI32ArrayAttr {
582 if (v.empty())
583 return nullptr;
584 SmallVector<int32_t> v32(v.begin(), v.end());
585 return DenseI32ArrayAttr::get(getContext(), v32);
586 };
587 auto droppedLayout = xegpu::LayoutAttr::get(
588 getContext(), toAttr(sgLayout), toAttr(sgData), toAttr(instData),
589 toAttr(laneLayout), toAttr(laneData), toAttr(newOrder));
590 return droppedLayout;
591}
592
593// Derive a new layout by collapsing dimensions.
594// `dimGroup` specifies a group of adjacent dimensions
595// that are collapsed into a single dimension in the derived layout.
596DistributeLayoutAttr LayoutAttr::collapseDims(SmallVector<int64_t> dimGroup) {
597
598 SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt();
599 SmallVector<int64_t> sgData = getEffectiveSgDataAsInt();
600 SmallVector<int64_t> instData = getEffectiveInstDataAsInt();
601 SmallVector<int64_t> laneLayout = getEffectiveLaneLayoutAsInt();
602 SmallVector<int64_t> laneData = getEffectiveLaneDataAsInt();
603 SmallVector<int64_t> origOrder = getEffectiveOrderAsInt();
604
605 SmallVector<int64_t> sortedDimGroup = dimGroup;
606 llvm::sort(sortedDimGroup);
607 int64_t dimBeforeCurrent = -1;
608 for (auto dimIdx : sortedDimGroup) {
609 // when order attr is present, adjacency dims are values like [3, 2, 1, 0]
610 // in decreasing order; otherwise based on dim indices like [0, 1, 2, 3]
611 // in increasing order
612 if (dimBeforeCurrent >= 0) {
613 if (getOrder() && !getOrder().empty()) {
614 int64_t orderBefore = origOrder[dimBeforeCurrent];
615 int64_t orderCurrent = origOrder[dimIdx];
616 if (orderBefore != (orderCurrent - 1))
617 llvm::report_fatal_error(
618 "dimensions being collapsed must be adjacent in order");
619 } else {
620 if (dimIdx != (dimBeforeCurrent + 1))
621 llvm::report_fatal_error(
622 "dimensions being collapsed must be adjacent");
623 }
624 }
625 dimBeforeCurrent = dimIdx;
626 }
627
628 int firstDim = sortedDimGroup.front();
629
630 // collapse the dimensions in dimGroup into one dimension by multiplying their
631 // sizes together
632
633 if (!sgLayout.empty()) {
634 int64_t collapsedSglayout = 1, collapsedSgData = 1;
635 for (auto dimIdx : dimGroup) {
636 collapsedSglayout *= sgLayout[dimIdx];
637 collapsedSgData *= sgData[dimIdx];
638 }
639 for (auto dimIdx : llvm::reverse(sortedDimGroup)) {
640 sgLayout.erase(sgLayout.begin() + dimIdx, sgLayout.begin() + dimIdx + 1);
641 sgData.erase(sgData.begin() + dimIdx, sgData.begin() + dimIdx + 1);
642 }
643 sgLayout.insert(sgLayout.begin() + firstDim, collapsedSglayout);
644 sgData.insert(sgData.begin() + firstDim, collapsedSgData);
645 }
646
647 if (!instData.empty()) {
648 int64_t collapsedInstData = 1;
649 for (auto dimIdx : dimGroup)
650 collapsedInstData *= instData[dimIdx];
651 for (auto dimIdx : llvm::reverse(sortedDimGroup))
652 instData.erase(instData.begin() + dimIdx, instData.begin() + dimIdx + 1);
653 instData.insert(instData.begin() + firstDim, collapsedInstData);
654 }
655
656 if (!laneLayout.empty()) {
657 int64_t collapsedLaneLayout = 1, collapsedLaneData = 1;
658 for (auto dimIdx : dimGroup) {
659 collapsedLaneLayout *= laneLayout[dimIdx];
660 collapsedLaneData *= laneData[dimIdx];
661 }
662 for (auto dimIdx : llvm::reverse(sortedDimGroup)) {
663 laneLayout.erase(laneLayout.begin() + dimIdx,
664 laneLayout.begin() + dimIdx + 1);
665 laneData.erase(laneData.begin() + dimIdx, laneData.begin() + dimIdx + 1);
666 }
667 laneLayout.insert(laneLayout.begin() + firstDim, collapsedLaneLayout);
668 laneData.insert(laneData.begin() + firstDim, collapsedLaneData);
669 }
670
671 SmallVector<int64_t> newOrder;
672 DenseI32ArrayAttr orderAttr = getOrder();
673 if (orderAttr && !orderAttr.empty()) {
674
675 for (auto dimIdx : llvm::reverse(sortedDimGroup)) {
676 if (dimIdx != firstDim)
677 origOrder.erase(origOrder.begin() + dimIdx);
678 }
679 // say we have orderVec = {5, 3, 2, 1, 0}
680 // Create indices [0, 1, 2, 3, 4]
681 SmallVector<size_t> indices =
682 llvm::to_vector(llvm::seq<size_t>(0, origOrder.size()));
683
684 // Sort indices based on corresponding values
685 llvm::sort(indices,
686 [&](size_t a, size_t b) { return origOrder[a] < origOrder[b]; });
687
688 newOrder = llvm::to_vector(llvm::map_range(
689 indices, [&](size_t i) { return static_cast<int64_t>(i); }));
690 }
691
692 auto toAttr = [&](ArrayRef<int64_t> v) -> DenseI32ArrayAttr {
693 if (v.empty())
694 return nullptr;
695 SmallVector<int32_t> v32(v.begin(), v.end());
696 return DenseI32ArrayAttr::get(getContext(), v32);
697 };
698 auto collapsedLayout = xegpu::LayoutAttr::get(
699 getContext(), toAttr(sgLayout), toAttr(sgData), toAttr(instData),
700 toAttr(laneLayout), toAttr(laneData), toAttr(newOrder));
701 return collapsedLayout;
702}
703
704// Derive a new layout by expanding a single dimension `dim` into multiple
705// adjacent dimensions whose extents are given by `targetShape`.
706//
707// Distribution is always outer-to-inner to make sure larger contiguous
708// chunks are given to each compute unit first. Concretely: sg_layout and
709// lane_layout walk the new dims outer-to-inner so each compute unit owns a
710// contiguous run after collapse; sg_data / lane_data / inst_data fill
711// innermost-first so the per-unit data tile is contiguous in the
712// fastest-varying expanded dim. The expanded dims are assumed to be in
713// row-major (FCD-first) order, which is intrinsic to `vector.shape_cast`'s
714// linear-order-preserving semantics.
715//
716// Distribution policy on the expanded src dims (replacing `dim`):
717// - sg_layout / lane_layout: spread outer-to-inner; each dim takes
718// min(remaining, targetShape[i]); leftover spills into the next inner
719// dim.
720// - sg_data: fill innermost-first, capped per dim by
721// targetShape[i] / sgLayout[i] (the per-sg share of the extent),
722// or targetShape[i] (if sg_data is replicated across all subgroups).
723// - lane_data: fill innermost-first, capped per dim by
724// (targetShape[i] / sgLayout[i]) / laneLayout[i] (the per-lane share of
725// the per-sg extent).
726// - inst_data: seeded from laneLayout[i] * laneData[i] per dim, then the
727// remaining factor is distributed innermost-first (capped per dim by
728// the per-sg extent).
729// - order: the original dim index is replaced by the expanded dim indices
730// in innermost-fastest order; entries past `dim` shift up by
731// `targetShape.size() - 1`.
732//
733// Examples (only the affected dim shown; assume rank-1 input):
734// layout<sg_layout=[8], sg_data=[512]>, expandDim(0, [8, 16, 32])
735// -> sg_layout=[8, 1, 1], sg_data=[1, 16, 32]
736// layout<sg_layout=[16], sg_data=[32]>, expandDim(0, [2, 8, 32])
737// -> sg_layout=[2, 8, 1], sg_data=[1, 1, 32] (sg_layout spills inward)
738// layout<inst_data=[32], lane_layout=[16], lane_data=[1]>,
739// expandDim(0, [8, 16, 32])
740// -> inst_data=[8, 2, 2], lane_layout=[8, 2, 1], lane_data=[1, 1, 1]
741DistributeLayoutAttr LayoutAttr::expandDim(int64_t dim,
742 ArrayRef<int64_t> targetShape) {
743 SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt();
744 SmallVector<int64_t> sgData = getEffectiveSgDataAsInt();
745 SmallVector<int64_t> instData = getEffectiveInstDataAsInt();
746 SmallVector<int64_t> laneLayout = getEffectiveLaneLayoutAsInt();
747 SmallVector<int64_t> laneData = getEffectiveLaneDataAsInt();
748
749 int64_t origRank = getRank();
750 int64_t expCount = static_cast<int64_t>(targetShape.size());
751 assert(dim >= 0 && dim < origRank && "dim out of range");
752 assert(expCount >= 1 && "targetShape must have at least one dim");
753 int64_t newRank = origRank + expCount - 1;
754
755 // Snapshot the per-dim values we need before any field is mutated by
756 // splice() below; without this, computations that read e.g. `laneLayout[dim]`
757 // after laneLayout has already been expanded would see the wrong value.
758 int64_t origSgLayoutDim = sgLayout.empty() ? 1 : sgLayout[dim];
759 int64_t origSgDataDim = sgData.empty() ? 1 : sgData[dim];
760 int64_t origLaneLayoutDim = laneLayout.empty() ? 1 : laneLayout[dim];
761 int64_t origLaneDataDim = laneData.empty() ? 1 : laneData[dim];
762 int64_t origInstDataDim = instData.empty() ? 1 : instData[dim];
763
764 // Spread `total` across the new dims (length expCount), capped per dim by
765 // `dimSizeCap[i]`. `outerToInner` selects iteration direction
766 // (true = i=0..n-1).
767 auto spread = [&](int64_t total, ArrayRef<int64_t> dimSizeCap,
768 bool outerToInner) -> SmallVector<int64_t> {
769 SmallVector<int64_t> out(expCount, 1);
770 int64_t remaining = total;
771 auto step = [&](int64_t i) {
772 if (remaining == 1)
773 return;
774 int64_t take = std::min(remaining, dimSizeCap[i]);
775 assert(take > 0 && "expandDim distribution must not be zero");
776 assert(remaining % take == 0 &&
777 "expandDims must divide evenly across dims");
778 out[i] = take;
779 remaining /= take;
780 };
781 if (outerToInner)
782 for (int64_t i = 0; i < expCount; ++i)
783 step(i);
784 else
785 for (int64_t i = expCount - 1; i >= 0; --i)
786 step(i);
787 assert(remaining == 1 && "expandDims total must fit within target shape");
788 return out;
789 };
790
791 // Splice `expanded` (length expCount) into `vec` at position `dim`,
792 // replacing the single entry at `dim`.
793 auto splice = [&](SmallVector<int64_t> &vec, ArrayRef<int64_t> expanded) {
794 if (vec.empty())
795 return;
796 vec.erase(vec.begin() + dim);
797 vec.insert(vec.begin() + dim, expanded.begin(), expanded.end());
798 };
799
800 bool hasSgLayout = !sgLayout.empty();
801 bool hasSgData = !sgData.empty();
802 bool hasLaneLayout = !laneLayout.empty();
803 bool hasLaneData = !laneData.empty();
804 bool hasInstData = !instData.empty();
805
806 // sg_layout / sg_data
807 SmallVector<int64_t> expSgLayout(expCount, 1);
808 if (hasSgLayout) {
809 expSgLayout = spread(origSgLayoutDim, targetShape, /*outerToInner=*/true);
810 splice(sgLayout, expSgLayout);
811 }
812 bool sgDataReplicated =
813 hasSgData && origSgDataDim == computeProduct(targetShape);
814 if (hasSgData) {
815 SmallVector<int64_t> dimSizeCap(targetShape.begin(), targetShape.end());
816 if (hasSgLayout && !sgDataReplicated)
817 for (int64_t i = 0; i < expCount; ++i)
818 dimSizeCap[i] /= expSgLayout[i];
819 SmallVector<int64_t> expSgData =
820 spread(origSgDataDim, dimSizeCap, /*outerToInner=*/false);
821 splice(sgData, expSgData);
822 }
823
824 // Per-sg view used as the base for lane_layout / lane_data / inst_data:
825 // targetShape[i] / sg_layout[i] when sg_layout is present, else
826 // targetShape itself.
827 SmallVector<int64_t> perSgShape(targetShape.begin(), targetShape.end());
828 if (hasSgLayout && !sgDataReplicated)
829 for (int64_t i = 0; i < expCount; ++i)
830 perSgShape[i] /= expSgLayout[i];
831
832 // lane_layout / lane_data
833 SmallVector<int64_t> expLaneLayout(expCount, 1);
834 SmallVector<int64_t> expLaneData(expCount, 1);
835 if (hasLaneLayout) {
836 expLaneLayout = spread(origLaneLayoutDim, perSgShape,
837 /*outerToInner=*/true);
838 splice(laneLayout, expLaneLayout);
839 }
840 if (hasLaneData) {
841 SmallVector<int64_t> dimSizeCap(perSgShape.begin(), perSgShape.end());
842 if (hasLaneLayout)
843 for (int64_t i = 0; i < expCount; ++i)
844 dimSizeCap[i] /= expLaneLayout[i];
845 expLaneData = spread(origLaneDataDim, dimSizeCap, /*outerToInner=*/false);
846 splice(laneData, expLaneData);
847 }
848
849 // inst_data: when lane info is present, the per-lane atom
850 // `laneLayout[i] * laneData[i]` is the minimum granularity each new dim must
851 // hold to keep the lane-level distribution unit aligned with inst_data; the
852 // remaining factor `inst_data / laneAtom` is then spread innermost-first
853 // (capped by `perSgShape[i] / atom[i]`) and multiplied back onto the atom.
854 // Without lane info, fall back to a plain innermost-first spread over
855 // perSgShape.
856 if (hasInstData) {
857 SmallVector<int64_t> expInstData;
858 if (!hasLaneLayout || !hasLaneData) {
859 expInstData = spread(origInstDataDim, perSgShape, /*outerToInner=*/false);
860 } else {
861 int64_t laneAtom = origLaneLayoutDim * origLaneDataDim;
862 SmallVector<int64_t> atom(expCount, 1);
863 SmallVector<int64_t> dimSizeCap(expCount, 1);
864 for (int64_t i = 0; i < expCount; ++i) {
865 atom[i] = expLaneLayout[i] * expLaneData[i];
866 dimSizeCap[i] = perSgShape[i] / atom[i];
867 }
868 expInstData = spread(origInstDataDim / laneAtom, dimSizeCap,
869 /*outerToInner=*/false);
870 for (int64_t i = 0; i < expCount; ++i)
871 expInstData[i] *= atom[i];
872 }
873 splice(instData, expInstData);
874 }
875
876 // order: replace `dim`'s entry with the expanded dim indices in
877 // innermost-fastest order; shift every other entry past `dim` up by
878 // (expCount - 1).
879 SmallVector<int64_t> newOrder;
880 DenseI32ArrayAttr orderAttr = getOrder();
881 if (orderAttr && !orderAttr.empty()) {
882 SmallVector<int64_t> origOrder = getEffectiveOrderAsInt();
883 newOrder.reserve(newRank);
884 for (int64_t o : origOrder) {
885 if (o == dim) {
886 // Innermost dim of the expanded group is fastest-varying.
887 for (int64_t i = expCount - 1; i >= 0; --i)
888 newOrder.push_back(dim + i);
889 } else if (o > dim) {
890 newOrder.push_back(o + expCount - 1);
891 } else {
892 newOrder.push_back(o);
893 }
894 }
895 }
896
897 auto toAttr = [&](ArrayRef<int64_t> v) -> DenseI32ArrayAttr {
898 if (v.empty())
899 return nullptr;
900 SmallVector<int32_t> v32(v.begin(), v.end());
901 return DenseI32ArrayAttr::get(getContext(), v32);
902 };
903 return xegpu::LayoutAttr::get(getContext(), toAttr(sgLayout), toAttr(sgData),
904 toAttr(instData), toAttr(laneLayout),
905 toAttr(laneData), toAttr(newOrder));
906}
907
908// Derive a new layout by transpose the layout using `permutation`.
909DistributeLayoutAttr LayoutAttr::transposeDims(ArrayRef<int64_t> permutation) {
910
911 SmallVector<int64_t> origSgLayout = getEffectiveSgLayoutAsInt();
912 SmallVector<int64_t> origSgData = getEffectiveSgDataAsInt();
913 SmallVector<int64_t> origInstData = getEffectiveInstDataAsInt();
914 SmallVector<int64_t> origLaneLayout = getEffectiveLaneLayoutAsInt();
915 SmallVector<int64_t> origLaneData = getEffectiveLaneDataAsInt();
916 SmallVector<int64_t> origOrder = getEffectiveOrderAsInt();
917
918 SmallVector<int32_t> sgLayout;
919 SmallVector<int32_t> sgData;
920 SmallVector<int32_t> instData;
921 SmallVector<int32_t> laneLayout;
922 SmallVector<int32_t> laneData;
923 SmallVector<int32_t> order;
924
925 for (int64_t idx : permutation) {
926 if (!origLaneLayout.empty()) {
927 laneLayout.push_back(static_cast<int32_t>(origLaneLayout[idx]));
928 laneData.push_back(static_cast<int32_t>(origLaneData[idx]));
929 }
930 if (!origInstData.empty())
931 instData.push_back(static_cast<int32_t>(origInstData[idx]));
932 if (!origSgLayout.empty()) {
933 sgLayout.push_back(static_cast<int32_t>(origSgLayout[idx]));
934 sgData.push_back(static_cast<int32_t>(origSgData[idx]));
935 }
936 order.push_back(static_cast<int32_t>(origOrder[idx]));
937 }
938 if (origLaneLayout.empty() && origSgLayout.empty())
939 order.clear();
940
941 auto toAttr = [&](ArrayRef<int32_t> v) -> DenseI32ArrayAttr {
942 return v.empty() ? nullptr : DenseI32ArrayAttr::get(getContext(), v);
943 };
944 return xegpu::LayoutAttr::get(getContext(), toAttr(sgLayout), toAttr(sgData),
945 toAttr(instData), toAttr(laneLayout),
946 toAttr(laneData), toAttr(order));
947}
948
949/// Check if this layout is a transpose of another layout.
950bool LayoutAttr::isTransposeOf(const xegpu::DistributeLayoutAttr &other,
951 ArrayRef<int64_t> perm,
952 const xegpu::LayoutKind kind) {
953 if (!other)
954 return false;
955 if (getRank() != other.getRank() ||
956 perm.size() != static_cast<size_t>(getRank()))
957 return false;
958 if (!isPermutationVector(perm))
959 return false;
960 // vector.transpose semantics: dst[i] = src[perm[i]]. So `this` (= dst) is a
961 // transpose of `other` (= src) via `perm` iff for all i, dst[i] ==
962 // src[perm[i]].
963 auto checkTranspose = [](ArrayRef<int64_t> dst, ArrayRef<int64_t> src,
964 ArrayRef<int64_t> perm) {
965 for (const auto &ta : llvm::enumerate(perm)) {
966 if (dst[ta.index()] != src[ta.value()])
967 return false;
968 }
969 return true;
970 };
971 if (kind == xegpu::LayoutKind::Subgroup)
972 return checkTranspose(getEffectiveSgLayoutAsInt(),
973 other.getEffectiveSgLayoutAsInt(), perm) &&
974 checkTranspose(getEffectiveSgDataAsInt(),
975 other.getEffectiveSgDataAsInt(), perm) &&
976 checkTranspose(getEffectiveOrderAsInt(),
977 other.getEffectiveOrderAsInt(), perm);
978 if (kind == xegpu::LayoutKind::InstData)
979 return checkTranspose(getEffectiveInstDataAsInt(),
980 other.getEffectiveInstDataAsInt(), perm);
981 if (kind == xegpu::LayoutKind::Lane)
982 return checkTranspose(getEffectiveLaneLayoutAsInt(),
983 other.getEffectiveLaneLayoutAsInt(), perm) &&
984 checkTranspose(getEffectiveLaneDataAsInt(),
985 other.getEffectiveLaneDataAsInt(), perm) &&
986 checkTranspose(getEffectiveOrderAsInt(),
987 other.getEffectiveOrderAsInt(), perm);
988
989 return false;
990}
991
992bool LayoutAttr::isCompatibleWith(const xegpu::DistributeLayoutAttr &other,
993 SmallVector<int64_t> shape,
994 xegpu::LayoutKind level) {
995 if (!other)
996 return false;
997 if (getEffectiveOrderAsInt() == other.getEffectiveOrderAsInt()) {
998 // short cut when order is the same, no need to compute coords and compare
999 if (level == xegpu::LayoutKind::Subgroup)
1000 if (getEffectiveSgLayoutAsInt() == other.getEffectiveSgLayoutAsInt() &&
1001 getEffectiveSgDataAsInt() == other.getEffectiveSgDataAsInt())
1002 return true;
1003 if (level == xegpu::LayoutKind::Lane)
1004 if (getEffectiveLaneLayoutAsInt() ==
1005 other.getEffectiveLaneLayoutAsInt() &&
1006 getEffectiveLaneDataAsInt() == other.getEffectiveLaneDataAsInt())
1007 return true;
1008 }
1009
1010 auto compareCoordsForAllIds = [&](int64_t size) {
1011 return compareDistributedCoords(*this, other, shape, level, size);
1012 };
1013
1014 if (level == xegpu::LayoutKind::Subgroup) {
1015 int64_t wgSize = computeProduct(getEffectiveSgLayoutAsInt());
1016 return compareCoordsForAllIds(wgSize);
1017 }
1018 if (level == xegpu::LayoutKind::InstData) {
1019 return (getEffectiveInstDataAsInt() == other.getEffectiveInstDataAsInt());
1020 }
1021 if (level == xegpu::LayoutKind::Lane) {
1022 int64_t subgroupSize = computeProduct(getEffectiveLaneLayoutAsInt());
1023 return compareCoordsForAllIds(subgroupSize);
1024 }
1025 return true;
1026}
1027
1028//===----------------------------------------------------------------------===//
1029// XeGPU_SliceAttr
1030//===----------------------------------------------------------------------===//
1031LogicalResult
1032SliceAttr::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
1033 xegpu::DistributeLayoutAttr parent, DenseI64ArrayAttr dims) {
1034
1035 if (!dims)
1036 return emitError() << "expected dims attribute";
1037
1038 // check every element in dims is unique and smaller than rank
1039 llvm::SmallDenseSet<int64_t> seen;
1040 for (int64_t dim : dims.asArrayRef()) {
1041 if (dim < 0)
1042 return emitError() << "invalid dim (" << dim << ") in slice attribute.";
1043 if (!seen.insert(dim).second)
1044 return emitError() << "repeated dim (" << dim << ") in slice attribute.";
1045 }
1046 return success();
1047}
1048
1049SliceAttr SliceAttr::flatten() const {
1050 xegpu::DistributeLayoutAttr parent = getParent();
1051 SmallVector<DenseI64ArrayAttr> slicedDims({getDims()});
1052
1053 while (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(parent)) {
1054 parent = sliceAttr.getParent();
1055 slicedDims.push_back(sliceAttr.getDims());
1056 }
1057
1058 auto layoutAttr = dyn_cast<xegpu::LayoutAttr>(parent);
1059 SmallVector<int64_t> indices =
1060 llvm::to_vector(llvm::seq<int64_t>(0, layoutAttr.getRank()));
1061
1062 // get remaining dims (flattened) by applying slice ops with all slicedDims
1063 SmallVector<int64_t> remainingDims(indices);
1064 for (auto dim : llvm::reverse(slicedDims))
1065 remainingDims = XeGPUDialect::slice(llvm::ArrayRef<int64_t>(remainingDims),
1066 dim.asArrayRef());
1067
1068 // get flattened sliced dims by applying slice ops with the remaining dims
1069 SmallVector<int64_t> flattenedDims = XeGPUDialect::slice(
1070 llvm::ArrayRef<int64_t>(indices), llvm::ArrayRef<int64_t>(remainingDims));
1071
1072 return xegpu::SliceAttr::get(
1073 getContext(), layoutAttr,
1074 DenseI64ArrayAttr::get(getContext(), flattenedDims));
1075}
1076
1077FailureOr<SmallVector<Value>>
1078SliceAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) {
1079 SliceAttr attr = flatten();
1080 auto parent = dyn_cast<LayoutAttr>(attr.getParent());
1081 return parent.delinearizeId(builder, loc, linearId);
1082}
1083
1084// Implements DistributeLayoutAttr::computeDistributedCoords to generate
1085// instructions for computing multi-dimensional offsets when distributed by
1086// LayoutAttr.
1087FailureOr<SmallVector<SmallVector<Value>>>
1088SliceAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
1089 Value linearId, ArrayRef<int64_t> shape) {
1090 assert(getRank() == static_cast<int64_t>(shape.size()) && "invalid shape.");
1091
1092 SmallVector<int64_t> layout;
1093 SmallVector<int64_t> subShape;
1094 if (isForWorkgroup()) {
1095 layout = getEffectiveSgLayoutAsInt();
1096 subShape = getEffectiveSgDataAsInt();
1097 } else if (isForSubgroup()) {
1098 layout = getEffectiveLaneLayoutAsInt();
1099 subShape = getEffectiveLaneDataAsInt();
1100 } else {
1101 return failure();
1102 }
1103
1104 if (subShape.empty())
1105 return failure();
1106
1107 // delinearize Ids
1108 auto maybeIds = delinearizeId(builder, loc, linearId);
1109 if (failed(maybeIds))
1110 return failure();
1111
1112 // The effective sgIds for offsets computing correspond
1113 // to the dims that are not sliced.
1114 ArrayRef<int64_t> dims = flatten().getDims().asArrayRef();
1115 SmallVector<Value> canonicalIds =
1116 XeGPUDialect::slice(ArrayRef<Value>(*maybeIds), dims);
1117
1118 return genCoordinates(builder, loc, canonicalIds, layout, subShape, shape);
1119}
1120
1121/// Implements DistributeLayoutAttr::computeStaticDistributedCoords to
1122/// compute multi-dimensional offsets for a given linear ID when distributed by
1123/// SliceAttr. Delegates delinearization to the parent LayoutAttr, then uses
1124/// only the non-sliced dimensions for coordinate computation.
1125SmallVector<SmallVector<int64_t>>
1126SliceAttr::computeStaticDistributedCoords(int64_t linearId,
1127 ArrayRef<int64_t> shape) {
1128 assert(getRank() == static_cast<int64_t>(shape.size()) && "invalid shape.");
1129
1130 SmallVector<int64_t> layout;
1131 SmallVector<int64_t> subShape;
1132 SmallVector<int64_t> instData;
1133 if (isForWorkgroup()) {
1134 layout = getEffectiveSgLayoutAsInt();
1135 subShape = getEffectiveSgDataAsInt();
1136 } else if (isForSubgroup()) {
1137 instData = getEffectiveInstDataAsInt();
1138 layout = getEffectiveLaneLayoutAsInt();
1139 subShape = getEffectiveLaneDataAsInt();
1140 }
1141 if (!instData.empty()) {
1142 linearId = 0;
1143 subShape = instData;
1144 }
1145
1146 assert(!subShape.empty() && "sgdata or lanedata cannot be empty");
1147
1148 // Delinearize the ID using the parent layout (same as the IR version).
1149 SliceAttr flattened = flatten();
1150 auto parent = dyn_cast<LayoutAttr>(flattened.getParent());
1151 SmallVector<int64_t> parentLayoutVec;
1152 if (parent.isForWorkgroup())
1153 parentLayoutVec = parent.getEffectiveSgLayoutAsInt();
1154 else
1155 parentLayoutVec = parent.getEffectiveLaneLayoutAsInt();
1156
1157 SmallVector<int64_t> order = parent.getEffectiveOrderAsInt();
1158 SmallVector<int64_t> allIds(parentLayoutVec.size());
1159 int64_t remaining = linearId;
1160 for (size_t i = 0; i < order.size(); ++i) {
1161 int64_t dimIdx = order[i];
1162 allIds[dimIdx] = remaining % parentLayoutVec[dimIdx];
1163 if (i < order.size() - 1)
1164 remaining = remaining / parentLayoutVec[dimIdx];
1165 }
1166
1167 // The effective IDs for coordinate computation correspond
1168 // to the dims that are not sliced.
1169 ArrayRef<int64_t> dims = flattened.getDims().asArrayRef();
1170 SmallVector<int64_t> canonicalIds =
1171 XeGPUDialect::slice(ArrayRef<int64_t>(allIds), dims);
1172
1173 return genStaticCoordinates(canonicalIds, layout, subShape, shape);
1174}
1175
1176bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) {
1177 auto flattenedThis = flatten();
1178 // If other is a LayoutAttr, just compare directly with parent of
1179 // flattenedThis.
1180 if (auto otherLayout = dyn_cast<xegpu::LayoutAttr>(other))
1181 return flattenedThis.getParent() == otherLayout;
1182 // If other is a SliceAttr, flatten it first before comparing.
1183 auto flattenedOther = dyn_cast<xegpu::SliceAttr>(other).flatten();
1184 // Both must have common parent LayoutAttr.
1185 if (flattenedThis.getParent() != flattenedOther.getParent())
1186 return false;
1187 // otherFlattened's sliced dims must be a subset of flattenedThis's sliced
1188 // dims.
1189 llvm::SmallDenseSet<int64_t> thisDims(
1190 flattenedThis.getDims().asArrayRef().begin(),
1191 flattenedThis.getDims().asArrayRef().end());
1192 return llvm::all_of(flattenedOther.getDims().asArrayRef(),
1193 [&](int64_t dim) { return thisDims.contains(dim); });
1194}
1195
1196bool SliceAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) {
1197 if (dyn_cast<xegpu::LayoutAttr>(other))
1198 return false;
1199
1200 auto flattenedThis = flatten();
1201 auto flattenedOther = dyn_cast<xegpu::SliceAttr>(other).flatten();
1202
1203 return ((flattenedThis.getParent() == flattenedOther.getParent()) &&
1204 (flattenedThis.getDims() == flattenedOther.getDims()));
1205}
1206
1207bool SliceAttr::isCompatibleWith(const xegpu::DistributeLayoutAttr &other,
1208 SmallVector<int64_t> shape,
1209 xegpu::LayoutKind level) {
1210 if (!other)
1211 return false;
1212 if (getEffectiveOrderAsInt() == other.getEffectiveOrderAsInt()) {
1213 // short cut when order is the same, no need to compute coords and compare
1214 if (level == xegpu::LayoutKind::Subgroup)
1215 if (getEffectiveSgLayoutAsInt() == other.getEffectiveSgLayoutAsInt() &&
1216 getEffectiveSgDataAsInt() == other.getEffectiveSgDataAsInt())
1217 return true;
1218 if (level == xegpu::LayoutKind::Lane)
1219 if (getEffectiveLaneLayoutAsInt() ==
1220 other.getEffectiveLaneLayoutAsInt() &&
1221 getEffectiveLaneDataAsInt() == other.getEffectiveLaneDataAsInt())
1222 return true;
1223 }
1224
1225 auto compareCoordsForAllIds = [&](int64_t size) {
1226 return compareDistributedCoords(*this, other, shape, level, size);
1227 };
1228
1229 auto flattenedThis = flatten();
1230 auto parent = dyn_cast<LayoutAttr>(flattenedThis.getParent());
1231 if (level == xegpu::LayoutKind::Subgroup) {
1232 int64_t wgSize = computeProduct(parent.getEffectiveSgLayoutAsInt());
1233 return compareCoordsForAllIds(wgSize);
1234 }
1235 if (level == xegpu::LayoutKind::InstData) {
1236 return (getEffectiveInstDataAsInt() == other.getEffectiveInstDataAsInt());
1237 }
1238 if (level == xegpu::LayoutKind::Lane) {
1239 int64_t subgroupSize = computeProduct(parent.getEffectiveLaneLayoutAsInt());
1240 return compareCoordsForAllIds(subgroupSize);
1241 }
1242 return true;
1243}
1244
1245xegpu::SliceAttr SliceAttr::dropSliceDims(ArrayRef<int64_t> sliceDimsToDrop) {
1246 if (sliceDimsToDrop.empty())
1247 return *this;
1248 SmallVector<int64_t> sliceDims{getDims().asArrayRef()};
1249 for (auto dim : sliceDimsToDrop) {
1250 auto foundIt = std::find(sliceDims.begin(), sliceDims.end(), dim);
1251 assert(foundIt != sliceDims.end() &&
1252 "Expected to find the specified reduction dim in slice dims");
1253 sliceDims.erase(foundIt);
1254 }
1255
1256 auto sliceWithoutDims = xegpu::SliceAttr::get(
1257 this->getContext(), getParent(),
1258 DenseI64ArrayAttr::get(this->getContext(), sliceDims));
1259
1260 return sliceWithoutDims;
1261}
1262
1263// Helper function to adjust dimensions from sliced space to parent space
1264// say we have a parent shape of rank 4, and slice dims [1,3], so the sliced
1265// shape is of rank 2, if we want to set unit dim [0] in sliced space, it maps
1266// to dim [0] in parent space; if we want to set unit dim [1] in sliced space,
1267// it maps to dim [2] in parent space.
1268static SmallVector<int64_t>
1270 ArrayRef<int64_t> sliceDims) {
1271 // Rather than recovering the exact parent rank, we compute a safe upper
1272 // bound so that dimsToMap can be adjusted safely. This upper bound is
1273 // defined as max(dimsToMap, sliceDims) + 1 + sliceDims.size().
1274 int64_t maxDim = -1;
1275 maxDim =
1276 std::max(maxDim, *std::max_element(sliceDims.begin(), sliceDims.end()));
1277 maxDim =
1278 std::max(maxDim, *std::max_element(dimsToMap.begin(), dimsToMap.end()));
1279 int64_t parentSpaceRank = maxDim + sliceDims.size() + 1;
1280
1281 // get remaining dims in parent space after applying slicing with parent's
1282 // slice Dims
1283 llvm::SmallDenseSet<int64_t> slicedDimsSet(sliceDims.begin(),
1284 sliceDims.end());
1285 SmallVector<int64_t> remainingDims;
1286 for (int64_t i = 0; i < parentSpaceRank; ++i) {
1287 if (!slicedDimsSet.contains(i))
1288 remainingDims.push_back(i);
1289 }
1290
1291 // Map unit dims from sliced space to parent space
1292 SmallVector<int64_t> adjustUnitDims;
1293 for (auto dim : dimsToMap) {
1294 int64_t mappedDim = remainingDims[dim];
1295 adjustUnitDims.push_back(mappedDim);
1296 }
1297
1298 return adjustUnitDims;
1299}
1300
1301// set the layout for unit dims: sg_data, inst_data and lane_data to 1
1302DistributeLayoutAttr
1303SliceAttr::setUnitDimData(SmallVector<int64_t> unitDims) const {
1304 DistributeLayoutAttr parentLayout = getParent();
1305
1306 ArrayRef<int64_t> sliceDims = getDims().asArrayRef();
1307
1308 SmallVector<int64_t> adjustUnitDims =
1309 mapSlicedDimsToParentSpace(unitDims, sliceDims);
1310
1311 return SliceAttr::get(getContext(),
1312 parentLayout.setUnitDimData(adjustUnitDims), getDims());
1313}
1314
1315// set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
1316DistributeLayoutAttr
1317SliceAttr::setUnitDimLayout(SmallVector<int64_t> unitDims) const {
1318 DistributeLayoutAttr parentLayout = getParent();
1319
1320 ArrayRef<int64_t> sliceDims = getDims().asArrayRef();
1321
1322 SmallVector<int64_t> adjustUnitDims =
1323 mapSlicedDimsToParentSpace(unitDims, sliceDims);
1324
1325 return SliceAttr::get(
1326 getContext(), parentLayout.setUnitDimLayout(adjustUnitDims), getDims());
1327}
1328
1329// Derive a new layout with sg_data, inst_data and lane_data set to the
1330// specified values for the given dimension
1331DistributeLayoutAttr SliceAttr::setDimData(int64_t dim, int64_t sgData,
1332 int64_t instData, int64_t laneData) {
1333 ArrayRef<int64_t> sliceDims = getDims().asArrayRef();
1334 auto parent = getParent();
1335
1336 SmallVector<int64_t> dimSet;
1337 dimSet.push_back(dim);
1338 SmallVector<int64_t> adjustDims =
1339 mapSlicedDimsToParentSpace(dimSet, sliceDims);
1340 return SliceAttr::get(
1341 getContext(),
1342 parent.setDimData(adjustDims[0], sgData, instData, laneData), getDims());
1343}
1344
1345// Derive a new layout by removing dimensions. `dimGroup` specifies a group of
1346// dimensions to be removed in the derived layout.
1347//
1348// Example: drop the 2nd dimension from a rank-3 sliced view.
1349//
1350// Suppose:
1351// xegpu.layout = slice<layout<[V0, V1, V2, V3, V4]>, [1, 3]>
1352//
1353// The slice removes parent dims [1, 3], so the sliced-space dims map to
1354// parent dims [V0, V2, V4].
1355//
1356// If we drop sliced-space dim 1 (the 2nd dim), that corresponds to dropping
1357// parent dim 2, result in parent layout [V0, V1, V3, V4] after dropping.
1358// After parent dim 2 is removed, sliced dims [1, 3] must be reindexed to [1,
1359// 2].
1360//
1361// Result:
1362// xegpu.layout = slice<layout<[0, 1, 3, 4]>, [1, 2]>
1363DistributeLayoutAttr SliceAttr::dropDims(SmallVector<int64_t> dimGroup) {
1364 // Map the sliced dims from parent space to collapsed space
1365 SmallVector<int64_t> sliceDims = llvm::to_vector(getDims().asArrayRef());
1366 SmallVector<int64_t> dimsInParentSpace =
1367 mapSlicedDimsToParentSpace(dimGroup, sliceDims);
1368
1369 auto droppedParent = getParent().dropDims(dimsInParentSpace);
1370
1371 // Adjust the sliced dims after dropping dims in parent space. For example, if
1372 // we drop dim 2 in parent space, the dims after dim 2 will all be shifted by
1373 // 1, so sliced dim 3 will be adjusted to 2.
1374 SmallVector<int64_t> newSliceDims;
1375 for (int64_t d : sliceDims) {
1376 int64_t offset =
1377 llvm::count_if(dimsInParentSpace, [&](int64_t s) { return s < d; });
1378 newSliceDims.push_back(d - offset);
1379 }
1380
1381 return SliceAttr::get(getContext(), droppedParent,
1382 DenseI64ArrayAttr::get(getContext(), newSliceDims));
1383}
1384
1385// Derive a new layout by collapsing dimensions.
1386// `dimGroup` specifies a group of adjacent dimensions
1387// that are collapsed into a single dimension in the derived layout.
1388DistributeLayoutAttr SliceAttr::collapseDims(SmallVector<int64_t> dimGroup) {
1389
1390 // Map the sliced dims from parent space to collapsed space
1391 SmallVector<int64_t> sliceDims = llvm::to_vector(getDims().asArrayRef());
1392 assert("expect sliceDims not being collapsed" &&
1393 llvm::none_of(dimGroup, [&](int64_t dim) {
1394 return llvm::is_contained(sliceDims, dim);
1395 }));
1396 SmallVector<int64_t> dimsInParentSpace =
1397 mapSlicedDimsToParentSpace(dimGroup, sliceDims);
1398
1399 auto collapsedParent = getParent().collapseDims(dimsInParentSpace);
1400 return SliceAttr::get(getContext(), collapsedParent,
1401 DenseI64ArrayAttr::get(getContext(), sliceDims));
1402}
1403
1404// Derive a new layout by expanding a single sliced-space dim into multiple
1405// adjacent dims. The dim is mapped to parent space, the parent layout is
1406// expanded there, and the slice dims that lie past the expanded position
1407// are shifted up by `targetShape.size() - 1`.
1408DistributeLayoutAttr SliceAttr::expandDim(int64_t dim,
1409 ArrayRef<int64_t> targetShape) {
1410 // `dim` is in slice space; map it to parent space (parent dims listed in
1411 // `sliceDims` are removed by the slice, so the mapping always lands on a
1412 // non-sliced parent dim).
1413 ArrayRef<int64_t> sliceDims = getDims().asArrayRef();
1414 SmallVector<int64_t> dimSet = {dim};
1415 SmallVector<int64_t> dimsInParentSpace =
1416 mapSlicedDimsToParentSpace(dimSet, sliceDims);
1417 int64_t parentDim = dimsInParentSpace[0];
1418
1419 auto expandedParent = getParent().expandDim(parentDim, targetShape);
1420
1421 int64_t shift = static_cast<int64_t>(targetShape.size()) - 1;
1422 SmallVector<int64_t> newSliceDims;
1423 newSliceDims.reserve(sliceDims.size());
1424 for (int64_t s : sliceDims)
1425 newSliceDims.push_back(s > parentDim ? s + shift : s);
1426
1427 return SliceAttr::get(getContext(), expandedParent,
1428 DenseI64ArrayAttr::get(getContext(), newSliceDims));
1429}
1430
1432 ArrayRef<int64_t> permutation) {
1433 SmallVector<int64_t> sortedSliceDims = llvm::to_vector(sliceDims);
1434 llvm::sort(sortedSliceDims);
1435
1436 for (size_t i = 1; i < sortedSliceDims.size(); ++i) {
1437 assert((sortedSliceDims[i] == sortedSliceDims[i - 1] + 1) &&
1438 "slice dims non consecutive, cannot be transposed");
1439 }
1440
1441 SmallVector<int64_t> permForParent;
1442 if (sortedSliceDims.front() == 0) {
1443 // Example: sliceDims.size() = 2, permutation= {1, 0}
1444 // result: {3, 2, 1, 0}.
1445 for (int64_t dim : permutation)
1446 permForParent.push_back(dim + sortedSliceDims.size());
1447 for (int64_t i = sortedSliceDims.size() - 1; i >= 0; --i)
1448 permForParent.push_back(i);
1449 } else {
1450 // Example: sliceDims.size() = 2, permutation = {0, 1}
1451 // result: {3, 2, 0, 1}.
1452 for (int64_t i = sortedSliceDims.size() - 1; i >= 0; --i)
1453 permForParent.push_back(i + permutation.size());
1454 for (int64_t dim : permutation)
1455 permForParent.push_back(dim);
1456 }
1457 return permForParent;
1458}
1459
1460// Derive a new layout by transpose the layout using `permutation`.
1461DistributeLayoutAttr SliceAttr::transposeDims(ArrayRef<int64_t> permutation) {
1462 SmallVector<int64_t> sliceDims = llvm::to_vector(getDims().asArrayRef());
1463 DistributeLayoutAttr parent = getParent();
1464 SmallVector<int64_t> permForParent =
1465 getPermForParentLayout(sliceDims, permutation);
1466 auto transposedParent = parent.transposeDims(permForParent);
1467 return SliceAttr::get(getContext(), transposedParent,
1468 DenseI64ArrayAttr::get(getContext(), sliceDims));
1469}
1470
1471/// Check if this layout is a transpose of another layout.
1472bool SliceAttr::isTransposeOf(const xegpu::DistributeLayoutAttr &other,
1473 ArrayRef<int64_t> perm,
1474 const xegpu::LayoutKind kind) {
1475 // other must be a SliceAttr with the same slice dims.
1476 auto otherSlice = dyn_cast<xegpu::SliceAttr>(other);
1477 if (!otherSlice || getDims() != otherSlice.getDims())
1478 return false;
1479 // check whether the parent layout is transpose of each other.
1480 SmallVector<int64_t> sliceDims = llvm::to_vector(getDims().asArrayRef());
1481 DistributeLayoutAttr parent = getParent();
1482 SmallVector<int64_t> permForParent = getPermForParentLayout(sliceDims, perm);
1483 auto otherParent = otherSlice.getParent();
1484 return parent.isTransposeOf(otherParent, permForParent, kind);
1485}
1486
1487//===----------------------------------------------------------------------===//
1488// XeGPU_RangeAttr
1489//===----------------------------------------------------------------------===//
1490
1491LogicalResult
1492RangeAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
1493 IntegerAttr startOfRange, IntegerAttr endOfRange) {
1494 if (startOfRange.getInt() >= endOfRange.getInt())
1495 return emitError() << "'end' : " << endOfRange.getInt()
1496 << " must be greater than 'start' : "
1497 << startOfRange.getInt();
1498
1499 return success();
1500}
1501
1502//===----------------------------------------------------------------------===//
1503// XeGPU_TensorDescType
1504//===----------------------------------------------------------------------===//
1505
1506mlir::Type TensorDescType::parse(AsmParser &parser) {
1507 llvm::SmallVector<int64_t> shape;
1508 mlir::Type elementType;
1509 mlir::FailureOr<mlir::Attribute> encoding;
1510 mlir::FailureOr<mlir::Attribute> layout;
1511
1512 // Parse literal '<'
1513 if (parser.parseLess())
1514 return {};
1515
1516 auto shapeLoc = parser.getCurrentLocation();
1517 if (mlir::failed(parser.parseDimensionList(shape))) {
1518 parser.emitError(shapeLoc, "failed to parse parameter 'shape'");
1519 return {};
1520 }
1521
1522 auto elemTypeLoc = parser.getCurrentLocation();
1523 if (mlir::failed(parser.parseType(elementType))) {
1524 parser.emitError(elemTypeLoc, "failed to parse parameter 'elementType'");
1525 return {};
1526 }
1527
1528 // parse optional attributes
1529 while (mlir::succeeded(parser.parseOptionalComma())) {
1530 mlir::Attribute attr;
1531 ParseResult res = parser.parseAttribute(attr);
1532 if (mlir::succeeded(res)) {
1533 if (mlir::isa<DistributeLayoutAttr>(attr)) {
1534 layout = attr;
1535 continue;
1536 }
1537 if (mlir::isa<BlockTensorDescAttr>(attr)) {
1538 encoding = attr;
1539 continue;
1540 }
1541 }
1542 return {};
1543 }
1544
1545 // Parse literal '>'
1546 if (parser.parseGreater())
1547 return {};
1548
1549 MLIRContext *ctxt = parser.getContext();
1550 return TensorDescType::getChecked(
1551 [&]() { return parser.emitError(parser.getNameLoc()); }, ctxt, shape,
1552 elementType, encoding.value_or(BlockTensorDescAttr::get(ctxt)),
1553 layout.value_or(mlir::Attribute()));
1554}
1555
1556void TensorDescType::print(AsmPrinter &printer) const {
1557 printer << "<";
1558
1559 auto shape = getShape();
1560 for (int64_t dim : shape) {
1561 if (mlir::ShapedType::isDynamic(dim))
1562 printer << '?';
1563 else
1564 printer << dim;
1565 printer << 'x';
1566 }
1567
1568 printer << getElementType();
1569
1570 auto encoding = getEncoding();
1571 auto blockAttr = llvm::dyn_cast_if_present<BlockTensorDescAttr>(encoding);
1572 if (encoding && (!blockAttr || !blockAttr.hasDefaultsOnly()))
1573 printer << ", " << encoding;
1574
1575 if (auto layout = getLayout())
1576 printer << ", " << layout;
1577
1578 printer << ">";
1579}
1580
1581TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
1582 mlir::Type elementType, int array_length,
1583 bool boundary_check,
1584 MemorySpace memory_space,
1585 mlir::Attribute layout) {
1586 auto *context = elementType.getContext();
1587 auto attr = BlockTensorDescAttr::get(context, memory_space, array_length,
1588 boundary_check);
1589 return Base::get(context, shape, elementType, attr, layout);
1590}
1591
1592LogicalResult
1593TensorDescType::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
1594 llvm::ArrayRef<int64_t> shape, mlir::Type elementType,
1595 mlir::Attribute encoding, mlir::Attribute layout) {
1596 size_t rank = shape.size();
1597
1598 if (rank == 0)
1599 return emitError() << "expected non-zero rank tensor";
1600
1601 auto blockAttr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(encoding);
1602 if (blockAttr) {
1603 MemorySpaceAttr memorySpaceAttr = blockAttr.getMemorySpace();
1604 if (rank > 1 && memorySpaceAttr &&
1605 memorySpaceAttr.getValue() == MemorySpace::SLM)
1606 return emitError() << "SLM is only supported for 1D block tensor";
1607 }
1608
1609 if (!elementType.isIntOrFloat())
1610 return emitError() << "unsupported element type " << elementType
1611 << ": expected integer or float";
1612
1613 if (auto layoutAttr =
1614 mlir::dyn_cast_if_present<DistributeLayoutAttr>(layout)) {
1615 if (rank != (size_t)layoutAttr.getRank())
1616 return emitError() << "expected layout rank to match tensor rank";
1617
1618 if (!layoutAttr.isDistributable(SmallVector<int64_t>(shape))) {
1619 std::string shapeStr;
1620 llvm::raw_string_ostream stream(shapeStr);
1621 llvm::interleaveComma(shape, stream);
1622 return emitError() << "cannot distribute [" << shapeStr << "] using "
1623 << layoutAttr;
1624 }
1625 }
1626
1627 return success();
1628}
1629
1630//===----------------------------------------------------------------------===//
1631// XeGPU_MemDescType
1632//===----------------------------------------------------------------------===//
1633mlir::Type MemDescType::parse(AsmParser &parser) {
1634 llvm::SmallVector<int64_t> shape;
1635 mlir::Type elementType;
1636 mlir::FailureOr<MemLayoutAttr> layout;
1637
1638 // Parse literal '<'
1639 if (parser.parseLess())
1640 return {};
1641
1642 auto shapeLoc = parser.getCurrentLocation();
1643 if (mlir::failed(parser.parseDimensionList(shape, false, true))) {
1644 parser.emitError(shapeLoc, "failed to parse parameter 'shape'");
1645 return {};
1646 }
1647
1648 auto elemTypeLoc = parser.getCurrentLocation();
1649 if (mlir::failed(parser.parseType(elementType))) {
1650 parser.emitError(elemTypeLoc, "failed to parse parameter 'elementType'");
1651 return {};
1652 }
1653
1654 // parse optional attributes
1655 if (mlir::succeeded(parser.parseOptionalComma())) {
1656 MemLayoutAttr attr;
1657 ParseResult res = parser.parseAttribute(attr);
1658 if (mlir::failed(res))
1659 return {};
1660 layout = attr;
1661 }
1662
1663 // Parse literal '>'
1664 if (parser.parseGreater())
1665 return {};
1666
1667 MLIRContext *ctxt = parser.getContext();
1668 return MemDescType::getChecked(
1669 [&]() { return parser.emitError(parser.getNameLoc()); }, ctxt, shape,
1670 elementType, layout.value_or(MemLayoutAttr()));
1671}
1672
1673void MemDescType::print(AsmPrinter &printer) const {
1674 printer << "<";
1675
1676 printer.printDimensionList(getShape());
1677 printer << 'x';
1678 printer << getElementType();
1679
1680 if (auto layout = getMemLayout())
1681 printer << ", " << layout;
1682
1683 printer << ">";
1684}
1685
1686//===----------------------------------------------------------------------===//
1687// XeGPU_MemDescType
1688//===----------------------------------------------------------------------===//
1689
1690Attribute MemLayoutAttr::parse(AsmParser &parser, Type type) {
1691
1692 auto *context = parser.getContext();
1693 llvm::SMLoc loc = parser.getCurrentLocation();
1694
1695 llvm::SmallDenseSet<StringRef> seenKeys;
1696 SmallVector<NamedAttribute> attributes;
1697
1698 auto parseElt = [&]() -> ParseResult {
1699 StringRef nameId;
1700 if (failed(parser.parseKeyword(&nameId)))
1701 return parser.emitError(loc, "expected valid attribute name");
1702
1703 if (!seenKeys.insert(nameId).second)
1704 return parser.emitError(loc, "duplicate key '")
1705 << nameId << " in mem layout attribute";
1706
1707 if (failed(parser.parseEqual()))
1708 return failure();
1709
1710 Attribute attr;
1711 if (failed(parser.parseAttribute(attr)))
1712 return failure();
1713 attributes.emplace_back(nameId, attr);
1714 return success();
1715 };
1716
1717 // Parse literal '<'
1718 if (parser.parseLess())
1719 return {};
1720
1721 if (failed(parser.parseCommaSeparatedList(parseElt)))
1722 return {};
1723
1724 // Parse literal '>'
1725 if (parser.parseGreater())
1726 return {};
1727
1728 return parser.getChecked<MemLayoutAttr>(
1729 loc, context, DictionaryAttr::get(context, attributes));
1730}
1731
1732void MemLayoutAttr::print(AsmPrinter &printer) const {
1733 printer << "<";
1734 ArrayRef<NamedAttribute> attrs = getAttrs().getValue();
1735 for (size_t i = 0; i < attrs.size(); i++) {
1736 printer << attrs[i].getName().str() << " = " << attrs[i].getValue();
1737 if (i < attrs.size() - 1)
1738 printer << ", ";
1739 }
1740 printer << ">";
1741}
1742// a helper utility to perform binary operation on OpFoldResult.
1743// If both a and b are attributes, it will simply return the result.
1744// Otherwise, the corresponding arith op will be generated, and an
1745// contant op will be created if one of them is an attribute.
1746template <typename ArithOp>
1748 OpBuilder &builder) {
1749 auto aVal = getValueOrCreateConstantIndexOp(builder, loc, a);
1750 auto bVal = getValueOrCreateConstantIndexOp(builder, loc, b);
1751 return ArithOp::create(builder, loc, aVal, bVal).getResult();
1752}
1753
1754// a helper utility to perform division operation on OpFoldResult and int64_t.
1755#define div(a, b) \
1756 genBinOp<arith::DivSIOp>(a, builder.getIndexAttr(b), loc, builder)
1757
1758// a helper utility to perform reminder operation on OpFoldResult and int64_t.
1759#define rem(a, b) \
1760 genBinOp<arith::RemSIOp>(a, builder.getIndexAttr(b), loc, builder)
1761
1762// a helper utility to perform multiply operation on OpFoldResult and int64_t.
1763#define mul(a, b) \
1764 genBinOp<arith::MulIOp>(a, builder.getIndexAttr(b), loc, builder)
1765
1766// a helper utility to perform addition operation on two OpFoldResult.
1767#define add(a, b) genBinOp<arith::AddIOp>(a, b, loc, builder)
1768
1769// block the given offsets according to the block shape
1770// say the original offset is [y, x], and the block shape is [By, Bx],
1771// then the blocked offset is [y/By, x/Bx, y%By, x%Bx]
1773 ArrayRef<OpFoldResult> offsets,
1774 ArrayRef<int64_t> blockShape) {
1775
1776 assert(offsets.size() == blockShape.size() &&
1777 "offsets and blockShape must have the same size");
1778 SmallVector<OpFoldResult> blockedOffsets;
1779 SmallVector<OpFoldResult> divs, rems;
1780
1781 for (auto [offset, block] : llvm::zip(offsets, blockShape)) {
1782 divs.push_back(div(offset, block));
1783 rems.push_back(rem(offset, block));
1784 }
1785 blockedOffsets.append(divs.begin(), divs.end());
1786 blockedOffsets.append(rems.begin(), rems.end());
1787
1788 return blockedOffsets;
1789}
1790
1791// Get strides as vector of integer for MemDesc.
1792SmallVector<int64_t> MemDescType::getStrideShape() {
1793
1794 SmallVector<int64_t> matrixShape(getShape().begin(), getShape().end());
1795
1796 ArrayAttr strideAttr = getStrideAttr();
1797 SmallVector<int64_t> strides;
1798 for (Attribute attr : strideAttr.getValue()) {
1799 strides.push_back(cast<IntegerAttr>(attr).getInt());
1800 }
1801
1802 SmallVector<int64_t> innerBlkShape = getBlockShape();
1803
1804 // get perm from FCD to LCD
1805 // perm[i] = the dim with i-th smallest stride
1806 SmallVector<int, 4> perm =
1807 llvm::to_vector<4>(llvm::seq<int>(0, strides.size()));
1808 llvm::sort(perm, [&](int a, int b) { return strides[a] < strides[b]; });
1809
1810 assert(strides[perm[0]] == 1 && "inner most dim must have stride 1");
1811
1812 SmallVector<int64_t> innerBlkStride(innerBlkShape.size());
1813 innerBlkStride[perm[0]] = 1;
1814 for (size_t i = 1; i < perm.size(); ++i)
1815 innerBlkStride[perm[i]] =
1816 innerBlkStride[perm[i - 1]] * innerBlkShape[perm[i - 1]];
1817
1818 // compute the original matrix shape using the stride info
1819 // and compute the number of blocks in each dimension
1820 // The shape of highest dim can't be derived from stride info,
1821 // but doesn't impact the stride computation for blocked layout.
1822 SmallVector<int64_t> matrixShapeOrig(matrixShape.size());
1823 SmallVector<int64_t> BlkShapeOrig(matrixShape.size());
1824 for (size_t i = 0; i < perm.size() - 1; ++i) {
1825 matrixShapeOrig[perm[i]] = strides[perm[i + 1]] / strides[perm[i]];
1826 BlkShapeOrig[perm[i]] = matrixShapeOrig[perm[i]] / innerBlkShape[perm[i]];
1827 }
1828
1829 int64_t innerBlkSize = 1;
1830 for (auto s : innerBlkShape)
1831 innerBlkSize *= s;
1832
1833 SmallVector<int64_t> outerBlkStride(matrixShape.size());
1834 outerBlkStride[perm[0]] = innerBlkSize;
1835 for (size_t i = 0; i < perm.size() - 1; ++i) {
1836 outerBlkStride[perm[i + 1]] =
1837 outerBlkStride[perm[i]] * BlkShapeOrig[perm[i]];
1838 }
1839
1840 // combine the inner and outer strides
1841 SmallVector<int64_t> blockedStrides;
1842 blockedStrides.append(outerBlkStride.begin(), outerBlkStride.end());
1843 blockedStrides.append(innerBlkStride.begin(), innerBlkStride.end());
1844
1845 return blockedStrides;
1846}
1847
1848// Calculate the linear offset using the blocked offsets and stride
1849Value MemDescType::getLinearOffsets(OpBuilder &builder, Location loc,
1850 ArrayRef<OpFoldResult> offsets) {
1851
1852 SmallVector<int64_t> matrixShape(getShape().begin(), getShape().end());
1853 SmallVector<int64_t> blockShape = getBlockShape();
1854 SmallVector<int64_t> strides = getStrideShape();
1855 SmallVector<OpFoldResult> blockedOffsets;
1856
1857 // blockshape equal to matrixshape means no blocking
1858 if (llvm::equal(blockShape, matrixShape)) {
1859 // remove the outer dims from strides
1860 strides.erase(strides.begin(), strides.begin() + matrixShape.size());
1861 } else {
1862 assert(offsets.size() == blockShape.size() &&
1863 "offsets and blockShape must have the same size");
1864 // say the original offset is [y, x], and the block shape is [By, Bx],
1865 // then the blocked offset is [y/By, x/Bx, y%By, x%Bx]
1866
1867 SmallVector<OpFoldResult> divs, rems;
1868
1869 for (auto [offset, block] : llvm::zip(offsets, blockShape)) {
1870 divs.push_back(div(offset, block));
1871 rems.push_back(rem(offset, block));
1872 }
1873 blockedOffsets.append(divs.begin(), divs.end());
1874 blockedOffsets.append(rems.begin(), rems.end());
1875 offsets = blockedOffsets;
1876 }
1877
1878 // Start with initial value as matrix descriptor's base offset.
1879 Value linearOffset = arith::ConstantIndexOp::create(builder, loc, 0);
1880 for (size_t i = 0; i < offsets.size(); ++i) {
1881 OpFoldResult mulResult = mul(offsets[i], strides[i]);
1882 Value mulVal = getValueOrCreateConstantIndexOp(builder, loc, mulResult);
1883 linearOffset = arith::AddIOp::create(builder, loc, mulVal, linearOffset);
1884 }
1885
1886 return linearOffset;
1887}
1888
1889} // namespace xegpu
1890} // namespace mlir
1891
1892#include <mlir/Dialect/XeGPU/IR/XeGPUDialect.cpp.inc>
1893#define GET_ATTRDEF_CLASSES
1894#include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc>
1895#define GET_TYPEDEF_CLASSES
1896#include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc>
return success()
static Type getElementType(Type type)
Determine the element type of type.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
#define mul(a, b)
#define div(a, b)
#define rem(a, b)
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
auto getChecked(SMLoc loc, ParamsT &&...params)
Invoke the getChecked method of the given Attribute or Type class, using the provided location to emi...
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
void printDimensionList(ArrayRef< int64_t > shape)
Attributes are known-constant values of operations.
Definition Attributes.h:25
static BoolAttr get(MLIRContext *context, bool value)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:209
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
This class represents a single result from folding an operation.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Specialization of arith.constant op that returns an integer of index type.
Definition Arith.h:114
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:384
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
auto getDims(VectorType vType)
Returns a range over the dims (size and scalability) of a VectorType.
static SmallVector< SmallVector< int64_t > > genStaticCoordinates(llvm::ArrayRef< int64_t > canonicalIds, llvm::ArrayRef< int64_t > layout, llvm::ArrayRef< int64_t > subShape, llvm::ArrayRef< int64_t > shape)
LayoutKind
Specifies the level of a layout hierarchy for comparison or propagation.
Definition XeGPU.h:32
static SmallVector< int64_t > mapSlicedDimsToParentSpace(const SmallVector< int64_t > &dimsToMap, ArrayRef< int64_t > sliceDims)
SmallVector< OpFoldResult > getBlockedOffsets(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > offsets, ArrayRef< int64_t > blockShape)
OpFoldResult genBinOp(OpFoldResult a, OpFoldResult b, Location loc, OpBuilder &builder)
static bool compareDistributedCoords(xegpu::DistributeLayoutAttr self, const xegpu::DistributeLayoutAttr &other, ArrayRef< int64_t > shape, xegpu::LayoutKind level, int64_t size)
Returns true if self and other distribute shape identically at level: every id in [0,...
static SmallVector< SmallVector< Value > > genCoordinates(OpBuilder &builder, Location loc, SmallVector< Value > delinearizedId, ArrayRef< int64_t > subShapesLayout, ArrayRef< int64_t > subShape, ArrayRef< int64_t > srcShape)
SmallVector< int64_t > getPermForParentLayout(ArrayRef< int64_t > sliceDims, ArrayRef< int64_t > permutation)
static SmallVector< SmallVector< int64_t > > expandBlockCoords(ArrayRef< SmallVector< int64_t > > blockStarts, ArrayRef< int64_t > subShape)
Expands per-distribution-unit block-start coordinates into the full list of element coordinates each ...
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
SmallVector< int64_t > computeElementwiseMul(ArrayRef< int64_t > v1, ArrayRef< int64_t > v2)
Return a vector containing llvm::zip_equal(v1, v2) multiplied elementwise.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:114
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.