MLIR 23.0.0git
XeGPULayoutImpl.cpp
Go to the documentation of this file.
1//===---- XeGPULayoutImpl.cpp - MLIR Utilities for XeGPUOps
2//------------------===//
3//
4// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
5// See https://llvm.org/LICENSE.txt for license information.
6// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7//
8//===----------------------------------------------------------------------===//
9//
10// This file implements layout utility functions for XeGPU dialect
11// transformation.
12//
13//===----------------------------------------------------------------------===//
14
22#include "mlir/IR/Builders.h"
23#include "mlir/IR/Operation.h"
24#include "mlir/IR/ValueRange.h"
27#include "llvm/Support/FormatVariadic.h"
28#include <cstdint>
29#include <numeric>
30
31using namespace mlir;
32
34 op->walk([&](Operation *nestOp) {
35 for (OpOperand &opr : nestOp->getOpOperands()) {
36 auto layout = getDistributeLayoutAttr(opr.get());
37 setDistributeLayoutAttr(opr, layout);
38 }
39
40 for (OpResult result : nestOp->getOpResults()) {
41 auto layout = getDistributeLayoutAttr(result);
43 }
44 });
45}
46
50 out.reserve(attrs.size());
51
52 for (auto attr : attrs) {
53 if (auto dist = dyn_cast<xegpu::DistributeLayoutAttr>(attr.getValue())) {
54 auto newLayout = dist.dropSgLayoutAndData();
55 if (newLayout)
56 out.emplace_back(attr.getName(), newLayout);
57 } else {
58 out.push_back(attr);
59 }
60 }
61
62 return out;
63}
64
68 out.reserve(attrs.size());
69
70 for (auto attr : attrs) {
71 if (auto dist = dyn_cast<xegpu::DistributeLayoutAttr>(attr.getValue())) {
72 auto newLayout = dist.dropInstData();
73 if (newLayout)
74 out.emplace_back(attr.getName(), newLayout);
75 } else {
76 out.push_back(attr);
77 }
78 }
79
80 return out;
81}
82
83// Attach layout attributes to all vector-type operands of operations within
84// the given operation's region. Reports an error if any vector operand lacks
85// a layout attribute.
87 auto result = rootOp->walk([&](Operation *op) {
88 for (OpOperand &operand : op->getOpOperands()) {
89 // Layouts are needed for vector type only.
90 if (!isa<VectorType>(operand.get().getType()))
91 continue;
92 // Skip block arguments since they don't have defining ops to attach
93 // layout attributes to.
94 if (isa<BlockArgument>(operand.get()))
95 continue;
96 auto layout = xegpu::getDistributeLayoutAttr(operand.get());
97 if (!layout) {
98 op->emitWarning("Could not find layout attribute for operand ")
99 << operand.getOperandNumber() << " of operation " << op->getName();
100 continue;
101 }
102 xegpu::setTemporaryLayout(operand, layout);
103 }
104 return WalkResult::advance();
105 });
106 return !result.wasInterrupted();
107}
108
109template <typename T, typename>
110void xegpu::removeLayoutAttr(const T &operandOrResult) {
111 Operation *owner = operandOrResult.getOwner();
112 std::string name = xegpu::getTemporaryLayoutName(operandOrResult);
113 if (owner->hasAttrOfType<DistributeLayoutAttr>(name))
114 owner->removeAttr(name);
115}
116
117// Explicit instantiation for OpResult
118template void
120
121// Explicit instantiation for OpOperand
122template void
124
126 op->walk([&](Operation *nestOp) {
127 // Remove all attributes of DistributeLayoutAttr type
128 SmallVector<StringAttr> attrsToRemove;
129 for (auto namedAttr : nestOp->getAttrs()) {
130 if (isa<DistributeLayoutAttr>(namedAttr.getValue()))
131 attrsToRemove.push_back(namedAttr.getName());
132 }
133 for (auto attrName : attrsToRemove)
134 nestOp->removeAttr(attrName);
135 });
136}
137
138/// Infers the source layout attribute for a broadcast operation given the
139/// result layout attribute, result shape, source shape.
140xegpu::DistributeLayoutAttr
141xegpu::inferBroadcastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
142 ArrayRef<int64_t> resShape,
143 ArrayRef<int64_t> srcShape) {
144
145 SmallVector<int64_t> bcastDims;
146 size_t dimDiff = resShape.size() - srcShape.size();
147 auto bcastSourceLayout = resLayout;
148 for (size_t i = dimDiff; i < resShape.size(); i++) {
149 if ((srcShape[i - dimDiff] == 1) && (resShape[i] != 1))
150 bcastDims.push_back(i);
151 }
152
153 // the sg_layout and lane_layout for unit dimensions are preserved so it can
154 // be propagate to producer op so potentially used by the multi-reduction op.
155 if (!bcastDims.empty())
156 bcastSourceLayout = bcastSourceLayout.setUnitDimData(bcastDims);
157
158 if (dimDiff > 0) {
159 SmallVector<int64_t> sliceDims;
160 for (size_t i = 0; i < dimDiff; i++)
161 sliceDims.push_back(i);
162 bcastSourceLayout = xegpu::SliceAttr::get(
163 resLayout.getContext(), bcastSourceLayout,
164 DenseI64ArrayAttr::get(resLayout.getContext(), sliceDims));
165 }
166 return bcastSourceLayout;
167}
168
169/// Infers the source layout attribute for a reduction operation given the
170/// result layout attribute and reduced dims.
171xegpu::DistributeLayoutAttr
172xegpu::inferMultiReductionSourceLayout(xegpu::DistributeLayoutAttr resLayout,
173 SmallVector<int64_t> reduceDims) {
174
175 assert(isa<xegpu::SliceAttr>(resLayout) &&
176 "reduction result layout must be slice layout");
177
178 xegpu::SliceAttr sliceLayout = dyn_cast<xegpu::SliceAttr>(resLayout);
179
180 assert((reduceDims == sliceLayout.getDims().asArrayRef()) &&
181 "reduction dims must match with slice dims");
182
183 return sliceLayout.getParent();
184}
185
186xegpu::DistributeLayoutAttr
187xegpu::inferReductionSourceLayout(xegpu::DistributeLayoutAttr resLayout) {
188 return xegpu::inferMultiReductionSourceLayout(resLayout, {0});
189}
190
191/// Infers the source layout attribute for a transpose operation given the
192/// result layout attribute and permutation.
193xegpu::DistributeLayoutAttr
194xegpu::inferTransposeSourceLayout(xegpu::DistributeLayoutAttr resLayout,
195 ArrayRef<int64_t> permutation) {
196 return resLayout.transposeDims(permutation);
197}
198
199/// Infers the source layout attribute for a bitcast operation given the
200/// result layout attribute, result element type bitwidth, and source element
201/// type bitwidth.
202xegpu::DistributeLayoutAttr
203xegpu::inferBitCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
204 int resElemTyBitWidth, int srcElemTyBitWidth) {
205
206 SmallVector<int64_t> sgData = resLayout.getEffectiveSgDataAsInt();
207 SmallVector<int64_t> instData = resLayout.getEffectiveInstDataAsInt();
208 SmallVector<int64_t> laneData = resLayout.getEffectiveLaneDataAsInt();
209 size_t sgDataSize = sgData.size();
210 size_t instDataSize = instData.size();
211 size_t laneDataSize = laneData.size();
212 int64_t sgDataValue = -1;
213 int64_t instDataValue = -1;
214 int64_t laneDataValue = -1;
215 int64_t dim = resLayout.getRank() - 1;
216
217 if (srcElemTyBitWidth <= resElemTyBitWidth) {
218 int bitWidthRatio = resElemTyBitWidth / srcElemTyBitWidth;
219 if (sgDataSize)
220 sgDataValue = sgData.back() * bitWidthRatio;
221 if (instDataSize)
222 instDataValue = instData.back() * bitWidthRatio;
223 if (laneDataSize)
224 laneDataValue = laneData.back() * bitWidthRatio;
225 } else {
226 int bitWidthRatio = srcElemTyBitWidth / resElemTyBitWidth;
227 if (sgDataSize) {
228 assert((sgData.back() % bitWidthRatio) == 0 &&
229 "sgData not divisible by bitWidthRatio");
230 sgDataValue = sgData.back() / bitWidthRatio;
231 }
232 if (instDataSize) {
233 assert((instData.back() % bitWidthRatio) == 0 &&
234 "instData not divisible by bitWidthRatio");
235 instDataValue = instData.back() / bitWidthRatio;
236 }
237 if (laneDataSize) {
238 assert((laneData.back() % bitWidthRatio) == 0 &&
239 "laneData not divisible by bitWidthRatio");
240 laneDataValue = laneData.back() / bitWidthRatio;
241 }
242 }
243
244 xegpu::DistributeLayoutAttr finalSrcLayout;
245 finalSrcLayout =
246 resLayout.setDimData(dim, sgDataValue, instDataValue, laneDataValue);
247
248 return finalSrcLayout;
249}
250
251/// Infers the source layout attribute for an insert strided slice operation
252/// given the result layout attribute, result shape, and source shape. Removes
253/// leading dimensions from the result layout to match the source shape size.
254xegpu::DistributeLayoutAttr xegpu::inferInsertStridedSliceSourceLayout(
255 xegpu::DistributeLayoutAttr resLayout, ArrayRef<int64_t> resShape,
256 ArrayRef<int64_t> srcShape) {
257
258 int srcShapeSize = srcShape.size();
259 int resShapeSize = resShape.size();
260 int dimDiff = resShapeSize - srcShapeSize;
261
262 if (dimDiff > 0) {
263 // assert that the leading dimensions being sliced off are not distributed
264 // (i.e. sg_layout and lane_layout for those dimensions are all 1)
265 auto resSgLayout = resLayout.getEffectiveSgLayoutAsInt();
266 auto resLaneLayout = resLayout.getEffectiveLaneLayoutAsInt();
267 for (int i = 0; i < dimDiff; i++) {
268 assert((resSgLayout.size() == 0 || resSgLayout[i] == 1) &&
269 (resLaneLayout.size() == 0 || resLaneLayout[i] == 1) &&
270 "Leading dimensions being sliced off must not be distributed");
271 }
272 return resLayout.dropDims(llvm::to_vector(llvm::seq<int64_t>(0, dimDiff)));
273 }
274 return resLayout;
275}
276
277/// Infers the source layout attribute for a shape cast operation given the
278/// result layout attribute, result shape, and source shape.
279xegpu::DistributeLayoutAttr
280xegpu::inferShapeCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
281 ArrayRef<int64_t> resShape,
282 ArrayRef<int64_t> srcShape) {
283
284 // There are three use cases:
285 // 1. expand dims of low-rank dimensions (e.g., 1D to 2D): to set up the
286 // tensor before broadcast
287 // 2. split dim of a high-rank dimension (e.g., 1D to 2D): to setup tensor
288 // for multi-stage reduction
289 // 3. combines all dims to a single dim and put in the innermost dim in 2d as
290 // [1, combinedData] or [combinedData]. Say, [2, 4, 8] -> [1, 64] or [64]
291 // Use cases are only supported after workgroup distribution,
292 // like cross-sg reduction saves multidimension data to
293 // 1D slm buffer, shapecast inserted by cse/canonicalization passes.
294
295 // Use case 1: Shapes only differ by expanding unit dimensions, for broadcast
296 SmallVector<int64_t> expandedUnitDims;
297
298 if (xegpu::matchUnitDimExpansion(srcShape, resShape, expandedUnitDims)) {
299 // create a slice layout for the source by removing the expanded unit dims
300 auto sliceDimsAttr = DenseI64ArrayAttr::get(
301 resLayout.getContext(), ArrayRef<int64_t>(expandedUnitDims));
302 auto srcLayout =
303 xegpu::SliceAttr::get(resLayout.getContext(), resLayout, sliceDimsAttr);
304 return srcLayout;
305 }
306
307 // Use case 2: Dim split from source to result, for multi-stage reduction
308 SmallVector<SmallVector<int64_t>> splitDimGroups;
309 if (xegpu::matchSplitDimExpansion(srcShape, resShape, splitDimGroups)) {
310 auto srcLayout = resLayout;
311 for (const auto &dimGroup : splitDimGroups)
312 srcLayout = srcLayout.collapseDims(dimGroup);
313
314 return srcLayout;
315 }
316
317 // Use case 3: Collaspse to innermost dim, for cross-sg reduction to SLM
318 auto matchCollapseToInnermostDim = [&](ArrayRef<int64_t> src,
319 ArrayRef<int64_t> dst) -> bool {
320 // only one non-unit dim in dst which is the innermost dim
321 if ((dst.size() != 2) && (dst.size() != 1))
322 return false;
323 int64_t srcSize = std::accumulate(src.begin(), src.end(), 1LL,
324 std::multiplies<int64_t>());
325 if (dst.size() == 1)
326 return (dst[0] == srcSize);
327 return (dst[0] == 1) && (dst[1] == srcSize);
328 };
329
330 if (matchCollapseToInnermostDim(srcShape, resShape)) {
331 int srcShapeSize = srcShape.size();
332 int resShapeSize = resShape.size();
333 auto context = resLayout.getContext();
334 auto resInstData = resLayout.getEffectiveInstDataAsInt();
335 auto resLaneLayout = resLayout.getEffectiveLaneLayoutAsInt();
336 auto resLaneData = resLayout.getEffectiveLaneDataAsInt();
337
338 // Extract layout info from result's innermost dimension and apply to
339 // source's innermost dimension while setting all other dimensions to 1.
340 // The inferred layout is restricted by srcShape to ensure it fits within
341 // the source dimensions.
342 // Examples 1:
343 // srcShape=[8, 16, 32], resShape=[1, 4096]
344 // resInstData=[1, 16]
345 // -> inferredInstData=[1, 1, min(16, 32)]=[1, 1, 16]
346 // Examples 2:
347 // srcShape=[4, 8, 64], resShape=[2048]
348 // resLaneLayout=[16], resLaneData=[2]
349 // -> inferredLaneLayout=[1, 1, 16]
350 // -> inferredLaneData=[1, 1, min(2, 64/16)]=[1, 1, 2]
351
352 if (resInstData.size() != 0) {
353 // assert resInstData must be 1 for all but the innermost dim
354 for (int i = 0; i < resShapeSize - 1; i++) {
355 assert(resInstData[i] == 1 &&
356 "only innermost dim can have non-unit instData");
357 }
358 SmallVector<int> inferredInstData(srcShapeSize, 1);
359 inferredInstData[srcShapeSize - 1] =
360 std::min(resInstData[resShapeSize - 1], srcShape[srcShapeSize - 1]);
361 return xegpu::LayoutAttr::get(context, inferredInstData);
362 }
363
364 if (resLaneLayout.size() != 0) {
365 for (int i = 0; i < resShapeSize - 1; i++) {
366 assert(resLaneData[i] == 1 &&
367 "only innermost dim can have non-unit instData");
368 }
369 assert(srcShape.back() % resLaneLayout.back() == 0 &&
370 "source innermost dim must be >= result lane layout");
371 SmallVector<int> inferredLaneLayout(srcShapeSize, 1);
372 SmallVector<int> inferredLaneData(srcShapeSize, 1);
373 inferredLaneLayout.back() = resLaneLayout.back();
374 inferredLaneData.back() = std::min(
375 resLaneData.back(), srcShape.back() / inferredLaneLayout.back());
376 return xegpu::LayoutAttr::get(context, inferredLaneLayout,
377 inferredLaneData);
378 }
379 }
380 llvm_unreachable("running into unsupported shape cast scenarios");
381 return nullptr;
382}
383
384/// Infers the layout attribute for mask and offset operand for Chunked load
385/// and store, given the anchor layout attribute for the value being load/store.
386xegpu::DistributeLayoutAttr xegpu::inferMaskOffsetLayoutForScatterIO(
387 xegpu::DistributeLayoutAttr payloadLayout, int chunkSize) {
388 auto rank = payloadLayout.getRank();
389 if (chunkSize > 1)
390 return payloadLayout.dropDims(
391 llvm::to_vector(llvm::seq<int64_t>(rank - 1, rank)));
392 return payloadLayout;
393}
394
395/// Sets up layout for reduction operations by creating a SliceAttr for the
396/// result.
397///
398/// Algorithm Overview:
399/// This function attempts to construct a source layout that, when sliced along
400/// reduction dimensions, produces a result layout compatible with the
401/// consumer layout.
402///
403/// For subgroup layouts, it first tries to align the source layout's subgroup
404/// layout and data with the consumer's layout on non-reduction dimensions.
405/// Then, it distributes remaining subgroups across reduction dimensions. This
406/// avoids subgroup data redistribution overhead between the reduced result and
407/// its consumer. When the consumer layout is a slice layout, it attempts to
408/// reuse the slice layout's parent layout for the source to further minimize
409/// potential data redistribution.
410///
411/// InstData requries {1, ..., min(maxReduceVectorSize, srcShape),subgroupSize}
412/// Lane Layout requires {1, ..., 1, subgroupSize}
413/// Lane data requires {1, ..., min(maxReduceVectorSize, srcShape), 1}
414///
415/// Examples:
416/// 1. Subgroup layout - Row reduction on 2D tensor:
417/// srcShape=[32, 128], reductionDims=[1], resShape=[32], subgroupSize=16,
418/// NumSg=32
419/// * Consumer Layout:
420/// #xegpu.slice<#xegpu.layout<sg_layout=[4, 8], sg_data=[8, 8]>, dims =
421/// [1]>}
422//// * Result Layout:
423/// #xegpu.slice<#xegpu.layout<sg_layout=[4, 8],sg_data=[8, 16]>, dims =
424/// [1]>}
425/// Note that the sg_layout is reused but sg_data needs to be adjusted to
426/// evenly distribute the source tensor tile among the reduction dim.
427///
428/// 2. Subgroup layout - Same example above but consumer doesn't have a
429/// reusable slice layout.
430/// * Consumer Layout:
431/// #xegpu.layout<sgLayout=[32], sgData=[1]>
432/// * Result Layout:
433/// #xegpu.slice<#xegpu.layout<sgLayout=[32,1], sgData=[1, 64]>, dims =
434/// [1]>}
435/// * Consumer Layout:
436/// #xegpu.slice<#xegpu.layout<sgLayout=[8, 2, 4], sgData=[4, 64, 32]>,
437/// dims = [1, 2]>}
438/// * Result Layout:
439/// #xegpu.slice<#xegpu.layout<sgLayout=[8,4], sgData=[4, 32]>, dims =
440/// [1]>}
441/// Note that the consumer's layout can't be directly reused as is.
442/// So the algorithm distributes all subgroups on non reduction dimensions
443/// first and then distribute remaining subgroups on the reduction
444/// dimension.
445///
446/// 2. InstData layout - Column reduction:
447/// srcShape=[32, 64], reductionDims=[0], subgroupSize=16
448/// Result: instData=[1, 16] (maxReduceVectorSize=1, subgroupSize on
449/// innermost)
450///
451/// 3. Lane layout - Multi-dimensional reduction:
452/// srcShape=[16, 32, 64], reductionDims=[1], subgroupSize=16
453/// Result: laneLayout=[1, 1, 16], laneData=[1, 1, 1]
454/// (subgroupSize on innermost dim, max vector size on reduction dim)
455
457 xegpu::LayoutKind layoutKind, VectorType srcVecTy,
458 DistributeLayoutAttr consumerLayout, SmallVector<int64_t> reductionDims,
459 int numSg, const xegpu::uArch::uArch *uArch) {
460
461 auto srcShape = srcVecTy.getShape();
462 int srcRank = srcShape.size();
463 auto context = srcVecTy.getContext();
464
465 // Helper lambda to convert int64 vectors to int32 DenseArrayAttr
466 auto toInt32Attr = [&](ArrayRef<int64_t> vec) {
467 SmallVector<int32_t> vec32(vec.begin(), vec.end());
468 return DenseI32ArrayAttr::get(context, vec32);
469 };
470
471 const int subgroupSize = uArch->getSubgroupSize();
472 int64_t maxReduceVectorSize = 1; // could extend to spirv vector Size
473 xegpu::DistributeLayoutAttr srcLayout;
474 if (layoutKind == xegpu::LayoutKind::Subgroup) {
475 xegpu::SliceAttr consumerSliceLayout =
476 dyn_cast_if_present<xegpu::SliceAttr>(consumerLayout);
477 if (consumerSliceLayout &&
478 consumerSliceLayout.getDims().asArrayRef().equals(reductionDims)) {
479 srcLayout = consumerSliceLayout.getParent();
480 SmallVector<int64_t> sgLayoutFromConsumer =
481 srcLayout.getEffectiveSgLayoutAsInt();
482 auto srcSgData = computeShapeRatio(srcShape, sgLayoutFromConsumer);
483 if (srcSgData)
484 for (int dim = 0; dim < srcRank; dim++) {
485 if (llvm::is_contained(reductionDims, dim))
486 srcLayout =
487 srcLayout.setDimData(dim, srcSgData.value()[dim], -1, -1);
488 }
489 } else {
490 SmallVector<int64_t> consumerSgLayout =
491 consumerLayout ? consumerLayout.getEffectiveSgLayoutAsInt()
493 SmallVector<int64_t> consumerSgData =
494 consumerLayout ? consumerLayout.getEffectiveSgDataAsInt()
496 SmallVector<int64_t> consumerOrder =
497 consumerLayout ? consumerLayout.getEffectiveOrderAsInt()
499 DenseI32ArrayAttr orderAttr =
500 consumerLayout ? consumerLayout.getOrder() : nullptr;
501 SmallVector<int64_t> sgLayout(srcRank), sgData(srcRank), order(srcRank);
502 int remainingSgCount =
503 consumerLayout ? consumerLayout.getNumSubgroups() : numSg;
504 int consumerIdx = 0;
505
506 // First pass: Match consumer's layout on non-reduction dimensions
507 for (int i = 0; i < srcRank; i++) {
508 if (!llvm::is_contained(reductionDims, i) &&
509 consumerIdx < static_cast<int>(consumerSgLayout.size())) {
510 sgLayout[i] = consumerSgLayout[consumerIdx];
511 sgData[i] = consumerSgData[consumerIdx];
512 remainingSgCount /= sgLayout[i];
513 order[i] = consumerOrder[consumerIdx];
514 consumerIdx++;
515 }
516 }
517
518 // Second pass: Distribute remaining subgroups across reduction dimensions
519 // the reduction to scalar case is handled only by this loop
520 int64_t remainOrder = consumerSgLayout.size();
521 for (int i = 0; i < srcRank; i++) {
522 if (llvm::is_contained(reductionDims, i)) {
523 sgLayout[i] =
524 std::min(srcShape[i], static_cast<int64_t>(remainingSgCount));
525 assert((srcShape[i] % sgLayout[i] == 0) &&
526 "source shape not divisible by sg_layout");
527 sgData[i] = srcShape[i] / sgLayout[i];
528 remainingSgCount /= sgLayout[i];
529 order[i] = remainOrder++;
530 }
531 }
532
533 assert(remainingSgCount == 1 && "not all subgroups distributed");
534 srcLayout = xegpu::LayoutAttr::get(
535 context, toInt32Attr(sgLayout), toInt32Attr(sgData),
536 /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
537 /*lane_data =*/nullptr, /*order =*/
538 (!orderAttr || orderAttr.empty()) ? nullptr : toInt32Attr(order));
539 }
540 } else if (layoutKind == xegpu::LayoutKind::InstData) {
541
542 SmallVector<int64_t> instData(srcRank, 1);
543 if (srcRank >= 2)
544 instData[srcRank - 2] =
545 std::min(maxReduceVectorSize, srcShape[srcRank - 2]);
546 instData[srcRank - 1] =
547 std::min(static_cast<int64_t>(subgroupSize), srcShape[srcRank - 1]);
548 srcLayout = xegpu::LayoutAttr::get(context, toInt32Attr(instData));
549 } else if (layoutKind == xegpu::LayoutKind::Lane) {
550
551 SmallVector<int64_t> laneLayout(srcRank, 1), laneData(srcRank, 1);
552 laneLayout[srcRank - 1] =
553 std::min(static_cast<int64_t>(subgroupSize), srcShape[srcRank - 1]);
554 if (srcRank >= 2)
555 laneData[srcRank - 2] =
556 std::min(maxReduceVectorSize, srcShape[srcRank - 2]);
557 srcLayout = xegpu::LayoutAttr::get(context, toInt32Attr(laneLayout),
558 toInt32Attr(laneData));
559 }
560
561 return xegpu::SliceAttr::get(context, srcLayout,
562 DenseI64ArrayAttr::get(context, reductionDims));
563}
564
565/// Sets up layout for Reduction operations by creating a SliceAttr for the
566/// result.
567xegpu::SliceAttr
569 VectorType srcVecTy,
570 const xegpu::uArch::uArch *uArch) {
571
572 auto srcShape = srcVecTy.getShape();
573 auto context = srcVecTy.getContext();
574 auto subgroupSize = uArch->getSubgroupSize();
575 xegpu::LayoutAttr srcLayout;
576
577 if (layoutKind == xegpu::LayoutKind::Subgroup) {
578 assert(true && "subgroup layout assignment not supported for reduction (op "
579 "is not expected at this level).");
580 } else if (layoutKind == xegpu::LayoutKind::InstData) {
581 assert(true && "instData layout assignment not supported for reduction (op "
582 "is not expected at this level).");
583 } else if (layoutKind == xegpu::LayoutKind::Lane) {
584 SmallVector<int32_t> laneLayout(1), laneData(1);
585 laneLayout[0] = std::min(subgroupSize, static_cast<int32_t>(srcShape[0]));
586 laneData[0] = 1;
587 srcLayout = xegpu::LayoutAttr::get(
588 context, DenseI32ArrayAttr::get(context, laneLayout),
589 DenseI32ArrayAttr::get(context, laneData));
590 }
591
592 auto result = xegpu::SliceAttr::get(context, srcLayout,
593 DenseI64ArrayAttr::get(context, 0));
594 return result;
595}
596
597/// Sets up the result layout for a bitcast operation.
598/// When casting to a smaller bitwidth, adjusts the layout dimensions (sgData,
599/// instData, or laneData) by multiplying by the bitwidth ratio to ensure the
600/// result layout can be correctly divided back to the source layout during
601/// inference.
602///
603/// Examples:
604/// 1. Casting f32 -> f16 (32-bit to 16-bit, bitWidthRatio = 2):
605/// Consumer layout: instData=[1, 16], subgroupSize=16
606/// Source shape: [8, 32]
607/// Result layout: instData=[1, 32] (16 * 2)
608/// The innermost dimension is multiplied by 2 to maintain consistency.
609///
610/// 2. Casting f32 -> i8 (32-bit to 8-bit, bitWidthRatio = 4):
611/// Consumer instData=[1, 16], subgroupSize=16
612/// Source shape: [4, 128]
613/// adjust the instData from [1, 16] to [1, 16 * 4 = 64]
614///
615/// 3. Casting i8 -> i32 (8-bit to 32-bit, bitWidthRatio = 1/4):
616/// Consumer layout: laneLayout=[1, 16], laneData=[1, 4]
617/// No adjustment needed - returns consumer layout directly.
618///
619xegpu::DistributeLayoutAttr xegpu::setupBitCastResultLayout(
620 xegpu::LayoutKind layoutKind, VectorType srcVecTy, VectorType resVecTy,
621 DistributeLayoutAttr consumerLayout, const xegpu::uArch::uArch *uArch) {
622
623 int srcElemTyBitWidth = srcVecTy.getElementType().getIntOrFloatBitWidth();
624 int resElemTyBitWidth = resVecTy.getElementType().getIntOrFloatBitWidth();
625
626 ArrayRef<int64_t> srcShape = srcVecTy.getShape();
627 SmallVector<int64_t> sgData = consumerLayout.getEffectiveSgDataAsInt();
628 SmallVector<int64_t> instData = consumerLayout.getEffectiveInstDataAsInt();
629 SmallVector<int64_t> laneData = consumerLayout.getEffectiveLaneDataAsInt();
630 assert(consumerLayout.getRank() == static_cast<int64_t>(srcShape.size()) &&
631 "laneData must be available for all dimensions");
632 size_t dim = srcShape.size() - 1;
633 int64_t sgDataValue = -1;
634 int64_t instDataValue = -1;
635 int64_t laneDataValue = -1;
636 const int subgroupSize = uArch->getSubgroupSize();
637
638 if (srcElemTyBitWidth > resElemTyBitWidth) {
639 // When casting to a smaller bitwidth, multiply the result layout
640 // accordingly to ensure it can be divided by the ratio back to the
641 // source layout.
642 int bitWidthRatio = srcElemTyBitWidth / resElemTyBitWidth;
643 int innermostDimLaneLayout = subgroupSize;
644 if (layoutKind == xegpu::LayoutKind::Subgroup) {
645 sgDataValue = sgData[dim];
646 } else if (layoutKind == xegpu::LayoutKind::InstData) {
647 instDataValue = instData[dim];
648 // Adjust instDataValue so it still fits within an instruction after
649 // dividing by bitWidthRatio
650 while ((instDataValue <= srcShape[dim]) &&
651 (instDataValue % (innermostDimLaneLayout * bitWidthRatio) != 0))
652 instDataValue *= 2;
653 assert((srcShape[dim] % instDataValue) == 0 &&
654 "srcShape, instData, and lanelayout for innermost must be 2^n !");
655 } else if (layoutKind == xegpu::LayoutKind::Lane) {
656 laneDataValue = laneData[dim];
657 while ((laneDataValue <= srcShape[dim]) &&
658 (laneDataValue % bitWidthRatio != 0))
659 laneDataValue *= 2;
660 }
661 // Now set only instData and laneData, preserving sgData
662 xegpu::DistributeLayoutAttr resLayout;
663 resLayout = consumerLayout.setDimData(dim, sgDataValue, instDataValue,
664 laneDataValue);
665 return resLayout;
666 }
667 return consumerLayout;
668}
669
670/// Sets up the result layout for an insert strided slice operation.
671/// Creates a result layout based on the specified layout kind (InstData or
672/// Lane).
673xegpu::DistributeLayoutAttr xegpu::setupInsertStridedSliceResultLayout(
674 xegpu::LayoutKind layoutKind, VectorType srcVectorTy,
675 VectorType resVectorTy, xegpu::DistributeLayoutAttr consumerLayout,
676 const xegpu::uArch::uArch *uArch) {
677
678 xegpu::DistributeLayoutAttr requiredResLayout;
679 SmallVector<int64_t> consumerInstData =
680 consumerLayout.getEffectiveInstDataAsInt();
681 SmallVector<int64_t> consumerLaneData =
682 consumerLayout.getEffectiveLaneDataAsInt();
683 SmallVector<int64_t> consumerLaneLayout =
684 consumerLayout.getEffectiveLaneLayoutAsInt();
685 ArrayRef<int64_t> srcShape = srcVectorTy.getShape();
686 int64_t instDataValue = -1;
687 int64_t laneDataValue = -1;
688
689 requiredResLayout = consumerLayout;
690 int srcRank = srcShape.size();
691
692 if (layoutKind == xegpu::LayoutKind::Subgroup) {
693 assert(true &&
694 "subgroup layout assignment not supported for insertStridedSlice.");
695 } else if (layoutKind == xegpu::LayoutKind::InstData) {
696 for (int dim = 0; dim < srcRank; dim++) {
697 instDataValue = std::min(srcShape[dim], consumerInstData[dim]);
698 requiredResLayout =
699 requiredResLayout.setDimData(dim, -1, instDataValue, -1);
700 }
701 } else if (layoutKind == xegpu::LayoutKind::Lane) {
702 for (int dim = 0; dim < srcRank; dim++) {
703 assert(srcShape[dim] % consumerLaneLayout[dim] == 0 &&
704 "srcShape must be divisible by laneLayout for all dimensions");
705 laneDataValue = std::min(srcShape[dim] / consumerLaneLayout[dim],
706 consumerLaneData[dim]);
707 requiredResLayout =
708 requiredResLayout.setDimData(dim, -1, -1, laneDataValue);
709 }
710 }
711 return requiredResLayout;
712}
713
714/// Sets up the anchor layout for load gather and load matrix operation.
715/// load matrix lowers to load gather and 1d block load. All of them share the
716/// same layout setup logic.
717/// For Subgroup layout, uses the consumer layout directly.
718/// non-chunked loads:
719/// InstData = {1, ..., min(consumer, maxLaneLoadSize * subgroupSize)}
720/// LaneLayout = {1, ..., subgroupSize}
721/// lane_data = {1, ..., min(consumer, maxLaneLoadSize)}
722/// chunked loads:
723/// InstData = {subgroupSize, min(consumer, maxLaneLoadSize)}
724/// LaneLayout = {subgroupSize, 1}
725/// lane_data={1,min(consumer, maxLaneLoadSize)}
726static xegpu::DistributeLayoutAttr setupGenericLoadAnchorLayout(
727 xegpu::LayoutKind layoutKind, mlir::MLIRContext *context,
728 xegpu::DistributeLayoutAttr consumerLayout, bool isChunkedLoad,
729 int maxChunkSize, ArrayRef<int64_t> resShape, int subgroupSize) {
730
731 if (layoutKind == xegpu::LayoutKind::Subgroup)
732 return consumerLayout;
733
734 SmallVector<int64_t> consumerInstData =
735 consumerLayout.getEffectiveInstDataAsInt();
736 SmallVector<int64_t> consumerLaneData =
737 consumerLayout.getEffectiveLaneDataAsInt();
738
739 SmallVector<int> instData(resShape.size(), 1);
740 SmallVector<int> laneLayout(resShape.size(), 1);
741 SmallVector<int> laneData(resShape.size(), 1);
742
743 if (!isChunkedLoad) {
744 if (layoutKind == xegpu::LayoutKind::InstData) {
745 instData.back() = std::min(static_cast<int>(consumerInstData.back()),
746 maxChunkSize * subgroupSize);
747 return xegpu::LayoutAttr::get(context, instData);
748 } else if (layoutKind == xegpu::LayoutKind::Lane) {
749 laneData.back() =
750 std::min(static_cast<int>(consumerLaneData.back()), maxChunkSize);
751 laneLayout.back() = std::min(static_cast<int64_t>(subgroupSize),
752 resShape.back() / laneData.back());
753 return xegpu::LayoutAttr::get(context, laneLayout, laneData);
754 }
755 } else {
756 assert(resShape.size() == 2 && "Chunked Store must access 2D tensor tile.");
757 if (layoutKind == xegpu::LayoutKind::InstData) {
758 instData[0] = subgroupSize;
759 instData[1] =
760 std::min(static_cast<int>(consumerInstData[1]), maxChunkSize);
761 return xegpu::LayoutAttr::get(context, instData);
762 } else if (layoutKind == xegpu::LayoutKind::Lane) {
763 laneLayout[0] = subgroupSize;
764 laneData[1] =
765 std::min(static_cast<int>(consumerLaneData[1]), maxChunkSize);
766 return xegpu::LayoutAttr::get(context, laneLayout, laneData);
767 }
768 }
769 return nullptr;
770}
771
772/// Sets up the anchor layout for a load gather operation.
773xegpu::DistributeLayoutAttr xegpu::setupLoadGatherAnchorLayout(
774 xegpu::LayoutKind layoutKind, VectorType resVecTy, int chunkSize,
775 xegpu::DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch) {
776
777 const int subgroupSize = uArch->getSubgroupSize();
778 ArrayRef<int64_t> resShape = resVecTy.getShape();
779 auto context = resVecTy.getContext();
780 auto elemBitWidth = resVecTy.getElementType().getIntOrFloatBitWidth();
781
782 const auto *uArchInstruction =
783 dyn_cast<xegpu::uArch::LoadGatherInstructionInterface>(
785 int maxChunkSize = uArchInstruction->getMaxLaneLoadSize(elemBitWidth);
786
787 return setupGenericLoadAnchorLayout(layoutKind, context, consumerLayout,
788 (chunkSize > 1), maxChunkSize, resShape,
789 subgroupSize);
790}
791
792/// Sets up the anchor layout for load matrix operation.
793/// TODO: enhance load matrix to indicate lowering to chunked load or not.
794xegpu::DistributeLayoutAttr
796 VectorType resVecTy,
797 xegpu::DistributeLayoutAttr consumerLayout,
798 const xegpu::uArch::uArch *uArch) {
799
800 const int subgroupSize = uArch->getSubgroupSize();
801 ArrayRef<int64_t> resShape = resVecTy.getShape();
802 auto context = resVecTy.getContext();
803 auto elemBitWidth = resVecTy.getElementType().getIntOrFloatBitWidth();
804
805 const auto *uArchInstruction =
806 dyn_cast<xegpu::uArch::LoadGatherInstructionInterface>(
808 int maxChunkSize = uArchInstruction->getMaxLaneLoadSize(elemBitWidth);
809 return setupGenericLoadAnchorLayout(layoutKind, context, consumerLayout,
810 false, maxChunkSize, resShape,
811 subgroupSize);
812}
813
814/// Sets up the anchor layout for store scatter and store matrix operation.
815/// store matrix lowers to store scatter and 1d block store. All of them share
816/// the same layout setup logic. For Subgroup layout, not support yet.
817/// non-chunked stores:
818/// InstData = {1, ..., subgroupSize}
819/// LaneLayout = {1, ..., subgroupSize}
820/// lane_data = {1, ..., 1}
821/// chunked stores:
822/// InstData = {subgroupSize, min(srcVec, maxLaneStoreSize)}
823/// LaneLayout = {subgroupSize, 1}
824/// lane_data={1,min(srcVec, maxLaneStoreSize)}
825static xegpu::DistributeLayoutAttr
827 mlir::MLIRContext *context, bool isChunkedStore,
828 int maxChunkSize, ArrayRef<int64_t> srcShape,
829 int subgroupSize) {
830
831 int srcShapeSize = srcShape.size();
832 SmallVector<int> instData(srcShapeSize, 1);
833 SmallVector<int> laneLayout(srcShapeSize, 1);
834 SmallVector<int> laneData(srcShapeSize, 1);
835
836 if (layoutKind == xegpu::LayoutKind::Subgroup) {
837 assert(true &&
838 "subgroup layout assignment not supported for storeScatter.");
839 return nullptr;
840 }
841
842 if (!isChunkedStore) {
843 if (layoutKind == xegpu::LayoutKind::InstData) {
844 instData[srcShapeSize - 1] =
845 std::min(subgroupSize, static_cast<int>(srcShape.back()));
846 return xegpu::LayoutAttr::get(context, instData);
847 } else if (layoutKind == xegpu::LayoutKind::Lane) {
848 laneLayout[srcShapeSize - 1] =
849 std::min(subgroupSize, static_cast<int>(srcShape.back()));
850 return xegpu::LayoutAttr::get(context, laneLayout, laneData);
851 }
852 } else {
853 assert(srcShapeSize == 2 && "Chunked Store must access 2D tensor tile.");
854 if (layoutKind == xegpu::LayoutKind::InstData) {
855 instData[0] = subgroupSize;
856 instData[1] = std::min(static_cast<int>(srcShape[1]), maxChunkSize);
857 return xegpu::LayoutAttr::get(context, instData);
858 } else if (layoutKind == xegpu::LayoutKind::Lane) {
859 laneLayout[0] = subgroupSize;
860 laneData[1] = std::min(static_cast<int>(srcShape[1]), maxChunkSize);
861 return xegpu::LayoutAttr::get(context, laneLayout, laneData);
862 }
863 }
864 return nullptr;
865}
866
867/// Sets up the anchor layout for a store scatter operation.
868xegpu::DistributeLayoutAttr
870 VectorType srcVecTy, int chunkSize,
871 const uArch::uArch *uArch) {
872
873 const int subgroupSize = uArch->getSubgroupSize();
874 ArrayRef<int64_t> srcShape = srcVecTy.getShape();
875 auto context = srcVecTy.getContext();
876 auto elemBitWidth = srcVecTy.getElementType().getIntOrFloatBitWidth();
877
878 const auto *uArchInstruction =
879 dyn_cast<xegpu::uArch::StoreScatterInstructionInterface>(
881 int maxChunkSize = uArchInstruction->getMaxLaneStoreSize(elemBitWidth);
882 return setupGenericStoreAnchorLayout(layoutKind, context, (chunkSize > 1),
883 maxChunkSize, srcShape, subgroupSize);
884}
885
886/// Sets up the anchor layout for a store matrix operation.
887xegpu::DistributeLayoutAttr
889 VectorType srcVecTy,
890 const xegpu::uArch::uArch *uArch) {
891
892 const int subgroupSize = uArch->getSubgroupSize();
893 ArrayRef<int64_t> srcShape = srcVecTy.getShape();
894 auto context = srcVecTy.getContext();
895 auto elemBitWidth = srcVecTy.getElementType().getIntOrFloatBitWidth();
896
897 const auto *uArchInstruction =
898 dyn_cast<xegpu::uArch::StoreScatterInstructionInterface>(
900 int maxChunkSize = uArchInstruction->getMaxLaneStoreSize(elemBitWidth);
901
902 return setupGenericStoreAnchorLayout(layoutKind, context, false, maxChunkSize,
903 srcShape, subgroupSize);
904}
905
906// This function returns the default lane layout for a given vector type.
907// - `packingSize` means multiple consecutive elements can be accessed
908// together as a single unit.
909// - `vnni` means data packing is column-wise (i.e., 2x1xf16 with vnni vs.
910// 1x2xf16 w/o vnni).
911template <typename RankedTy>
912static xegpu::LayoutAttr getDefaultLaneLayout2DBlockIo(
913 RankedTy ty, const xegpu::uArch::uArch *uArch,
914 std::optional<unsigned> packingSize = std::nullopt, bool vnni = false) {
915 // Expecting a 1D or 2D vector.
916 assert(((ty.getRank() == 1 && !vnni) || ty.getRank() == 2) &&
917 "Expected 1D non-vnni or 2D vector.");
918 // Expecting int or float element type.
919 assert(ty.getElementType().isIntOrFloat() &&
920 "Expected int or float element type.");
921
922 auto context = ty.getContext();
923 auto rank = ty.getRank();
924 SmallVector<int> laneLayout(rank, 1);
925 SmallVector<int> laneData(rank, 1);
926 if (packingSize.has_value()) {
927 unsigned bitwidth = ty.getElementType().getIntOrFloatBitWidth();
928 int &laneDataPos = vnni ? laneData[rank - 2] : laneData.back();
929 laneDataPos = bitwidth < *packingSize ? *packingSize / bitwidth : 1;
930 }
931 laneLayout.back() = uArch->getSubgroupSize();
932 return xegpu::LayoutAttr::get(context, laneLayout, laneData);
933}
934
935// This function returns all layouts for the given sgCount, whose sgData:
936// 1. Evenly divides the wgShape.
937// 2. Is a multiple of instData.
938// Example:
939// wgShape = [128, 64], instData = [8, 16], sgCount = 32
940// Returns layouts:
941// [(8,4), (16,2)], which correspond to sgData [16,16] and [8,32].
942using LayoutRepresentation = std::pair<int64_t, int64_t>;
945 int64_t sgCount) {
947 for (int sgLayout0 = 1; sgLayout0 <= sgCount; ++sgLayout0) {
948 if (sgCount % sgLayout0)
949 continue;
950 int64_t sgLayout1 = sgCount / sgLayout0;
951 int64_t sgData0 = wgShape[0] / sgLayout0;
952 int64_t sgData1 = wgShape[1] / sgLayout1;
953 if ((wgShape[0] % sgLayout0 || wgShape[1] % sgLayout1) ||
954 (sgData0 % instData[0] || sgData1 % instData[1]))
955 continue;
956 candidates.emplace_back(sgLayout0, sgLayout1);
957 }
958 // Sort primarily by how balanced they are
959 // (i.e., minimize the absolute difference between the two dimensions), and
960 // secondarily by the first dimension in ascending order.
961 llvm::sort(candidates, [](const LayoutRepresentation &lhs,
962 const LayoutRepresentation &rhs) {
963 int diffLhs = std::abs(lhs.first - lhs.second);
964 int diffRhs = std::abs(rhs.first - rhs.second);
965 if (diffLhs != diffRhs)
966 return diffLhs < diffRhs;
967 return lhs.first < rhs.first;
968 });
969 return candidates;
970}
971
972/// Sets up the anchor layouts for dpas operands (A, B, and C/D).
973/// The numSg and consumerLayout (optional) are only used by sg layout
974/// creation.
975std::optional<
976 std::tuple<xegpu::DistributeLayoutAttr, xegpu::DistributeLayoutAttr,
977 xegpu::DistributeLayoutAttr>>
978xegpu::setupDpasLayout(xegpu::LayoutKind layoutKind, VectorType aTy,
979 VectorType bTy, VectorType cdTy,
980 xegpu::DistributeLayoutAttr consumerLayout, int numSg,
981 const xegpu::uArch::uArch *uArch) {
982 auto context = aTy.getContext();
983 const auto *uArchInstruction =
984 dyn_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc>(uArch->getInstruction(
986
987 auto getInstDataVectors = [&]()
988 -> std::optional<std::tuple<SmallVector<int64_t>, SmallVector<int64_t>,
990 const int subgroupSize = uArch->getSubgroupSize();
991 const unsigned dataALen = aTy.getShape().front();
992 auto supportedALen = uArchInstruction->getSupportedM(aTy.getElementType());
993 const int maxALen =
994 xegpu::getLargestDivisor(dataALen, ArrayRef<unsigned>(supportedALen));
995
996 const unsigned dataBLen = bTy.getShape().back();
997 auto supportedBLen = uArchInstruction->getSupportedN(bTy.getElementType());
998 const int maxBLen =
999 xegpu::getLargestDivisor(dataBLen, ArrayRef<unsigned>(supportedBLen));
1000
1001 auto supportedCLen = uArchInstruction->getSupportedN(cdTy.getElementType());
1002 const int maxCLen =
1003 xegpu::getLargestDivisor(dataBLen, ArrayRef<unsigned>(supportedCLen));
1004 if (maxALen == -1 || maxBLen == -1 || maxCLen == -1)
1005 return std::nullopt;
1006
1007 SmallVector<int64_t> instDataA(aTy.getRank(), 1);
1008 instDataA[aTy.getRank() - 2] = maxALen;
1009 instDataA[aTy.getRank() - 1] = subgroupSize;
1010 SmallVector<int64_t> instDataB(bTy.getRank(), 1);
1011 instDataB[bTy.getRank() - 2] = subgroupSize;
1012 instDataB[bTy.getRank() - 1] = maxBLen;
1013 SmallVector<int64_t> instDataCD(cdTy.getRank(), 1);
1014 instDataCD[cdTy.getRank() - 2] = maxALen;
1015 instDataCD[cdTy.getRank() - 1] = maxCLen;
1016 return std::make_tuple(instDataA, instDataB, instDataCD);
1017 };
1018
1019 if (layoutKind == xegpu::LayoutKind::Subgroup) {
1020 assert(numSg > 0 &&
1021 "Number of subgroups must be provided for sg layout creation.");
1022 auto instDataVecs = getInstDataVectors();
1023 if (!instDataVecs)
1024 return std::nullopt;
1025 auto [instDataA, instDataB, instDataCD] = *instDataVecs;
1026 assert(instDataA.size() == 2 && instDataB.size() == 2 &&
1027 instDataCD.size() == 2 &&
1028 "Sg layout creation expects valid 2D inst data");
1029
1030 std::optional<LayoutRepresentation> consumerSgLayout = std::nullopt;
1031 if (consumerLayout && consumerLayout.isForWorkgroup()) {
1032 SmallVector<int64_t> sgLayoutD =
1033 consumerLayout.getEffectiveSgLayoutAsInt();
1034 consumerSgLayout = std::make_pair(sgLayoutD[0], sgLayoutD[1]);
1035 }
1036
1037 // Step 1. Get all valid layouts for A, B and C/D operands.
1038 // Order them from most balanced to least balanced.
1039 auto layoutsA = getValidLayouts(aTy.getShape(), instDataA, numSg);
1040 auto layoutsB = getValidLayouts(bTy.getShape(), instDataB, numSg);
1041 auto layoutsCD = getValidLayouts(cdTy.getShape(), instDataCD, numSg);
1042 if (layoutsA.empty() || layoutsB.empty() || layoutsCD.empty())
1043 return std::nullopt;
1044
1045 // Step 2. If the consumer layout can be reused for all operands, that
1046 // layout is chosen. Otherwise, pick the most balanced subgroup layout
1047 // that is valid for A, B and C (if present) operands
1048 llvm::DenseSet<LayoutRepresentation> setA(layoutsA.begin(), layoutsA.end());
1049 llvm::DenseSet<LayoutRepresentation> setCD(layoutsCD.begin(),
1050 layoutsCD.end());
1051 std::optional<LayoutRepresentation> bestPick;
1052 for (auto &sgLayout : layoutsB) {
1053 if (setA.contains(sgLayout) && setCD.contains(sgLayout)) {
1054 // Is in (A and B and CD) and matches consumer -> best pick
1055 if (consumerSgLayout.has_value() && sgLayout == *consumerSgLayout) {
1056 bestPick = sgLayout;
1057 break;
1058 }
1059 // Is in (A and B and CD) layoutsB is ordered from most
1060 // balanced to least. So the first one we see is the most balanced
1061 // one, remember it and later only update if there is one that matches
1062 // the consumer.
1063 if (!bestPick)
1064 bestPick = sgLayout;
1065 }
1066 }
1067 // Step 3. If there is no subgroup layout compatible with A, B and C (if
1068 // present) operands, we fail.
1069 if (!bestPick)
1070 return std::nullopt;
1071 SmallVector<int> sgLayout = {static_cast<int>(bestPick->first),
1072 static_cast<int>(bestPick->second)};
1073 SmallVector<int> sgDataA = {
1074 static_cast<int>(aTy.getShape()[0] / sgLayout[0]),
1075 static_cast<int>(aTy.getShape()[1] / sgLayout[1])};
1076 SmallVector<int> sgDataB = {
1077 static_cast<int>(bTy.getShape()[0] / sgLayout[0]),
1078 static_cast<int>(bTy.getShape()[1] / sgLayout[1])};
1079 SmallVector<int> sgDataCD = {
1080 static_cast<int>(cdTy.getShape()[0] / sgLayout[0]),
1081 static_cast<int>(cdTy.getShape()[1] / sgLayout[1])};
1082
1083 auto dpasALayout = xegpu::LayoutAttr::get(
1084 context, DenseI32ArrayAttr::get(context, sgLayout),
1085 DenseI32ArrayAttr::get(context, sgDataA),
1086 /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
1087 /*lane_data =*/nullptr, /*order =*/nullptr);
1088
1089 auto dpasBLayout = xegpu::LayoutAttr::get(
1090 context, DenseI32ArrayAttr::get(context, sgLayout),
1091 DenseI32ArrayAttr::get(context, sgDataB),
1092 /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
1093 /*lane_data =*/nullptr, /*order =*/nullptr);
1094
1095 auto dpasCDLayout = xegpu::LayoutAttr::get(
1096 context, DenseI32ArrayAttr::get(context, sgLayout),
1097 DenseI32ArrayAttr::get(context, sgDataCD),
1098 /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
1099 /*lane_data =*/nullptr, /*order =*/nullptr);
1100 return std::make_tuple(dpasALayout, dpasBLayout, dpasCDLayout);
1101 } else if (layoutKind == xegpu::LayoutKind::InstData) {
1102 auto instDataVecs = getInstDataVectors();
1103 if (!instDataVecs)
1104 return std::nullopt;
1105 auto [instDataA, instDataB, instDataCD] = *instDataVecs;
1106 return std::make_tuple(
1107 xegpu::LayoutAttr::get(
1108 context, SmallVector<int>(instDataA.begin(), instDataA.end())),
1109 xegpu::LayoutAttr::get(
1110 context, SmallVector<int>(instDataB.begin(), instDataB.end())),
1111 xegpu::LayoutAttr::get(
1112 context, SmallVector<int>(instDataCD.begin(), instDataCD.end())));
1113 } else if (layoutKind == xegpu::LayoutKind::Lane) {
1114 auto aLayout = getDefaultLaneLayout2DBlockIo(
1115 aTy, uArch, uArchInstruction->getPackedFormatBitSizeA());
1116 auto bLayout = getDefaultLaneLayout2DBlockIo(
1117 bTy, uArch, uArchInstruction->getPackedFormatBitSizeB(), true);
1118 auto cdLayout = getDefaultLaneLayout2DBlockIo(
1119 cdTy, uArch /*, packingSize = std::nullopt */);
1120 return std::make_tuple(aLayout, bLayout, cdLayout);
1121 }
1122 return std::nullopt;
1123}
1124
1125xegpu::DistributeLayoutAttr xegpu::getConsumerLayoutAt(OpOperand &operand) {
1126 Operation *op = operand.getOwner();
1127 unsigned idx = operand.getOperandNumber();
1128 xegpu::DistributeLayoutAttr resLayout;
1129 if (op->getNumResults() == 1)
1130 resLayout = xegpu::getDistributeLayoutAttr(op->getResult(0));
1131
1132 // For vector::BroadcastOp, infer the source layout from the result layout.
1133 if (auto broadcast = dyn_cast<vector::BroadcastOp>(op)) {
1134 if (!resLayout)
1135 return xegpu::DistributeLayoutAttr();
1136 auto srcTy = dyn_cast<VectorType>(broadcast.getSourceType());
1137 if (!srcTy)
1138 return xegpu::DistributeLayoutAttr();
1140 resLayout, broadcast.getResultVectorType().getShape(),
1141 srcTy.getShape());
1142 }
1143
1144 // For vector::MultiDimReductionOp, infer source layout from result layout
1145 // using reduction dims. Acc operand is expected to have the same layout as
1146 // the result.
1147 if (auto reduction = dyn_cast<vector::MultiDimReductionOp>(op)) {
1148 if (!resLayout)
1149 return xegpu::DistributeLayoutAttr();
1150 if (idx == 0) {
1151 SmallVector<int64_t> reductionDims(reduction.getReductionDims());
1152 return xegpu::inferMultiReductionSourceLayout(resLayout, reductionDims);
1153 }
1154 if (idx == 1)
1155 return resLayout;
1156 }
1157
1158 if (auto reduction = dyn_cast<vector::ReductionOp>(op)) {
1159 if (!resLayout)
1160 return xegpu::DistributeLayoutAttr();
1161 return xegpu::inferReductionSourceLayout(resLayout);
1162 }
1163
1164 // For vector::BitCastOp, infer source layout from result layout using
1165 // element type bitwidths.
1166 if (auto bitcast = dyn_cast<vector::BitCastOp>(op)) {
1167 if (!resLayout)
1168 return xegpu::DistributeLayoutAttr();
1169 int resElemBitWidth =
1170 bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
1171 int srcElemBitWidth =
1172 bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
1173 return xegpu::inferBitCastSourceLayout(resLayout, resElemBitWidth,
1174 srcElemBitWidth);
1175 }
1176
1177 // For vector::ShapeCastOp, infer source layout from result layout using
1178 // shapes.
1179 if (auto shapeCast = dyn_cast<vector::ShapeCastOp>(op)) {
1180 if (!resLayout)
1181 return xegpu::DistributeLayoutAttr();
1183 resLayout, shapeCast.getResultVectorType().getShape(),
1184 shapeCast.getSourceVectorType().getShape());
1185 }
1186
1187 // For vector::InsertStridedSliceOp, infer source layout from result layout.
1188 // Dest vector must have the same layout as the result.
1189 if (auto insertSlice = dyn_cast<vector::InsertStridedSliceOp>(op)) {
1190 if (!resLayout)
1191 return xegpu::DistributeLayoutAttr();
1192 if (idx == 0)
1194 resLayout, insertSlice.getDestVectorType().getShape(),
1195 insertSlice.getSourceVectorType().getShape());
1196 if (idx == 1)
1197 return resLayout;
1198 }
1199
1200 // For vector::TransposeOp, infer source layout from result layout using
1201 // permutation.
1202 if (auto transpose = dyn_cast<vector::TransposeOp>(op)) {
1203 if (!resLayout)
1204 return xegpu::DistributeLayoutAttr();
1205 return xegpu::inferTransposeSourceLayout(resLayout,
1206 transpose.getPermutation());
1207 }
1208
1209 // For elementwise operations, all operands must have the same layout as the
1210 // result.
1212 if (!resLayout)
1213 return xegpu::DistributeLayoutAttr();
1214 return resLayout;
1215 }
1216 // TODO: Handle more cases as needed here.
1217 // By default, assume no layout conflict and return the current layout of
1218 // the operand.
1219 return xegpu::getDistributeLayoutAttr(operand.get());
1220}
lhs
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
std::pair< int64_t, int64_t > LayoutRepresentation
static xegpu::DistributeLayoutAttr setupGenericStoreAnchorLayout(xegpu::LayoutKind layoutKind, mlir::MLIRContext *context, bool isChunkedStore, int maxChunkSize, ArrayRef< int64_t > srcShape, int subgroupSize)
Sets up the anchor layout for store scatter and store matrix operation.
static SmallVector< LayoutRepresentation > getValidLayouts(ArrayRef< int64_t > wgShape, ArrayRef< int64_t > instData, int64_t sgCount)
static xegpu::LayoutAttr getDefaultLaneLayout2DBlockIo(RankedTy ty, const xegpu::uArch::uArch *uArch, std::optional< unsigned > packingSize=std::nullopt, bool vnni=false)
static xegpu::DistributeLayoutAttr setupGenericLoadAnchorLayout(xegpu::LayoutKind layoutKind, mlir::MLIRContext *context, xegpu::DistributeLayoutAttr consumerLayout, bool isChunkedLoad, int maxChunkSize, ArrayRef< int64_t > resShape, int subgroupSize)
Sets up the anchor layout for load gather and load matrix operation.
IRValueT get() const
Return the current value being used by this operand.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class represents an operand of an operation.
Definition Value.h:254
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
This is a value defined by a result of an operation.
Definition Value.h:454
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool hasAttrOfType(NameT &&name)
Definition Operation.h:601
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:538
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:433
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:409
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:116
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:823
result_range getOpResults()
Definition Operation.h:446
Attribute removeAttr(StringAttr name)
Remove the attribute with the specified name if it exists.
Definition Operation.h:626
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:430
static WalkResult advance()
Definition WalkResult.h:47
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int64_t > content)
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
DistributeLayoutAttr inferShapeCastSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for a shape cast operation given the result layout attribute,...
DistributeLayoutAttr inferTransposeSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > permutation)
Infers the source layout attribute for a transpose operation given the result layout attribute and pe...
DistributeLayoutAttr inferInsertStridedSliceSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for an insert strided slice operation given the result layout attr...
void setTemporaryLayout(const T &operandOrResult, const DistributeLayoutAttr layout)
LayoutKind
Specifies the level of a layout hierarchy for comparison or propagation.
Definition XeGPU.h:32
void setDistributeLayoutAttr(const OpResult &Result, const DistributeLayoutAttr layout)
[to-be-deprecated] Sets the DistributeLayoutAttr for a given OpResult user should use setAnchorLayout...
SmallVector< NamedAttribute > dropInstDataOnAttrs(ArrayRef< NamedAttribute > attrs)
Updates the NamedAttribute sequence by dropping inst-data information from any DistributeLayoutAttr f...
bool matchUnitDimExpansion(ArrayRef< int64_t > src, ArrayRef< int64_t > dst, SmallVector< int64_t > &expandedUnitDims)
DistributeLayoutAttr setupLoadMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the anchor layout for load matrix operation.
int getLargestDivisor(T dim, ArrayRef< T > candidates, ArrayRef< T > candidateMultiples={})
Helper Function to find a proper instruction multiple for the user-supplied sg-level data shape (dive...
bool recoverTemporaryLayouts(Operation *rootOp)
Attach layout attributes to all vector-type operands of operations within the given operation's neste...
DistributeLayoutAttr inferBroadcastSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for a broadcast operation given the result layout attribute,...
DistributeLayoutAttr setupStoreScatterAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, int chunkSize, const uArch::uArch *uArch)
Sets up the anchor layout for a store scatter operation.
void recoverTemporaryLayoutsDeprecated(Operation *op)
[to-be-deprecated] Set the DistributeLayoutAttr for each OpOperand and OpResult of of the given opera...
SliceAttr setupMultiReductionResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, DistributeLayoutAttr consumerLayout, SmallVector< int64_t > reductionDims, int numSg, const uArch::uArch *uArch)
Sets up layout for Multi-Reduction operations by creating a SliceAttr for the result.
bool matchSplitDimExpansion(ArrayRef< int64_t > src, ArrayRef< int64_t > dst, SmallVector< SmallVector< int64_t > > &splitDimGroups)
DistributeLayoutAttr setupBitCastResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Setup the result layout attribute for a bitcast operation based on element type bitwidths.
void removeLayoutAttr(const T &operandOrResult)
Removes the LayoutAttr for a given OpOperand or OpResult if it exists.
DistributeLayoutAttr inferMaskOffsetLayoutForScatterIO(DistributeLayoutAttr payloadLayout, int chunkSize)
Infers the layout attribute for mask and offset operand for Chunked load and store,...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
SmallVector< NamedAttribute > dropSgLayoutAndDataOnAttrs(ArrayRef< NamedAttribute > attrs)
Updates the NamedAttribute sequence by dropping sg-layout and sg-data information from any Distribute...
std::string getTemporaryLayoutName(const OpOperand &operand)
Return the attribute name for the OpOperand to attach DistributeLayoutAttr.
DistributeLayoutAttr inferBitCastSourceLayout(DistributeLayoutAttr resLayout, int resElemTyBitWidth, int srcElemTyBitWidth)
Infers the source layout attribute for a bitcast operation given the result layout attribute,...
DistributeLayoutAttr setupInsertStridedSliceResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the result layout for an insert strided slice operation.
DistributeLayoutAttr inferReductionSourceLayout(DistributeLayoutAttr resLayout)
Infers the source layout attribute for a reduction operation given the result layout attribute and re...
xegpu::DistributeLayoutAttr getConsumerLayoutAt(OpOperand &operand)
Gets the expected layout for a given consumer operand.
void removeLayoutAttrs(Operation *op)
Removes the DistributeLayoutAttr for each OpOperand and OpResult of the given operation if they exist...
DistributeLayoutAttr inferMultiReductionSourceLayout(DistributeLayoutAttr resLayout, SmallVector< int64_t > reduceDims)
Infers the source layout attribute for a reduction operation given the result layout attribute and re...
DistributeLayoutAttr setupLoadGatherAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, int chunkSize, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the anchor layout for a load gather operation.
std::optional< std::tuple< DistributeLayoutAttr, DistributeLayoutAttr, DistributeLayoutAttr > > setupDpasLayout(LayoutKind layoutKind, VectorType aTy, VectorType bTy, VectorType cdTy, DistributeLayoutAttr consumerLayout, int numSg, const uArch::uArch *uArch)
Sets up the anchor layouts for a dpas operands (A, B, and C/D).
SliceAttr setupReductionResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, const uArch::uArch *uArch)
Sets up layout for Reduction operations by creating a SliceAttr for the result.
DistributeLayoutAttr setupStoreMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, const uArch::uArch *uArch)
Sets up the anchor layout for a store matrix operation.
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
virtual int getSubgroupSize() const =0
uArch(StringRef name, StringRef description, llvm::ArrayRef< const Instruction * > instructionRegistry)
Definition uArchBase.h:151
const Instruction * getInstruction(InstructionKind instKind) const
Definition uArchBase.h:163