MLIR 23.0.0git
XeGPUDialect.cpp
Go to the documentation of this file.
1//===- XeGPUDialect.cpp - MLIR XeGPU dialect implementation -----*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
14#include "mlir/IR/Builders.h"
16#include "llvm/ADT/SmallVectorExtras.h"
17#include "llvm/ADT/TypeSwitch.h"
18#include "llvm/Support/Debug.h"
19
20using std::optional;
21
22namespace mlir {
23namespace xegpu {
24
25void XeGPUDialect::initialize() {
26 addTypes<
27#define GET_TYPEDEF_LIST
28#include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc>
29 >();
30 addOperations<
31#define GET_OP_LIST
32#include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
33 >();
34 addAttributes<
35#define GET_ATTRDEF_LIST
36#include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc>
37 >();
38}
39#define GET_OP_INTERFACE_CLASSES
40#include "mlir/Dialect/XeGPU/IR/XeGPUOpInterface.cpp.inc"
41
42// A `srcShape` consists of N distribution units, each being `subShapesLayout` x
43// `subShape`. A `delinearizedId` is used to identify a particular `subShape`
44// within each distribution unit.
45// Example:
46// WG data is 128x256. SG data is 16x32, in 4x2 layout, this gives a
47// distribution unit of shape 64x64, we have 2x4 such distribution units.
48// `delinearizedId` is used to identify a 16x32 of a subgroup in each
49// distribution unit.
52 SmallVector<Value> delinearizedId,
53 ArrayRef<int64_t> subShapesLayout, ArrayRef<int64_t> subShape,
54 ArrayRef<int64_t> srcShape) {
56
57 // A distribution unit must be less than or equal to `srcShape`
58 SmallVector<int64_t> distUnitShape = llvm::map_to_vector(
59 llvm::zip_equal(srcShape,
60 computeElementwiseMul(subShapesLayout, subShape)),
61 [](const auto &t) { return std::min(std::get<0>(t), std::get<1>(t)); });
62
63 // Get the offset of `subShape` within a distribution unit.
64 SmallVector<Value> distUnitLocalOffset = llvm::map_to_vector(
65 llvm::zip(delinearizedId, subShape), [&](const auto &t) -> Value {
66 return builder.createOrFold<arith::MulIOp>(
67 loc, std::get<0>(t),
68 builder.createOrFold<arith::ConstantIndexOp>(loc, std::get<1>(t)));
69 });
70
71 // For each dist unit
72 for (SmallVector<int64_t> unitOffs :
73 StaticTileOffsetRange(srcShape, distUnitShape)) {
74 // Get dist unit offset within `srcShape`.
76 llvm::map_to_vector(unitOffs, [&](int64_t d) -> Value {
77 return arith::ConstantIndexOp::create(builder, loc, d);
78 });
79 // Calculate `subShape` offset within `srcShape`.
81 llvm::map_to_vector(llvm::zip_equal(base, distUnitLocalOffset),
82 [&](const auto &t) -> Value {
83 return builder.createOrFold<arith::AddIOp>(
84 loc, std::get<0>(t), std::get<1>(t));
85 });
86 // Do not go beyond `srcShape` bounds.
87 SmallVector<Value> mods = llvm::map_to_vector(
88 llvm::zip_equal(adds, srcShape), [&](const auto &t) -> Value {
89 return builder.createOrFold<arith::RemUIOp>(
90 loc, std::get<0>(t),
91 arith::ConstantIndexOp::create(builder, loc, std::get<1>(t)));
92 });
93
94 coordinates.push_back(mods);
95 }
96 return coordinates;
97}
98
102 // Compute distribution unit shape (clamped to srcShape).
103 SmallVector<int64_t> distUnitShape(shape.size());
104 for (size_t i = 0; i < shape.size(); ++i)
105 distUnitShape[i] = std::min(shape[i], layout[i] * subShape[i]);
106
107 // Compute local offset of this ID within a distribution unit.
108 SmallVector<int64_t> localOffset(shape.size());
109 for (size_t i = 0; i < shape.size(); ++i)
110 localOffset[i] = canonicalIds[i] * subShape[i];
111
112 // Enumerate all distribution units and compute coordinates.
114 for (SmallVector<int64_t> unitOffs :
115 StaticTileOffsetRange(shape, distUnitShape)) {
116 SmallVector<int64_t> coord(shape.size());
117 for (size_t i = 0; i < shape.size(); ++i)
118 coord[i] = (unitOffs[i] + localOffset[i]) % shape[i];
119 coordinates.push_back(coord);
120 }
121 return coordinates;
122}
123
124// Checks if the given shape can be evenly distributed based on the layout
125// and data factors provided by the LayoutAttr.
126bool XeGPUDialect::isEvenlyDistributable(llvm::ArrayRef<int64_t> shape,
127 xegpu::DistributeLayoutAttr attr) {
128 assert(attr && "Layout attribute is missing.");
129
130 // Checks whether the given shape can be evenly distributed using the
131 // specified layout and data attributes. If successful, it returns the work
132 // size for each compute unit; otherwise, it returns `std::nullopt`. The work
133 // size per compute unit is calculated as follows:
134 // - If `data` is null: newShape[i] = shape[i] / layout[i]
135 // - If `data` is not null: newShape[i] = data[i]
136 // When round-robin distribution (`rr`) is enabled, `shape[i]` can be
137 // smaller than `layout[i] * data[i]`, allowing multiple compute units to
138 // share the data.
139 auto tryDistribute = [&](llvm::ArrayRef<int64_t> shape,
142 bool rr = true) -> optional<SmallVector<int64_t>> {
144 if (layout.size()) {
145 if (layout.size() != shape.size())
146 return std::nullopt;
147 auto ratio = computeShapeRatio(shape, layout);
148 if (ratio.has_value()) {
149 newShape = ratio.value();
150 } else if (!rr || !computeShapeRatio(layout, shape).has_value()) {
151 return std::nullopt;
152 }
153 // Round-robin case: continue with original newShape
154 }
155
156 if (data.size()) {
157 if (data.size() != shape.size())
158 return std::nullopt;
159 auto ratio = computeShapeRatio(newShape, data);
160 if (!ratio.has_value() && rr)
161 ratio = computeShapeRatio(data, newShape);
162 if (!ratio.has_value())
163 return std::nullopt;
164
165 // if data is not null, we always return it for next phase.
166 newShape = data;
167 }
168 return newShape;
169 };
170
171 // check the sgLayout and sgData
172 auto maybeSgShape = tryDistribute(shape, attr.getEffectiveSgLayoutAsInt(),
173 attr.getEffectiveSgDataAsInt());
174 if (!maybeSgShape)
175 return false;
176 auto sgShape = maybeSgShape.value();
177
178 // check InstData, it neither have layout nor need round-robin
179 auto maybeInstShape =
180 tryDistribute(sgShape, {}, attr.getEffectiveInstDataAsInt(), false);
181 if (!maybeInstShape)
182 return false;
183 auto instShape = maybeInstShape.value();
184
185 // check LaneLayout and LaneData
186 auto maybeLaneShape =
187 tryDistribute(instShape, attr.getEffectiveLaneLayoutAsInt(),
188 attr.getEffectiveLaneDataAsInt());
189 return maybeLaneShape.has_value();
190}
191
192//===----------------------------------------------------------------------===//
193// XeGPU_BlockTensorDescAttr
194//===----------------------------------------------------------------------===//
195BlockTensorDescAttr BlockTensorDescAttr::get(mlir::MLIRContext *context,
196 xegpu::MemorySpace memory_space,
197 int array_length,
198 bool boundary_check) {
199 auto scopeAttr = MemorySpaceAttr::get(context, memory_space);
200 auto lengthAttr =
201 IntegerAttr::get(IntegerType::get(context, 64), array_length);
202 auto boundaryAttr = BoolAttr::get(context, boundary_check);
203 return Base::get(context, scopeAttr, lengthAttr, boundaryAttr);
204}
205
206bool BlockTensorDescAttr::hasDefaultsOnly() {
207 return getMemorySpace().getValue() == xegpu::MemorySpace::Global &&
208 getArrayLength().getInt() == 1 && getBoundaryCheck().getValue();
209}
210
211//===----------------------------------------------------------------------===//
212// XeGPU_ScatterTensorDescAttr
213//===----------------------------------------------------------------------===//
214ScatterTensorDescAttr
215ScatterTensorDescAttr::get(mlir::MLIRContext *context,
216 xegpu::MemorySpace memory_space, int chunk_size) {
217 auto scopeAttr = MemorySpaceAttr::get(context, memory_space);
218 auto chunkSizeAttr =
219 IntegerAttr::get(IntegerType::get(context, 64), chunk_size);
220 return Base::get(context, scopeAttr, chunkSizeAttr);
221}
222
223LogicalResult ScatterTensorDescAttr::verify(
224 llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
225 MemorySpaceAttr memory_space, IntegerAttr chunk_size) {
226 int64_t chunkSize = chunk_size.getInt();
227 if (chunkSize <= 0)
228 return emitError() << "invalid chunk size";
229
230 return success();
231}
232
233//===----------------------------------------------------------------------===//
234// XeGPU_LayoutAttr
235//===----------------------------------------------------------------------===//
236LogicalResult
237LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
238 DenseI32ArrayAttr sg_layout, DenseI32ArrayAttr sg_data,
239 DenseI32ArrayAttr inst_data, DenseI32ArrayAttr lane_layout,
240 DenseI32ArrayAttr lane_data, DenseI32ArrayAttr order) {
241
242 // Special case for store_matrix
243 if (!sg_layout && !inst_data && !lane_layout)
244 return success();
245
246 // generate code to check sg_laout, inst_data and lane_layout having the same
247 // rank if they are not null.
248
249 if (sg_layout && inst_data && sg_layout.size() != inst_data.size()) {
250 return emitError()
251 << "expected sg_layout and inst_data to have the same rank";
252 }
253
254 if (sg_layout && lane_layout && sg_layout.size() != lane_layout.size()) {
255 return emitError()
256 << "expected sg_layout and lane_layout to have the same rank";
257 }
258
259 if (inst_data && lane_layout && inst_data.size() != lane_layout.size()) {
260 return emitError() << "expected inst_data and lane_layout to have the same "
261 "rank, got inst_data "
262 << inst_data.size() << ", lane_layout "
263 << lane_layout.size();
264 }
265
266 if ((sg_layout && !sg_data) || (!sg_layout && sg_data))
267 return emitError() << "sg_layout and sg_data must be used together";
268 if (sg_layout && sg_data && sg_layout.size() != sg_data.size())
269 return emitError()
270 << "expected sg_data and sg_layout to have the same rank";
271
272 if ((lane_layout && !lane_data) || (!lane_layout && lane_data))
273 return emitError() << "lane_layout and lane_data must be used together";
274 if (lane_layout && lane_data && lane_layout.size() != lane_data.size())
275 return emitError()
276 << "expected lane_data and lane_layout to have the same rank";
277
278 if (order) {
279 if (!sg_layout && !lane_layout)
280 return emitError()
281 << "expected sg_layout/lane_layout being used with order";
282
283 if (sg_layout && order.size() != sg_layout.size())
284 return emitError()
285 << "expected order and sg_layout to have the same rank";
286
287 if (lane_layout && order.size() != lane_layout.size())
288 return emitError()
289 << "expected order and lane_layout to have the same rank";
290 }
291
292 return success();
293}
294
295FailureOr<SmallVector<Value>>
296LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) {
297
298 SmallVector<int64_t> sgLayoutInt;
299 if (isForWorkgroup()) {
300 sgLayoutInt = getEffectiveSgLayoutAsInt();
301 } else if (isForSubgroup()) {
302 sgLayoutInt = getEffectiveLaneLayoutAsInt();
303 } else {
304 return failure();
305 }
306
307 DenseI32ArrayAttr orderAttr = getOrder();
308
309 // Handle order attribute
310 SmallVector<int64_t> order;
311 if (orderAttr && !orderAttr.empty()) {
312 order = llvm::map_to_vector(orderAttr.asArrayRef(), [](int32_t idx) {
313 return static_cast<int64_t>(idx);
314 });
315 } else {
316 // Default order: [1, 0] for 2D (row-major), [2, 1, 0] for 3D, etc.
317 order = llvm::to_vector(
318 llvm::reverse(llvm::seq<int64_t>(0, sgLayoutInt.size())));
319 }
320
321 if (order.size() != sgLayoutInt.size()) {
322 return failure();
323 }
324
325 SmallVector<Value> result(sgLayoutInt.size());
326 Value remaining = linearId;
327
328 /// Process dimensions in the order they appear in the order array
329 /// The first dimension in order is the fastest-changing
330 ///
331 /// Example walkthrough for linearId=22, sgLayout=[2,4,4], order=[2,1,0]:
332 ///
333 /// Initial: remaining=22, dimIdx = order[i], dimSize = sgLayout[dimIdx],
334 /// result=[?,?,?]
335 ///
336 /// i=0 (process columns, dimIdx=2, dimSize=4):
337 /// result[2] = 22 % 4 = 2 (column coordinate)
338 /// remaining = 22 / 4 = 5 (5 complete groups of 4 columns processed)
339 ///
340 /// i=1 (process rows, dimIdx=1, dimSize=4):
341 /// result[1] = 5 % 4 = 1 (row coordinate)
342 /// remaining = 5 / 4 = 1 (1 complete group of 4 rows processed)
343 ///
344 /// i=2 (process layers, dimIdx=0, dimSize=2):
345 /// result[0] = 1 % 2 = 1 (layer coordinate)
346 /// (no remaining update - last iteration)
347 ///
348 /// Final result: [1,1,2] = Layer 1, Row 1, Column 2
349 for (size_t i = 0; i < order.size(); ++i) {
350 int64_t dimIdx = order[i];
351 int64_t dimSize = sgLayoutInt[dimIdx];
352
353 Value dimSizeVal =
354 builder.createOrFold<arith::ConstantIndexOp>(loc, dimSize);
355
356 /// Extract the coordinate for this dimension using modulo operation
357 /// This gives us "how far within this dimension" we are
358 /// e.g., linearId=22, dimSize=4: 22 % 4 = 2 (we're at position 2 within
359 /// this dimension)
360 result[dimIdx] =
361 builder.createOrFold<arith::RemUIOp>(loc, remaining, dimSizeVal);
362
363 /// Update remaining for the next dimension by removing what we've already
364 /// processed. Division tells us "how many complete groups of this dimension
365 /// we've gone through" e.g., linearId=22, dimSize=4: 22 / 4 = 5 (we've
366 /// completed 5 groups of 4) Skip this for the last iteration since there's
367 /// no next dimension to process
368 if (i < order.size() - 1) {
369 remaining =
370 builder.createOrFold<arith::DivUIOp>(loc, remaining, dimSizeVal);
371 }
372 }
373 return result;
374}
375
376/// Implements DistributeLayoutAttr::computeDistributedCoords to generate
377/// instructions for computing multi-dimensional offsets when distributed by
378/// LayoutAttr.
379FailureOr<SmallVector<SmallVector<Value>>>
380LayoutAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
381 Value linearId, ArrayRef<int64_t> shape) {
382 SmallVector<int64_t> layout;
383 SmallVector<int64_t> subShape;
384 if (isForWorkgroup()) {
385 layout = getEffectiveSgLayoutAsInt();
386 subShape = getEffectiveSgDataAsInt();
387 } else if (isForSubgroup()) {
388 layout = getEffectiveLaneLayoutAsInt();
389 subShape = getEffectiveLaneDataAsInt();
390 } else {
391 return failure();
392 }
393 assert(!subShape.empty() && "sgdata or lanedata cannot be empty for "
394 "distributed coordinates computation");
395
396 // delinearize Ids
397 auto maybeIds = delinearizeId(builder, loc, linearId);
398 if (failed(maybeIds))
399 return failure();
400 SmallVector<Value> ids = *maybeIds;
401
402 return genCoordinates(builder, loc, ids, layout, subShape, shape);
403}
404
405bool LayoutAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) {
406 if (dyn_cast<xegpu::SliceAttr>(other))
407 return false;
408
409 return *this == dyn_cast<xegpu::LayoutAttr>(other);
410}
411
412/// Implements DistributeLayoutAttr::computeStaticDistributedCoords to
413/// compute multi-dimensional offsets for a given linear ID when distributed by
414/// LayoutAttr.
415SmallVector<SmallVector<int64_t>>
416LayoutAttr::computeStaticDistributedCoords(int64_t linearId,
417 ArrayRef<int64_t> shape) {
418 SmallVector<int64_t> layoutVec;
419 SmallVector<int64_t> subShape;
420 SmallVector<int64_t> instData;
421 if (isForWorkgroup()) {
422 layoutVec = getEffectiveSgLayoutAsInt();
423 subShape = getEffectiveSgDataAsInt();
424 } else if (isForSubgroup()) {
425 instData = getEffectiveInstDataAsInt();
426 layoutVec = getEffectiveLaneLayoutAsInt();
427 subShape = getEffectiveLaneDataAsInt();
428 }
429 if (!instData.empty()) {
430 linearId = 0;
431 subShape = instData;
432 }
433 assert(!subShape.empty() && "sgdata or lanedata cannot be empty");
434
435 // Delinearize the linear ID using the order attribute.
436 SmallVector<int64_t> order = getEffectiveOrderAsInt();
437 SmallVector<int64_t> delinearizedId(layoutVec.size());
438 int64_t remaining = linearId;
439 for (size_t i = 0; i < order.size(); ++i) {
440 int64_t dimIdx = order[i];
441 delinearizedId[dimIdx] = remaining % layoutVec[dimIdx];
442 remaining = remaining / layoutVec[dimIdx];
443 }
444
445 return genStaticCoordinates(delinearizedId, layoutVec, subShape, shape);
446}
447
448// set the layout for unit dims: sg_data, inst_data and lane_data to 1
449DistributeLayoutAttr
450LayoutAttr::setUnitDimData(SmallVector<int64_t> unitDims) const {
451 auto sgDataOpt = getSgData();
452 auto instDataOpt = getInstData();
453 auto laneDataOpt = getLaneData();
454
455 SmallVector<int32_t> sgData;
456 SmallVector<int32_t> instData;
457 SmallVector<int32_t> laneData;
458
459 if (sgDataOpt)
460 sgData = llvm::to_vector(sgDataOpt.asArrayRef());
461
462 if (instDataOpt)
463 instData = llvm::to_vector(instDataOpt.asArrayRef());
464
465 if (laneDataOpt)
466 laneData = llvm::to_vector(laneDataOpt.asArrayRef());
467
468 for (auto dim : unitDims) {
469 if (dim < static_cast<int64_t>(sgData.size()))
470 sgData[dim] = 1;
471 if (dim < static_cast<int64_t>(instData.size()))
472 instData[dim] = 1;
473 if (dim < static_cast<int64_t>(laneData.size()))
474 laneData[dim] = 1;
475 }
476
477 return LayoutAttr::get(
478 getContext(), getSgLayout(),
479 sgData.empty() ? DenseI32ArrayAttr()
481 instData.empty() ? DenseI32ArrayAttr()
482 : DenseI32ArrayAttr::get(getContext(), instData),
483 getLaneLayout(),
484 laneData.empty() ? DenseI32ArrayAttr()
485 : DenseI32ArrayAttr::get(getContext(), laneData),
486 getOrder());
487}
488
489// set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
490DistributeLayoutAttr
491LayoutAttr::setUnitDimLayout(SmallVector<int64_t> unitDims) const {
492 auto sgLayoutOpt = getSgLayout();
493 auto laneLayoutOpt = getLaneLayout();
494
495 SmallVector<int32_t> sgLayout;
496 SmallVector<int32_t> laneLayout;
497
498 if (sgLayoutOpt)
499 sgLayout = llvm::to_vector(sgLayoutOpt.asArrayRef());
500 if (laneLayoutOpt)
501 laneLayout = llvm::to_vector(laneLayoutOpt.asArrayRef());
502
503 for (auto dim : unitDims) {
504 if (dim < static_cast<int64_t>(sgLayout.size()))
505 sgLayout[dim] = 1;
506 if (dim < static_cast<int64_t>(laneLayout.size()))
507 laneLayout[dim] = 1;
508 }
509
510 return LayoutAttr::get(
511 getContext(),
512 sgLayout.empty() ? DenseI32ArrayAttr()
513 : DenseI32ArrayAttr::get(getContext(), sgLayout),
514 getSgData(), getInstData(),
515 laneLayout.empty() ? DenseI32ArrayAttr()
516 : DenseI32ArrayAttr::get(getContext(), laneLayout),
517 getLaneData(), getOrder());
518}
519
520// Derive a new layout with sg_data, inst_data and lane_data set to the
521// specified values for the given dimension
522DistributeLayoutAttr LayoutAttr::setDimData(int64_t dim, int64_t sgData,
523 int64_t instData,
524 int64_t laneData) {
525
526 SmallVector<int64_t> sgDataVec = getEffectiveSgDataAsInt();
527 SmallVector<int64_t> instDataVec = getEffectiveInstDataAsInt();
528 SmallVector<int64_t> laneDataVec = getEffectiveLaneDataAsInt();
529
530 if (dim < static_cast<int64_t>(sgDataVec.size()) && sgData != -1)
531 sgDataVec[dim] = sgData;
532 if (dim < static_cast<int64_t>(instDataVec.size()) && instData != -1)
533 instDataVec[dim] = instData;
534 if (dim < static_cast<int64_t>(laneDataVec.size()) && laneData != -1)
535 laneDataVec[dim] = laneData;
536
537 SmallVector<int32_t> sgDataVec32(sgDataVec.begin(), sgDataVec.end());
538 SmallVector<int32_t> instDataVec32(instDataVec.begin(), instDataVec.end());
539 SmallVector<int32_t> laneDataVec32(laneDataVec.begin(), laneDataVec.end());
540
541 return LayoutAttr::get(
542 getContext(), getSgLayout(),
543 sgDataVec.empty() ? DenseI32ArrayAttr()
544 : DenseI32ArrayAttr::get(getContext(), sgDataVec32),
545 instDataVec.empty() ? DenseI32ArrayAttr()
546 : DenseI32ArrayAttr::get(getContext(), instDataVec32),
547 getLaneLayout(),
548 laneDataVec.empty() ? DenseI32ArrayAttr()
549 : DenseI32ArrayAttr::get(getContext(), laneDataVec32),
550 getOrder());
551}
552
553// Derive a new layout by removing dimensions.
554// `dimGroup` specifies a group of dimensions to be removed in the derived
555// layout.
556DistributeLayoutAttr LayoutAttr::dropDims(SmallVector<int64_t> dimGroup) {
557
558 SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt();
559 SmallVector<int64_t> sgData = getEffectiveSgDataAsInt();
560 SmallVector<int64_t> instData = getEffectiveInstDataAsInt();
561 SmallVector<int64_t> laneLayout = getEffectiveLaneLayoutAsInt();
562 SmallVector<int64_t> laneData = getEffectiveLaneDataAsInt();
563 SmallVector<int64_t> origOrder = getEffectiveOrderAsInt();
564
565 SmallVector<int64_t> sortedDimGroup = dimGroup;
566 llvm::sort(sortedDimGroup);
567
568 for (auto dimIdx : llvm::reverse(sortedDimGroup)) {
569 if (!sgLayout.empty()) {
570 sgLayout.erase(sgLayout.begin() + dimIdx);
571 sgData.erase(sgData.begin() + dimIdx);
572 }
573 if (!instData.empty())
574 instData.erase(instData.begin() + dimIdx);
575 if (!laneLayout.empty()) {
576 laneLayout.erase(laneLayout.begin() + dimIdx);
577 laneData.erase(laneData.begin() + dimIdx);
578 }
579 }
580
581 SmallVector<int64_t> newOrder;
582 for (int64_t d : origOrder) {
583 if (llvm::is_contained(dimGroup, d))
584 continue;
585 int64_t offset = llvm::count_if(dimGroup, [&](int64_t s) { return s < d; });
586 newOrder.push_back(d - offset);
587 }
588 if (sgLayout.empty() && laneLayout.empty())
589 newOrder.clear();
590
591 auto toAttr = [&](ArrayRef<int64_t> v) -> DenseI32ArrayAttr {
592 if (v.empty())
593 return DenseI32ArrayAttr();
594 SmallVector<int32_t> v32(v.begin(), v.end());
595 return DenseI32ArrayAttr::get(getContext(), v32);
596 };
597 auto droppedLayout = xegpu::LayoutAttr::get(
598 getContext(), toAttr(sgLayout), toAttr(sgData), toAttr(instData),
599 toAttr(laneLayout), toAttr(laneData), toAttr(newOrder));
600 return droppedLayout;
601}
602
603// Derive a new layout by collapsing dimensions.
604// `dimGroup` specifies a group of adjacent dimensions
605// that are collapsed into a single dimension in the derived layout.
606DistributeLayoutAttr LayoutAttr::collapseDims(SmallVector<int64_t> dimGroup) {
607
608 SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt();
609 SmallVector<int64_t> sgData = getEffectiveSgDataAsInt();
610 SmallVector<int64_t> instData = getEffectiveInstDataAsInt();
611 SmallVector<int64_t> laneLayout = getEffectiveLaneLayoutAsInt();
612 SmallVector<int64_t> laneData = getEffectiveLaneDataAsInt();
613 SmallVector<int64_t> origOrder = getEffectiveOrderAsInt();
614
615 SmallVector<int64_t> sortedDimGroup = dimGroup;
616 llvm::sort(sortedDimGroup);
617 int64_t dimBeforeCurrent = -1;
618 for (auto dimIdx : sortedDimGroup) {
619 // when order attr is present, adjacency dims are values like [3, 2, 1, 0]
620 // in decreasing order; otherwise based on dim indices like [0, 1, 2, 3]
621 // in increasing order
622 if (dimBeforeCurrent >= 0) {
623 if (getOrder() && !getOrder().empty()) {
624 int64_t orderBefore = origOrder[dimBeforeCurrent];
625 int64_t orderCurrent = origOrder[dimIdx];
626 if (orderBefore != (orderCurrent - 1))
627 llvm::report_fatal_error(
628 "dimensions being collapsed must be adjacent in order");
629 } else {
630 if (dimIdx != (dimBeforeCurrent + 1))
631 llvm::report_fatal_error(
632 "dimensions being collapsed must be adjacent");
633 }
634 }
635 dimBeforeCurrent = dimIdx;
636 }
637
638 int firstDim = sortedDimGroup.front();
639
640 // collapse the dimensions in dimGroup into one dimension by multiplying their
641 // sizes together
642
643 if (!sgLayout.empty()) {
644 int64_t collapsedSglayout = 1, collapsedSgData = 1;
645 for (auto dimIdx : dimGroup) {
646 collapsedSglayout *= sgLayout[dimIdx];
647 collapsedSgData *= sgData[dimIdx];
648 }
649 for (auto dimIdx : llvm::reverse(sortedDimGroup)) {
650 sgLayout.erase(sgLayout.begin() + dimIdx, sgLayout.begin() + dimIdx + 1);
651 sgData.erase(sgData.begin() + dimIdx, sgData.begin() + dimIdx + 1);
652 }
653 sgLayout.insert(sgLayout.begin() + firstDim, collapsedSglayout);
654 sgData.insert(sgData.begin() + firstDim, collapsedSgData);
655 }
656
657 if (!instData.empty()) {
658 int64_t collapsedInstData = 1;
659 for (auto dimIdx : dimGroup)
660 collapsedInstData *= instData[dimIdx];
661 for (auto dimIdx : llvm::reverse(sortedDimGroup))
662 instData.erase(instData.begin() + dimIdx, instData.begin() + dimIdx + 1);
663 instData.insert(instData.begin() + firstDim, collapsedInstData);
664 }
665
666 if (!laneLayout.empty()) {
667 int64_t collapsedLaneLayout = 1, collapsedLaneData = 1;
668 for (auto dimIdx : dimGroup) {
669 collapsedLaneLayout *= laneLayout[dimIdx];
670 collapsedLaneData *= laneData[dimIdx];
671 }
672 for (auto dimIdx : llvm::reverse(sortedDimGroup)) {
673 laneLayout.erase(laneLayout.begin() + dimIdx,
674 laneLayout.begin() + dimIdx + 1);
675 laneData.erase(laneData.begin() + dimIdx, laneData.begin() + dimIdx + 1);
676 }
677 laneLayout.insert(laneLayout.begin() + firstDim, collapsedLaneLayout);
678 laneData.insert(laneData.begin() + firstDim, collapsedLaneData);
679 }
680
681 SmallVector<int64_t> newOrder;
682 DenseI32ArrayAttr orderAttr = getOrder();
683 if (orderAttr && !orderAttr.empty()) {
684
685 for (auto dimIdx : llvm::reverse(sortedDimGroup)) {
686 if (dimIdx != firstDim)
687 origOrder.erase(origOrder.begin() + dimIdx);
688 }
689 // say we have orderVec = {5, 3, 2, 1, 0}
690 // Create indices [0, 1, 2, 3, 4]
691 SmallVector<size_t> indices =
692 llvm::to_vector(llvm::seq<size_t>(0, orderAttr.size()));
693
694 // Sort indices based on corresponding values
695 llvm::sort(indices,
696 [&](size_t a, size_t b) { return origOrder[a] < origOrder[b]; });
697
698 newOrder = llvm::to_vector(llvm::map_range(
699 indices, [&](size_t i) { return static_cast<int64_t>(i); }));
700 }
701
702 auto toAttr = [&](ArrayRef<int64_t> v) -> DenseI32ArrayAttr {
703 if (v.empty())
704 return DenseI32ArrayAttr();
705 SmallVector<int32_t> v32(v.begin(), v.end());
706 return DenseI32ArrayAttr::get(getContext(), v32);
707 };
708 auto collapsedLayout = xegpu::LayoutAttr::get(
709 getContext(), toAttr(sgLayout), toAttr(sgData), toAttr(instData),
710 toAttr(laneLayout), toAttr(laneData), toAttr(newOrder));
711 return collapsedLayout;
712}
713
714// Derive a new layout by transpose the layout using `permutation`.
715DistributeLayoutAttr LayoutAttr::transposeDims(ArrayRef<int64_t> permutation) {
716
717 SmallVector<int64_t> origSgLayout = getEffectiveSgLayoutAsInt();
718 SmallVector<int64_t> origSgData = getEffectiveSgDataAsInt();
719 SmallVector<int64_t> origInstData = getEffectiveInstDataAsInt();
720 SmallVector<int64_t> origLaneLayout = getEffectiveLaneLayoutAsInt();
721 SmallVector<int64_t> origLaneData = getEffectiveLaneDataAsInt();
722 SmallVector<int64_t> origOrder = getEffectiveOrderAsInt();
723
724 SmallVector<int32_t> sgLayout;
725 SmallVector<int32_t> sgData;
726 SmallVector<int32_t> instData;
727 SmallVector<int32_t> laneLayout;
728 SmallVector<int32_t> laneData;
729 SmallVector<int32_t> order;
730
731 for (int64_t idx : permutation) {
732 if (!origLaneLayout.empty()) {
733 laneLayout.push_back(static_cast<int32_t>(origLaneLayout[idx]));
734 laneData.push_back(static_cast<int32_t>(origLaneData[idx]));
735 }
736 if (!origInstData.empty())
737 instData.push_back(static_cast<int32_t>(origInstData[idx]));
738 if (!origSgLayout.empty()) {
739 sgLayout.push_back(static_cast<int32_t>(origSgLayout[idx]));
740 sgData.push_back(static_cast<int32_t>(origSgData[idx]));
741 }
742 order.push_back(static_cast<int32_t>(origOrder[idx]));
743 }
744 if (origLaneLayout.empty() && origSgLayout.empty())
745 order.clear();
746
747 auto toAttr = [&](ArrayRef<int32_t> v) -> DenseI32ArrayAttr {
748 return v.empty() ? nullptr : DenseI32ArrayAttr::get(getContext(), v);
749 };
750 return xegpu::LayoutAttr::get(getContext(), toAttr(sgLayout), toAttr(sgData),
751 toAttr(instData), toAttr(laneLayout),
752 toAttr(laneData), toAttr(order));
753}
754
755/// Check if this layout is a transpose of another layout.
756bool LayoutAttr::isTransposeOf(const xegpu::DistributeLayoutAttr &other,
757 ArrayRef<int64_t> perm,
758 const xegpu::LayoutKind kind) {
759 if (!other)
760 return false;
761 if (getRank() != other.getRank() ||
762 perm.size() != static_cast<size_t>(getRank()))
763 return false;
764 if (!isPermutationVector(perm))
765 return false;
766 auto checkTranspose = [](ArrayRef<int64_t> dst, ArrayRef<int64_t> src,
767 ArrayRef<int64_t> perm) {
768 for (const auto &ta : llvm::enumerate(perm)) {
769 if (src[ta.index()] != dst[ta.value()])
770 return false;
771 }
772 return true;
773 };
774 if (kind == xegpu::LayoutKind::Subgroup)
775 return checkTranspose(getEffectiveSgLayoutAsInt(),
776 other.getEffectiveSgLayoutAsInt(), perm) &&
777 checkTranspose(getEffectiveSgDataAsInt(),
778 other.getEffectiveSgDataAsInt(), perm) &&
779 checkTranspose(getEffectiveOrderAsInt(),
780 other.getEffectiveOrderAsInt(), perm);
781 if (kind == xegpu::LayoutKind::InstData)
782 return checkTranspose(getEffectiveInstDataAsInt(),
783 other.getEffectiveInstDataAsInt(), perm);
784 if (kind == xegpu::LayoutKind::Lane)
785 return checkTranspose(getEffectiveLaneLayoutAsInt(),
786 other.getEffectiveLaneLayoutAsInt(), perm) &&
787 checkTranspose(getEffectiveLaneDataAsInt(),
788 other.getEffectiveLaneDataAsInt(), perm) &&
789 checkTranspose(getEffectiveOrderAsInt(),
790 other.getEffectiveOrderAsInt(), perm);
791
792 return false;
793}
794
795bool LayoutAttr::isCompatibleWith(const xegpu::DistributeLayoutAttr &other,
796 SmallVector<int64_t> shape,
797 xegpu::LayoutKind level) {
798 if (!other)
799 return false;
800 if (getEffectiveOrderAsInt() == other.getEffectiveOrderAsInt()) {
801 // short cut when order is the same, no need to compute coords and compare
802 if (level == xegpu::LayoutKind::Subgroup)
803 if (getEffectiveSgLayoutAsInt() == other.getEffectiveSgLayoutAsInt() &&
804 getEffectiveSgDataAsInt() == other.getEffectiveSgDataAsInt())
805 return true;
806 if (level == xegpu::LayoutKind::Lane)
807 if (getEffectiveLaneLayoutAsInt() ==
808 other.getEffectiveLaneLayoutAsInt() &&
809 getEffectiveLaneDataAsInt() == other.getEffectiveLaneDataAsInt())
810 return true;
811 }
812
813 auto compareCoordsForAllIds = [&](int64_t size) {
814 for (int64_t id : llvm::seq<int64_t>(0, size)) {
815 auto coords = computeStaticDistributedCoords(id, shape);
816 auto otherCoords = other.computeStaticDistributedCoords(id, shape);
817 if (coords != otherCoords)
818 return false;
819 }
820 return true;
821 };
822
823 if (level == xegpu::LayoutKind::Subgroup) {
824 int64_t wgSize = computeProduct(getEffectiveSgLayoutAsInt());
825 return compareCoordsForAllIds(wgSize);
826 }
827 if (level == xegpu::LayoutKind::InstData) {
828 return (getEffectiveInstDataAsInt() == other.getEffectiveInstDataAsInt());
829 }
830 if (level == xegpu::LayoutKind::Lane) {
831 int64_t subgroupSize = computeProduct(getEffectiveLaneLayoutAsInt());
832 return compareCoordsForAllIds(subgroupSize);
833 }
834 return true;
835}
836
837//===----------------------------------------------------------------------===//
838// XeGPU_SliceAttr
839//===----------------------------------------------------------------------===//
840LogicalResult
841SliceAttr::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
842 xegpu::DistributeLayoutAttr parent, DenseI64ArrayAttr dims) {
843
844 if (!dims)
845 return emitError() << "expected dims attribute";
846
847 // check every element in dims is unique and smaller than rank
848 llvm::SmallDenseSet<int64_t> seen;
849 for (int64_t dim : dims.asArrayRef()) {
850 if (dim < 0)
851 return emitError() << "invalid dim (" << dim << ") in slice attribute.";
852 if (!seen.insert(dim).second)
853 return emitError() << "repeated dim (" << dim << ") in slice attribute.";
854 }
855 return success();
856}
857
858SliceAttr SliceAttr::flatten() const {
859 xegpu::DistributeLayoutAttr parent = getParent();
860 SmallVector<DenseI64ArrayAttr> slicedDims({getDims()});
861
862 while (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(parent)) {
863 parent = sliceAttr.getParent();
864 slicedDims.push_back(sliceAttr.getDims());
865 }
866
867 auto layoutAttr = dyn_cast<xegpu::LayoutAttr>(parent);
868 SmallVector<int64_t> indices =
869 llvm::to_vector(llvm::seq<int64_t>(0, layoutAttr.getRank()));
870
871 // get remaining dims (flattend) by applying slice ops with all slicedDims
872 SmallVector<int64_t> remainingDims(indices);
873 for (auto dim : llvm::reverse(slicedDims))
874 remainingDims = XeGPUDialect::slice(llvm::ArrayRef<int64_t>(remainingDims),
875 dim.asArrayRef());
876
877 // get flattend sliced dims by applying slice ops with the remaining dims
878 SmallVector<int64_t> flattendDims = XeGPUDialect::slice(
879 llvm::ArrayRef<int64_t>(indices), llvm::ArrayRef<int64_t>(remainingDims));
880
881 return xegpu::SliceAttr::get(
882 getContext(), layoutAttr,
883 DenseI64ArrayAttr::get(getContext(), flattendDims));
884}
885
886FailureOr<SmallVector<Value>>
887SliceAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) {
888 SliceAttr attr = flatten();
889 auto parent = dyn_cast<LayoutAttr>(attr.getParent());
890 return parent.delinearizeId(builder, loc, linearId);
891}
892
893// Implements DistributeLayoutAttr::computeDistributedCoords to generate
894// instructions for computing multi-dimensional offsets when distributed by
895// LayoutAttr.
896FailureOr<SmallVector<SmallVector<Value>>>
897SliceAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
898 Value linearId, ArrayRef<int64_t> shape) {
899 assert(getRank() == static_cast<int64_t>(shape.size()) && "invalid shape.");
900
901 SmallVector<int64_t> layout;
902 SmallVector<int64_t> subShape;
903 if (isForWorkgroup()) {
904 layout = getEffectiveSgLayoutAsInt();
905 subShape = getEffectiveSgDataAsInt();
906 } else if (isForSubgroup()) {
907 layout = getEffectiveLaneLayoutAsInt();
908 subShape = getEffectiveLaneDataAsInt();
909 } else {
910 return failure();
911 }
912
913 if (subShape.empty())
914 return failure();
915
916 // delinearize Ids
917 auto maybeIds = delinearizeId(builder, loc, linearId);
918 if (failed(maybeIds))
919 return failure();
920
921 // The effective sgIds for offsets computing correspond
922 // to the dims that are not sliced.
923 ArrayRef<int64_t> dims = flatten().getDims().asArrayRef();
924 SmallVector<Value> canonicalIds =
925 XeGPUDialect::slice(ArrayRef<Value>(*maybeIds), dims);
926
927 return genCoordinates(builder, loc, canonicalIds, layout, subShape, shape);
928}
929
930/// Implements DistributeLayoutAttr::computeStaticDistributedCoords to
931/// compute multi-dimensional offsets for a given linear ID when distributed by
932/// SliceAttr. Delegates delinearization to the parent LayoutAttr, then uses
933/// only the non-sliced dimensions for coordinate computation.
934SmallVector<SmallVector<int64_t>>
935SliceAttr::computeStaticDistributedCoords(int64_t linearId,
936 ArrayRef<int64_t> shape) {
937 assert(getRank() == static_cast<int64_t>(shape.size()) && "invalid shape.");
938
939 SmallVector<int64_t> layout;
940 SmallVector<int64_t> subShape;
941 SmallVector<int64_t> instData;
942 if (isForWorkgroup()) {
943 layout = getEffectiveSgLayoutAsInt();
944 subShape = getEffectiveSgDataAsInt();
945 } else if (isForSubgroup()) {
946 instData = getEffectiveInstDataAsInt();
947 layout = getEffectiveLaneLayoutAsInt();
948 subShape = getEffectiveLaneDataAsInt();
949 }
950 if (!instData.empty()) {
951 linearId = 0;
952 subShape = instData;
953 }
954
955 assert(!subShape.empty() && "sgdata or lanedata cannot be empty");
956
957 // Delinearize the ID using the parent layout (same as the IR version).
958 SliceAttr flattened = flatten();
959 auto parent = dyn_cast<LayoutAttr>(flattened.getParent());
960 SmallVector<int64_t> parentLayoutVec;
961 if (parent.isForWorkgroup())
962 parentLayoutVec = parent.getEffectiveSgLayoutAsInt();
963 else
964 parentLayoutVec = parent.getEffectiveLaneLayoutAsInt();
965
966 SmallVector<int64_t> order = parent.getEffectiveOrderAsInt();
967 SmallVector<int64_t> allIds(parentLayoutVec.size());
968 int64_t remaining = linearId;
969 for (size_t i = 0; i < order.size(); ++i) {
970 int64_t dimIdx = order[i];
971 allIds[dimIdx] = remaining % parentLayoutVec[dimIdx];
972 if (i < order.size() - 1)
973 remaining = remaining / parentLayoutVec[dimIdx];
974 }
975
976 // The effective IDs for coordinate computation correspond
977 // to the dims that are not sliced.
978 ArrayRef<int64_t> dims = flattened.getDims().asArrayRef();
979 SmallVector<int64_t> canonicalIds =
980 XeGPUDialect::slice(ArrayRef<int64_t>(allIds), dims);
981
982 return genStaticCoordinates(canonicalIds, layout, subShape, shape);
983}
984
985bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) {
986 auto flattenedThis = flatten();
987 // If other is a LayoutAttr, just compare directly with parent of
988 // flattenedThis.
989 if (auto otherLayout = dyn_cast<xegpu::LayoutAttr>(other))
990 return flattenedThis.getParent() == otherLayout;
991 // If other is a SliceAttr, flatten it first before comparing.
992 auto flattenedOther = dyn_cast<xegpu::SliceAttr>(other).flatten();
993 // Both must have common parent LayoutAttr.
994 if (flattenedThis.getParent() != flattenedOther.getParent())
995 return false;
996 // otherFlattened's sliced dims must be a subset of flattenedThis's sliced
997 // dims.
998 llvm::SmallDenseSet<int64_t> thisDims(
999 flattenedThis.getDims().asArrayRef().begin(),
1000 flattenedThis.getDims().asArrayRef().end());
1001 return llvm::all_of(flattenedOther.getDims().asArrayRef(),
1002 [&](int64_t dim) { return thisDims.contains(dim); });
1003}
1004
1005bool SliceAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) {
1006 if (dyn_cast<xegpu::LayoutAttr>(other))
1007 return false;
1008
1009 auto flattenedThis = flatten();
1010 auto flattenedOther = dyn_cast<xegpu::SliceAttr>(other).flatten();
1011
1012 return ((flattenedThis.getParent() == flattenedOther.getParent()) &&
1013 (flattenedThis.getDims() == flattenedOther.getDims()));
1014}
1015
1016bool SliceAttr::isCompatibleWith(const xegpu::DistributeLayoutAttr &other,
1017 SmallVector<int64_t> shape,
1018 xegpu::LayoutKind level) {
1019 if (!other)
1020 return false;
1021 if (getEffectiveOrderAsInt() == other.getEffectiveOrderAsInt()) {
1022 // short cut when order is the same, no need to compute coords and compare
1023 if (level == xegpu::LayoutKind::Subgroup)
1024 if (getEffectiveSgLayoutAsInt() == other.getEffectiveSgLayoutAsInt() &&
1025 getEffectiveSgDataAsInt() == other.getEffectiveSgDataAsInt())
1026 return true;
1027 if (level == xegpu::LayoutKind::Lane)
1028 if (getEffectiveLaneLayoutAsInt() ==
1029 other.getEffectiveLaneLayoutAsInt() &&
1030 getEffectiveLaneDataAsInt() == other.getEffectiveLaneDataAsInt())
1031 return true;
1032 }
1033
1034 auto compareCoordsForAllIds = [&](int64_t size) {
1035 for (int64_t id : llvm::seq<int64_t>(0, size)) {
1036 auto coords = computeStaticDistributedCoords(id, shape);
1037 auto otherCoords = other.computeStaticDistributedCoords(id, shape);
1038 if (coords != otherCoords)
1039 return false;
1040 }
1041 return true;
1042 };
1043
1044 auto flattenedThis = flatten();
1045 auto parent = dyn_cast<LayoutAttr>(flattenedThis.getParent());
1046 if (level == xegpu::LayoutKind::Subgroup) {
1047 int64_t wgSize = computeProduct(parent.getEffectiveSgLayoutAsInt());
1048 return compareCoordsForAllIds(wgSize);
1049 }
1050 if (level == xegpu::LayoutKind::InstData) {
1051 return (getEffectiveInstDataAsInt() == other.getEffectiveInstDataAsInt());
1052 }
1053 if (level == xegpu::LayoutKind::Lane) {
1054 int64_t subgroupSize = computeProduct(parent.getEffectiveLaneLayoutAsInt());
1055 return compareCoordsForAllIds(subgroupSize);
1056 }
1057 return true;
1058}
1059
1060xegpu::SliceAttr SliceAttr::dropSliceDims(ArrayRef<int64_t> sliceDimsToDrop) {
1061 if (sliceDimsToDrop.empty())
1062 return *this;
1063 SmallVector<int64_t> sliceDims{getDims().asArrayRef()};
1064 for (auto dim : sliceDimsToDrop) {
1065 auto foundIt = std::find(sliceDims.begin(), sliceDims.end(), dim);
1066 assert(foundIt != sliceDims.end() &&
1067 "Expected to find the specified reduction dim in slice dims");
1068 sliceDims.erase(foundIt);
1069 }
1070
1071 auto sliceWithoutDims = xegpu::SliceAttr::get(
1072 this->getContext(), getParent(),
1073 DenseI64ArrayAttr::get(this->getContext(), sliceDims));
1074
1075 return sliceWithoutDims;
1076}
1077
1078// Helper function to adjust dimensions from sliced space to parent space
1079// say we have a parent shape of rank 4, and slice dims [1,3], so the sliced
1080// shape is of rank 2, if we want to set unit dim [0] in sliced space, it maps
1081// to dim [0] in parent space; if we want to set unit dim [1] in sliced space,
1082// it maps to dim [2] in parent space.
1083static SmallVector<int64_t>
1085 ArrayRef<int64_t> sliceDims) {
1086 // Rather than recovering the exact parent rank, we compute a safe upper
1087 // bound so that dimsToMap can be adjusted safely. This upper bound is
1088 // defined as max(dimsToMap, sliceDims) + 1 + sliceDims.size().
1089 int64_t maxDim = -1;
1090 maxDim =
1091 std::max(maxDim, *std::max_element(sliceDims.begin(), sliceDims.end()));
1092 maxDim =
1093 std::max(maxDim, *std::max_element(dimsToMap.begin(), dimsToMap.end()));
1094 int64_t parentSpaceRank = maxDim + sliceDims.size() + 1;
1095
1096 // get remaining dims in parent space after applying slicing with parent's
1097 // slice Dims
1098 llvm::SmallDenseSet<int64_t> slicedDimsSet(sliceDims.begin(),
1099 sliceDims.end());
1100 SmallVector<int64_t> remainingDims;
1101 for (int64_t i = 0; i < parentSpaceRank; ++i) {
1102 if (!slicedDimsSet.contains(i))
1103 remainingDims.push_back(i);
1104 }
1105
1106 // Map unit dims from sliced space to parent space
1107 SmallVector<int64_t> adjustUnitDims;
1108 for (auto dim : dimsToMap) {
1109 int64_t mappedDim = remainingDims[dim];
1110 adjustUnitDims.push_back(mappedDim);
1111 }
1112
1113 return adjustUnitDims;
1114}
1115
1116// set the layout for unit dims: sg_data, inst_data and lane_data to 1
1117DistributeLayoutAttr
1118SliceAttr::setUnitDimData(SmallVector<int64_t> unitDims) const {
1119 DistributeLayoutAttr parentLayout = getParent();
1120
1121 ArrayRef<int64_t> sliceDims = getDims().asArrayRef();
1122
1123 SmallVector<int64_t> adjustUnitDims =
1124 mapSlicedDimsToParentSpace(unitDims, sliceDims);
1125
1126 return SliceAttr::get(getContext(),
1127 parentLayout.setUnitDimData(adjustUnitDims), getDims());
1128}
1129
1130// set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
1131DistributeLayoutAttr
1132SliceAttr::setUnitDimLayout(SmallVector<int64_t> unitDims) const {
1133 DistributeLayoutAttr parentLayout = getParent();
1134
1135 ArrayRef<int64_t> sliceDims = getDims().asArrayRef();
1136
1137 SmallVector<int64_t> adjustUnitDims =
1138 mapSlicedDimsToParentSpace(unitDims, sliceDims);
1139
1140 return SliceAttr::get(
1141 getContext(), parentLayout.setUnitDimLayout(adjustUnitDims), getDims());
1142}
1143
1144// Derive a new layout with sg_data, inst_data and lane_data set to the
1145// specified values for the given dimension
1146DistributeLayoutAttr SliceAttr::setDimData(int64_t dim, int64_t sgData,
1147 int64_t instData, int64_t laneData) {
1148 ArrayRef<int64_t> sliceDims = getDims().asArrayRef();
1149 auto parent = getParent();
1150
1151 SmallVector<int64_t> dimSet;
1152 dimSet.push_back(dim);
1153 SmallVector<int64_t> adjustDims =
1154 mapSlicedDimsToParentSpace(dimSet, sliceDims);
1155 return SliceAttr::get(
1156 getContext(),
1157 parent.setDimData(adjustDims[0], sgData, instData, laneData), getDims());
1158}
1159
1160// Derive a new layout by removing dimensions. `dimGroup` specifies a group of
1161// dimensions to be removed in the derived layout.
1162//
1163// Example: drop the 2nd dimension from a rank-3 sliced view.
1164//
1165// Suppose:
1166// xegpu.layout = slice<layout<[V0, V1, V2, V3, V4]>, [1, 3]>
1167//
1168// The slice removes parent dims [1, 3], so the sliced-space dims map to
1169// parent dims [V0, V2, V4].
1170//
1171// If we drop sliced-space dim 1 (the 2nd dim), that corresponds to dropping
1172// parent dim 2, result in parent layout [V0, V1, V3, V4] after dropping.
1173// After parent dim 2 is removed, sliced dims [1, 3] must be reindexed to [1,
1174// 2].
1175//
1176// Result:
1177// xegpu.layout = slice<layout<[0, 1, 3, 4]>, [1, 2]>
1178DistributeLayoutAttr SliceAttr::dropDims(SmallVector<int64_t> dimGroup) {
1179 // Map the sliced dims from parent space to collapsed space
1180 SmallVector<int64_t> sliceDims = llvm::to_vector(getDims().asArrayRef());
1181 SmallVector<int64_t> dimsInParentSpace =
1182 mapSlicedDimsToParentSpace(dimGroup, sliceDims);
1183
1184 auto droppedParent = getParent().dropDims(dimsInParentSpace);
1185
1186 // Adjust the sliced dims after dropping dims in parent space. For example, if
1187 // we drop dim 2 in parent space, the dims after dim 2 will all be shifted by
1188 // 1, so sliced dim 3 will be adjusted to 2.
1189 SmallVector<int64_t> newSliceDims;
1190 for (int64_t d : sliceDims) {
1191 int64_t offset =
1192 llvm::count_if(dimsInParentSpace, [&](int64_t s) { return s < d; });
1193 newSliceDims.push_back(d - offset);
1194 }
1195
1196 return SliceAttr::get(getContext(), droppedParent,
1197 DenseI64ArrayAttr::get(getContext(), newSliceDims));
1198}
1199
1200// Derive a new layout by collapsing dimensions.
1201// `dimGroup` specifies a group of adjacent dimensions
1202// that are collapsed into a single dimension in the derived layout.
1203DistributeLayoutAttr SliceAttr::collapseDims(SmallVector<int64_t> dimGroup) {
1204
1205 // Map the sliced dims from parent space to collapsed space
1206 SmallVector<int64_t> sliceDims = llvm::to_vector(getDims().asArrayRef());
1207 assert("expect sliceDims not being collapsed" &&
1208 llvm::none_of(dimGroup, [&](int64_t dim) {
1209 return llvm::is_contained(sliceDims, dim);
1210 }));
1211 SmallVector<int64_t> dimsInParentSpace =
1212 mapSlicedDimsToParentSpace(dimGroup, sliceDims);
1213
1214 auto collapsedParent = getParent().collapseDims(dimsInParentSpace);
1215 return SliceAttr::get(getContext(), collapsedParent,
1216 DenseI64ArrayAttr::get(getContext(), sliceDims));
1217}
1218
1220 ArrayRef<int64_t> permutation) {
1221 SmallVector<int64_t> sortedSliceDims = llvm::to_vector(sliceDims);
1222 llvm::sort(sortedSliceDims);
1223
1224 for (size_t i = 1; i < sortedSliceDims.size(); ++i) {
1225 assert((sortedSliceDims[i] == sortedSliceDims[i - 1] + 1) &&
1226 "slice dims non consecutive, cannot be transposed");
1227 }
1228
1229 SmallVector<int64_t> permForParent;
1230 if (sortedSliceDims.front() == 0) {
1231 // Example: sliceDims.size() = 2, permutation= {1, 0}
1232 // result: {3, 2, 1, 0}.
1233 for (int64_t dim : permutation)
1234 permForParent.push_back(dim + sortedSliceDims.size());
1235 for (int64_t i = sortedSliceDims.size() - 1; i >= 0; --i)
1236 permForParent.push_back(i);
1237 } else {
1238 // Example: sliceDims.size() = 2, permutation = {0, 1}
1239 // result: {3, 2, 0, 1}.
1240 for (int64_t i = sortedSliceDims.size() - 1; i >= 0; --i)
1241 permForParent.push_back(i + permutation.size());
1242 for (int64_t dim : permutation)
1243 permForParent.push_back(dim);
1244 }
1245 return permForParent;
1246}
1247
1248// Derive a new layout by transpose the layout using `permutation`.
1249DistributeLayoutAttr SliceAttr::transposeDims(ArrayRef<int64_t> permutation) {
1250 SmallVector<int64_t> sliceDims = llvm::to_vector(getDims().asArrayRef());
1251 DistributeLayoutAttr parent = getParent();
1252 SmallVector<int64_t> permForParent =
1253 getPermForParentLayout(sliceDims, permutation);
1254 auto transposedParent = parent.transposeDims(permForParent);
1255 return SliceAttr::get(getContext(), transposedParent,
1256 DenseI64ArrayAttr::get(getContext(), sliceDims));
1257}
1258
1259/// Check if this layout is a transpose of another layout.
1260bool SliceAttr::isTransposeOf(const xegpu::DistributeLayoutAttr &other,
1261 ArrayRef<int64_t> perm,
1262 const xegpu::LayoutKind kind) {
1263 // other must be a SliceAttr with the same slice dims.
1264 auto otherSlice = dyn_cast<xegpu::SliceAttr>(other);
1265 if (!otherSlice || getDims() != otherSlice.getDims())
1266 return false;
1267 // check whether the parent layout is transpose of each other.
1268 SmallVector<int64_t> sliceDims = llvm::to_vector(getDims().asArrayRef());
1269 DistributeLayoutAttr parent = getParent();
1270 SmallVector<int64_t> permForParent = getPermForParentLayout(sliceDims, perm);
1271 auto otherParent = otherSlice.getParent();
1272 return parent.isTransposeOf(otherParent, permForParent, kind);
1273}
1274
1275//===----------------------------------------------------------------------===//
1276// XeGPU_RangeAttr
1277//===----------------------------------------------------------------------===//
1278
1279LogicalResult
1280RangeAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
1281 IntegerAttr startOfRange, IntegerAttr endOfRange) {
1282 if (startOfRange.getInt() >= endOfRange.getInt())
1283 return emitError() << "'end' : " << endOfRange.getInt()
1284 << " must be greater than 'start' : "
1285 << startOfRange.getInt();
1286
1287 return success();
1288}
1289
1290//===----------------------------------------------------------------------===//
1291// XeGPU_TensorDescType
1292//===----------------------------------------------------------------------===//
1293
1294mlir::Type TensorDescType::parse(AsmParser &parser) {
1295 llvm::SmallVector<int64_t> shape;
1296 mlir::Type elementType;
1297 mlir::FailureOr<mlir::Attribute> encoding;
1298 mlir::FailureOr<mlir::Attribute> layout;
1299
1300 // Parse literal '<'
1301 if (parser.parseLess())
1302 return {};
1303
1304 auto shapeLoc = parser.getCurrentLocation();
1305 if (mlir::failed(parser.parseDimensionList(shape))) {
1306 parser.emitError(shapeLoc, "failed to parse parameter 'shape'");
1307 return {};
1308 }
1309
1310 auto elemTypeLoc = parser.getCurrentLocation();
1311 if (mlir::failed(parser.parseType(elementType))) {
1312 parser.emitError(elemTypeLoc, "failed to parse parameter 'elementType'");
1313 return {};
1314 }
1315
1316 // parse optional attributes
1317 while (mlir::succeeded(parser.parseOptionalComma())) {
1318 mlir::Attribute attr;
1319 ParseResult res = parser.parseAttribute(attr);
1320 if (mlir::succeeded(res)) {
1321 if (mlir::isa<LayoutAttr>(attr)) {
1322 layout = attr;
1323 continue;
1324 }
1325 if (mlir::isa<BlockTensorDescAttr, ScatterTensorDescAttr>(attr)) {
1326 encoding = attr;
1327 continue;
1328 }
1329 }
1330 return {};
1331 }
1332
1333 // Parse literal '>'
1334 if (parser.parseGreater())
1335 return {};
1336
1337 MLIRContext *ctxt = parser.getContext();
1338 return TensorDescType::getChecked(
1339 [&]() { return parser.emitError(parser.getNameLoc()); }, ctxt, shape,
1340 elementType, encoding.value_or(BlockTensorDescAttr::get(ctxt)),
1341 layout.value_or(mlir::Attribute()));
1342}
1343
1344void TensorDescType::print(AsmPrinter &printer) const {
1345 printer << "<";
1346
1347 auto shape = getShape();
1348 for (int64_t dim : shape) {
1349 if (mlir::ShapedType::isDynamic(dim))
1350 printer << '?';
1351 else
1352 printer << dim;
1353 printer << 'x';
1354 }
1355
1356 printer << getElementType();
1357
1358 auto encoding = getEncoding();
1359 auto blockAttr = llvm::dyn_cast_if_present<BlockTensorDescAttr>(encoding);
1360 if (encoding && (!blockAttr || !blockAttr.hasDefaultsOnly()))
1361 printer << ", " << encoding;
1362
1363 if (auto layout = getLayout())
1364 printer << ", " << layout;
1365
1366 printer << ">";
1367}
1368
1369TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
1370 mlir::Type elementType, int array_length,
1371 bool boundary_check,
1372 MemorySpace memory_space,
1373 mlir::Attribute layout) {
1374 auto *context = elementType.getContext();
1375 auto attr = BlockTensorDescAttr::get(context, memory_space, array_length,
1376 boundary_check);
1377 return Base::get(context, shape, elementType, attr, layout);
1378}
1379
1380TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
1381 mlir::Type elementType, int chunk_size,
1382 MemorySpace memory_space,
1383 mlir::Attribute layout) {
1384 auto *context = elementType.getContext();
1385 auto attr = ScatterTensorDescAttr::get(context, memory_space, chunk_size);
1386 return Base::get(context, shape, elementType, attr, layout);
1387}
1388
1389LogicalResult
1390TensorDescType::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
1391 llvm::ArrayRef<int64_t> shape, mlir::Type elementType,
1392 mlir::Attribute encoding, mlir::Attribute layout) {
1393 size_t rank = shape.size();
1394
1395 if (rank == 0)
1396 return emitError() << "expected non-zero rank tensor";
1397
1398 auto blockAttr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(encoding);
1399 if (blockAttr) {
1400 MemorySpaceAttr memorySpaceAttr = blockAttr.getMemorySpace();
1401 if (rank > 1 && memorySpaceAttr &&
1402 memorySpaceAttr.getValue() == MemorySpace::SLM)
1403 return emitError() << "SLM is only supported for 1D block tensor";
1404 }
1405
1406 if (!elementType.isIntOrFloat())
1407 return emitError() << "unsupported element type " << elementType
1408 << ": expected integer or float";
1409
1410 // for gather and scatter ops, Low-precision types are packed in 32-bit
1411 // units.
1412 unsigned bitWidth = elementType.getIntOrFloatBitWidth();
1413 int chunkAlignmentFactor =
1416 : 1;
1417 auto scatterAttr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(encoding);
1418 if (scatterAttr) {
1419 int64_t chunkSize = scatterAttr.getChunkSizeAsInt();
1420 if (rank == 1 && chunkSize != 1)
1421 return emitError() << "expected non-contiguous elements for 1D tensor";
1422
1423 // If chunk size > 1, the second dimension of the tensor shape must be
1424 // equal to chunk size and it must be a multiple of the
1425 // chunkAlignmentFactor.
1426 if (chunkSize > 1) {
1427 if (shape.back() != chunkSize)
1428 return emitError() << "expected last dim of tensor to match chunk size";
1429 if (shape.back() % chunkAlignmentFactor != 0)
1430 return emitError() << "expected last dim of tensor to be a multiple of "
1431 << chunkAlignmentFactor;
1432 }
1433 }
1434
1435 auto layoutAttr = llvm::dyn_cast_if_present<LayoutAttr>(layout);
1436 if (layoutAttr) {
1437 if (rank != (size_t)layoutAttr.getRank())
1438 return emitError() << "expected layout rank to match tensor rank";
1439
1440 auto laneData = layoutAttr.getLaneData();
1441 if (scatterAttr && laneData) {
1442 // Validate subgroup mapping rules for scattered tensors.
1443 // if chunkSize > 1, the last dimension of the tensor should
1444 // be distributed in the units divisible by chunkAlignmentFactor.
1445 int64_t chunkSize = scatterAttr.getChunkSizeAsInt();
1446 if (chunkSize > 1 && laneData[rank - 1] % chunkAlignmentFactor)
1447 return emitError()
1448 << "expected last dim of lane_data to be a multiple of: "
1449 << chunkAlignmentFactor;
1450 }
1451
1452 if (!XeGPUDialect::isEvenlyDistributable(shape, layoutAttr)) {
1453 std::string shapeStr;
1454 llvm::raw_string_ostream stream(shapeStr);
1455 llvm::interleaveComma(shape, stream);
1456 return emitError() << "cannot distribute [" << shapeStr << "] using "
1457 << layoutAttr;
1458 }
1459 }
1460 return success();
1461}
1462
1463//===----------------------------------------------------------------------===//
1464// XeGPU_MemDescType
1465//===----------------------------------------------------------------------===//
1466mlir::Type MemDescType::parse(AsmParser &parser) {
1467 llvm::SmallVector<int64_t> shape;
1468 mlir::Type elementType;
1469 mlir::FailureOr<MemLayoutAttr> layout;
1470
1471 // Parse literal '<'
1472 if (parser.parseLess())
1473 return {};
1474
1475 auto shapeLoc = parser.getCurrentLocation();
1476 if (mlir::failed(parser.parseDimensionList(shape, false, true))) {
1477 parser.emitError(shapeLoc, "failed to parse parameter 'shape'");
1478 return {};
1479 }
1480
1481 auto elemTypeLoc = parser.getCurrentLocation();
1482 if (mlir::failed(parser.parseType(elementType))) {
1483 parser.emitError(elemTypeLoc, "failed to parse parameter 'elementType'");
1484 return {};
1485 }
1486
1487 // parse optional attributes
1488 if (mlir::succeeded(parser.parseOptionalComma())) {
1489 MemLayoutAttr attr;
1490 ParseResult res = parser.parseAttribute(attr);
1491 if (mlir::failed(res))
1492 return {};
1493 layout = attr;
1494 }
1495
1496 // Parse literal '>'
1497 if (parser.parseGreater())
1498 return {};
1499
1500 MLIRContext *ctxt = parser.getContext();
1501 return MemDescType::getChecked(
1502 [&]() { return parser.emitError(parser.getNameLoc()); }, ctxt, shape,
1503 elementType, layout.value_or(MemLayoutAttr()));
1504}
1505
1506void MemDescType::print(AsmPrinter &printer) const {
1507 printer << "<";
1508
1509 printer.printDimensionList(getShape());
1510 printer << 'x';
1511 printer << getElementType();
1512
1513 if (auto layout = getMemLayout())
1514 printer << ", " << layout;
1515
1516 printer << ">";
1517}
1518
1519//===----------------------------------------------------------------------===//
1520// XeGPU_MemDescType
1521//===----------------------------------------------------------------------===//
1522
1523Attribute MemLayoutAttr::parse(AsmParser &parser, Type type) {
1524
1525 auto *context = parser.getContext();
1526 llvm::SMLoc loc = parser.getCurrentLocation();
1527
1528 llvm::SmallDenseSet<StringRef> seenKeys;
1529 SmallVector<NamedAttribute> attributes;
1530
1531 auto parseElt = [&]() -> ParseResult {
1532 StringRef nameId;
1533 if (failed(parser.parseKeyword(&nameId)))
1534 return parser.emitError(loc, "expected valid attribute name");
1535
1536 if (!seenKeys.insert(nameId).second)
1537 return parser.emitError(loc, "duplicate key '")
1538 << nameId << " in mem layout attribute";
1539
1540 if (failed(parser.parseEqual()))
1541 return failure();
1542
1543 Attribute attr;
1544 if (failed(parser.parseAttribute(attr)))
1545 return failure();
1546 attributes.emplace_back(nameId, attr);
1547 return success();
1548 };
1549
1550 // Parse literal '<'
1551 if (parser.parseLess())
1552 return {};
1553
1554 if (failed(parser.parseCommaSeparatedList(parseElt)))
1555 return {};
1556
1557 // Parse literal '>'
1558 if (parser.parseGreater())
1559 return {};
1560
1561 return parser.getChecked<MemLayoutAttr>(
1562 loc, context, DictionaryAttr::get(context, attributes));
1563}
1564
1565void MemLayoutAttr::print(AsmPrinter &printer) const {
1566 printer << "<";
1567 ArrayRef<NamedAttribute> attrs = getAttrs().getValue();
1568 for (size_t i = 0; i < attrs.size(); i++) {
1569 printer << attrs[i].getName().str() << " = " << attrs[i].getValue();
1570 if (i < attrs.size() - 1)
1571 printer << ", ";
1572 }
1573 printer << ">";
1574}
1575// a helper utility to perform binary operation on OpFoldResult.
1576// If both a and b are attributes, it will simply return the result.
1577// Otherwise, the corresponding arith op will be generated, and an
1578// contant op will be created if one of them is an attribute.
1579template <typename ArithOp>
1581 OpBuilder &builder) {
1582 auto aVal = getValueOrCreateConstantIndexOp(builder, loc, a);
1583 auto bVal = getValueOrCreateConstantIndexOp(builder, loc, b);
1584 return ArithOp::create(builder, loc, aVal, bVal).getResult();
1585}
1586
1587// a helper utility to perform division operation on OpFoldResult and int64_t.
1588#define div(a, b) \
1589 genBinOp<arith::DivSIOp>(a, builder.getIndexAttr(b), loc, builder)
1590
1591// a helper utility to perform reminder operation on OpFoldResult and int64_t.
1592#define rem(a, b) \
1593 genBinOp<arith::RemSIOp>(a, builder.getIndexAttr(b), loc, builder)
1594
1595// a helper utility to perform multiply operation on OpFoldResult and int64_t.
1596#define mul(a, b) \
1597 genBinOp<arith::MulIOp>(a, builder.getIndexAttr(b), loc, builder)
1598
1599// a helper utility to perform addition operation on two OpFoldResult.
1600#define add(a, b) genBinOp<arith::AddIOp>(a, b, loc, builder)
1601
1602// block the given offsets according to the block shape
1603// say the original offset is [y, x], and the block shape is [By, Bx],
1604// then the blocked offset is [y/By, x/Bx, y%By, x%Bx]
1606 ArrayRef<OpFoldResult> offsets,
1607 ArrayRef<int64_t> blockShape) {
1608
1609 assert(offsets.size() == blockShape.size() &&
1610 "offsets and blockShape must have the same size");
1611 SmallVector<OpFoldResult> blockedOffsets;
1612 SmallVector<OpFoldResult> divs, rems;
1613
1614 for (auto [offset, block] : llvm::zip(offsets, blockShape)) {
1615 divs.push_back(div(offset, block));
1616 rems.push_back(rem(offset, block));
1617 }
1618 blockedOffsets.append(divs.begin(), divs.end());
1619 blockedOffsets.append(rems.begin(), rems.end());
1620
1621 return blockedOffsets;
1622}
1623
1624// Get strides as vector of integer for MemDesc.
1625SmallVector<int64_t> MemDescType::getStrideShape() {
1626
1627 SmallVector<int64_t> matrixShape(getShape().begin(), getShape().end());
1628
1629 ArrayAttr strideAttr = getStrideAttr();
1630 SmallVector<int64_t> strides;
1631 for (Attribute attr : strideAttr.getValue()) {
1632 strides.push_back(cast<IntegerAttr>(attr).getInt());
1633 }
1634
1635 SmallVector<int64_t> innerBlkShape = getBlockShape();
1636
1637 // get perm from FCD to LCD
1638 // perm[i] = the dim with i-th smallest stride
1639 SmallVector<int, 4> perm =
1640 llvm::to_vector<4>(llvm::seq<int>(0, strides.size()));
1641 llvm::sort(perm, [&](int a, int b) { return strides[a] < strides[b]; });
1642
1643 assert(strides[perm[0]] == 1 && "inner most dim must have stride 1");
1644
1645 SmallVector<int64_t> innerBlkStride(innerBlkShape.size());
1646 innerBlkStride[perm[0]] = 1;
1647 for (size_t i = 1; i < perm.size(); ++i)
1648 innerBlkStride[perm[i]] =
1649 innerBlkStride[perm[i - 1]] * innerBlkShape[perm[i - 1]];
1650
1651 // compute the original matrix shape using the stride info
1652 // and compute the number of blocks in each dimension
1653 // The shape of highest dim can't be derived from stride info,
1654 // but doesn't impact the stride computation for blocked layout.
1655 SmallVector<int64_t> matrixShapeOrig(matrixShape.size());
1656 SmallVector<int64_t> BlkShapeOrig(matrixShape.size());
1657 for (size_t i = 0; i < perm.size() - 1; ++i) {
1658 matrixShapeOrig[perm[i]] = strides[perm[i + 1]] / strides[perm[i]];
1659 BlkShapeOrig[perm[i]] = matrixShapeOrig[perm[i]] / innerBlkShape[perm[i]];
1660 }
1661
1662 int64_t innerBlkSize = 1;
1663 for (auto s : innerBlkShape)
1664 innerBlkSize *= s;
1665
1666 SmallVector<int64_t> outerBlkStride(matrixShape.size());
1667 outerBlkStride[perm[0]] = innerBlkSize;
1668 for (size_t i = 0; i < perm.size() - 1; ++i) {
1669 outerBlkStride[perm[i + 1]] =
1670 outerBlkStride[perm[i]] * BlkShapeOrig[perm[i]];
1671 }
1672
1673 // combine the inner and outer strides
1674 SmallVector<int64_t> blockedStrides;
1675 blockedStrides.append(outerBlkStride.begin(), outerBlkStride.end());
1676 blockedStrides.append(innerBlkStride.begin(), innerBlkStride.end());
1677
1678 return blockedStrides;
1679}
1680
1681// Calculate the linear offset using the blocked offsets and stride
1682Value MemDescType::getLinearOffsets(OpBuilder &builder, Location loc,
1683 ArrayRef<OpFoldResult> offsets) {
1684
1685 SmallVector<int64_t> matrixShape(getShape().begin(), getShape().end());
1686 SmallVector<int64_t> blockShape = getBlockShape();
1687 SmallVector<int64_t> strides = getStrideShape();
1688 SmallVector<OpFoldResult> blockedOffsets;
1689
1690 // blockshape equal to matrixshape means no blocking
1691 if (llvm::equal(blockShape, matrixShape)) {
1692 // remove the outer dims from strides
1693 strides.erase(strides.begin(), strides.begin() + matrixShape.size());
1694 } else {
1695 assert(offsets.size() == blockShape.size() &&
1696 "offsets and blockShape must have the same size");
1697 // say the original offset is [y, x], and the block shape is [By, Bx],
1698 // then the blocked offset is [y/By, x/Bx, y%By, x%Bx]
1699
1700 SmallVector<OpFoldResult> divs, rems;
1701
1702 for (auto [offset, block] : llvm::zip(offsets, blockShape)) {
1703 divs.push_back(div(offset, block));
1704 rems.push_back(rem(offset, block));
1705 }
1706 blockedOffsets.append(divs.begin(), divs.end());
1707 blockedOffsets.append(rems.begin(), rems.end());
1708 offsets = blockedOffsets;
1709 }
1710
1711 // Start with initial value as matrix descriptor's base offset.
1712 Value linearOffset = arith::ConstantIndexOp::create(builder, loc, 0);
1713 for (size_t i = 0; i < offsets.size(); ++i) {
1714 OpFoldResult mulResult = mul(offsets[i], strides[i]);
1715 Value mulVal = getValueOrCreateConstantIndexOp(builder, loc, mulResult);
1716 linearOffset = arith::AddIOp::create(builder, loc, mulVal, linearOffset);
1717 }
1718
1719 return linearOffset;
1720}
1721
1722} // namespace xegpu
1723} // namespace mlir
1724
1725#include <mlir/Dialect/XeGPU/IR/XeGPUDialect.cpp.inc>
1726#define GET_ATTRDEF_CLASSES
1727#include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc>
1728#define GET_TYPEDEF_CLASSES
1729#include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc>
return success()
static Type getElementType(Type type)
Determine the element type of type.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
#define mul(a, b)
#define div(a, b)
#define rem(a, b)
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
auto getChecked(SMLoc loc, ParamsT &&...params)
Invoke the getChecked method of the given Attribute or Type class, using the provided location to emi...
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
void printDimensionList(ArrayRef< int64_t > shape)
Attributes are known-constant values of operations.
Definition Attributes.h:25
static BoolAttr get(MLIRContext *context, bool value)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:209
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
This class represents a single result from folding an operation.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Specialization of arith.constant op that returns an integer of index type.
Definition Arith.h:113
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
auto getDims(VectorType vType)
Returns a range over the dims (size and scalability) of a VectorType.
constexpr unsigned generalPackedFormatBitSize
Definition uArchBase.h:32
static SmallVector< SmallVector< int64_t > > genStaticCoordinates(llvm::ArrayRef< int64_t > canonicalIds, llvm::ArrayRef< int64_t > layout, llvm::ArrayRef< int64_t > subShape, llvm::ArrayRef< int64_t > shape)
LayoutKind
Specifies the level of a layout hierarchy for comparison or propagation.
Definition XeGPU.h:32
static SmallVector< int64_t > mapSlicedDimsToParentSpace(const SmallVector< int64_t > &dimsToMap, ArrayRef< int64_t > sliceDims)
SmallVector< OpFoldResult > getBlockedOffsets(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > offsets, ArrayRef< int64_t > blockShape)
OpFoldResult genBinOp(OpFoldResult a, OpFoldResult b, Location loc, OpBuilder &builder)
static SmallVector< SmallVector< Value > > genCoordinates(OpBuilder &builder, Location loc, SmallVector< Value > delinearizedId, ArrayRef< int64_t > subShapesLayout, ArrayRef< int64_t > subShape, ArrayRef< int64_t > srcShape)
SmallVector< int64_t > getPermForParentLayout(ArrayRef< int64_t > sliceDims, ArrayRef< int64_t > permutation)
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
SmallVector< int64_t > computeElementwiseMul(ArrayRef< int64_t > v1, ArrayRef< int64_t > v2)
Return a vector containing llvm::zip_equal(v1, v2) multiplied elementwise.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:114
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.