MLIR 22.0.0git
XeGPUDialect.cpp
Go to the documentation of this file.
1//===- XeGPUDialect.cpp - MLIR XeGPU dialect implementation -----*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
15#include "mlir/IR/Builders.h"
17#include "llvm/ADT/TypeSwitch.h"
18#include "llvm/Support/Debug.h"
19
20using std::optional;
21
22namespace mlir {
23namespace xegpu {
24
25void XeGPUDialect::initialize() {
26 addTypes<
27#define GET_TYPEDEF_LIST
28#include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc>
29 >();
30 addOperations<
31#define GET_OP_LIST
32#include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
33 >();
34 addAttributes<
35#define GET_ATTRDEF_LIST
36#include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc>
37 >();
38}
39
40// A `srcShape` consists of N distribution units, each being `subShapesLayout` x
41// `subShape`. A `delinearizedId` is used to identify a particular `subShape`
42// within each distribution unit.
43// Example:
44// WG data is 128x256. SG data is 16x32, in 4x2 layout, this gives a
45// distribution unit of shape 64x64, we have 2x4 such distribution units.
46// `delinearizedId` is used to identify a 16x32 of a subgroup in each
47// distribution unit.
48static SmallVector<SmallVector<Value>>
50 SmallVector<Value> delinearizedId,
51 ArrayRef<int64_t> subShapesLayout, ArrayRef<int64_t> subShape,
52 ArrayRef<int64_t> srcShape) {
54
55 // A distribution unit must be less than or equal to `srcShape`
56 SmallVector<int64_t> distUnitShape = llvm::map_to_vector(
57 llvm::zip_equal(srcShape,
58 computeElementwiseMul(subShapesLayout, subShape)),
59 [](const auto &t) { return std::min(std::get<0>(t), std::get<1>(t)); });
60
61 // Get the offset of `subShape` within a distribution unit.
62 SmallVector<Value> distUnitLocalOffset = llvm::map_to_vector(
63 llvm::zip(delinearizedId, subShape), [&](const auto &t) -> Value {
64 return builder.createOrFold<index::MulOp>(
65 loc, std::get<0>(t),
66 builder.createOrFold<arith::ConstantIndexOp>(loc, std::get<1>(t)));
67 });
68
69 // For each dist unit
70 for (SmallVector<int64_t> unitOffs :
71 StaticTileOffsetRange(srcShape, distUnitShape)) {
72 // Get dist unit offset within `srcShape`.
74 llvm::map_to_vector(unitOffs, [&](int64_t d) -> Value {
75 return arith::ConstantIndexOp::create(builder, loc, d);
76 });
77 // Calculate `subShape` offset within `srcShape`.
79 llvm::map_to_vector(llvm::zip_equal(base, distUnitLocalOffset),
80 [&](const auto &t) -> Value {
81 return builder.createOrFold<arith::AddIOp>(
82 loc, std::get<0>(t), std::get<1>(t));
83 });
84 // Do not go beyond `srcShape` bounds.
85 SmallVector<Value> mods = llvm::map_to_vector(
86 llvm::zip_equal(adds, srcShape), [&](const auto &t) -> Value {
87 return builder.createOrFold<index::RemUOp>(
88 loc, std::get<0>(t),
89 arith::ConstantIndexOp::create(builder, loc, std::get<1>(t)));
90 });
91
92 coordinates.push_back(mods);
93 }
94 return coordinates;
95}
96
97// Checks if the given shape can be evenly distributed based on the layout
98// and data factors provided by the LayoutAttr.
99bool XeGPUDialect::isEvenlyDistributable(llvm::ArrayRef<int64_t> shape,
100 xegpu::DistributeLayoutAttr attr) {
101 assert(attr && "Layout attribute is missing.");
102
103 // Checks whether the given shape can be evenly distributed using the
104 // specified layout and data attributes. If successful, it returns the work
105 // size for each compute unit; otherwise, it returns `std::nullopt`. The work
106 // size per compute unit is calculated as follows:
107 // - If `data` is null: newShape[i] = shape[i] / layout[i]
108 // - If `data` is not null: newShape[i] = data[i]
109 // When round-robin distribution (`rr`) is enabled, `shape[i]` can be
110 // smaller than `layout[i] * data[i]`, allowing multiple compute units to
111 // share the data.
112 auto tryDistribute = [&](llvm::ArrayRef<int64_t> shape,
115 bool rr = true) -> optional<SmallVector<int64_t>> {
117 if (layout.size()) {
118 if (layout.size() != shape.size())
119 return std::nullopt;
120 auto ratio = computeShapeRatio(shape, layout);
121 if (ratio.has_value()) {
122 newShape = ratio.value();
123 } else if (!rr || !computeShapeRatio(layout, shape).has_value()) {
124 return std::nullopt;
125 }
126 // Round-robin case: continue with original newShape
127 }
128
129 if (data.size()) {
130 if (data.size() != shape.size())
131 return std::nullopt;
132 auto ratio = computeShapeRatio(newShape, data);
133 if (!ratio.has_value() && rr)
134 ratio = computeShapeRatio(data, newShape);
135 if (!ratio.has_value())
136 return std::nullopt;
137
138 // if data is not null, we always return it for next phase.
139 newShape = data;
140 }
141 return newShape;
142 };
143
144 // check the sgLayout and sgData
145 auto maybeSgShape = tryDistribute(shape, attr.getEffectiveSgLayoutAsInt(),
146 attr.getEffectiveSgDataAsInt());
147 if (!maybeSgShape)
148 return false;
149 auto sgShape = maybeSgShape.value();
150
151 // check InstData, it neither have layout nor need round-robin
152 auto maybeInstShape =
153 tryDistribute(sgShape, {}, attr.getEffectiveInstDataAsInt(), false);
154 if (!maybeInstShape)
155 return false;
156 auto instShape = maybeInstShape.value();
157
158 // check LaneLayout and LaneData
159 auto maybeLaneShape =
160 tryDistribute(instShape, attr.getEffectiveLaneLayoutAsInt(),
161 attr.getEffectiveLaneDataAsInt(), false);
162 return maybeLaneShape.has_value();
163}
164
165//===----------------------------------------------------------------------===//
166// XeGPU_BlockTensorDescAttr
167//===----------------------------------------------------------------------===//
168BlockTensorDescAttr BlockTensorDescAttr::get(mlir::MLIRContext *context,
169 xegpu::MemorySpace memory_space,
170 int array_length,
171 bool boundary_check) {
172 auto scopeAttr = MemorySpaceAttr::get(context, memory_space);
173 auto lengthAttr =
174 IntegerAttr::get(IntegerType::get(context, 64), array_length);
175 auto boundaryAttr = BoolAttr::get(context, boundary_check);
176 return Base::get(context, scopeAttr, lengthAttr, boundaryAttr);
177}
178
179bool BlockTensorDescAttr::hasDefaultsOnly() {
180 return getMemorySpace().getValue() == xegpu::MemorySpace::Global &&
181 getArrayLength().getInt() == 1 && getBoundaryCheck().getValue();
182}
183
184//===----------------------------------------------------------------------===//
185// XeGPU_ScatterTensorDescAttr
186//===----------------------------------------------------------------------===//
187ScatterTensorDescAttr
188ScatterTensorDescAttr::get(mlir::MLIRContext *context,
189 xegpu::MemorySpace memory_space, int chunk_size) {
190 auto scopeAttr = MemorySpaceAttr::get(context, memory_space);
191 auto chunkSizeAttr =
192 IntegerAttr::get(IntegerType::get(context, 64), chunk_size);
193 return Base::get(context, scopeAttr, chunkSizeAttr);
194}
195
196LogicalResult ScatterTensorDescAttr::verify(
197 llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
198 MemorySpaceAttr memory_space, IntegerAttr chunk_size) {
199 int64_t chunkSize = chunk_size.getInt();
200 if (chunkSize <= 0)
201 return emitError() << "invalid chunk size";
202
203 return success();
204}
205
206//===----------------------------------------------------------------------===//
207// XeGPU_LayoutAttr
208//===----------------------------------------------------------------------===//
209LogicalResult
210LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
211 DenseI32ArrayAttr sg_layout, DenseI32ArrayAttr sg_data,
212 DenseI32ArrayAttr inst_data, DenseI32ArrayAttr lane_layout,
213 DenseI32ArrayAttr lane_data, DenseI32ArrayAttr order) {
214
215 // A valid layout must include at least one of sg_layout and lane_layout.
216 // sg_layout is essential for Workgroup layout, while lane_layout is
217 // required for Subgroup layout.
218 if (!sg_layout && !inst_data && !lane_layout) {
219 return emitError()
220 << "expected at least one of sg_layout, inst_data or lane_layout";
221 }
222
223 // generate code to check sg_laout, inst_data and lane_layout having the same
224 // rank if they are not null.
225
226 if (sg_layout && inst_data && sg_layout.size() != inst_data.size()) {
227 return emitError()
228 << "expected sg_layout and inst_data to have the same rank";
229 }
230
231 if (sg_layout && lane_layout && sg_layout.size() != lane_layout.size()) {
232 return emitError()
233 << "expected sg_layout and lane_layout to have the same rank";
234 }
235
236 if (inst_data && lane_layout && inst_data.size() != lane_layout.size()) {
237 return emitError() << "expected inst_data and lane_layout to have the same "
238 "rank, got inst_data "
239 << inst_data.size() << ", lane_layout "
240 << lane_layout.size();
241 }
242
243 // sg_data is optional for Workgroup layout, but its presence requires
244 // sg_layout.
245 if (sg_data) {
246 if (!sg_layout)
247 return emitError() << "expected sg_layout being used with sg_data";
248 if (sg_data.size() != sg_layout.size())
249 return emitError()
250 << "expected sg_data and sg_layout to have the same rank";
251 }
252
253 // lane_data is optional for Subgroup layout, but its presence requires
254 // lane_layout.
255 if (lane_data) {
256 if (!lane_layout)
257 return emitError() << "expected lane_layout being used with lane_data";
258 if (lane_data.size() != lane_layout.size())
259 return emitError()
260 << "expected lane_data and lane_layout to have the same rank";
261 }
262
263 if (order) {
264 if (!sg_layout && !lane_layout)
265 return emitError()
266 << "expected sg_layout/lane_layout being used with order";
267
268 if (sg_layout && order.size() != sg_layout.size())
269 return emitError()
270 << "expected order and sg_layout to have the same rank";
271
272 if (lane_layout && order.size() != lane_layout.size())
273 return emitError()
274 << "expected order and lane_layout to have the same rank";
275 }
276
277 return success();
278}
279
280FailureOr<SmallVector<Value>>
281LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) {
282
283 SmallVector<int64_t> sgLayoutInt;
284 if (isForWorkgroup()) {
285 sgLayoutInt = getEffectiveSgLayoutAsInt();
286 } else if (isForSubgroup()) {
287 sgLayoutInt = getEffectiveLaneLayoutAsInt();
288 } else {
289 return failure();
290 }
291
292 DenseI32ArrayAttr orderAttr = getOrder();
293
294 // Handle order attribute
295 SmallVector<int64_t> order;
296 if (orderAttr && !orderAttr.empty()) {
297 order = llvm::to_vector(
298 llvm::map_range(orderAttr.asArrayRef(),
299 [](int32_t idx) { return static_cast<int64_t>(idx); }));
300 } else {
301 // Default order: [1, 0] for 2D (row-major), [2, 1, 0] for 3D, etc.
302 order = llvm::to_vector(
303 llvm::reverse(llvm::seq<int64_t>(0, sgLayoutInt.size())));
304 }
305
306 if (order.size() != sgLayoutInt.size()) {
307 return failure();
308 }
309
310 SmallVector<Value> result(sgLayoutInt.size());
311 Value remaining = linearId;
312
313 /// Process dimensions in the order they appear in the order array
314 /// The first dimension in order is the fastest-changing
315 ///
316 /// Example walkthrough for linearId=22, sgLayout=[2,4,4], order=[2,1,0]:
317 ///
318 /// Initial: remaining=22, dimIdx = order[i], dimSize = sgLayout[dimIdx],
319 /// result=[?,?,?]
320 ///
321 /// i=0 (process columns, dimIdx=2, dimSize=4):
322 /// result[2] = 22 % 4 = 2 (column coordinate)
323 /// remaining = 22 / 4 = 5 (5 complete groups of 4 columns processed)
324 ///
325 /// i=1 (process rows, dimIdx=1, dimSize=4):
326 /// result[1] = 5 % 4 = 1 (row coordinate)
327 /// remaining = 5 / 4 = 1 (1 complete group of 4 rows processed)
328 ///
329 /// i=2 (process layers, dimIdx=0, dimSize=2):
330 /// result[0] = 1 % 2 = 1 (layer coordinate)
331 /// (no remaining update - last iteration)
332 ///
333 /// Final result: [1,1,2] = Layer 1, Row 1, Column 2
334 for (size_t i = 0; i < order.size(); ++i) {
335 int64_t dimIdx = order[i];
336 int64_t dimSize = sgLayoutInt[dimIdx];
337
338 Value dimSizeVal =
339 builder.createOrFold<arith::ConstantIndexOp>(loc, dimSize);
340
341 /// Extract the coordinate for this dimension using modulo operation
342 /// This gives us "how far within this dimension" we are
343 /// e.g., linearId=22, dimSize=4: 22 % 4 = 2 (we're at position 2 within
344 /// this dimension)
345 result[dimIdx] =
346 builder.createOrFold<index::RemUOp>(loc, remaining, dimSizeVal);
347
348 /// Update remaining for the next dimension by removing what we've already
349 /// processed. Division tells us "how many complete groups of this dimension
350 /// we've gone through" e.g., linearId=22, dimSize=4: 22 / 4 = 5 (we've
351 /// completed 5 groups of 4) Skip this for the last iteration since there's
352 /// no next dimension to process
353 if (i < order.size() - 1) {
354 remaining =
355 builder.createOrFold<index::DivUOp>(loc, remaining, dimSizeVal);
356 }
357 }
358 return result;
359}
360
361/// Implements DistributeLayoutAttr::computeDistributedCoords to generate
362/// instructions for computing multi-dimensional offsets when distributed by
363/// LayoutAttr.
364FailureOr<SmallVector<SmallVector<Value>>>
365LayoutAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
366 Value linearId, ArrayRef<int64_t> shape) {
367 SmallVector<int64_t> layout;
368 SmallVector<int64_t> subShape;
369 if (isForWorkgroup()) {
370 layout = getEffectiveSgLayoutAsInt();
371 subShape = getEffectiveSgDataAsInt();
372 } else if (isForSubgroup()) {
373 layout = getEffectiveLaneLayoutAsInt();
374 subShape = getEffectiveLaneDataAsInt();
375 } else {
376 return failure();
377 }
378 if (subShape.empty()) {
379 if (auto derivedShape = computeShapeRatio(shape, layout))
380 subShape = derivedShape.value();
381 else
382 return failure();
383 }
384
385 // delinearize Ids
386 auto maybeIds = delinearizeId(builder, loc, linearId);
387 if (failed(maybeIds))
388 return failure();
389 SmallVector<Value> ids = *maybeIds;
390
391 return genCoordinates(builder, loc, ids, layout, subShape, shape);
392}
393
394//===----------------------------------------------------------------------===//
395// XeGPU_SliceAttr
396//===----------------------------------------------------------------------===//
397LogicalResult
398SliceAttr::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
399 xegpu::DistributeLayoutAttr parent, DenseI64ArrayAttr dims) {
400 if (!parent || !dims)
401 return emitError() << "expected parent layout and dims attribute";
402
403 int64_t rank = parent.getRank();
404
405 // check every element in dims is unique and smaller than rank
406 llvm::SmallDenseSet<int64_t> seen;
407 for (int64_t dim : dims.asArrayRef()) {
408 if (dim < 0 || dim >= rank)
409 return emitError() << "invalid dim (" << dim << ") in slice attribute.";
410 if (!seen.insert(dim).second)
411 return emitError() << "repeated dim (" << dim << ") in slice attribute.";
412 }
413 return success();
414}
415
416SliceAttr SliceAttr::flatten() const {
417 xegpu::DistributeLayoutAttr parent = getParent();
418 SmallVector<DenseI64ArrayAttr> slicedDims({getDims()});
419
420 while (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(parent)) {
421 parent = sliceAttr.getParent();
422 slicedDims.push_back(sliceAttr.getDims());
423 }
424
425 auto layoutAttr = dyn_cast<xegpu::LayoutAttr>(parent);
426 SmallVector<int64_t> indices =
427 llvm::to_vector(llvm::seq<int64_t>(0, layoutAttr.getRank()));
428
429 // get remaining dims (flattend) by applying slice ops with all slicedDims
430 SmallVector<int64_t> remainingDims(indices);
431 for (auto dim : llvm::reverse(slicedDims))
432 remainingDims = XeGPUDialect::slice(llvm::ArrayRef<int64_t>(remainingDims),
433 dim.asArrayRef());
434
435 // get flattend sliced dims by applying slice ops with the remaining dims
436 SmallVector<int64_t> flattendDims = XeGPUDialect::slice(
437 llvm::ArrayRef<int64_t>(indices), llvm::ArrayRef<int64_t>(remainingDims));
438
439 return xegpu::SliceAttr::get(
440 getContext(), layoutAttr,
441 DenseI64ArrayAttr::get(getContext(), flattendDims));
442}
443
444FailureOr<SmallVector<Value>>
445SliceAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) {
446 SliceAttr attr = flatten();
447 auto parent = dyn_cast<LayoutAttr>(attr.getParent());
448 return parent.delinearizeId(builder, loc, linearId);
449}
450
451// Implements DistributeLayoutAttr::computeDistributedCoords to generate
452// instructions for computing multi-dimensional offsets when distributed by
453// LayoutAttr.
454FailureOr<SmallVector<SmallVector<Value>>>
455SliceAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
456 Value linearId, ArrayRef<int64_t> shape) {
457 assert(getRank() == static_cast<int64_t>(shape.size()) && "invalid shape.");
458 if (!isForWorkgroup())
459 return failure();
460
461 SmallVector<int64_t> layout;
462 SmallVector<int64_t> subShape;
463 if (isForWorkgroup()) {
464 layout = getEffectiveSgLayoutAsInt();
465 subShape = getEffectiveSgDataAsInt();
466 } else if (isForSubgroup()) {
467 layout = getEffectiveLaneLayoutAsInt();
468 subShape = getEffectiveLaneDataAsInt();
469 } else {
470 return failure();
471 }
472
473 if (subShape.empty()) {
474 if (auto derivedShape = computeShapeRatio(shape, layout))
475 subShape = derivedShape.value();
476 else
477 return failure();
478 }
479
480 // delinearize Ids
481 auto maybeIds = delinearizeId(builder, loc, linearId);
482 if (failed(maybeIds))
483 return failure();
484
485 // The effective sgIds for offsets computing correspond
486 // to the dims that are not sliced.
487 ArrayRef<int64_t> dims = flatten().getDims().asArrayRef();
488 SmallVector<Value> sgIds =
489 XeGPUDialect::slice(ArrayRef<Value>(*maybeIds), dims);
490
491 return genCoordinates(builder, loc, sgIds, layout, subShape, shape);
492}
493
494bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) {
495 auto flattenedThis = flatten();
496 // If other is a LayoutAttr, just compare directly with parent of
497 // flattenedThis.
498 if (auto otherLayout = dyn_cast<xegpu::LayoutAttr>(other))
499 return flattenedThis.getParent() == otherLayout;
500 // If other is a SliceAttr, flatten it first before comparing.
501 auto flattenedOther = dyn_cast<xegpu::SliceAttr>(other).flatten();
502 // Both must have common parent LayoutAttr.
503 if (flattenedThis.getParent() != flattenedOther.getParent())
504 return false;
505 // otherFlattened's sliced dims must be a subset of flattenedThis's sliced
506 // dims.
507 llvm::SmallDenseSet<int64_t> thisDims(
508 flattenedThis.getDims().asArrayRef().begin(),
509 flattenedThis.getDims().asArrayRef().end());
510 return llvm::all_of(flattenedOther.getDims().asArrayRef(),
511 [&](int64_t dim) { return thisDims.contains(dim); });
512}
513
514//===----------------------------------------------------------------------===//
515// XeGPU_RangeAttr
516//===----------------------------------------------------------------------===//
517
518LogicalResult
519RangeAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
520 IntegerAttr startOfRange, IntegerAttr endOfRange) {
521 if (startOfRange.getInt() >= endOfRange.getInt())
522 return emitError() << "'end' : " << endOfRange.getInt()
523 << " must be greater than 'start' : "
524 << startOfRange.getInt();
525
526 return success();
527}
528
529//===----------------------------------------------------------------------===//
530// XeGPU_TensorDescType
531//===----------------------------------------------------------------------===//
532
533mlir::Type TensorDescType::parse(AsmParser &parser) {
534 llvm::SmallVector<int64_t> shape;
535 mlir::Type elementType;
536 mlir::FailureOr<mlir::Attribute> encoding;
537 mlir::FailureOr<mlir::Attribute> layout;
538
539 // Parse literal '<'
540 if (parser.parseLess())
541 return {};
542
543 auto shapeLoc = parser.getCurrentLocation();
544 if (mlir::failed(parser.parseDimensionList(shape))) {
545 parser.emitError(shapeLoc, "failed to parse parameter 'shape'");
546 return {};
547 }
548
549 auto elemTypeLoc = parser.getCurrentLocation();
550 if (mlir::failed(parser.parseType(elementType))) {
551 parser.emitError(elemTypeLoc, "failed to parse parameter 'elementType'");
552 return {};
553 }
554
555 // parse optional attributes
556 while (mlir::succeeded(parser.parseOptionalComma())) {
557 mlir::Attribute attr;
558 ParseResult res = parser.parseAttribute(attr);
559 if (mlir::succeeded(res)) {
560 if (mlir::isa<LayoutAttr>(attr)) {
561 layout = attr;
562 continue;
563 }
564 if (mlir::isa<BlockTensorDescAttr, ScatterTensorDescAttr>(attr)) {
565 encoding = attr;
566 continue;
567 }
568 }
569 return {};
570 }
571
572 // Parse literal '>'
573 if (parser.parseGreater())
574 return {};
575
576 MLIRContext *ctxt = parser.getContext();
577 return TensorDescType::getChecked(
578 [&]() { return parser.emitError(parser.getNameLoc()); }, ctxt, shape,
579 elementType, encoding.value_or(BlockTensorDescAttr::get(ctxt)),
580 layout.value_or(mlir::Attribute()));
581}
582
583void TensorDescType::print(AsmPrinter &printer) const {
584 printer << "<";
585
586 auto shape = getShape();
587 for (int64_t dim : shape) {
588 if (mlir::ShapedType::isDynamic(dim))
589 printer << '?';
590 else
591 printer << dim;
592 printer << 'x';
593 }
594
595 printer << getElementType();
596
597 auto encoding = getEncoding();
598 auto blockAttr = llvm::dyn_cast_if_present<BlockTensorDescAttr>(encoding);
599 if (encoding && (!blockAttr || !blockAttr.hasDefaultsOnly()))
600 printer << ", " << encoding;
601
602 if (auto layout = getLayout())
603 printer << ", " << layout;
604
605 printer << ">";
606}
607
608TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
609 mlir::Type elementType, int array_length,
610 bool boundary_check,
611 MemorySpace memory_space,
612 mlir::Attribute layout) {
613 auto context = elementType.getContext();
614 auto attr = BlockTensorDescAttr::get(context, memory_space, array_length,
615 boundary_check);
616 return Base::get(context, shape, elementType, attr, layout);
617}
618
619TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
620 mlir::Type elementType, int chunk_size,
621 MemorySpace memory_space,
622 mlir::Attribute layout) {
623 auto context = elementType.getContext();
624 auto attr = ScatterTensorDescAttr::get(context, memory_space, chunk_size);
625 return Base::get(context, shape, elementType, attr, layout);
626}
627
628LogicalResult
629TensorDescType::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
630 llvm::ArrayRef<int64_t> shape, mlir::Type elementType,
631 mlir::Attribute encoding, mlir::Attribute layout) {
632 size_t rank = shape.size();
633
634 if (rank == 0)
635 return emitError() << "expected non-zero rank tensor";
636
637 auto blockAttr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(encoding);
638 if (blockAttr) {
639 MemorySpaceAttr memorySpaceAttr = blockAttr.getMemorySpace();
640 if (rank > 1 && memorySpaceAttr &&
641 memorySpaceAttr.getValue() == MemorySpace::SLM)
642 return emitError() << "SLM is only supported for 1D block tensor";
643 }
644
645 // for gather and scatter ops, Low-precision types are packed in 32-bit units.
646 unsigned bitWidth = elementType.getIntOrFloatBitWidth();
647 int chunkAlignmentFactor =
650 : 1;
651 auto scatterAttr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(encoding);
652 if (scatterAttr) {
653 int64_t chunkSize = scatterAttr.getChunkSizeAsInt();
654 if (rank == 1 && chunkSize != 1)
655 return emitError() << "expected non-contiguous elements for 1D tensor";
656
657 // If chunk size > 1, the second dimension of the tensor shape must be
658 // equal to chunk size and it must be a multiple of the
659 // chunkAlignmentFactor.
660 if (chunkSize > 1) {
661 if (shape.back() != chunkSize)
662 return emitError() << "expected last dim of tensor to match chunk size";
663 if (shape.back() % chunkAlignmentFactor != 0)
664 return emitError() << "expected last dim of tensor to be a multiple of "
665 << chunkAlignmentFactor;
666 }
667 }
668
669 auto layoutAttr = llvm::dyn_cast_if_present<LayoutAttr>(layout);
670 if (layoutAttr) {
671 if (rank != (size_t)layoutAttr.getRank())
672 return emitError() << "expected layout rank to match tensor rank";
673
674 auto laneData = layoutAttr.getLaneData();
675 if (scatterAttr && laneData) {
676 // Validate subgroup mapping rules for scattered tensors.
677 // if chunkSize > 1, the last dimension of the tensor should
678 // be distributed in the units divisible by chunkAlignmentFactor.
679 int64_t chunkSize = scatterAttr.getChunkSizeAsInt();
680 if (chunkSize > 1 && laneData[rank - 1] % chunkAlignmentFactor)
681 return emitError()
682 << "expected last dim of lane_data to be a multiple of: "
683 << chunkAlignmentFactor;
684 }
685
686 if (!XeGPUDialect::isEvenlyDistributable(shape, layoutAttr)) {
687 std::string shapeStr;
688 llvm::raw_string_ostream stream(shapeStr);
689 llvm::interleaveComma(shape, stream);
690 return emitError() << "cannot distribute [" << shapeStr << "] using "
691 << layoutAttr;
692 }
693 }
694 return success();
695}
696
697//===----------------------------------------------------------------------===//
698// XeGPU_MemDescType
699//===----------------------------------------------------------------------===//
700mlir::Type MemDescType::parse(AsmParser &parser) {
701 llvm::SmallVector<int64_t> shape;
702 mlir::Type elementType;
703 mlir::FailureOr<MemLayoutAttr> layout;
704
705 // Parse literal '<'
706 if (parser.parseLess())
707 return {};
708
709 auto shapeLoc = parser.getCurrentLocation();
710 if (mlir::failed(parser.parseDimensionList(shape, false, true))) {
711 parser.emitError(shapeLoc, "failed to parse parameter 'shape'");
712 return {};
713 }
714
715 auto elemTypeLoc = parser.getCurrentLocation();
716 if (mlir::failed(parser.parseType(elementType))) {
717 parser.emitError(elemTypeLoc, "failed to parse parameter 'elementType'");
718 return {};
719 }
720
721 // parse optional attributes
722 if (mlir::succeeded(parser.parseOptionalComma())) {
723 MemLayoutAttr attr;
724 ParseResult res = parser.parseAttribute(attr);
725 if (mlir::failed(res))
726 return {};
727 layout = attr;
728 }
729
730 // Parse literal '>'
731 if (parser.parseGreater())
732 return {};
733
734 MLIRContext *ctxt = parser.getContext();
735 return MemDescType::getChecked(
736 [&]() { return parser.emitError(parser.getNameLoc()); }, ctxt, shape,
737 elementType, layout.value_or(MemLayoutAttr()));
738}
739
740void MemDescType::print(AsmPrinter &printer) const {
741 printer << "<";
742
743 printer.printDimensionList(getShape());
744 printer << 'x';
745 printer << getElementType();
746
747 if (auto layout = getMemLayout())
748 printer << ", " << layout;
749
750 printer << ">";
751}
752
753//===----------------------------------------------------------------------===//
754// XeGPU_MemDescType
755//===----------------------------------------------------------------------===//
756
757Attribute MemLayoutAttr::parse(AsmParser &parser, Type type) {
758
759 auto context = parser.getContext();
760 llvm::SMLoc loc = parser.getCurrentLocation();
761
762 llvm::SmallDenseSet<StringRef> seenKeys;
763 SmallVector<NamedAttribute> attributes;
764
765 auto parseElt = [&]() -> ParseResult {
766 StringRef nameId;
767 if (failed(parser.parseKeyword(&nameId)))
768 return parser.emitError(loc, "expected valid attribute name");
769
770 if (!seenKeys.insert(nameId).second)
771 return parser.emitError(loc, "duplicate key '")
772 << nameId << " in mem layout attribute";
773
774 if (failed(parser.parseEqual()))
775 return failure();
776
777 Attribute attr;
778 if (failed(parser.parseAttribute(attr)))
779 return failure();
780 attributes.emplace_back(nameId, attr);
781 return success();
782 };
783
784 // Parse literal '<'
785 if (parser.parseLess())
786 return {};
787
788 if (failed(parser.parseCommaSeparatedList(parseElt)))
789 return {};
790
791 // Parse literal '>'
792 if (parser.parseGreater())
793 return {};
794
795 return parser.getChecked<MemLayoutAttr>(
796 loc, context, DictionaryAttr::get(context, attributes));
797}
798
799void MemLayoutAttr::print(AsmPrinter &printer) const {
800 printer << "<";
801 ArrayRef<NamedAttribute> attrs = getAttrs().getValue();
802 for (size_t i = 0; i < attrs.size(); i++) {
803 printer << attrs[i].getName().str() << " = " << attrs[i].getValue();
804 if (i < attrs.size() - 1)
805 printer << ", ";
806 }
807 printer << ">";
808}
809// a helper utility to perform binary operation on OpFoldResult.
810// If both a and b are attributes, it will simply return the result.
811// Otherwise, the corresponding arith op will be generated, and an
812// contant op will be created if one of them is an attribute.
813template <typename ArithOp>
815 OpBuilder &builder) {
816 auto aVal = getValueOrCreateConstantIndexOp(builder, loc, a);
817 auto bVal = getValueOrCreateConstantIndexOp(builder, loc, b);
818 return ArithOp::create(builder, loc, aVal, bVal).getResult();
819}
820
821// a helper utility to perform division operation on OpFoldResult and int64_t.
822#define div(a, b) \
823 genBinOp<arith::DivSIOp>(a, builder.getIndexAttr(b), loc, builder)
824
825// a helper utility to perform reminder operation on OpFoldResult and int64_t.
826#define rem(a, b) \
827 genBinOp<arith::RemSIOp>(a, builder.getIndexAttr(b), loc, builder)
828
829// a helper utility to perform multiply operation on OpFoldResult and int64_t.
830#define mul(a, b) \
831 genBinOp<arith::MulIOp>(a, builder.getIndexAttr(b), loc, builder)
832
833// a helper utility to perform addition operation on two OpFoldResult.
834#define add(a, b) genBinOp<arith::AddIOp>(a, b, loc, builder)
835
836// block the given offsets according to the block shape
837// say the original offset is [y, x], and the block shape is [By, Bx],
838// then the blocked offset is [y/By, x/Bx, y%By, x%Bx]
841 ArrayRef<int64_t> blockShape) {
842
843 assert(offsets.size() == blockShape.size() &&
844 "offsets and blockShape must have the same size");
845 SmallVector<OpFoldResult> blockedOffsets;
846 SmallVector<OpFoldResult> divs, rems;
847
848 for (auto [offset, block] : llvm::zip(offsets, blockShape)) {
849 divs.push_back(div(offset, block));
850 rems.push_back(rem(offset, block));
851 }
852 blockedOffsets.append(divs.begin(), divs.end());
853 blockedOffsets.append(rems.begin(), rems.end());
854
855 return blockedOffsets;
856}
857
858// Get strides as vector of integer for MemDesc.
859SmallVector<int64_t> MemDescType::getStrideShape() {
860
861 SmallVector<int64_t> matrixShape(getShape().begin(), getShape().end());
862
863 ArrayAttr strideAttr = getStrideAttr();
864 SmallVector<int64_t> strides;
865 for (Attribute attr : strideAttr.getValue()) {
866 strides.push_back(cast<IntegerAttr>(attr).getInt());
867 }
868
869 SmallVector<int64_t> innerBlkShape = getBlockShape();
870
871 // get perm from FCD to LCD
872 // perm[i] = the dim with i-th smallest stride
874 llvm::to_vector<4>(llvm::seq<int>(0, strides.size()));
875 llvm::sort(perm, [&](int a, int b) { return strides[a] < strides[b]; });
876
877 assert(strides[perm[0]] == 1 && "inner most dim must have stride 1");
878
879 SmallVector<int64_t> innerBlkStride(innerBlkShape.size());
880 innerBlkStride[perm[0]] = 1;
881 for (size_t i = 1; i < perm.size(); ++i)
882 innerBlkStride[perm[i]] =
883 innerBlkStride[perm[i - 1]] * innerBlkShape[perm[i - 1]];
884
885 // compute the original matrix shape using the stride info
886 // and compute the number of blocks in each dimension
887 // The shape of highest dim can't be derived from stride info,
888 // but doesn't impact the stride computation for blocked layout.
889 SmallVector<int64_t> matrixShapeOrig(matrixShape.size());
890 SmallVector<int64_t> BlkShapeOrig(matrixShape.size());
891 for (size_t i = 0; i < perm.size() - 1; ++i) {
892 matrixShapeOrig[perm[i]] = strides[perm[i + 1]] / strides[perm[i]];
893 BlkShapeOrig[perm[i]] = matrixShapeOrig[perm[i]] / innerBlkShape[perm[i]];
894 }
895
896 int64_t innerBlkSize = 1;
897 for (auto s : innerBlkShape)
898 innerBlkSize *= s;
899
900 SmallVector<int64_t> outerBlkStride(matrixShape.size());
901 outerBlkStride[perm[0]] = innerBlkSize;
902 for (size_t i = 0; i < perm.size() - 1; ++i) {
903 outerBlkStride[perm[i + 1]] =
904 outerBlkStride[perm[i]] * BlkShapeOrig[perm[i]];
905 }
906
907 // combine the inner and outer strides
908 SmallVector<int64_t> blockedStrides;
909 blockedStrides.append(outerBlkStride.begin(), outerBlkStride.end());
910 blockedStrides.append(innerBlkStride.begin(), innerBlkStride.end());
911
912 return blockedStrides;
913}
914
915// Calculate the linear offset using the blocked offsets and stride
916Value MemDescType::getLinearOffsets(OpBuilder &builder, Location loc,
917 ArrayRef<OpFoldResult> offsets) {
918
919 SmallVector<int64_t> matrixShape(getShape().begin(), getShape().end());
920 SmallVector<int64_t> blockShape = getBlockShape();
921 SmallVector<int64_t> strides = getStrideShape();
922 SmallVector<OpFoldResult> blockedOffsets;
923
924 // blockshape equal to matrixshape means no blocking
925 if (llvm::equal(blockShape, matrixShape)) {
926 // remove the outer dims from strides
927 strides.erase(strides.begin(), strides.begin() + matrixShape.size());
928 } else {
929 assert(offsets.size() == blockShape.size() &&
930 "offsets and blockShape must have the same size");
931 // say the original offset is [y, x], and the block shape is [By, Bx],
932 // then the blocked offset is [y/By, x/Bx, y%By, x%Bx]
933
934 SmallVector<OpFoldResult> divs, rems;
935
936 for (auto [offset, block] : llvm::zip(offsets, blockShape)) {
937 divs.push_back(div(offset, block));
938 rems.push_back(rem(offset, block));
939 }
940 blockedOffsets.append(divs.begin(), divs.end());
941 blockedOffsets.append(rems.begin(), rems.end());
942 offsets = blockedOffsets;
943 }
944
945 // Start with initial value as matrix descriptor's base offset.
946 Value linearOffset = arith::ConstantIndexOp::create(builder, loc, 0);
947 for (size_t i = 0; i < offsets.size(); ++i) {
948 OpFoldResult mulResult = mul(offsets[i], strides[i]);
949 Value mulVal = getValueOrCreateConstantIndexOp(builder, loc, mulResult);
950 linearOffset = arith::AddIOp::create(builder, loc, mulVal, linearOffset);
951 }
952
953 return linearOffset;
954}
955
956} // namespace xegpu
957} // namespace mlir
958
959#include <mlir/Dialect/XeGPU/IR/XeGPUDialect.cpp.inc>
960#define GET_ATTRDEF_CLASSES
961#include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc>
962#define GET_TYPEDEF_CLASSES
963#include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc>
return success()
static Type getElementType(Type type)
Determine the element type of type.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
#define mul(a, b)
#define div(a, b)
#define rem(a, b)
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
auto getChecked(SMLoc loc, ParamsT &&...params)
Invoke the getChecked method of the given Attribute or Type class, using the provided location to emi...
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
void printDimensionList(ArrayRef< int64_t > shape)
Attributes are known-constant values of operations.
Definition Attributes.h:25
static BoolAttr get(MLIRContext *context, bool value)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:207
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:526
This class represents a single result from folding an operation.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Specialization of arith.constant op that returns an integer of index type.
Definition Arith.h:113
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int64_t > content)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
auto getDims(VectorType vType)
Returns a range over the dims (size and scalability) of a VectorType.
constexpr unsigned generalPackedFormatBitSize
Definition uArchBase.h:32
SmallVector< OpFoldResult > getBlockedOffsets(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > offsets, ArrayRef< int64_t > blockShape)
OpFoldResult genBinOp(OpFoldResult a, OpFoldResult b, Location loc, OpBuilder &builder)
static SmallVector< SmallVector< Value > > genCoordinates(OpBuilder &builder, Location loc, SmallVector< Value > delinearizedId, ArrayRef< int64_t > subShapesLayout, ArrayRef< int64_t > subShape, ArrayRef< int64_t > srcShape)
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
SmallVector< int64_t > computeElementwiseMul(ArrayRef< int64_t > v1, ArrayRef< int64_t > v2)
Return a vector containing llvm::zip_equal(v1, v2) multiplied elementwise.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.