MLIR 23.0.0git
XeGPUDialect.cpp
Go to the documentation of this file.
1//===- XeGPUDialect.cpp - MLIR XeGPU dialect implementation -----*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
14#include "mlir/IR/Builders.h"
16#include "llvm/ADT/SmallVectorExtras.h"
17#include "llvm/ADT/TypeSwitch.h"
18#include "llvm/Support/Debug.h"
19
20using std::optional;
21
22namespace mlir {
23namespace xegpu {
24
25void XeGPUDialect::initialize() {
26 addTypes<
27#define GET_TYPEDEF_LIST
28#include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc>
29 >();
30 addOperations<
31#define GET_OP_LIST
32#include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
33 >();
34 addAttributes<
35#define GET_ATTRDEF_LIST
36#include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc>
37 >();
38}
39#define GET_OP_INTERFACE_CLASSES
40#include "mlir/Dialect/XeGPU/IR/XeGPUOpInterface.cpp.inc"
41
42// A `srcShape` consists of N distribution units, each being `subShapesLayout` x
43// `subShape`. A `delinearizedId` is used to identify a particular `subShape`
44// within each distribution unit.
45// Example:
46// WG data is 128x256. SG data is 16x32, in 4x2 layout, this gives a
47// distribution unit of shape 64x64, we have 2x4 such distribution units.
48// `delinearizedId` is used to identify a 16x32 of a subgroup in each
49// distribution unit.
52 SmallVector<Value> delinearizedId,
53 ArrayRef<int64_t> subShapesLayout, ArrayRef<int64_t> subShape,
54 ArrayRef<int64_t> srcShape) {
56
57 // A distribution unit must be less than or equal to `srcShape`
58 SmallVector<int64_t> distUnitShape = llvm::map_to_vector(
59 llvm::zip_equal(srcShape,
60 computeElementwiseMul(subShapesLayout, subShape)),
61 [](const auto &t) { return std::min(std::get<0>(t), std::get<1>(t)); });
62
63 // Get the offset of `subShape` within a distribution unit.
64 SmallVector<Value> distUnitLocalOffset = llvm::map_to_vector(
65 llvm::zip(delinearizedId, subShape), [&](const auto &t) -> Value {
66 return builder.createOrFold<arith::MulIOp>(
67 loc, std::get<0>(t),
68 builder.createOrFold<arith::ConstantIndexOp>(loc, std::get<1>(t)));
69 });
70
71 // For each dist unit
72 for (SmallVector<int64_t> unitOffs :
73 StaticTileOffsetRange(srcShape, distUnitShape)) {
74 // Get dist unit offset within `srcShape`.
76 llvm::map_to_vector(unitOffs, [&](int64_t d) -> Value {
77 return arith::ConstantIndexOp::create(builder, loc, d);
78 });
79 // Calculate `subShape` offset within `srcShape`.
81 llvm::map_to_vector(llvm::zip_equal(base, distUnitLocalOffset),
82 [&](const auto &t) -> Value {
83 return builder.createOrFold<arith::AddIOp>(
84 loc, std::get<0>(t), std::get<1>(t));
85 });
86 // Do not go beyond `srcShape` bounds.
87 SmallVector<Value> mods = llvm::map_to_vector(
88 llvm::zip_equal(adds, srcShape), [&](const auto &t) -> Value {
89 return builder.createOrFold<arith::RemUIOp>(
90 loc, std::get<0>(t),
91 arith::ConstantIndexOp::create(builder, loc, std::get<1>(t)));
92 });
93
94 coordinates.push_back(mods);
95 }
96 return coordinates;
97}
98
99// Checks if the given shape can be evenly distributed based on the layout
100// and data factors provided by the LayoutAttr.
101bool XeGPUDialect::isEvenlyDistributable(llvm::ArrayRef<int64_t> shape,
102 xegpu::DistributeLayoutAttr attr) {
103 assert(attr && "Layout attribute is missing.");
104
105 // Checks whether the given shape can be evenly distributed using the
106 // specified layout and data attributes. If successful, it returns the work
107 // size for each compute unit; otherwise, it returns `std::nullopt`. The work
108 // size per compute unit is calculated as follows:
109 // - If `data` is null: newShape[i] = shape[i] / layout[i]
110 // - If `data` is not null: newShape[i] = data[i]
111 // When round-robin distribution (`rr`) is enabled, `shape[i]` can be
112 // smaller than `layout[i] * data[i]`, allowing multiple compute units to
113 // share the data.
114 auto tryDistribute = [&](llvm::ArrayRef<int64_t> shape,
117 bool rr = true) -> optional<SmallVector<int64_t>> {
119 if (layout.size()) {
120 if (layout.size() != shape.size())
121 return std::nullopt;
122 auto ratio = computeShapeRatio(shape, layout);
123 if (ratio.has_value()) {
124 newShape = ratio.value();
125 } else if (!rr || !computeShapeRatio(layout, shape).has_value()) {
126 return std::nullopt;
127 }
128 // Round-robin case: continue with original newShape
129 }
130
131 if (data.size()) {
132 if (data.size() != shape.size())
133 return std::nullopt;
134 auto ratio = computeShapeRatio(newShape, data);
135 if (!ratio.has_value() && rr)
136 ratio = computeShapeRatio(data, newShape);
137 if (!ratio.has_value())
138 return std::nullopt;
139
140 // if data is not null, we always return it for next phase.
141 newShape = data;
142 }
143 return newShape;
144 };
145
146 // check the sgLayout and sgData
147 auto maybeSgShape = tryDistribute(shape, attr.getEffectiveSgLayoutAsInt(),
148 attr.getEffectiveSgDataAsInt());
149 if (!maybeSgShape)
150 return false;
151 auto sgShape = maybeSgShape.value();
152
153 // check InstData, it neither have layout nor need round-robin
154 auto maybeInstShape =
155 tryDistribute(sgShape, {}, attr.getEffectiveInstDataAsInt(), false);
156 if (!maybeInstShape)
157 return false;
158 auto instShape = maybeInstShape.value();
159
160 // check LaneLayout and LaneData
161 auto maybeLaneShape =
162 tryDistribute(instShape, attr.getEffectiveLaneLayoutAsInt(),
163 attr.getEffectiveLaneDataAsInt(), false);
164 return maybeLaneShape.has_value();
165}
166
167//===----------------------------------------------------------------------===//
168// XeGPU_BlockTensorDescAttr
169//===----------------------------------------------------------------------===//
170BlockTensorDescAttr BlockTensorDescAttr::get(mlir::MLIRContext *context,
171 xegpu::MemorySpace memory_space,
172 int array_length,
173 bool boundary_check) {
174 auto scopeAttr = MemorySpaceAttr::get(context, memory_space);
175 auto lengthAttr =
176 IntegerAttr::get(IntegerType::get(context, 64), array_length);
177 auto boundaryAttr = BoolAttr::get(context, boundary_check);
178 return Base::get(context, scopeAttr, lengthAttr, boundaryAttr);
179}
180
181bool BlockTensorDescAttr::hasDefaultsOnly() {
182 return getMemorySpace().getValue() == xegpu::MemorySpace::Global &&
183 getArrayLength().getInt() == 1 && getBoundaryCheck().getValue();
184}
185
186//===----------------------------------------------------------------------===//
187// XeGPU_ScatterTensorDescAttr
188//===----------------------------------------------------------------------===//
189ScatterTensorDescAttr
190ScatterTensorDescAttr::get(mlir::MLIRContext *context,
191 xegpu::MemorySpace memory_space, int chunk_size) {
192 auto scopeAttr = MemorySpaceAttr::get(context, memory_space);
193 auto chunkSizeAttr =
194 IntegerAttr::get(IntegerType::get(context, 64), chunk_size);
195 return Base::get(context, scopeAttr, chunkSizeAttr);
196}
197
198LogicalResult ScatterTensorDescAttr::verify(
199 llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
200 MemorySpaceAttr memory_space, IntegerAttr chunk_size) {
201 int64_t chunkSize = chunk_size.getInt();
202 if (chunkSize <= 0)
203 return emitError() << "invalid chunk size";
204
205 return success();
206}
207
208//===----------------------------------------------------------------------===//
209// XeGPU_LayoutAttr
210//===----------------------------------------------------------------------===//
211LogicalResult
212LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
213 DenseI32ArrayAttr sg_layout, DenseI32ArrayAttr sg_data,
214 DenseI32ArrayAttr inst_data, DenseI32ArrayAttr lane_layout,
215 DenseI32ArrayAttr lane_data, DenseI32ArrayAttr order) {
216
217 // Special case for store_matrix
218 if (!sg_layout && !inst_data && !lane_layout)
219 return success();
220
221 // generate code to check sg_laout, inst_data and lane_layout having the same
222 // rank if they are not null.
223
224 if (sg_layout && inst_data && sg_layout.size() != inst_data.size()) {
225 return emitError()
226 << "expected sg_layout and inst_data to have the same rank";
227 }
228
229 if (sg_layout && lane_layout && sg_layout.size() != lane_layout.size()) {
230 return emitError()
231 << "expected sg_layout and lane_layout to have the same rank";
232 }
233
234 if (inst_data && lane_layout && inst_data.size() != lane_layout.size()) {
235 return emitError() << "expected inst_data and lane_layout to have the same "
236 "rank, got inst_data "
237 << inst_data.size() << ", lane_layout "
238 << lane_layout.size();
239 }
240
241 // sg_data is optional for Workgroup layout, but its presence requires
242 // sg_layout.
243 if (sg_data) {
244 if (!sg_layout)
245 return emitError() << "expected sg_layout being used with sg_data";
246 if (sg_data.size() != sg_layout.size())
247 return emitError()
248 << "expected sg_data and sg_layout to have the same rank";
249 }
250
251 // lane_data is optional for Subgroup layout, but its presence requires
252 // lane_layout.
253 if (lane_data) {
254 if (!lane_layout)
255 return emitError() << "expected lane_layout being used with lane_data";
256 if (lane_data.size() != lane_layout.size())
257 return emitError()
258 << "expected lane_data and lane_layout to have the same rank";
259 }
260
261 if (order) {
262 if (!sg_layout && !lane_layout)
263 return emitError()
264 << "expected sg_layout/lane_layout being used with order";
265
266 if (sg_layout && order.size() != sg_layout.size())
267 return emitError()
268 << "expected order and sg_layout to have the same rank";
269
270 if (lane_layout && order.size() != lane_layout.size())
271 return emitError()
272 << "expected order and lane_layout to have the same rank";
273 }
274
275 return success();
276}
277
278FailureOr<SmallVector<Value>>
279LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) {
280
281 SmallVector<int64_t> sgLayoutInt;
282 if (isForWorkgroup()) {
283 sgLayoutInt = getEffectiveSgLayoutAsInt();
284 } else if (isForSubgroup()) {
285 sgLayoutInt = getEffectiveLaneLayoutAsInt();
286 } else {
287 return failure();
288 }
289
290 DenseI32ArrayAttr orderAttr = getOrder();
291
292 // Handle order attribute
293 SmallVector<int64_t> order;
294 if (orderAttr && !orderAttr.empty()) {
295 order = llvm::map_to_vector(orderAttr.asArrayRef(), [](int32_t idx) {
296 return static_cast<int64_t>(idx);
297 });
298 } else {
299 // Default order: [1, 0] for 2D (row-major), [2, 1, 0] for 3D, etc.
300 order = llvm::to_vector(
301 llvm::reverse(llvm::seq<int64_t>(0, sgLayoutInt.size())));
302 }
303
304 if (order.size() != sgLayoutInt.size()) {
305 return failure();
306 }
307
308 SmallVector<Value> result(sgLayoutInt.size());
309 Value remaining = linearId;
310
311 /// Process dimensions in the order they appear in the order array
312 /// The first dimension in order is the fastest-changing
313 ///
314 /// Example walkthrough for linearId=22, sgLayout=[2,4,4], order=[2,1,0]:
315 ///
316 /// Initial: remaining=22, dimIdx = order[i], dimSize = sgLayout[dimIdx],
317 /// result=[?,?,?]
318 ///
319 /// i=0 (process columns, dimIdx=2, dimSize=4):
320 /// result[2] = 22 % 4 = 2 (column coordinate)
321 /// remaining = 22 / 4 = 5 (5 complete groups of 4 columns processed)
322 ///
323 /// i=1 (process rows, dimIdx=1, dimSize=4):
324 /// result[1] = 5 % 4 = 1 (row coordinate)
325 /// remaining = 5 / 4 = 1 (1 complete group of 4 rows processed)
326 ///
327 /// i=2 (process layers, dimIdx=0, dimSize=2):
328 /// result[0] = 1 % 2 = 1 (layer coordinate)
329 /// (no remaining update - last iteration)
330 ///
331 /// Final result: [1,1,2] = Layer 1, Row 1, Column 2
332 for (size_t i = 0; i < order.size(); ++i) {
333 int64_t dimIdx = order[i];
334 int64_t dimSize = sgLayoutInt[dimIdx];
335
336 Value dimSizeVal =
337 builder.createOrFold<arith::ConstantIndexOp>(loc, dimSize);
338
339 /// Extract the coordinate for this dimension using modulo operation
340 /// This gives us "how far within this dimension" we are
341 /// e.g., linearId=22, dimSize=4: 22 % 4 = 2 (we're at position 2 within
342 /// this dimension)
343 result[dimIdx] =
344 builder.createOrFold<arith::RemUIOp>(loc, remaining, dimSizeVal);
345
346 /// Update remaining for the next dimension by removing what we've already
347 /// processed. Division tells us "how many complete groups of this dimension
348 /// we've gone through" e.g., linearId=22, dimSize=4: 22 / 4 = 5 (we've
349 /// completed 5 groups of 4) Skip this for the last iteration since there's
350 /// no next dimension to process
351 if (i < order.size() - 1) {
352 remaining =
353 builder.createOrFold<arith::DivUIOp>(loc, remaining, dimSizeVal);
354 }
355 }
356 return result;
357}
358
359/// Implements DistributeLayoutAttr::computeDistributedCoords to generate
360/// instructions for computing multi-dimensional offsets when distributed by
361/// LayoutAttr.
362FailureOr<SmallVector<SmallVector<Value>>>
363LayoutAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
364 Value linearId, ArrayRef<int64_t> shape) {
365 SmallVector<int64_t> layout;
366 SmallVector<int64_t> subShape;
367 if (isForWorkgroup()) {
368 layout = getEffectiveSgLayoutAsInt();
369 subShape = getEffectiveSgDataAsInt();
370 } else if (isForSubgroup()) {
371 layout = getEffectiveLaneLayoutAsInt();
372 subShape = getEffectiveLaneDataAsInt();
373 } else {
374 return failure();
375 }
376 if (subShape.empty()) {
377 if (auto derivedShape = computeShapeRatio(shape, layout))
378 subShape = derivedShape.value();
379 else
380 return failure();
381 }
382
383 // delinearize Ids
384 auto maybeIds = delinearizeId(builder, loc, linearId);
385 if (failed(maybeIds))
386 return failure();
387 SmallVector<Value> ids = *maybeIds;
388
389 return genCoordinates(builder, loc, ids, layout, subShape, shape);
390}
391
392bool LayoutAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) {
393 if (dyn_cast<xegpu::SliceAttr>(other))
394 return false;
395
396 return *this == dyn_cast<xegpu::LayoutAttr>(other);
397}
398
399// set the layout for unit dims: sg_data, inst_data and lane_data to 1
400DistributeLayoutAttr
401LayoutAttr::setUnitDimData(SmallVector<int64_t> unitDims) const {
402 auto sgDataOpt = getSgData();
403 auto instDataOpt = getInstData();
404 auto laneDataOpt = getLaneData();
405
406 SmallVector<int32_t> sgData;
407 SmallVector<int32_t> instData;
408 SmallVector<int32_t> laneData;
409
410 if (sgDataOpt)
411 sgData = llvm::to_vector(sgDataOpt.asArrayRef());
412
413 if (instDataOpt)
414 instData = llvm::to_vector(instDataOpt.asArrayRef());
415
416 if (laneDataOpt)
417 laneData = llvm::to_vector(laneDataOpt.asArrayRef());
418
419 for (auto dim : unitDims) {
420 if (dim < static_cast<int64_t>(sgData.size()))
421 sgData[dim] = 1;
422 if (dim < static_cast<int64_t>(instData.size()))
423 instData[dim] = 1;
424 if (dim < static_cast<int64_t>(laneData.size()))
425 laneData[dim] = 1;
426 }
427
428 return LayoutAttr::get(
429 getContext(), getSgLayout(),
430 sgData.empty() ? DenseI32ArrayAttr()
432 instData.empty() ? DenseI32ArrayAttr()
433 : DenseI32ArrayAttr::get(getContext(), instData),
434 getLaneLayout(),
435 laneData.empty() ? DenseI32ArrayAttr()
436 : DenseI32ArrayAttr::get(getContext(), laneData),
437 getOrder());
438}
439
440// set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
441DistributeLayoutAttr
442LayoutAttr::setUnitDimLayout(SmallVector<int64_t> unitDims) const {
443 auto sgLayoutOpt = getSgLayout();
444 auto laneLayoutOpt = getLaneLayout();
445
446 SmallVector<int32_t> sgLayout;
447 SmallVector<int32_t> laneLayout;
448
449 if (sgLayoutOpt)
450 sgLayout = llvm::to_vector(sgLayoutOpt.asArrayRef());
451 if (laneLayoutOpt)
452 laneLayout = llvm::to_vector(laneLayoutOpt.asArrayRef());
453
454 for (auto dim : unitDims) {
455 if (dim < static_cast<int64_t>(sgLayout.size()))
456 sgLayout[dim] = 1;
457 if (dim < static_cast<int64_t>(laneLayout.size()))
458 laneLayout[dim] = 1;
459 }
460
461 return LayoutAttr::get(
462 getContext(),
463 sgLayout.empty() ? DenseI32ArrayAttr()
464 : DenseI32ArrayAttr::get(getContext(), sgLayout),
465 getSgData(), getInstData(),
466 laneLayout.empty() ? DenseI32ArrayAttr()
467 : DenseI32ArrayAttr::get(getContext(), laneLayout),
468 getLaneData(), getOrder());
469}
470
471// Derive a new layout with sg_data, inst_data and lane_data set to the
472// specified values for the given dimension
473DistributeLayoutAttr LayoutAttr::setDimData(int64_t dim, int64_t sgData,
474 int64_t instData,
475 int64_t laneData) {
476
477 SmallVector<int64_t> sgDataVec = getEffectiveSgDataAsInt();
478 SmallVector<int64_t> instDataVec = getEffectiveInstDataAsInt();
479 SmallVector<int64_t> laneDataVec = getEffectiveLaneDataAsInt();
480
481 if (dim < static_cast<int64_t>(sgDataVec.size()) && sgData != -1)
482 sgDataVec[dim] = sgData;
483 if (dim < static_cast<int64_t>(instDataVec.size()) && instData != -1)
484 instDataVec[dim] = instData;
485 if (dim < static_cast<int64_t>(laneDataVec.size()) && laneData != -1)
486 laneDataVec[dim] = laneData;
487
488 SmallVector<int32_t> sgDataVec32(sgDataVec.begin(), sgDataVec.end());
489 SmallVector<int32_t> instDataVec32(instDataVec.begin(), instDataVec.end());
490 SmallVector<int32_t> laneDataVec32(laneDataVec.begin(), laneDataVec.end());
491
492 return LayoutAttr::get(
493 getContext(), getSgLayout(),
494 sgDataVec.empty() ? DenseI32ArrayAttr()
495 : DenseI32ArrayAttr::get(getContext(), sgDataVec32),
496 instDataVec.empty() ? DenseI32ArrayAttr()
497 : DenseI32ArrayAttr::get(getContext(), instDataVec32),
498 getLaneLayout(),
499 laneDataVec.empty() ? DenseI32ArrayAttr()
500 : DenseI32ArrayAttr::get(getContext(), laneDataVec32),
501 getOrder());
502}
503
504// Derive a new layout by collapsing dimensions.
505// `dimGroup` specifies a group of adjacent dimensions
506// that are collapsed into a single dimension in the derived layout.
507DistributeLayoutAttr LayoutAttr::collapseDims(SmallVector<int64_t> dimGroup) {
508
509 SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt();
510 SmallVector<int64_t> sgData = getEffectiveSgDataAsInt();
511 SmallVector<int64_t> instData = getEffectiveInstDataAsInt();
512 SmallVector<int64_t> laneLayout = getEffectiveLaneLayoutAsInt();
513 SmallVector<int64_t> laneData = getEffectiveLaneDataAsInt();
514
515 DenseI32ArrayAttr orderAttr = getOrder();
516 SmallVector<int32_t> orderVec;
517 if (orderAttr && !orderAttr.empty()) {
518 orderVec = llvm::to_vector(
519 llvm::map_range(orderAttr.asArrayRef(),
520 [](int32_t idx) { return static_cast<int32_t>(idx); }));
521 }
522
523 SmallVector<int64_t> sortedDimGroup = dimGroup;
524 llvm::sort(sortedDimGroup);
525 int64_t dimBeforeCurrent = -1;
526 for (auto dimIdx : sortedDimGroup) {
527 // when order is present, adjacency dims are on order values like [3, 2, 1,
528 // 0] in decreasing order otherwise based on dim indices like [0, 1, 2, 3]
529 // in increasing order
530 if (dimBeforeCurrent >= 0) {
531 if (!orderVec.empty()) {
532 int64_t orderBefore = orderVec[dimBeforeCurrent];
533 int64_t orderCurrent = orderVec[dimIdx];
534 if (orderBefore != (orderCurrent - 1))
535 llvm::report_fatal_error(
536 "dimensions being collapsed must be adjacent in order");
537 } else {
538 if (dimIdx != (dimBeforeCurrent + 1))
539 llvm::report_fatal_error(
540 "dimensions being collapsed must be adjacent");
541 }
542 }
543 dimBeforeCurrent = dimIdx;
544 }
545
546 int firstDim = sortedDimGroup.front();
547
548 // collapse the dimensions in dimGroup into one dimension by multiplying their
549 // sizes together
550
551 if (!sgLayout.empty()) {
552 int64_t collapsedSglayout = 1, collapsedSgData = 1;
553 for (auto dimIdx : dimGroup) {
554 collapsedSglayout *= sgLayout[dimIdx];
555 collapsedSgData *= sgData[dimIdx];
556 }
557 for (auto dimIdx : llvm::reverse(sortedDimGroup)) {
558 sgLayout.erase(sgLayout.begin() + dimIdx, sgLayout.begin() + dimIdx + 1);
559 sgData.erase(sgData.begin() + dimIdx, sgData.begin() + dimIdx + 1);
560 }
561 sgLayout.insert(sgLayout.begin() + firstDim, collapsedSglayout);
562 sgData.insert(sgData.begin() + firstDim, collapsedSgData);
563 }
564
565 if (!instData.empty()) {
566 int64_t collapsedInstData = 1;
567 for (auto dimIdx : dimGroup)
568 collapsedInstData *= instData[dimIdx];
569 for (auto dimIdx : llvm::reverse(sortedDimGroup))
570 instData.erase(instData.begin() + dimIdx, instData.begin() + dimIdx + 1);
571 instData.insert(instData.begin() + firstDim, collapsedInstData);
572 }
573
574 if (!laneLayout.empty()) {
575 int64_t collapsedLaneLayout = 1, collapsedLaneData = 1;
576 for (auto dimIdx : dimGroup) {
577 collapsedLaneLayout *= laneLayout[dimIdx];
578 collapsedLaneData *= laneData[dimIdx];
579 }
580 for (auto dimIdx : llvm::reverse(sortedDimGroup)) {
581 laneLayout.erase(laneLayout.begin() + dimIdx,
582 laneLayout.begin() + dimIdx + 1);
583 laneData.erase(laneData.begin() + dimIdx, laneData.begin() + dimIdx + 1);
584 }
585 laneLayout.insert(laneLayout.begin() + firstDim, collapsedLaneLayout);
586 laneData.insert(laneData.begin() + firstDim, collapsedLaneData);
587 }
588
589 // go through the values inside collapsedOrder, and re-map the order values
590 // to be in range of [0, N-1] where N is the number of dimensions in
591 // collapsed shape for exmaple, collapse dim group {2, 3} of order[1, 2, 3,
592 // 4] to new order[1, 3, 4]. the loop below remaps it to [1, 2, 3].
593 SmallVector<int32_t> collapsedOrder;
594 if (!orderVec.empty()) {
595
596 for (auto dimIdx : llvm::reverse(sortedDimGroup)) {
597 if (dimIdx != firstDim)
598 orderVec.erase(orderVec.begin() + dimIdx,
599 orderVec.begin() + dimIdx + 1);
600 }
601
602 // say we have orderVec = {5, 3, 2, 1, 0}
603 // Create indices [0, 1, 2, 3, 4]
604 SmallVector<size_t> indices =
605 llvm::to_vector(llvm::seq<size_t>(0, orderVec.size()));
606
607 // Sort indices based on corresponding values
608 llvm::sort(indices,
609 [&](size_t a, size_t b) { return orderVec[a] < orderVec[b]; });
610 collapsedOrder = llvm::to_vector(llvm::map_range(
611 indices, [&](size_t i) { return static_cast<int32_t>(i); }));
612 }
613
614 // Create collapsed layout
615 SmallVector<int32_t> sgLayout32(sgLayout.begin(), sgLayout.end());
616 SmallVector<int32_t> sgData32(sgData.begin(), sgData.end());
617 SmallVector<int32_t> instData32(instData.begin(), instData.end());
618 SmallVector<int32_t> laneLayout32(laneLayout.begin(), laneLayout.end());
619 SmallVector<int32_t> laneData32(laneData.begin(), laneData.end());
620
621 auto collapsedLayout = xegpu::LayoutAttr::get(
622 getContext(),
623 sgLayout32.empty() ? DenseI32ArrayAttr()
624 : DenseI32ArrayAttr::get(getContext(), sgLayout32),
625 sgData32.empty() ? DenseI32ArrayAttr()
626 : DenseI32ArrayAttr::get(getContext(), sgData32),
627 instData32.empty() ? DenseI32ArrayAttr()
628 : DenseI32ArrayAttr::get(getContext(), instData32),
629 laneLayout32.empty() ? DenseI32ArrayAttr()
630 : DenseI32ArrayAttr::get(getContext(), laneLayout32),
631 laneData32.empty() ? DenseI32ArrayAttr()
632 : DenseI32ArrayAttr::get(getContext(), laneData32),
633 collapsedOrder.empty()
635 : DenseI32ArrayAttr::get(getContext(), collapsedOrder));
636 return collapsedLayout;
637}
638
639//===----------------------------------------------------------------------===//
640// XeGPU_SliceAttr
641//===----------------------------------------------------------------------===//
642LogicalResult
643SliceAttr::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
644 xegpu::DistributeLayoutAttr parent, DenseI64ArrayAttr dims) {
645
646 if (!dims)
647 return emitError() << "expected dims attribute";
648
649 // check every element in dims is unique and smaller than rank
650 llvm::SmallDenseSet<int64_t> seen;
651 for (int64_t dim : dims.asArrayRef()) {
652 if (dim < 0)
653 return emitError() << "invalid dim (" << dim << ") in slice attribute.";
654 if (!seen.insert(dim).second)
655 return emitError() << "repeated dim (" << dim << ") in slice attribute.";
656 }
657 return success();
658}
659
660SliceAttr SliceAttr::flatten() const {
661 xegpu::DistributeLayoutAttr parent = getParent();
662 SmallVector<DenseI64ArrayAttr> slicedDims({getDims()});
663
664 while (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(parent)) {
665 parent = sliceAttr.getParent();
666 slicedDims.push_back(sliceAttr.getDims());
667 }
668
669 auto layoutAttr = dyn_cast<xegpu::LayoutAttr>(parent);
670 SmallVector<int64_t> indices =
671 llvm::to_vector(llvm::seq<int64_t>(0, layoutAttr.getRank()));
672
673 // get remaining dims (flattend) by applying slice ops with all slicedDims
674 SmallVector<int64_t> remainingDims(indices);
675 for (auto dim : llvm::reverse(slicedDims))
676 remainingDims = XeGPUDialect::slice(llvm::ArrayRef<int64_t>(remainingDims),
677 dim.asArrayRef());
678
679 // get flattend sliced dims by applying slice ops with the remaining dims
680 SmallVector<int64_t> flattendDims = XeGPUDialect::slice(
681 llvm::ArrayRef<int64_t>(indices), llvm::ArrayRef<int64_t>(remainingDims));
682
683 return xegpu::SliceAttr::get(
684 getContext(), layoutAttr,
685 DenseI64ArrayAttr::get(getContext(), flattendDims));
686}
687
688FailureOr<SmallVector<Value>>
689SliceAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) {
690 SliceAttr attr = flatten();
691 auto parent = dyn_cast<LayoutAttr>(attr.getParent());
692 return parent.delinearizeId(builder, loc, linearId);
693}
694
695// Implements DistributeLayoutAttr::computeDistributedCoords to generate
696// instructions for computing multi-dimensional offsets when distributed by
697// LayoutAttr.
698FailureOr<SmallVector<SmallVector<Value>>>
699SliceAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
700 Value linearId, ArrayRef<int64_t> shape) {
701 assert(getRank() == static_cast<int64_t>(shape.size()) && "invalid shape.");
702 if (!isForWorkgroup())
703 return failure();
704
705 SmallVector<int64_t> layout;
706 SmallVector<int64_t> subShape;
707 if (isForWorkgroup()) {
708 layout = getEffectiveSgLayoutAsInt();
709 subShape = getEffectiveSgDataAsInt();
710 } else if (isForSubgroup()) {
711 layout = getEffectiveLaneLayoutAsInt();
712 subShape = getEffectiveLaneDataAsInt();
713 } else {
714 return failure();
715 }
716
717 if (subShape.empty()) {
718 if (auto derivedShape = computeShapeRatio(shape, layout))
719 subShape = derivedShape.value();
720 else
721 return failure();
722 }
723
724 // delinearize Ids
725 auto maybeIds = delinearizeId(builder, loc, linearId);
726 if (failed(maybeIds))
727 return failure();
728
729 // The effective sgIds for offsets computing correspond
730 // to the dims that are not sliced.
731 ArrayRef<int64_t> dims = flatten().getDims().asArrayRef();
732 SmallVector<Value> sgIds =
733 XeGPUDialect::slice(ArrayRef<Value>(*maybeIds), dims);
734
735 return genCoordinates(builder, loc, sgIds, layout, subShape, shape);
736}
737
738bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) {
739 auto flattenedThis = flatten();
740 // If other is a LayoutAttr, just compare directly with parent of
741 // flattenedThis.
742 if (auto otherLayout = dyn_cast<xegpu::LayoutAttr>(other))
743 return flattenedThis.getParent() == otherLayout;
744 // If other is a SliceAttr, flatten it first before comparing.
745 auto flattenedOther = dyn_cast<xegpu::SliceAttr>(other).flatten();
746 // Both must have common parent LayoutAttr.
747 if (flattenedThis.getParent() != flattenedOther.getParent())
748 return false;
749 // otherFlattened's sliced dims must be a subset of flattenedThis's sliced
750 // dims.
751 llvm::SmallDenseSet<int64_t> thisDims(
752 flattenedThis.getDims().asArrayRef().begin(),
753 flattenedThis.getDims().asArrayRef().end());
754 return llvm::all_of(flattenedOther.getDims().asArrayRef(),
755 [&](int64_t dim) { return thisDims.contains(dim); });
756}
757
758bool SliceAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) {
759 if (dyn_cast<xegpu::LayoutAttr>(other))
760 return false;
761
762 auto flattenedThis = flatten();
763 auto flattenedOther = dyn_cast<xegpu::SliceAttr>(other).flatten();
764
765 return ((flattenedThis.getParent() == flattenedOther.getParent()) &&
766 (flattenedThis.getDims() == flattenedOther.getDims()));
767}
768
769xegpu::SliceAttr SliceAttr::dropSliceDims(ArrayRef<int64_t> sliceDimsToDrop) {
770 if (sliceDimsToDrop.empty())
771 return *this;
772 SmallVector<int64_t> sliceDims{getDims().asArrayRef()};
773 for (auto dim : sliceDimsToDrop) {
774 auto foundIt = std::find(sliceDims.begin(), sliceDims.end(), dim);
775 assert(foundIt != sliceDims.end() &&
776 "Expected to find the specified reduction dim in slice dims");
777 sliceDims.erase(foundIt);
778 }
779
780 auto sliceWithoutDims = xegpu::SliceAttr::get(
781 this->getContext(), getParent(),
782 DenseI64ArrayAttr::get(this->getContext(), sliceDims));
783
784 return sliceWithoutDims;
785}
786
787// Helper function to adjust dimensions from sliced space to parent space
788// say we have a parent shape of rank 4, and slice dims [1,3], so the sliced
789// shape is of rank 2, if we want to set unit dim [0] in sliced space, it maps
790// to dim [0] in parent space; if we want to set unit dim [1] in sliced space,
791// it maps to dim [2] in parent space.
792static SmallVector<int64_t>
794 ArrayRef<int64_t> sliceDims) {
795 // Rather than recovering the exact parent rank, we compute a safe upper
796 // bound so that dimsToMap can be adjusted safely. This upper bound is
797 // defined as max(dimsToMap, sliceDims) + 1 + sliceDims.size().
798 int64_t maxDim = -1;
799 maxDim =
800 std::max(maxDim, *std::max_element(sliceDims.begin(), sliceDims.end()));
801 maxDim =
802 std::max(maxDim, *std::max_element(dimsToMap.begin(), dimsToMap.end()));
803 int64_t parentSpaceRank = maxDim + sliceDims.size() + 1;
804
805 // get remaining dims in parent space after applying slicing with parent's
806 // slice Dims
807 llvm::SmallDenseSet<int64_t> slicedDimsSet(sliceDims.begin(),
808 sliceDims.end());
809 SmallVector<int64_t> remainingDims;
810 for (int64_t i = 0; i < parentSpaceRank; ++i) {
811 if (!slicedDimsSet.contains(i))
812 remainingDims.push_back(i);
813 }
814
815 // Map unit dims from sliced space to parent space
816 SmallVector<int64_t> adjustUnitDims;
817 for (auto dim : dimsToMap) {
818 int64_t mappedDim = remainingDims[dim];
819 adjustUnitDims.push_back(mappedDim);
820 }
821
822 return adjustUnitDims;
823}
824
825// set the layout for unit dims: sg_data, inst_data and lane_data to 1
826DistributeLayoutAttr
827SliceAttr::setUnitDimData(SmallVector<int64_t> unitDims) const {
828 DistributeLayoutAttr parentLayout = getParent();
829
830 ArrayRef<int64_t> sliceDims = getDims().asArrayRef();
831
832 SmallVector<int64_t> adjustUnitDims =
833 mapSlicedDimsToParentSpace(unitDims, sliceDims);
834
835 return SliceAttr::get(getContext(),
836 parentLayout.setUnitDimData(adjustUnitDims), getDims());
837}
838
839// set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
840DistributeLayoutAttr
841SliceAttr::setUnitDimLayout(SmallVector<int64_t> unitDims) const {
842 DistributeLayoutAttr parentLayout = getParent();
843
844 ArrayRef<int64_t> sliceDims = getDims().asArrayRef();
845
846 SmallVector<int64_t> adjustUnitDims =
847 mapSlicedDimsToParentSpace(unitDims, sliceDims);
848
849 return SliceAttr::get(
850 getContext(), parentLayout.setUnitDimLayout(adjustUnitDims), getDims());
851}
852
853// Derive a new layout with sg_data, inst_data and lane_data set to the
854// specified values for the given dimension
855DistributeLayoutAttr SliceAttr::setDimData(int64_t dim, int64_t sgData,
856 int64_t instData, int64_t laneData) {
857 ArrayRef<int64_t> sliceDims = getDims().asArrayRef();
858 auto parent = getParent();
859
860 SmallVector<int64_t> dimSet;
861 dimSet.push_back(dim);
862 SmallVector<int64_t> adjustDims =
863 mapSlicedDimsToParentSpace(dimSet, sliceDims);
864 return SliceAttr::get(
865 getContext(),
866 parent.setDimData(adjustDims[0], sgData, instData, laneData), getDims());
867}
868
869// Derive a new layout by collapsing dimensions.
870// `dimGroup` specifies a group of adjacent dimensions
871// that are collapsed into a single dimension in the derived layout.
872DistributeLayoutAttr SliceAttr::collapseDims(SmallVector<int64_t> dimGroup) {
873
874 // Map the sliced dims from parent space to collapsed space
875 SmallVector<int64_t> sliceDims = llvm::to_vector(getDims().asArrayRef());
876
877 SmallVector<int64_t> dimsInParentSpace =
878 mapSlicedDimsToParentSpace(dimGroup, sliceDims);
879
880 auto collapsedParent = getParent().collapseDims(dimsInParentSpace);
881
882 return SliceAttr::get(getContext(), collapsedParent,
883 DenseI64ArrayAttr::get(getContext(), sliceDims));
884}
885
886//===----------------------------------------------------------------------===//
887// XeGPU_RangeAttr
888//===----------------------------------------------------------------------===//
889
890LogicalResult
891RangeAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
892 IntegerAttr startOfRange, IntegerAttr endOfRange) {
893 if (startOfRange.getInt() >= endOfRange.getInt())
894 return emitError() << "'end' : " << endOfRange.getInt()
895 << " must be greater than 'start' : "
896 << startOfRange.getInt();
897
898 return success();
899}
900
901//===----------------------------------------------------------------------===//
902// XeGPU_TensorDescType
903//===----------------------------------------------------------------------===//
904
905mlir::Type TensorDescType::parse(AsmParser &parser) {
906 llvm::SmallVector<int64_t> shape;
907 mlir::Type elementType;
908 mlir::FailureOr<mlir::Attribute> encoding;
909 mlir::FailureOr<mlir::Attribute> layout;
910
911 // Parse literal '<'
912 if (parser.parseLess())
913 return {};
914
915 auto shapeLoc = parser.getCurrentLocation();
916 if (mlir::failed(parser.parseDimensionList(shape))) {
917 parser.emitError(shapeLoc, "failed to parse parameter 'shape'");
918 return {};
919 }
920
921 auto elemTypeLoc = parser.getCurrentLocation();
922 if (mlir::failed(parser.parseType(elementType))) {
923 parser.emitError(elemTypeLoc, "failed to parse parameter 'elementType'");
924 return {};
925 }
926
927 // parse optional attributes
928 while (mlir::succeeded(parser.parseOptionalComma())) {
929 mlir::Attribute attr;
930 ParseResult res = parser.parseAttribute(attr);
931 if (mlir::succeeded(res)) {
932 if (mlir::isa<LayoutAttr>(attr)) {
933 layout = attr;
934 continue;
935 }
936 if (mlir::isa<BlockTensorDescAttr, ScatterTensorDescAttr>(attr)) {
937 encoding = attr;
938 continue;
939 }
940 }
941 return {};
942 }
943
944 // Parse literal '>'
945 if (parser.parseGreater())
946 return {};
947
948 MLIRContext *ctxt = parser.getContext();
949 return TensorDescType::getChecked(
950 [&]() { return parser.emitError(parser.getNameLoc()); }, ctxt, shape,
951 elementType, encoding.value_or(BlockTensorDescAttr::get(ctxt)),
952 layout.value_or(mlir::Attribute()));
953}
954
955void TensorDescType::print(AsmPrinter &printer) const {
956 printer << "<";
957
958 auto shape = getShape();
959 for (int64_t dim : shape) {
960 if (mlir::ShapedType::isDynamic(dim))
961 printer << '?';
962 else
963 printer << dim;
964 printer << 'x';
965 }
966
967 printer << getElementType();
968
969 auto encoding = getEncoding();
970 auto blockAttr = llvm::dyn_cast_if_present<BlockTensorDescAttr>(encoding);
971 if (encoding && (!blockAttr || !blockAttr.hasDefaultsOnly()))
972 printer << ", " << encoding;
973
974 if (auto layout = getLayout())
975 printer << ", " << layout;
976
977 printer << ">";
978}
979
980TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
981 mlir::Type elementType, int array_length,
982 bool boundary_check,
983 MemorySpace memory_space,
984 mlir::Attribute layout) {
985 auto context = elementType.getContext();
986 auto attr = BlockTensorDescAttr::get(context, memory_space, array_length,
987 boundary_check);
988 return Base::get(context, shape, elementType, attr, layout);
989}
990
991TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
992 mlir::Type elementType, int chunk_size,
993 MemorySpace memory_space,
994 mlir::Attribute layout) {
995 auto context = elementType.getContext();
996 auto attr = ScatterTensorDescAttr::get(context, memory_space, chunk_size);
997 return Base::get(context, shape, elementType, attr, layout);
998}
999
1000LogicalResult
1001TensorDescType::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
1002 llvm::ArrayRef<int64_t> shape, mlir::Type elementType,
1003 mlir::Attribute encoding, mlir::Attribute layout) {
1004 size_t rank = shape.size();
1005
1006 if (rank == 0)
1007 return emitError() << "expected non-zero rank tensor";
1008
1009 auto blockAttr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(encoding);
1010 if (blockAttr) {
1011 MemorySpaceAttr memorySpaceAttr = blockAttr.getMemorySpace();
1012 if (rank > 1 && memorySpaceAttr &&
1013 memorySpaceAttr.getValue() == MemorySpace::SLM)
1014 return emitError() << "SLM is only supported for 1D block tensor";
1015 }
1016
1017 if (!elementType.isIntOrFloat())
1018 return emitError() << "unsupported element type " << elementType
1019 << ": expected integer or float";
1020
1021 // for gather and scatter ops, Low-precision types are packed in 32-bit
1022 // units.
1023 unsigned bitWidth = elementType.getIntOrFloatBitWidth();
1024 int chunkAlignmentFactor =
1027 : 1;
1028 auto scatterAttr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(encoding);
1029 if (scatterAttr) {
1030 int64_t chunkSize = scatterAttr.getChunkSizeAsInt();
1031 if (rank == 1 && chunkSize != 1)
1032 return emitError() << "expected non-contiguous elements for 1D tensor";
1033
1034 // If chunk size > 1, the second dimension of the tensor shape must be
1035 // equal to chunk size and it must be a multiple of the
1036 // chunkAlignmentFactor.
1037 if (chunkSize > 1) {
1038 if (shape.back() != chunkSize)
1039 return emitError() << "expected last dim of tensor to match chunk size";
1040 if (shape.back() % chunkAlignmentFactor != 0)
1041 return emitError() << "expected last dim of tensor to be a multiple of "
1042 << chunkAlignmentFactor;
1043 }
1044 }
1045
1046 auto layoutAttr = llvm::dyn_cast_if_present<LayoutAttr>(layout);
1047 if (layoutAttr) {
1048 if (rank != (size_t)layoutAttr.getRank())
1049 return emitError() << "expected layout rank to match tensor rank";
1050
1051 auto laneData = layoutAttr.getLaneData();
1052 if (scatterAttr && laneData) {
1053 // Validate subgroup mapping rules for scattered tensors.
1054 // if chunkSize > 1, the last dimension of the tensor should
1055 // be distributed in the units divisible by chunkAlignmentFactor.
1056 int64_t chunkSize = scatterAttr.getChunkSizeAsInt();
1057 if (chunkSize > 1 && laneData[rank - 1] % chunkAlignmentFactor)
1058 return emitError()
1059 << "expected last dim of lane_data to be a multiple of: "
1060 << chunkAlignmentFactor;
1061 }
1062
1063 if (!XeGPUDialect::isEvenlyDistributable(shape, layoutAttr)) {
1064 std::string shapeStr;
1065 llvm::raw_string_ostream stream(shapeStr);
1066 llvm::interleaveComma(shape, stream);
1067 return emitError() << "cannot distribute [" << shapeStr << "] using "
1068 << layoutAttr;
1069 }
1070 }
1071 return success();
1072}
1073
1074//===----------------------------------------------------------------------===//
1075// XeGPU_MemDescType
1076//===----------------------------------------------------------------------===//
1077mlir::Type MemDescType::parse(AsmParser &parser) {
1078 llvm::SmallVector<int64_t> shape;
1079 mlir::Type elementType;
1080 mlir::FailureOr<MemLayoutAttr> layout;
1081
1082 // Parse literal '<'
1083 if (parser.parseLess())
1084 return {};
1085
1086 auto shapeLoc = parser.getCurrentLocation();
1087 if (mlir::failed(parser.parseDimensionList(shape, false, true))) {
1088 parser.emitError(shapeLoc, "failed to parse parameter 'shape'");
1089 return {};
1090 }
1091
1092 auto elemTypeLoc = parser.getCurrentLocation();
1093 if (mlir::failed(parser.parseType(elementType))) {
1094 parser.emitError(elemTypeLoc, "failed to parse parameter 'elementType'");
1095 return {};
1096 }
1097
1098 // parse optional attributes
1099 if (mlir::succeeded(parser.parseOptionalComma())) {
1100 MemLayoutAttr attr;
1101 ParseResult res = parser.parseAttribute(attr);
1102 if (mlir::failed(res))
1103 return {};
1104 layout = attr;
1105 }
1106
1107 // Parse literal '>'
1108 if (parser.parseGreater())
1109 return {};
1110
1111 MLIRContext *ctxt = parser.getContext();
1112 return MemDescType::getChecked(
1113 [&]() { return parser.emitError(parser.getNameLoc()); }, ctxt, shape,
1114 elementType, layout.value_or(MemLayoutAttr()));
1115}
1116
1117void MemDescType::print(AsmPrinter &printer) const {
1118 printer << "<";
1119
1120 printer.printDimensionList(getShape());
1121 printer << 'x';
1122 printer << getElementType();
1123
1124 if (auto layout = getMemLayout())
1125 printer << ", " << layout;
1126
1127 printer << ">";
1128}
1129
1130//===----------------------------------------------------------------------===//
1131// XeGPU_MemDescType
1132//===----------------------------------------------------------------------===//
1133
1134Attribute MemLayoutAttr::parse(AsmParser &parser, Type type) {
1135
1136 auto context = parser.getContext();
1137 llvm::SMLoc loc = parser.getCurrentLocation();
1138
1139 llvm::SmallDenseSet<StringRef> seenKeys;
1140 SmallVector<NamedAttribute> attributes;
1141
1142 auto parseElt = [&]() -> ParseResult {
1143 StringRef nameId;
1144 if (failed(parser.parseKeyword(&nameId)))
1145 return parser.emitError(loc, "expected valid attribute name");
1146
1147 if (!seenKeys.insert(nameId).second)
1148 return parser.emitError(loc, "duplicate key '")
1149 << nameId << " in mem layout attribute";
1150
1151 if (failed(parser.parseEqual()))
1152 return failure();
1153
1154 Attribute attr;
1155 if (failed(parser.parseAttribute(attr)))
1156 return failure();
1157 attributes.emplace_back(nameId, attr);
1158 return success();
1159 };
1160
1161 // Parse literal '<'
1162 if (parser.parseLess())
1163 return {};
1164
1165 if (failed(parser.parseCommaSeparatedList(parseElt)))
1166 return {};
1167
1168 // Parse literal '>'
1169 if (parser.parseGreater())
1170 return {};
1171
1172 return parser.getChecked<MemLayoutAttr>(
1173 loc, context, DictionaryAttr::get(context, attributes));
1174}
1175
1176void MemLayoutAttr::print(AsmPrinter &printer) const {
1177 printer << "<";
1178 ArrayRef<NamedAttribute> attrs = getAttrs().getValue();
1179 for (size_t i = 0; i < attrs.size(); i++) {
1180 printer << attrs[i].getName().str() << " = " << attrs[i].getValue();
1181 if (i < attrs.size() - 1)
1182 printer << ", ";
1183 }
1184 printer << ">";
1185}
1186// a helper utility to perform binary operation on OpFoldResult.
1187// If both a and b are attributes, it will simply return the result.
1188// Otherwise, the corresponding arith op will be generated, and an
1189// contant op will be created if one of them is an attribute.
1190template <typename ArithOp>
1192 OpBuilder &builder) {
1193 auto aVal = getValueOrCreateConstantIndexOp(builder, loc, a);
1194 auto bVal = getValueOrCreateConstantIndexOp(builder, loc, b);
1195 return ArithOp::create(builder, loc, aVal, bVal).getResult();
1196}
1197
1198// a helper utility to perform division operation on OpFoldResult and int64_t.
1199#define div(a, b) \
1200 genBinOp<arith::DivSIOp>(a, builder.getIndexAttr(b), loc, builder)
1201
1202// a helper utility to perform reminder operation on OpFoldResult and int64_t.
1203#define rem(a, b) \
1204 genBinOp<arith::RemSIOp>(a, builder.getIndexAttr(b), loc, builder)
1205
1206// a helper utility to perform multiply operation on OpFoldResult and int64_t.
1207#define mul(a, b) \
1208 genBinOp<arith::MulIOp>(a, builder.getIndexAttr(b), loc, builder)
1209
1210// a helper utility to perform addition operation on two OpFoldResult.
1211#define add(a, b) genBinOp<arith::AddIOp>(a, b, loc, builder)
1212
1213// block the given offsets according to the block shape
1214// say the original offset is [y, x], and the block shape is [By, Bx],
1215// then the blocked offset is [y/By, x/Bx, y%By, x%Bx]
1217 ArrayRef<OpFoldResult> offsets,
1218 ArrayRef<int64_t> blockShape) {
1219
1220 assert(offsets.size() == blockShape.size() &&
1221 "offsets and blockShape must have the same size");
1222 SmallVector<OpFoldResult> blockedOffsets;
1223 SmallVector<OpFoldResult> divs, rems;
1224
1225 for (auto [offset, block] : llvm::zip(offsets, blockShape)) {
1226 divs.push_back(div(offset, block));
1227 rems.push_back(rem(offset, block));
1228 }
1229 blockedOffsets.append(divs.begin(), divs.end());
1230 blockedOffsets.append(rems.begin(), rems.end());
1231
1232 return blockedOffsets;
1233}
1234
1235// Get strides as vector of integer for MemDesc.
1236SmallVector<int64_t> MemDescType::getStrideShape() {
1237
1238 SmallVector<int64_t> matrixShape(getShape().begin(), getShape().end());
1239
1240 ArrayAttr strideAttr = getStrideAttr();
1241 SmallVector<int64_t> strides;
1242 for (Attribute attr : strideAttr.getValue()) {
1243 strides.push_back(cast<IntegerAttr>(attr).getInt());
1244 }
1245
1246 SmallVector<int64_t> innerBlkShape = getBlockShape();
1247
1248 // get perm from FCD to LCD
1249 // perm[i] = the dim with i-th smallest stride
1250 SmallVector<int, 4> perm =
1251 llvm::to_vector<4>(llvm::seq<int>(0, strides.size()));
1252 llvm::sort(perm, [&](int a, int b) { return strides[a] < strides[b]; });
1253
1254 assert(strides[perm[0]] == 1 && "inner most dim must have stride 1");
1255
1256 SmallVector<int64_t> innerBlkStride(innerBlkShape.size());
1257 innerBlkStride[perm[0]] = 1;
1258 for (size_t i = 1; i < perm.size(); ++i)
1259 innerBlkStride[perm[i]] =
1260 innerBlkStride[perm[i - 1]] * innerBlkShape[perm[i - 1]];
1261
1262 // compute the original matrix shape using the stride info
1263 // and compute the number of blocks in each dimension
1264 // The shape of highest dim can't be derived from stride info,
1265 // but doesn't impact the stride computation for blocked layout.
1266 SmallVector<int64_t> matrixShapeOrig(matrixShape.size());
1267 SmallVector<int64_t> BlkShapeOrig(matrixShape.size());
1268 for (size_t i = 0; i < perm.size() - 1; ++i) {
1269 matrixShapeOrig[perm[i]] = strides[perm[i + 1]] / strides[perm[i]];
1270 BlkShapeOrig[perm[i]] = matrixShapeOrig[perm[i]] / innerBlkShape[perm[i]];
1271 }
1272
1273 int64_t innerBlkSize = 1;
1274 for (auto s : innerBlkShape)
1275 innerBlkSize *= s;
1276
1277 SmallVector<int64_t> outerBlkStride(matrixShape.size());
1278 outerBlkStride[perm[0]] = innerBlkSize;
1279 for (size_t i = 0; i < perm.size() - 1; ++i) {
1280 outerBlkStride[perm[i + 1]] =
1281 outerBlkStride[perm[i]] * BlkShapeOrig[perm[i]];
1282 }
1283
1284 // combine the inner and outer strides
1285 SmallVector<int64_t> blockedStrides;
1286 blockedStrides.append(outerBlkStride.begin(), outerBlkStride.end());
1287 blockedStrides.append(innerBlkStride.begin(), innerBlkStride.end());
1288
1289 return blockedStrides;
1290}
1291
1292// Calculate the linear offset using the blocked offsets and stride
1293Value MemDescType::getLinearOffsets(OpBuilder &builder, Location loc,
1294 ArrayRef<OpFoldResult> offsets) {
1295
1296 SmallVector<int64_t> matrixShape(getShape().begin(), getShape().end());
1297 SmallVector<int64_t> blockShape = getBlockShape();
1298 SmallVector<int64_t> strides = getStrideShape();
1299 SmallVector<OpFoldResult> blockedOffsets;
1300
1301 // blockshape equal to matrixshape means no blocking
1302 if (llvm::equal(blockShape, matrixShape)) {
1303 // remove the outer dims from strides
1304 strides.erase(strides.begin(), strides.begin() + matrixShape.size());
1305 } else {
1306 assert(offsets.size() == blockShape.size() &&
1307 "offsets and blockShape must have the same size");
1308 // say the original offset is [y, x], and the block shape is [By, Bx],
1309 // then the blocked offset is [y/By, x/Bx, y%By, x%Bx]
1310
1311 SmallVector<OpFoldResult> divs, rems;
1312
1313 for (auto [offset, block] : llvm::zip(offsets, blockShape)) {
1314 divs.push_back(div(offset, block));
1315 rems.push_back(rem(offset, block));
1316 }
1317 blockedOffsets.append(divs.begin(), divs.end());
1318 blockedOffsets.append(rems.begin(), rems.end());
1319 offsets = blockedOffsets;
1320 }
1321
1322 // Start with initial value as matrix descriptor's base offset.
1323 Value linearOffset = arith::ConstantIndexOp::create(builder, loc, 0);
1324 for (size_t i = 0; i < offsets.size(); ++i) {
1325 OpFoldResult mulResult = mul(offsets[i], strides[i]);
1326 Value mulVal = getValueOrCreateConstantIndexOp(builder, loc, mulResult);
1327 linearOffset = arith::AddIOp::create(builder, loc, mulVal, linearOffset);
1328 }
1329
1330 return linearOffset;
1331}
1332
1333} // namespace xegpu
1334} // namespace mlir
1335
1336#include <mlir/Dialect/XeGPU/IR/XeGPUDialect.cpp.inc>
1337#define GET_ATTRDEF_CLASSES
1338#include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc>
1339#define GET_TYPEDEF_CLASSES
1340#include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc>
return success()
static Type getElementType(Type type)
Determine the element type of type.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
#define mul(a, b)
#define div(a, b)
#define rem(a, b)
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
auto getChecked(SMLoc loc, ParamsT &&...params)
Invoke the getChecked method of the given Attribute or Type class, using the provided location to emi...
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
void printDimensionList(ArrayRef< int64_t > shape)
Attributes are known-constant values of operations.
Definition Attributes.h:25
static BoolAttr get(MLIRContext *context, bool value)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:209
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
This class represents a single result from folding an operation.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Specialization of arith.constant op that returns an integer of index type.
Definition Arith.h:113
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
auto getDims(VectorType vType)
Returns a range over the dims (size and scalability) of a VectorType.
constexpr unsigned generalPackedFormatBitSize
Definition uArchBase.h:32
static SmallVector< int64_t > mapSlicedDimsToParentSpace(const SmallVector< int64_t > &dimsToMap, ArrayRef< int64_t > sliceDims)
SmallVector< OpFoldResult > getBlockedOffsets(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > offsets, ArrayRef< int64_t > blockShape)
OpFoldResult genBinOp(OpFoldResult a, OpFoldResult b, Location loc, OpBuilder &builder)
static SmallVector< SmallVector< Value > > genCoordinates(OpBuilder &builder, Location loc, SmallVector< Value > delinearizedId, ArrayRef< int64_t > subShapesLayout, ArrayRef< int64_t > subShape, ArrayRef< int64_t > srcShape)
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
SmallVector< int64_t > computeElementwiseMul(ArrayRef< int64_t > v1, ArrayRef< int64_t > v2)
Return a vector containing llvm::zip_equal(v1, v2) multiplied elementwise.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:112
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.