MLIR 23.0.0git
XeGPULayoutImpl.cpp
Go to the documentation of this file.
1//===---- XeGPULayoutImpl.cpp - MLIR Utilities for XeGPUOps
2//------------------===//
3//
4// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
5// See https://llvm.org/LICENSE.txt for license information.
6// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7//
8//===----------------------------------------------------------------------===//
9//
10// This file implements layout utility functions for XeGPU dialect
11// transformation.
12//
13//===----------------------------------------------------------------------===//
14
23#include "mlir/IR/Builders.h"
24#include "mlir/IR/Operation.h"
25#include "mlir/IR/ValueRange.h"
30#include "llvm/ADT/PostOrderIterator.h"
31#include "llvm/Support/FormatVariadic.h"
32#include <cstdint>
33#include <numeric>
34
35using namespace mlir;
36
40 out.reserve(attrs.size());
41
42 for (auto attr : attrs) {
43 if (auto dist = dyn_cast<xegpu::DistributeLayoutAttr>(attr.getValue())) {
44 auto newLayout = dist.dropSgLayoutAndData();
45 if (newLayout)
46 out.emplace_back(attr.getName(), newLayout);
47 } else {
48 out.push_back(attr);
49 }
50 }
51
52 return out;
53}
54
58 out.reserve(attrs.size());
59
60 for (auto attr : attrs) {
61 if (auto dist = dyn_cast<xegpu::DistributeLayoutAttr>(attr.getValue())) {
62 auto newLayout = dist.dropInstData();
63 if (newLayout)
64 out.emplace_back(attr.getName(), newLayout);
65 } else {
66 out.push_back(attr);
67 }
68 }
69
70 return out;
71}
72
73// Sets the layout on a TensorDesc value by updating its type to include
74// the given layout, if the type does not already have a layout attached.
75static void setTensorDescLayout(Value val, xegpu::DistributeLayoutAttr layout) {
76 auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(val.getType());
77 if (!tensorDescTy || tensorDescTy.getLayoutAttr())
78 return;
79 auto typeWithLayout = xegpu::TensorDescType::get(
80 tensorDescTy.getContext(), tensorDescTy.getShape(),
81 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
82 val.setType(typeWithLayout);
83}
84
85// the walkRegionBackward() is a recursive function
86// the input rootOp is the function operation, which is also a region op.
87// it recursively processes the region op in reverse topological order.
88static void walkRegionBackward(Region &region,
90
91 // Use post-order traversal to process blocks in reverse topological order.
92 // This ensures that use blocks are visited before def blocks, which is
93 // required for backward layout propagation.
94 if (region.empty())
95 return;
96 llvm::ReversePostOrderTraversal<Region *> rpot(&region);
97 SmallVector<Block *> blocks(rpot.begin(), rpot.end());
98 for (Block *block : llvm::reverse(blocks)) {
99 // ops: back -> front
100 for (Operation &op : llvm::reverse(*block)) {
101 // make sure we first visit inside the region op (so yield op first)
102 // and then move to region op itself
103 // Regions are iterated in forward order so that for multi-region ops
104 // like scf.while, earlier regions (e.g., "before/cond") are processed
105 // first. This ensures that when a later region's terminator (e.g., "do"
106 // yield) needs the layout of an earlier region's block args, those
107 // layouts are already available from use points.
108 for (Region &nested : op.getRegions())
109 walkRegionBackward(nested, visit);
110
111 visit(&op);
112 }
113 }
114}
115
116static xegpu::DistributeLayoutAttr getLayoutFromUsePoints(Value result) {
117 xegpu::DistributeLayoutAttr layout = nullptr;
118 for (OpOperand &use : result.getUses()) {
119 if (auto tmpLayout = xegpu::getDistributeLayoutAttr(use)) {
120 if (!layout)
121 layout = tmpLayout;
122 break;
123 }
124 }
125 return layout;
126}
127
128// Returns true if `op` is safe and cheap to clone (no side effects, no
129// regions, and all operands are themselves trivially rematerializable, e.g.
130// block-arg-free pure value generators such as `vector.step`, splat
131// `arith.constant`, or `vector.create_mask` whose operands are constants).
133 if (!op || op->getNumRegions() != 0)
134 return false;
135 if (!isMemoryEffectFree(op))
136 return false;
137 for (Value v : op->getOperands()) {
138 Operation *defOp = v.getDefiningOp();
139 if (!defOp)
140 return false;
141 if (!isTriviallyRematerializable(defOp))
142 return false;
143 }
144 return true;
145}
146
147// For regular operations: First the result layouts are propagated from uses.
148// Then the result layouts are propagated to uses (operands).
150 if (op->getNumResults() == 0)
151 return;
152 if (op->getNumResults() > 1 && !isa<vector::DeinterleaveOp>(op))
153 return;
154 OpResult result = op->getResult(0);
155 xegpu::DistributeLayoutAttr resLayout = getLayoutFromUsePoints(result);
156 Type resultType = result.getType();
157
158 if (!resLayout)
159 return;
160
161 // Recover layout for TensorDesc type results by updating the type to include
162 // the layout. For vector type
163 if (isa<xegpu::TensorDescType>(resultType))
164 setTensorDescLayout(result, resLayout);
165
166 // Recover layout for vector type results, or for multi-reduction ops which
167 // may reduce to a scalar that still needs a layout.
168 if (isa<VectorType>(resultType) || isa<vector::MultiDimReductionOp>(op))
170
171 if (isa<vector::DeinterleaveOp>(op))
172 xegpu::setTemporaryLayout(op->getResult(1), resLayout);
173
174 for (OpOperand &opr : op->getOpOperands()) {
175 xegpu::DistributeLayoutAttr operandLayout =
177 if (isa<VectorType>(opr.get().getType()) && operandLayout)
178 xegpu::setTemporaryLayout(opr, operandLayout);
179 }
180}
181
182// Propagate layout from region op results and sibling region block args
183// to yield/condition operands. For each successor of this terminator:
184// - Parent successor: propagate from parent op's result layouts (use points).
185// - Region successor: propagate from target region's block arg layouts (use
186// points), e.g., scf.yield in "after/do" region propagates to "before/cond"
187// block args.
189 mlir::RegionBranchTerminatorOpInterface yieldOp) {
190 auto regionBranchOp =
191 dyn_cast<RegionBranchOpInterface>(yieldOp->getParentOp());
192 if (!regionBranchOp)
193 return;
194
196 SmallVector<Attribute> operandAttrs(yieldOp->getNumOperands(), nullptr);
197 yieldOp.getSuccessorRegions(operandAttrs, successors);
198
199 for (const RegionSuccessor &successor : successors) {
200 OperandRange succOps = yieldOp.getSuccessorOperands(successor);
201 if (succOps.empty())
202 continue;
203 unsigned beginIdx = succOps.getBeginOperandIndex();
204 ValueRange successorInputs = regionBranchOp.getSuccessorInputs(successor);
205 unsigned count = std::min<unsigned>(succOps.size(), successorInputs.size());
206
207 for (unsigned i = 0; i < count; ++i) {
208 xegpu::DistributeLayoutAttr layout;
209 if (successor.isParent()) {
210 // For parent successor, get layout from external use points of the
211 // parent op's results.
212 auto regionResult = regionBranchOp->getResult(i);
213 layout = getLayoutFromUsePoints(regionResult);
214 if (layout) {
215 // set layout for the region op, like scf.loop
216 xegpu::setTemporaryLayout(regionResult, layout);
217 if (isa<xegpu::TensorDescType>(regionResult.getType()))
218 setTensorDescLayout(regionResult, layout);
219 }
220 } else {
221 // For region successor, get layout from the target region's block
222 // arg use points (e.g., "before/cond" region args for scf.while
223 // "after/do" yield).
224 layout = getLayoutFromUsePoints(successorInputs[i]);
225 }
226 if (!layout)
227 continue;
228 auto operandType = succOps[i].getType();
229 if (isa<VectorType>(operandType) ||
230 dyn_cast<xegpu::TensorDescType>(operandType))
231 // recover layout for yield op operands
232 xegpu::setTemporaryLayout(yieldOp->getOpOperand(beginIdx + i), layout);
233 }
234 }
235}
236
237// Propagate layout from region arguments to region op's init operands. This
238// sets the temporary layout for region arguments and init operands.
239static void propagateRegionArgsToInits(mlir::RegionBranchOpInterface regionOp) {
240 // Iterate all regions of the region op. For each block argument that has a
241 // layout (determined from its use points), trace back to find the
242 // corresponding init operand of the regionOp and set the layout on it.
243 // This works generically for scf.for, scf.while, and other
244 // RegionBranchOpInterface ops.
245 for (Region &region : regionOp->getRegions()) {
246 RegionSuccessor regionSuccessor(&region);
247 // Use getSuccessorInputs to get the block arguments that correspond to
248 // predecessor operands. This correctly handles ops like scf.for where
249 // the induction variable is a block arg but not a successor input.
250 ValueRange successorInputs = regionOp.getSuccessorInputs(regionSuccessor);
251 for (auto [inputIdx, regionArg] : llvm::enumerate(successorInputs)) {
252 auto layout = getLayoutFromUsePoints(regionArg);
253 if (!layout)
254 continue;
255
256 // Recover layout for tensor_desc block args by updating the type.
257 if (isa<xegpu::TensorDescType>(regionArg.getType()))
258 setTensorDescLayout(regionArg, layout);
259
260 // Recover layout for region op operands, like scf.for's init operands.
261 // Find all predecessor values that flow into this block argument.
262 SmallVector<Value> predValues;
263 regionOp.getPredecessorValues(regionSuccessor, inputIdx, predValues);
264 for (Value predVal : predValues) {
265 // Match predecessor value to an operand of the regionOp.
266 for (OpOperand &operand : regionOp->getOpOperands()) {
267 if (operand.get() == predVal)
268 xegpu::setTemporaryLayout(operand, layout);
269 }
270 }
271 }
272 }
273}
274
275// Prerequisite for Layout Recovery
276// It relies on the following invariant:
277// 1. there is no layout conflict between different uses of the same definition.
278// 2. each definition has a well-defined layout requirement at its use point.
279// - Every definition must have at least one use that appears after it in
280// topological order.
281// - TODO: If a definition has no such use (e.g., a loop result or region
282// output), an explicit convert_layout operation is inserted to create a
283// use.
284// - Only the result of convert_layout is permitted to have no subsequent
285// use.
286//
287// The recovery proceeds by scanning the operation in reverse topological order
288// as follows:
289// For regular operations: First the result layouts are propagated from uses.
290// Then the result layouts are propagated to operands.
291//
292// For region operations (e.g., loops):
293// - When backward propagation reaches a region op, it sets the layout of
294// the region op’s results according to use points like regular ops.
295// - Then, the result layouts (such as a loop output) are propagated to
296// their corresponding operands in the yield.
297// - When backward propagation reaches the first operation inside the
298// region, the pass examines the region op’s initialization list,
299// propagating from region arguments to the corresponding initialization
300// operands.
301// - This ensures that layouts are consistently propagated
302// across region boundaries while preserving a single well-defined use for
303// each definition at the region-op level.
305 auto processFunc = [&](Region &body, StringRef funcName) {
306 walkRegionBackward(body, [&](Operation *op) {
307 if (auto regionOp = dyn_cast<mlir::RegionBranchOpInterface>(op)) {
309 } else if (auto yieldOp =
310 dyn_cast<mlir::RegionBranchTerminatorOpInterface>(op)) {
312 } else if (!dyn_cast<xegpu::AnchorLayoutInterface>(op)) {
314 }
315 });
316 };
318 rootOp->walk([&](func::FuncOp func) {
319 processFunc(func.getBody(), func.getSymName());
320 });
321 rootOp->walk([&](gpu::GPUFuncOp func) {
322 processFunc(func.getBody(), func.getName());
323 });
324
325 return true;
326}
327
328template <typename T, typename>
329void xegpu::removeLayoutAttr(const T &operandOrResult) {
330 Operation *owner = operandOrResult.getOwner();
331 std::string name = xegpu::getTemporaryLayoutName(operandOrResult);
332 if (owner->hasAttrOfType<DistributeLayoutAttr>(name))
333 owner->removeAttr(name);
334}
335
336// Explicit instantiation for OpResult
337template void
339
340// Explicit instantiation for OpOperand
341template void
343
345 op->walk([&](Operation *nestOp) {
346 // Remove all attributes of DistributeLayoutAttr type
347 SmallVector<StringAttr> attrsToRemove;
348 for (auto namedAttr : nestOp->getAttrs()) {
349 if (isa<DistributeLayoutAttr>(namedAttr.getValue()))
350 attrsToRemove.push_back(namedAttr.getName());
351 }
352 for (auto attrName : attrsToRemove)
353 nestOp->removeAttr(attrName);
354 });
355}
356
358 op->walk([&](Operation *nestOp) {
359 SmallVector<StringAttr> attrsToRemove;
360 for (auto namedAttr : nestOp->getDiscardableAttrs()) {
361 if (isa<xegpu::DistributeLayoutAttr>(namedAttr.getValue()))
362 attrsToRemove.push_back(namedAttr.getName());
363 }
364 for (auto attrName : attrsToRemove)
365 nestOp->removeDiscardableAttr(attrName);
366 });
367}
368
369/// Infers the source layout attribute for a broadcast operation given the
370/// result layout attribute, result shape, source shape.
371xegpu::DistributeLayoutAttr
372xegpu::inferBroadcastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
373 ArrayRef<int64_t> resShape,
374 ArrayRef<int64_t> srcShape) {
375
376 SmallVector<int64_t> bcastDims;
377 size_t dimDiff = resShape.size() - srcShape.size();
378 auto bcastSourceLayout = resLayout;
379 for (size_t i = dimDiff; i < resShape.size(); i++) {
380 if ((srcShape[i - dimDiff] == 1) && (resShape[i] != 1))
381 bcastDims.push_back(i);
382 }
383
384 // the sg_layout and lane_layout for unit dimensions are preserved so it can
385 // be propagate to producer op so potentially used by the multi-reduction op.
386 if (!bcastDims.empty())
387 bcastSourceLayout = bcastSourceLayout.setUnitDimData(bcastDims);
388
389 if (dimDiff > 0) {
390 SmallVector<int64_t> sliceDims;
391 for (size_t i = 0; i < dimDiff; i++)
392 sliceDims.push_back(i);
393 bcastSourceLayout = xegpu::SliceAttr::get(
394 resLayout.getContext(), bcastSourceLayout,
395 DenseI64ArrayAttr::get(resLayout.getContext(), sliceDims));
396 }
397 return bcastSourceLayout;
398}
399
400/// Infers the source layout attribute for a reduction operation given the
401/// result layout attribute and reduced dims.
402xegpu::DistributeLayoutAttr
403xegpu::inferMultiReductionSourceLayout(xegpu::DistributeLayoutAttr resLayout,
404 SmallVector<int64_t> reduceDims) {
405
406 assert(isa<xegpu::SliceAttr>(resLayout) &&
407 "reduction result layout must be slice layout");
408
409 xegpu::SliceAttr sliceLayout = dyn_cast<xegpu::SliceAttr>(resLayout);
410
411 assert((reduceDims == sliceLayout.getDims().asArrayRef()) &&
412 "reduction dims must match with slice dims");
413
414 return sliceLayout.getParent();
415}
416
417xegpu::DistributeLayoutAttr
418xegpu::inferReductionSourceLayout(xegpu::DistributeLayoutAttr resLayout) {
419 return xegpu::inferMultiReductionSourceLayout(resLayout, {0});
420}
421
422/// Infers the source layout attribute for a transpose operation given the
423/// result layout attribute and permutation.
424xegpu::DistributeLayoutAttr
425xegpu::inferTransposeSourceLayout(xegpu::DistributeLayoutAttr resLayout,
426 ArrayRef<int64_t> permutation) {
427 return resLayout.transposeDims(permutation);
428}
429
430/// Infers the source layout attribute for a bitcast operation given the
431/// result layout attribute, result element type bitwidth, and source element
432/// type bitwidth.
433xegpu::DistributeLayoutAttr
434xegpu::inferBitCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
435 int resElemTyBitWidth, int srcElemTyBitWidth) {
436
437 SmallVector<int64_t> sgData = resLayout.getEffectiveSgDataAsInt();
438 SmallVector<int64_t> instData = resLayout.getEffectiveInstDataAsInt();
439 SmallVector<int64_t> laneData = resLayout.getEffectiveLaneDataAsInt();
440 size_t sgDataSize = sgData.size();
441 size_t instDataSize = instData.size();
442 size_t laneDataSize = laneData.size();
443 int64_t sgDataValue = -1;
444 int64_t instDataValue = -1;
445 int64_t laneDataValue = -1;
446 int64_t dim = resLayout.getRank() - 1;
447
448 if (srcElemTyBitWidth <= resElemTyBitWidth) {
449 int bitWidthRatio = resElemTyBitWidth / srcElemTyBitWidth;
450 if (sgDataSize)
451 sgDataValue = sgData.back() * bitWidthRatio;
452 if (instDataSize)
453 instDataValue = instData.back() * bitWidthRatio;
454 if (laneDataSize)
455 laneDataValue = laneData.back() * bitWidthRatio;
456 } else {
457 int bitWidthRatio = srcElemTyBitWidth / resElemTyBitWidth;
458 if (sgDataSize) {
459 assert((sgData.back() % bitWidthRatio) == 0 &&
460 "sgData not divisible by bitWidthRatio");
461 sgDataValue = sgData.back() / bitWidthRatio;
462 }
463 if (instDataSize) {
464 assert((instData.back() % bitWidthRatio) == 0 &&
465 "instData not divisible by bitWidthRatio");
466 instDataValue = instData.back() / bitWidthRatio;
467 }
468 if (laneDataSize) {
469 assert((laneData.back() % bitWidthRatio) == 0 &&
470 "laneData not divisible by bitWidthRatio");
471 laneDataValue = laneData.back() / bitWidthRatio;
472 }
473 }
474
475 xegpu::DistributeLayoutAttr finalSrcLayout;
476 finalSrcLayout =
477 resLayout.setDimData(dim, sgDataValue, instDataValue, laneDataValue);
478
479 return finalSrcLayout;
480}
481
482/// Infers the source layout attribute for an interleave operation given the
483/// result layout attribute. Interleave doubles the size of the innermost
484/// dimension, so the layout inference is similar to bitcast where the source
485/// element type is larger than the result element type (ratio = 2).
486xegpu::DistributeLayoutAttr
487xegpu::inferInterleaveSourceLayout(xegpu::DistributeLayoutAttr resLayout) {
488
489 SmallVector<int64_t> sgData = resLayout.getEffectiveSgDataAsInt();
490 SmallVector<int64_t> instData = resLayout.getEffectiveInstDataAsInt();
491 SmallVector<int64_t> laneData = resLayout.getEffectiveLaneDataAsInt();
492 size_t sgDataSize = sgData.size();
493 size_t instDataSize = instData.size();
494 size_t laneDataSize = laneData.size();
495 int64_t sgDataValue = -1;
496 int64_t instDataValue = -1;
497 int64_t laneDataValue = -1;
498 int64_t dim = resLayout.getRank() - 1;
499
500 // Interleave doubles the innermost dimension, so we need to halve the
501 // layout values (similar to bitcast with ratio = 2)
502 constexpr int ratio = 2;
503 if (sgDataSize) {
504 assert((sgData.back() % ratio) == 0 &&
505 "sgData not divisible by interleave ratio");
506 sgDataValue = sgData.back() / ratio;
507 }
508 if (instDataSize) {
509 assert((instData.back() % ratio) == 0 &&
510 "instData not divisible by interleave ratio");
511 instDataValue = instData.back() / ratio;
512 }
513 if (laneDataSize) {
514 assert((laneData.back() % ratio) == 0 &&
515 "laneData not divisible by interleave ratio");
516 laneDataValue = laneData.back() / ratio;
517 }
518
519 return resLayout.setDimData(dim, sgDataValue, instDataValue, laneDataValue);
520}
521
522/// Infers the source layout attribute for a deinterleave operation given the
523/// result layout attribute. Deinterleave halves the size of the innermost
524/// dimension, so the layout inference is similar to bitcast where the source
525/// element type is smaller than the result element type (ratio = 2).
526xegpu::DistributeLayoutAttr
527xegpu::inferDeinterleaveSourceLayout(xegpu::DistributeLayoutAttr resLayout) {
528
529 SmallVector<int64_t> sgData = resLayout.getEffectiveSgDataAsInt();
530 SmallVector<int64_t> instData = resLayout.getEffectiveInstDataAsInt();
531 SmallVector<int64_t> laneData = resLayout.getEffectiveLaneDataAsInt();
532 size_t sgDataSize = sgData.size();
533 size_t instDataSize = instData.size();
534 size_t laneDataSize = laneData.size();
535 int64_t sgDataValue = -1;
536 int64_t instDataValue = -1;
537 int64_t laneDataValue = -1;
538 int64_t dim = resLayout.getRank() - 1;
539
540 // Deinterleave halves the innermost dimension, so we need to double the
541 // layout values (similar to bitcast with ratio = 2)
542 constexpr int ratio = 2;
543 if (sgDataSize)
544 sgDataValue = sgData.back() * ratio;
545 if (instDataSize)
546 instDataValue = instData.back() * ratio;
547 if (laneDataSize)
548 laneDataValue = laneData.back() * ratio;
549
550 return resLayout.setDimData(dim, sgDataValue, instDataValue, laneDataValue);
551}
552
553/// Infers the source layout attribute for an insert strided slice operation
554/// given the result layout attribute, result shape, and source shape. Removes
555/// leading dimensions from the result layout to match the source shape size.
556xegpu::DistributeLayoutAttr xegpu::inferInsertStridedSliceSourceLayout(
557 xegpu::DistributeLayoutAttr resLayout, ArrayRef<int64_t> resShape,
558 ArrayRef<int64_t> srcShape) {
559
560 int srcShapeSize = srcShape.size();
561 int resShapeSize = resShape.size();
562 int dimDiff = resShapeSize - srcShapeSize;
563
564 if (dimDiff > 0) {
565 // assert that the leading dimensions being sliced off are not distributed
566 // (i.e. sg_layout and lane_layout for those dimensions are all 1)
567 auto resSgLayout = resLayout.getEffectiveSgLayoutAsInt();
568 auto resLaneLayout = resLayout.getEffectiveLaneLayoutAsInt();
569 for (int i = 0; i < dimDiff; i++) {
570 assert((resSgLayout.size() == 0 || resSgLayout[i] == 1) &&
571 (resLaneLayout.size() == 0 || resLaneLayout[i] == 1) &&
572 "Leading dimensions being sliced off must not be distributed");
573 }
574 return resLayout.dropDims(llvm::to_vector(llvm::seq<int64_t>(0, dimDiff)));
575 }
576 return resLayout;
577}
578
579/// Infers the source layout attribute for an insert operation
580/// given the result layout attribute, result shape, and source shape. Removes
581/// leading dimensions from the result layout to match the source shape size.
582// TODO: add propagation support for insert op
583xegpu::DistributeLayoutAttr
584xegpu::inferInsertSourceLayout(xegpu::DistributeLayoutAttr resLayout,
585 ArrayRef<int64_t> resShape,
586 ArrayRef<int64_t> srcShape) {
587
588 int srcShapeSize = srcShape.size();
589 int resShapeSize = resShape.size();
590 int dimDiff = resShapeSize - srcShapeSize;
591
592 if (dimDiff > 0) {
593 // assert that the leading dimensions being sliced off are not distributed
594 // (i.e. sg_layout and lane_layout for those dimensions are all 1)
595 auto resSgLayout = resLayout.getEffectiveSgLayoutAsInt();
596 auto resLaneLayout = resLayout.getEffectiveLaneLayoutAsInt();
597 for (int i = 0; i < dimDiff; i++) {
598 assert((resSgLayout.size() == 0 || resSgLayout[i] == 1) &&
599 (resLaneLayout.size() == 0 || resLaneLayout[i] == 1) &&
600 "Leading dimensions being sliced off must not be distributed");
601 }
602 return resLayout.dropDims(llvm::to_vector(llvm::seq<int64_t>(0, dimDiff)));
603 }
604 return resLayout;
605}
606
607/// Infers the source layout attribute for extract operation
608/// given the result layout attribute, result shape, and source shape. Adds
609/// leading dimensions to the source layout to match the source shape size.
610// TODO: add layout attribute interface: expandDims() and use it here.
611// TODO: add propagation support for extract op
612xegpu::DistributeLayoutAttr
613xegpu::inferExtractSourceLayout(xegpu::DistributeLayoutAttr resLayout,
614 ArrayRef<int64_t> resShape,
615 ArrayRef<int64_t> srcShape) {
616
617 int srcShapeSize = srcShape.size();
618 int resShapeSize = resShape.size();
619 int dimDiff = srcShapeSize - resShapeSize;
620 auto context = resLayout.getContext();
621 // construct the source layout by adding unit dimensions to the front of
622 // result layout
623 if (dimDiff > 0) {
624 auto sgLayout = resLayout.getEffectiveSgLayoutAsInt();
625 auto sgData = resLayout.getEffectiveSgDataAsInt();
626 auto instData = resLayout.getEffectiveInstDataAsInt();
627 auto laneLayout = resLayout.getEffectiveLaneLayoutAsInt();
628 auto laneData = resLayout.getEffectiveLaneDataAsInt();
629 auto order = resLayout.getEffectiveOrderAsInt();
630
631 // Example: result shape is 3D with order [1, 2, 0], source shape is 5D
632 // (adding 2 leading dimensions). Expected source order: [3, 4, 2, 1, 0]
633 // Step 1: shift existing order by dimDiff: [1, 2, 0] -> [3, 4, 2]
634 // Step 2: append new leading dims in reverse (slowest first): [3, 4, 2, 1,
635 // 0]
636
637 // Shift existing dimension indices in order by dimDiff to account for the
638 // new leading dimensions being added to the source shape
639 for (auto &o : order)
640 o += dimDiff;
641
642 // Add unit dimensions to the front of non-empty layout vectors and append
643 // the new dimension indices to the order array in reverse (slowest
644 // dimension has the lowest index and appears last in the order array)
645 for (int i = 0; i < dimDiff; i++) {
646 if (!sgLayout.empty())
647 sgLayout.insert(sgLayout.begin(), 1);
648 if (!sgData.empty())
649 sgData.insert(sgData.begin(), 1);
650 if (!instData.empty())
651 instData.insert(instData.begin(), 1);
652 if (!laneLayout.empty())
653 laneLayout.insert(laneLayout.begin(), 1);
654 if (!laneData.empty())
655 laneData.insert(laneData.begin(), 1);
656 order.push_back(dimDiff - 1 - i);
657 }
658
659 DenseI32ArrayAttr orderAttr = resLayout ? resLayout.getOrder() : nullptr;
660 auto toAttr = [&](ArrayRef<int64_t> v) -> DenseI32ArrayAttr {
661 if (v.empty())
662 return DenseI32ArrayAttr();
663 SmallVector<int32_t> v32(v.begin(), v.end());
664 return DenseI32ArrayAttr::get(context, v32);
665 };
666 auto srcLayout = xegpu::LayoutAttr::get(
667 context, sgLayout.empty() ? nullptr : toAttr(sgLayout),
668 sgData.empty() ? nullptr : toAttr(sgData),
669 instData.empty() ? nullptr : toAttr(instData),
670 laneLayout.empty() ? nullptr : toAttr(laneLayout),
671 laneData.empty() ? nullptr : toAttr(laneData),
672 (!orderAttr || orderAttr.empty()) ? nullptr : toAttr(order));
673 return srcLayout;
674 }
675 return resLayout;
676}
677
678/// Infers the source layout attribute for a shape cast operation given the
679/// result layout attribute, result shape, and source shape.
680xegpu::DistributeLayoutAttr
681xegpu::inferShapeCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
682 ArrayRef<int64_t> resShape,
683 ArrayRef<int64_t> srcShape) {
684
685 // There are three use cases:
686 // 1. expand dims of low-rank dimensions (e.g., 1D to 2D): to set up the
687 // tensor before broadcast
688 // 2. split dim of a high-rank dimension (e.g., 1D to 2D): to setup tensor
689 // for multi-stage reduction
690 // 3. combines all dims to a single dim and put in the innermost dim in 2d as
691 // [1, combinedData] or [combinedData]. Say, [2, 4, 8] -> [1, 64] or [64]
692 // Use cases are only supported after workgroup distribution,
693 // like cross-sg reduction saves multidimension data to
694 // 1D slm buffer, shapecast inserted by cse/canonicalization passes.
695
696 // Use case 1: Shapes only differ by expanding unit dimensions, for broadcast
697 SmallVector<int64_t> expandedUnitDims;
698
699 if (xegpu::matchUnitDimExpansion(srcShape, resShape, expandedUnitDims)) {
700 // create a slice layout for the source by removing the expanded unit dims
701 auto sliceDimsAttr = DenseI64ArrayAttr::get(
702 resLayout.getContext(), ArrayRef<int64_t>(expandedUnitDims));
703 auto srcLayout =
704 xegpu::SliceAttr::get(resLayout.getContext(), resLayout, sliceDimsAttr);
705 return srcLayout;
706 }
707
708 // Use case 2: Dim split from source to result, for multi-stage reduction
709 SmallVector<SmallVector<int64_t>> splitDimGroups;
710 if (xegpu::matchSplitDimExpansion(srcShape, resShape, splitDimGroups)) {
711 auto srcLayout = resLayout;
712 for (const auto &dimGroup : splitDimGroups)
713 srcLayout = srcLayout.collapseDims(dimGroup);
714
715 return srcLayout;
716 }
717
718 // Use case 3: Collaspse to innermost dim, for cross-sg reduction to SLM
719 auto matchCollapseToInnermostDim = [&](ArrayRef<int64_t> src,
720 ArrayRef<int64_t> dst) -> bool {
721 // only one non-unit dim in dst which is the innermost dim
722 if ((dst.size() != 2) && (dst.size() != 1))
723 return false;
724 int64_t srcSize = std::accumulate(src.begin(), src.end(), 1LL,
725 std::multiplies<int64_t>());
726 if (dst.size() == 1)
727 return (dst[0] == srcSize);
728 return (dst[0] == 1) && (dst[1] == srcSize);
729 };
730
731 if (matchCollapseToInnermostDim(srcShape, resShape)) {
732 int srcShapeSize = srcShape.size();
733 int resShapeSize = resShape.size();
734 auto context = resLayout.getContext();
735 auto resInstData = resLayout.getEffectiveInstDataAsInt();
736 auto resLaneLayout = resLayout.getEffectiveLaneLayoutAsInt();
737 auto resLaneData = resLayout.getEffectiveLaneDataAsInt();
738
739 // Extract layout info from result's innermost dimension and apply to
740 // source's innermost dimension while setting all other dimensions to 1.
741 // The inferred layout is restricted by srcShape to ensure it fits within
742 // the source dimensions.
743 // Examples 1:
744 // srcShape=[8, 16, 32], resShape=[1, 4096]
745 // resInstData=[1, 16]
746 // -> inferredInstData=[1, 1, min(16, 32)]=[1, 1, 16]
747 // Examples 2:
748 // srcShape=[4, 8, 64], resShape=[2048]
749 // resLaneLayout=[16], resLaneData=[2]
750 // -> inferredLaneLayout=[1, 1, 16]
751 // -> inferredLaneData=[1, 1, min(2, 64/16)]=[1, 1, 2]
752
753 if (resInstData.size() != 0) {
754 // assert resInstData must be 1 for all but the innermost dim
755 for (int i = 0; i < resShapeSize - 1; i++) {
756 assert(resInstData[i] == 1 &&
757 "only innermost dim can have non-unit instData");
758 }
759 SmallVector<int> inferredInstData(srcShapeSize, 1);
760 inferredInstData[srcShapeSize - 1] =
761 std::min(resInstData[resShapeSize - 1], srcShape[srcShapeSize - 1]);
762 return xegpu::LayoutAttr::get(context, inferredInstData);
763 }
764
765 if (resLaneLayout.size() != 0) {
766 for (int i = 0; i < resShapeSize - 1; i++) {
767 assert(resLaneData[i] == 1 &&
768 "only innermost dim can have non-unit instData");
769 }
770 assert(srcShape.back() % resLaneLayout.back() == 0 &&
771 "source innermost dim must be >= result lane layout");
772 SmallVector<int> inferredLaneLayout(srcShapeSize, 1);
773 SmallVector<int> inferredLaneData(srcShapeSize, 1);
774 inferredLaneLayout.back() = resLaneLayout.back();
775 inferredLaneData.back() = std::min(
776 resLaneData.back(), srcShape.back() / inferredLaneLayout.back());
777 return xegpu::LayoutAttr::get(context, inferredLaneLayout,
778 inferredLaneData);
779 }
780 }
781 llvm_unreachable("running into unsupported shape cast scenarios");
782 return nullptr;
783}
784
785/// Infers the layout attribute for mask and offset operand for Chunked load
786/// and store, given the anchor layout attribute for the value being load/store.
787xegpu::DistributeLayoutAttr xegpu::inferMaskOffsetLayoutForScatterIO(
788 xegpu::DistributeLayoutAttr payloadLayout, int chunkSize) {
789 auto rank = payloadLayout.getRank();
790 if (chunkSize > 1)
791 return payloadLayout.dropDims(
792 llvm::to_vector(llvm::seq<int64_t>(rank - 1, rank)));
793 return payloadLayout;
794}
795
796/// Sets up layout for reduction operations by creating a SliceAttr for the
797/// result.
798///
799/// Algorithm Overview:
800/// This function attempts to construct a source layout that, when sliced along
801/// reduction dimensions, produces a result layout compatible with the
802/// consumer layout.
803///
804/// For subgroup layouts, it first tries to align the source layout's subgroup
805/// layout and data with the consumer's layout on non-reduction dimensions.
806/// Then, it distributes remaining subgroups across reduction dimensions. This
807/// avoids subgroup data redistribution overhead between the reduced result and
808/// its consumer. When the consumer layout is a slice layout, it attempts to
809/// reuse the slice layout's parent layout for the source to further minimize
810/// potential data redistribution.
811///
812/// InstData requries {1, ..., min(maxReduceVectorSize, srcShape),subgroupSize}
813/// Lane Layout requires {1, ..., 1, subgroupSize}
814/// Lane data requires {1, ..., min(maxReduceVectorSize, srcShape), 1}
815///
816/// Examples:
817/// 1. Subgroup layout - Row reduction on 2D tensor:
818/// srcShape=[32, 128], reductionDims=[1], resShape=[32], subgroupSize=16,
819/// NumSg=32
820/// * Consumer Layout:
821/// #xegpu.slice<#xegpu.layout<sg_layout=[4, 8], sg_data=[8, 8]>, dims =
822/// [1]>}
823//// * Result Layout:
824/// #xegpu.slice<#xegpu.layout<sg_layout=[4, 8],sg_data=[8, 16]>, dims =
825/// [1]>}
826/// Note that the sg_layout is reused but sg_data needs to be adjusted to
827/// evenly distribute the source tensor tile among the reduction dim.
828///
829/// 2. Subgroup layout - Same example above but consumer doesn't have a
830/// reusable slice layout.
831/// * Consumer Layout:
832/// #xegpu.layout<sgLayout=[32], sgData=[1]>
833/// * Result Layout:
834/// #xegpu.slice<#xegpu.layout<sgLayout=[32,1], sgData=[1, 64]>, dims =
835/// [1]>}
836/// * Consumer Layout:
837/// #xegpu.slice<#xegpu.layout<sgLayout=[8, 2, 4], sgData=[4, 64, 32]>,
838/// dims = [1, 2]>}
839/// * Result Layout:
840/// #xegpu.slice<#xegpu.layout<sgLayout=[8,4], sgData=[4, 32]>, dims =
841/// [1]>}
842/// Note that the consumer's layout can't be directly reused as is.
843/// So the algorithm distributes all subgroups on non reduction dimensions
844/// first and then distribute remaining subgroups on the reduction
845/// dimension.
846///
847/// 2. InstData layout - Column reduction:
848/// srcShape=[32, 64], reductionDims=[0], subgroupSize=16
849/// Result: instData=[1, 16] (maxReduceVectorSize=1, subgroupSize on
850/// innermost)
851///
852/// 3. Lane layout - Multi-dimensional reduction:
853/// srcShape=[16, 32, 64], reductionDims=[1], subgroupSize=16
854/// Result: laneLayout=[1, 1, 16], laneData=[1, 1, 1]
855/// (subgroupSize on innermost dim, max vector size on reduction dim)
856
858 xegpu::LayoutKind layoutKind, VectorType srcVecTy,
859 DistributeLayoutAttr consumerLayout, SmallVector<int64_t> reductionDims,
860 int numSg, const xegpu::uArch::uArch *uArch) {
861
862 auto srcShape = srcVecTy.getShape();
863 int srcRank = srcShape.size();
864 auto context = srcVecTy.getContext();
865
866 // Helper lambda to convert int64 vectors to int32 DenseArrayAttr
867 auto toInt32Attr = [&](ArrayRef<int64_t> vec) {
868 SmallVector<int32_t> vec32(vec.begin(), vec.end());
869 return DenseI32ArrayAttr::get(context, vec32);
870 };
871
872 const int subgroupSize = uArch->getSubgroupSize();
873 int64_t maxReduceVectorSize = 1; // could extend to spirv vector Size
874 xegpu::DistributeLayoutAttr srcLayout;
875 if (layoutKind == xegpu::LayoutKind::Subgroup) {
876 xegpu::SliceAttr consumerSliceLayout =
877 dyn_cast_if_present<xegpu::SliceAttr>(consumerLayout);
878 if (consumerSliceLayout &&
879 consumerSliceLayout.getDims().asArrayRef().equals(reductionDims)) {
880 srcLayout = consumerSliceLayout.getParent();
881 SmallVector<int64_t> sgLayoutFromConsumer =
882 srcLayout.getEffectiveSgLayoutAsInt();
883 auto srcSgData = computeShapeRatio(srcShape, sgLayoutFromConsumer);
884 if (srcSgData)
885 for (int dim = 0; dim < srcRank; dim++) {
886 if (llvm::is_contained(reductionDims, dim))
887 srcLayout =
888 srcLayout.setDimData(dim, srcSgData.value()[dim], -1, -1);
889 }
890 } else {
891 SmallVector<int64_t> consumerSgLayout =
892 consumerLayout ? consumerLayout.getEffectiveSgLayoutAsInt()
894 SmallVector<int64_t> consumerSgData =
895 consumerLayout ? consumerLayout.getEffectiveSgDataAsInt()
897 SmallVector<int64_t> consumerOrder =
898 consumerLayout ? consumerLayout.getEffectiveOrderAsInt()
900 DenseI32ArrayAttr orderAttr =
901 consumerLayout ? consumerLayout.getOrder() : nullptr;
902 SmallVector<int64_t> sgLayout(srcRank), sgData(srcRank), order(srcRank);
903 int remainingSgCount =
904 consumerLayout ? consumerLayout.getNumSubgroups() : numSg;
905 int consumerIdx = 0;
906
907 // First pass: Match consumer's layout on non-reduction dimensions
908 for (int i = 0; i < srcRank; i++) {
909 if (!llvm::is_contained(reductionDims, i) &&
910 consumerIdx < static_cast<int>(consumerSgLayout.size())) {
911 sgLayout[i] = consumerSgLayout[consumerIdx];
912 sgData[i] = consumerSgData[consumerIdx];
913 remainingSgCount /= sgLayout[i];
914 order[i] = consumerOrder[consumerIdx];
915 consumerIdx++;
916 }
917 }
918
919 // Second pass: Distribute remaining subgroups across reduction dimensions
920 // the reduction to scalar case is handled only by this loop
921 int64_t remainOrder = consumerSgLayout.size();
922 for (int i = 0; i < srcRank; i++) {
923 if (llvm::is_contained(reductionDims, i)) {
924 sgLayout[i] =
925 std::min(srcShape[i], static_cast<int64_t>(remainingSgCount));
926 assert((srcShape[i] % sgLayout[i] == 0) &&
927 "source shape not divisible by sg_layout");
928 sgData[i] = srcShape[i] / sgLayout[i];
929 remainingSgCount /= sgLayout[i];
930 order[i] = remainOrder++;
931 }
932 }
933
934 assert(remainingSgCount == 1 && "not all subgroups distributed");
935 srcLayout = xegpu::LayoutAttr::get(
936 context, toInt32Attr(sgLayout), toInt32Attr(sgData),
937 /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
938 /*lane_data =*/nullptr, /*order =*/
939 (!orderAttr || orderAttr.empty()) ? nullptr : toInt32Attr(order));
940 }
941 } else if (layoutKind == xegpu::LayoutKind::InstData) {
942
943 SmallVector<int64_t> instData(srcRank, 1);
944 if (srcRank >= 2)
945 instData[srcRank - 2] =
946 std::min(maxReduceVectorSize, srcShape[srcRank - 2]);
947 instData[srcRank - 1] =
948 std::min(static_cast<int64_t>(subgroupSize), srcShape[srcRank - 1]);
949 srcLayout = xegpu::LayoutAttr::get(context, toInt32Attr(instData));
950 } else if (layoutKind == xegpu::LayoutKind::Lane) {
951
952 SmallVector<int64_t> laneLayout(srcRank, 1), laneData(srcRank, 1);
953 laneLayout[srcRank - 1] =
954 std::min(static_cast<int64_t>(subgroupSize), srcShape[srcRank - 1]);
955 if (srcRank >= 2)
956 laneData[srcRank - 2] =
957 std::min(maxReduceVectorSize, srcShape[srcRank - 2]);
958 srcLayout = xegpu::LayoutAttr::get(context, toInt32Attr(laneLayout),
959 toInt32Attr(laneData));
960 }
961
962 return xegpu::SliceAttr::get(context, srcLayout,
963 DenseI64ArrayAttr::get(context, reductionDims));
964}
965
966/// Sets up layout for Reduction operations by creating a SliceAttr for the
967/// result.
968xegpu::SliceAttr
970 VectorType srcVecTy,
971 const xegpu::uArch::uArch *uArch) {
972
973 auto srcShape = srcVecTy.getShape();
974 auto context = srcVecTy.getContext();
975 auto subgroupSize = uArch->getSubgroupSize();
976 xegpu::LayoutAttr srcLayout;
977
978 if (layoutKind == xegpu::LayoutKind::Subgroup) {
979 assert(true && "subgroup layout assignment not supported for reduction (op "
980 "is not expected at this level).");
981 } else if (layoutKind == xegpu::LayoutKind::InstData) {
982 assert(true && "instData layout assignment not supported for reduction (op "
983 "is not expected at this level).");
984 } else if (layoutKind == xegpu::LayoutKind::Lane) {
985 SmallVector<int32_t> laneLayout(1), laneData(1);
986 laneLayout[0] = std::min(subgroupSize, static_cast<int32_t>(srcShape[0]));
987 laneData[0] = 1;
988 srcLayout = xegpu::LayoutAttr::get(
989 context, DenseI32ArrayAttr::get(context, laneLayout),
990 DenseI32ArrayAttr::get(context, laneData));
991 }
992
993 auto result = xegpu::SliceAttr::get(context, srcLayout,
994 DenseI64ArrayAttr::get(context, 0));
995 return result;
996}
997
998/// Sets up the result layout for a bitcast operation.
999/// When casting to a smaller bitwidth, adjusts the layout dimensions (sgData,
1000/// instData, or laneData) by multiplying by the bitwidth ratio to ensure the
1001/// result layout can be correctly divided back to the source layout during
1002/// inference.
1003///
1004/// Examples:
1005/// 1. Casting f32 -> f16 (32-bit to 16-bit, bitWidthRatio = 2):
1006/// Consumer layout: instData=[1, 16], subgroupSize=16
1007/// Source shape: [8, 32]
1008/// Result layout: instData=[1, 32] (16 * 2)
1009/// The innermost dimension is multiplied by 2 to maintain consistency.
1010///
1011/// 2. Casting f32 -> i8 (32-bit to 8-bit, bitWidthRatio = 4):
1012/// Consumer instData=[1, 16], subgroupSize=16
1013/// Source shape: [4, 128]
1014/// adjust the instData from [1, 16] to [1, 16 * 4 = 64]
1015///
1016/// 3. Casting i8 -> i32 (8-bit to 32-bit, bitWidthRatio = 1/4):
1017/// Consumer layout: laneLayout=[1, 16], laneData=[1, 4]
1018/// No adjustment needed - returns consumer layout directly.
1019///
1020xegpu::DistributeLayoutAttr xegpu::setupBitCastResultLayout(
1021 xegpu::LayoutKind layoutKind, VectorType srcVecTy, VectorType resVecTy,
1022 DistributeLayoutAttr consumerLayout, const xegpu::uArch::uArch *uArch) {
1023
1024 int srcElemTyBitWidth = srcVecTy.getElementType().getIntOrFloatBitWidth();
1025 int resElemTyBitWidth = resVecTy.getElementType().getIntOrFloatBitWidth();
1026
1027 ArrayRef<int64_t> srcShape = srcVecTy.getShape();
1028 ArrayRef<int64_t> resShape = resVecTy.getShape();
1029 SmallVector<int64_t> sgData = consumerLayout.getEffectiveSgDataAsInt();
1030 SmallVector<int64_t> instData = consumerLayout.getEffectiveInstDataAsInt();
1031 SmallVector<int64_t> laneData = consumerLayout.getEffectiveLaneDataAsInt();
1032 SmallVector<int64_t> laneLayout =
1033 consumerLayout.getEffectiveLaneLayoutAsInt();
1034
1035 assert(consumerLayout.getRank() == static_cast<int64_t>(srcShape.size()) &&
1036 "laneData must be available for all dimensions");
1037 size_t innerMostDim = srcShape.size() - 1;
1038 int64_t sgDataValue = -1;
1039 int64_t instDataValue = -1;
1040 int64_t laneDataValue = -1;
1041 if (srcElemTyBitWidth > resElemTyBitWidth) {
1042 // When casting to a smaller bitwidth, multiply the result layout
1043 // accordingly to ensure it can be divided by the ratio back to the
1044 // source layout.
1045 int bitWidthRatio = srcElemTyBitWidth / resElemTyBitWidth;
1046 if (layoutKind == xegpu::LayoutKind::Subgroup) {
1047 sgDataValue = sgData[innerMostDim];
1048 while ((sgDataValue <= resShape[innerMostDim]) &&
1049 (sgDataValue % bitWidthRatio) != 0)
1050 sgDataValue *= 2;
1051 } else if (layoutKind == xegpu::LayoutKind::InstData) {
1052 instDataValue = instData[innerMostDim];
1053 const int innermostDimLaneLayout = laneLayout.empty()
1054 ? uArch->getSubgroupSize()
1055 : laneLayout[innerMostDim];
1056 // Adjust instDataValue so it still fits within an instruction after
1057 // dividing by bitWidthRatio
1058 while ((instDataValue <= resShape[innerMostDim]) &&
1059 (instDataValue % (innermostDimLaneLayout * bitWidthRatio) != 0))
1060 instDataValue *= 2;
1061 assert((resShape[innerMostDim] % instDataValue) == 0 &&
1062 "resShape, instData, and lanelayout for innermost must be 2^n !");
1063 } else if (layoutKind == xegpu::LayoutKind::Lane) {
1064 laneDataValue = laneData[innerMostDim];
1065 while ((laneDataValue <= resShape[innerMostDim]) &&
1066 (laneDataValue % bitWidthRatio != 0))
1067 laneDataValue *= 2;
1068 }
1069 // Now set only instData and laneData, preserving sgData
1070 xegpu::DistributeLayoutAttr resLayout;
1071 resLayout = consumerLayout.setDimData(innerMostDim, sgDataValue,
1072 instDataValue, laneDataValue);
1073 return resLayout;
1074 }
1075 return consumerLayout;
1076}
1077
1078/// Sets up the result layout for an interleave operation to ensure the source
1079/// layout can be safely derived. Interleave doubles the innermost dimension,
1080/// so the result layout must ensure that laneData is a multiple
1081/// of 2, and instData must be divisible by innermostDimLaneLayout * 2.
1082///
1083/// Example:
1084/// Interleave: vector<128x256xf4> -> vector<128x512xf4>
1085/// Consumer layout: laneLayout=[1, 16], laneData=[1, 4], instData=[1, 64]
1086/// Result layout adjustment to ensure source can be safely inferred:
1087/// - laneData must be >= 2 and multiple of 2 (so source = laneData/2 is
1088/// valid)
1089/// - instData must be divisible by (16 * 2 = 32) (so source = instData/2 is
1090/// valid)
1091/// - Adjusted instData: ensure (instData % 32 == 0)
1092///
1093xegpu::DistributeLayoutAttr xegpu::setupInterleaveResultLayout(
1094 xegpu::LayoutKind layoutKind, VectorType srcVecTy, VectorType resVecTy,
1095 DistributeLayoutAttr consumerLayout, const xegpu::uArch::uArch *uArch) {
1096
1097 ArrayRef<int64_t> srcShape = srcVecTy.getShape();
1098 SmallVector<int64_t> sgData = consumerLayout.getEffectiveSgDataAsInt();
1099 SmallVector<int64_t> instData = consumerLayout.getEffectiveInstDataAsInt();
1100 SmallVector<int64_t> laneData = consumerLayout.getEffectiveLaneDataAsInt();
1101 SmallVector<int64_t> laneLayout =
1102 consumerLayout.getEffectiveLaneLayoutAsInt();
1103
1104 assert(consumerLayout.getRank() == static_cast<int64_t>(srcShape.size()) &&
1105 "consumer layout rank must match source shape rank");
1106 const size_t innerMostDim = srcShape.size() - 1;
1107 int64_t sgDataValue = -1;
1108 int64_t instDataValue = -1;
1109 int64_t laneDataValue = -1;
1110
1111 // Interleave doubles the innermost dimension (ratio = 2)
1112 constexpr int ratio = 2;
1113
1114 if (layoutKind == xegpu::LayoutKind::Subgroup) {
1115 sgDataValue = sgData[innerMostDim];
1116 // Ensure sgDataValue is divisible by ratio so source sgData can be inferred
1117 while ((sgDataValue <= srcShape[innerMostDim]) &&
1118 (sgDataValue % ratio != 0))
1119 sgDataValue *= ratio;
1120 } else if (layoutKind == xegpu::LayoutKind::InstData) {
1121 instDataValue = instData[innerMostDim];
1122 const int innermostDimLaneLayout = laneLayout.empty()
1123 ? uArch->getSubgroupSize()
1124 : laneLayout[innerMostDim];
1125 // Adjust instDataValue so it can be divided by (innermostDimLaneLayout *
1126 // ratio) when inferring the source layout
1127 while ((instDataValue <= srcShape[innerMostDim]) &&
1128 (instDataValue % (innermostDimLaneLayout * ratio) != 0))
1129 instDataValue *= ratio;
1130 assert((srcShape[innerMostDim] % instDataValue) == 0 &&
1131 "srcShape, instData, and laneLayout for innermost must be 2^n!");
1132 } else if (layoutKind == xegpu::LayoutKind::Lane) {
1133 laneDataValue = laneData[innerMostDim];
1134 // Ensure laneDataValue is at least 2 and divisible by ratio
1135 // so that source laneData = laneDataValue/2 is valid
1136 while ((laneDataValue <= srcShape[innerMostDim]) &&
1137 (laneDataValue % ratio != 0))
1138 laneDataValue *= ratio;
1139 }
1140
1141 return consumerLayout.setDimData(innerMostDim, sgDataValue, instDataValue,
1142 laneDataValue);
1143}
1144
1145/// Sets up the result layout for an insert strided slice operation.
1146/// Creates a result layout based on the specified layout kind (InstData or
1147/// Lane).
1148xegpu::DistributeLayoutAttr xegpu::setupInsertStridedSliceResultLayout(
1149 xegpu::LayoutKind layoutKind, VectorType srcVectorTy,
1150 VectorType resVectorTy, xegpu::DistributeLayoutAttr consumerLayout,
1151 const xegpu::uArch::uArch *uArch) {
1152
1153 xegpu::DistributeLayoutAttr requiredResLayout;
1154 SmallVector<int64_t> consumerInstData =
1155 consumerLayout.getEffectiveInstDataAsInt();
1156 SmallVector<int64_t> consumerLaneData =
1157 consumerLayout.getEffectiveLaneDataAsInt();
1158 SmallVector<int64_t> consumerLaneLayout =
1159 consumerLayout.getEffectiveLaneLayoutAsInt();
1160 ArrayRef<int64_t> srcShape = srcVectorTy.getShape();
1161 int64_t instDataValue = -1;
1162 int64_t laneDataValue = -1;
1163
1164 requiredResLayout = consumerLayout;
1165 int srcRank = srcShape.size();
1166
1167 if (layoutKind == xegpu::LayoutKind::Subgroup) {
1168 assert(true &&
1169 "subgroup layout assignment not supported for insertStridedSlice.");
1170 } else if (layoutKind == xegpu::LayoutKind::InstData) {
1171 for (int dim = 0; dim < srcRank; dim++) {
1172 instDataValue = std::min(srcShape[dim], consumerInstData[dim]);
1173 requiredResLayout =
1174 requiredResLayout.setDimData(dim, -1, instDataValue, -1);
1175 }
1176 } else if (layoutKind == xegpu::LayoutKind::Lane) {
1177 for (int dim = 0; dim < srcRank; dim++) {
1178 assert(srcShape[dim] % consumerLaneLayout[dim] == 0 &&
1179 "srcShape must be divisible by laneLayout for all dimensions");
1180 laneDataValue = std::min(srcShape[dim] / consumerLaneLayout[dim],
1181 consumerLaneData[dim]);
1182 requiredResLayout =
1183 requiredResLayout.setDimData(dim, -1, -1, laneDataValue);
1184 }
1185 }
1186 return requiredResLayout;
1187}
1188
1189/// Sets up the anchor layout for load gather and load matrix operation.
1190/// load matrix lowers to load gather and 1d block load. All of them share the
1191/// same layout setup logic.
1192/// For Subgroup layout, uses the consumer layout directly.
1193/// non-chunked loads (1D or 2D):
1194/// InstData = {1, ..., min(consumer, maxLaneLoadSize * subgroupSize)}
1195/// LaneLayout = {1, ..., subgroupSize}
1196/// lane_data = {1, ..., min(consumer, maxLaneLoadSize)}
1197/// chunked loads (2D only):
1198/// InstData = {subgroupSize, min(consumer, maxLaneLoadSize)}
1199/// LaneLayout = {subgroupSize, 1}
1200/// lane_data={1,min(consumer, maxLaneLoadSize)}
1201static xegpu::DistributeLayoutAttr setupGenericLoadAnchorLayout(
1202 xegpu::LayoutKind layoutKind, mlir::MLIRContext *context,
1203 xegpu::DistributeLayoutAttr consumerLayout, bool isChunkedLoad,
1204 int maxChunkSize, ArrayRef<int64_t> resShape, int subgroupSize) {
1205
1206 if (layoutKind == xegpu::LayoutKind::Subgroup)
1207 return consumerLayout;
1208
1209 SmallVector<int64_t> consumerInstData =
1210 consumerLayout.getEffectiveInstDataAsInt();
1211 SmallVector<int64_t> consumerLaneData =
1212 consumerLayout.getEffectiveLaneDataAsInt();
1213
1214 SmallVector<int> instData(resShape.size(), 1);
1215 SmallVector<int> laneLayout(resShape.size(), 1);
1216 SmallVector<int> laneData(resShape.size(), 1);
1217
1218 if (!isChunkedLoad) {
1219 if (layoutKind == xegpu::LayoutKind::InstData) {
1220 instData.back() = std::min(static_cast<int>(consumerInstData.back()),
1221 maxChunkSize * subgroupSize);
1222 return xegpu::LayoutAttr::get(context, instData);
1223 } else if (layoutKind == xegpu::LayoutKind::Lane) {
1224 laneData.back() =
1225 std::min(static_cast<int>(consumerLaneData.back()), maxChunkSize);
1226 laneLayout.back() = std::min(static_cast<int64_t>(subgroupSize),
1227 resShape.back() / laneData.back());
1228 return xegpu::LayoutAttr::get(context, laneLayout, laneData);
1229 }
1230 } else {
1231 assert(resShape.size() == 2 && "Chunked Store must access 2D tensor tile.");
1232 if (layoutKind == xegpu::LayoutKind::InstData) {
1233 instData[0] = subgroupSize;
1234 instData[1] =
1235 std::min(static_cast<int>(consumerInstData[1]), maxChunkSize);
1236 return xegpu::LayoutAttr::get(context, instData);
1237 } else if (layoutKind == xegpu::LayoutKind::Lane) {
1238 laneLayout[0] = subgroupSize;
1239 laneData[1] =
1240 std::min(static_cast<int>(consumerLaneData[1]), maxChunkSize);
1241 return xegpu::LayoutAttr::get(context, laneLayout, laneData);
1242 }
1243 }
1244 return nullptr;
1245}
1246
1247/// Sets up the anchor layout for a load gather operation.
1248xegpu::DistributeLayoutAttr xegpu::setupLoadGatherAnchorLayout(
1249 xegpu::LayoutKind layoutKind, VectorType resVecTy, int chunkSize,
1250 xegpu::DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch) {
1251
1252 const int subgroupSize = uArch->getSubgroupSize();
1253 ArrayRef<int64_t> resShape = resVecTy.getShape();
1254 auto context = resVecTy.getContext();
1255 auto elemBitWidth = resVecTy.getElementType().getIntOrFloatBitWidth();
1256
1257 const auto *uArchInstruction =
1258 dyn_cast<xegpu::uArch::LoadGatherInstructionInterface>(
1260 int maxChunkSize = uArchInstruction->getMaxLaneLoadSize(elemBitWidth);
1261
1262 return setupGenericLoadAnchorLayout(layoutKind, context, consumerLayout,
1263 (chunkSize > 1), maxChunkSize, resShape,
1264 subgroupSize);
1265}
1266
1267/// Sets up the anchor layout for load matrix operation.
1268/// TODO: enhance load matrix to indicate lowering to chunked load or not.
1269xegpu::DistributeLayoutAttr
1271 VectorType resVecTy,
1272 xegpu::DistributeLayoutAttr consumerLayout,
1273 const xegpu::uArch::uArch *uArch) {
1274
1275 const int subgroupSize = uArch->getSubgroupSize();
1276 ArrayRef<int64_t> resShape = resVecTy.getShape();
1277 auto context = resVecTy.getContext();
1278 auto elemBitWidth = resVecTy.getElementType().getIntOrFloatBitWidth();
1279
1280 const auto *uArchInstruction =
1281 dyn_cast<xegpu::uArch::LoadGatherInstructionInterface>(
1283 int maxChunkSize = uArchInstruction->getMaxLaneLoadSize(elemBitWidth);
1284 return setupGenericLoadAnchorLayout(layoutKind, context, consumerLayout,
1285 false, maxChunkSize, resShape,
1286 subgroupSize);
1287}
1288
1289/// Sets up the anchor layout for store scatter and store matrix operation.
1290/// store matrix lowers to store scatter and 1d block store. All of them share
1291/// the same layout setup logic. For Subgroup layout, not supported yet.
1292/// non-chunked stores (1D or 2D):
1293/// InstData = {1, ..., subgroupSize}
1294/// LaneLayout = {1, ..., subgroupSize}
1295/// lane_data = {1, ..., 1}
1296/// chunked stores (2D only):
1297/// InstData = {subgroupSize, min(srcVec, maxLaneStoreSize)}
1298/// LaneLayout = {subgroupSize, 1}
1299/// lane_data={1,min(srcVec, maxLaneStoreSize)}
1300static xegpu::DistributeLayoutAttr
1302 mlir::MLIRContext *context, bool isChunkedStore,
1303 int maxChunkSize, ArrayRef<int64_t> srcShape,
1304 int subgroupSize) {
1305
1306 int srcShapeSize = srcShape.size();
1307 SmallVector<int> instData(srcShapeSize, 1);
1308 SmallVector<int> laneLayout(srcShapeSize, 1);
1309 SmallVector<int> laneData(srcShapeSize, 1);
1310
1311 if (layoutKind == xegpu::LayoutKind::Subgroup) {
1312 assert(true &&
1313 "subgroup layout assignment not supported for storeScatter.");
1314 return nullptr;
1315 }
1316
1317 if (!isChunkedStore) {
1318 if (layoutKind == xegpu::LayoutKind::InstData) {
1319 instData[srcShapeSize - 1] =
1320 std::min(subgroupSize, static_cast<int>(srcShape.back()));
1321 return xegpu::LayoutAttr::get(context, instData);
1322 } else if (layoutKind == xegpu::LayoutKind::Lane) {
1323 laneLayout[srcShapeSize - 1] =
1324 std::min(subgroupSize, static_cast<int>(srcShape.back()));
1325 return xegpu::LayoutAttr::get(context, laneLayout, laneData);
1326 }
1327 } else {
1328 assert(srcShapeSize == 2 && "Chunked Store must access 2D tensor tile.");
1329 if (layoutKind == xegpu::LayoutKind::InstData) {
1330 instData[0] = subgroupSize;
1331 instData[1] = std::min(static_cast<int>(srcShape[1]), maxChunkSize);
1332 return xegpu::LayoutAttr::get(context, instData);
1333 } else if (layoutKind == xegpu::LayoutKind::Lane) {
1334 laneLayout[0] = subgroupSize;
1335 laneData[1] = std::min(static_cast<int>(srcShape[1]), maxChunkSize);
1336 return xegpu::LayoutAttr::get(context, laneLayout, laneData);
1337 }
1338 }
1339 return nullptr;
1340}
1341
1342/// Sets up the anchor layout for a store scatter operation.
1343xegpu::DistributeLayoutAttr
1345 VectorType srcVecTy, int chunkSize,
1346 const uArch::uArch *uArch) {
1347
1348 const int subgroupSize = uArch->getSubgroupSize();
1349 ArrayRef<int64_t> srcShape = srcVecTy.getShape();
1350 auto context = srcVecTy.getContext();
1351 auto elemBitWidth = srcVecTy.getElementType().getIntOrFloatBitWidth();
1352
1353 const auto *uArchInstruction =
1354 dyn_cast<xegpu::uArch::StoreScatterInstructionInterface>(
1356 int maxChunkSize = uArchInstruction->getMaxLaneStoreSize(elemBitWidth);
1357 return setupGenericStoreAnchorLayout(layoutKind, context, (chunkSize > 1),
1358 maxChunkSize, srcShape, subgroupSize);
1359}
1360
1361/// Sets up the anchor layout for a store matrix operation.
1362xegpu::DistributeLayoutAttr
1364 VectorType srcVecTy,
1365 const xegpu::uArch::uArch *uArch) {
1366
1367 const int subgroupSize = uArch->getSubgroupSize();
1368 ArrayRef<int64_t> srcShape = srcVecTy.getShape();
1369 auto context = srcVecTy.getContext();
1370 auto elemBitWidth = srcVecTy.getElementType().getIntOrFloatBitWidth();
1371
1372 const auto *uArchInstruction =
1373 dyn_cast<xegpu::uArch::StoreScatterInstructionInterface>(
1375 int maxChunkSize = uArchInstruction->getMaxLaneStoreSize(elemBitWidth);
1376
1377 return setupGenericStoreAnchorLayout(layoutKind, context, false, maxChunkSize,
1378 srcShape, subgroupSize);
1379}
1380
1381// This function returns the default lane layout for a given vector type.
1382// - `packingSize` means multiple consecutive elements can be accessed
1383// together as a single unit.
1384// - `vnni` means data packing is column-wise (i.e., 2x1xf16 with vnni vs.
1385// 1x2xf16 w/o vnni).
1386template <typename RankedTy>
1387static xegpu::LayoutAttr getDefaultLaneLayout2DBlockIo(
1388 RankedTy ty, const xegpu::uArch::uArch *uArch,
1389 std::optional<unsigned> packingSize = std::nullopt, bool vnni = false) {
1390 // Expecting a 1D or 2D vector.
1391 assert(((ty.getRank() == 1 && !vnni) || ty.getRank() == 2) &&
1392 "Expected 1D non-vnni or 2D vector.");
1393 // Expecting int or float element type.
1394 assert(ty.getElementType().isIntOrFloat() &&
1395 "Expected int or float element type.");
1396
1397 auto context = ty.getContext();
1398 auto rank = ty.getRank();
1399 SmallVector<int> laneLayout(rank, 1);
1400 SmallVector<int> laneData(rank, 1);
1401 if (packingSize.has_value()) {
1402 unsigned bitwidth = ty.getElementType().getIntOrFloatBitWidth();
1403 int &laneDataPos = vnni ? laneData[rank - 2] : laneData.back();
1404 laneDataPos = bitwidth < *packingSize ? *packingSize / bitwidth : 1;
1405 }
1406 laneLayout.back() = uArch->getSubgroupSize();
1407 return xegpu::LayoutAttr::get(context, laneLayout, laneData);
1408}
1409
1410// This function returns all layouts for the given sgCount, whose sgData:
1411// 1. Evenly divides the wgShape.
1412// 2. Is a multiple of instData.
1413// Example:
1414// wgShape = [128, 64], instData = [8, 16], sgCount = 32
1415// Returns layouts:
1416// [(8,4), (16,2)], which correspond to sgData [16,16] and [8,32].
1417using LayoutRepresentation = std::pair<int64_t, int64_t>;
1420 int64_t sgCount) {
1422 for (int sgLayout0 = 1; sgLayout0 <= sgCount; ++sgLayout0) {
1423 if (sgCount % sgLayout0)
1424 continue;
1425 int64_t sgLayout1 = sgCount / sgLayout0;
1426 int64_t sgData0 = wgShape[0] / sgLayout0;
1427 int64_t sgData1 = wgShape[1] / sgLayout1;
1428 if ((wgShape[0] % sgLayout0 || wgShape[1] % sgLayout1) ||
1429 (sgData0 % instData[0] || sgData1 % instData[1]))
1430 continue;
1431 candidates.emplace_back(sgLayout0, sgLayout1);
1432 }
1433 // Sort primarily by how balanced they are
1434 // (i.e., minimize the absolute difference between the two dimensions), and
1435 // secondarily by the first dimension in ascending order.
1436 llvm::sort(candidates, [](const LayoutRepresentation &lhs,
1437 const LayoutRepresentation &rhs) {
1438 int diffLhs = std::abs(lhs.first - lhs.second);
1439 int diffRhs = std::abs(rhs.first - rhs.second);
1440 if (diffLhs != diffRhs)
1441 return diffLhs < diffRhs;
1442 return lhs.first < rhs.first;
1443 });
1444 return candidates;
1445}
1446
1447/// Helper function to compute inst_data vectors for DPAS operands A, B, and
1448/// C/D.
1449static std::optional<std::tuple<SmallVector<int64_t>, SmallVector<int64_t>,
1451getDpasInstDataVectors(VectorType aTy, VectorType bTy, VectorType cdTy,
1453 bool isDpasMx = false) {
1454 const int subgroupSize = uArch->getSubgroupSize();
1455
1456 const xegpu::uArch::MMAInstructionInterface *uArchInstruction;
1457 if (isDpasMx)
1458 uArchInstruction = dyn_cast<xegpu::uArch::SubgroupScaledMatrixMultiplyAcc>(
1461 else
1462 uArchInstruction =
1463 dyn_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc>(uArch->getInstruction(
1465
1466 const unsigned dataALen = aTy.getShape().front();
1467 auto supportedALen = uArchInstruction->getSupportedM(aTy.getElementType());
1468 const int maxALen =
1469 xegpu::getLargestDivisor(dataALen, ArrayRef<unsigned>(supportedALen));
1470
1471 const unsigned dataBLen = bTy.getShape().back();
1472 auto supportedBLen = uArchInstruction->getSupportedN(bTy.getElementType());
1473 const int maxBLen =
1474 xegpu::getLargestDivisor(dataBLen, ArrayRef<unsigned>(supportedBLen));
1475
1476 auto supportedCLen = uArchInstruction->getSupportedN(cdTy.getElementType());
1477 const int maxCLen =
1478 xegpu::getLargestDivisor(dataBLen, ArrayRef<unsigned>(supportedCLen));
1479 if (maxALen == -1 || maxBLen == -1 || maxCLen == -1)
1480 return std::nullopt;
1481
1482 // For DPAS_MX, use getSupportedK to get the scaled K dimension.
1483 // assume single element in the returned vector.
1484 int kDimSize = subgroupSize;
1485 if (isDpasMx) {
1486 auto supportedKLen = uArchInstruction->getSupportedK(aTy.getElementType());
1487 if (supportedKLen.empty())
1488 return std::nullopt;
1489 kDimSize = supportedKLen[0];
1490 }
1491
1492 SmallVector<int64_t> instDataA(aTy.getRank(), 1);
1493 instDataA[aTy.getRank() - 2] = maxALen;
1494 instDataA[aTy.getRank() - 1] = kDimSize;
1495 SmallVector<int64_t> instDataB(bTy.getRank(), 1);
1496 instDataB[bTy.getRank() - 2] = kDimSize;
1497 instDataB[bTy.getRank() - 1] = maxBLen;
1498 SmallVector<int64_t> instDataCD(cdTy.getRank(), 1);
1499 instDataCD[cdTy.getRank() - 2] = maxALen;
1500 instDataCD[cdTy.getRank() - 1] = maxCLen;
1501 return std::make_tuple(instDataA, instDataB, instDataCD);
1502}
1503
1504/// Helper function to set up subgroup layouts for DPAS operands A, B, and C/D.
1505/// Returns the three layouts if successful, nullopt otherwise.
1506static std::optional<
1507 std::tuple<xegpu::DistributeLayoutAttr, xegpu::DistributeLayoutAttr,
1508 xegpu::DistributeLayoutAttr>>
1510 VectorType bTy, VectorType cdTy,
1511 xegpu::DistributeLayoutAttr consumerLayout, int numSg,
1512 const xegpu::uArch::uArch *uArch) {
1513 auto instDataVecs = getDpasInstDataVectors(aTy, bTy, cdTy, uArch);
1514 if (!instDataVecs)
1515 return std::nullopt;
1516 auto [instDataA, instDataB, instDataCD] = *instDataVecs;
1517 assert(instDataA.size() == 2 && instDataB.size() == 2 &&
1518 instDataCD.size() == 2 &&
1519 "Sg layout creation expects valid 2D inst data");
1520
1521 std::optional<LayoutRepresentation> consumerSgLayout = std::nullopt;
1522 if (consumerLayout && consumerLayout.isForWorkgroup()) {
1523 SmallVector<int64_t> sgLayoutD = consumerLayout.getEffectiveSgLayoutAsInt();
1524 consumerSgLayout = std::make_pair(sgLayoutD[0], sgLayoutD[1]);
1525 }
1526
1527 // Get all valid layouts for A, B and C/D operands
1528 auto layoutsA = getValidLayouts(aTy.getShape(), instDataA, numSg);
1529 auto layoutsB = getValidLayouts(bTy.getShape(), instDataB, numSg);
1530 auto layoutsCD = getValidLayouts(cdTy.getShape(), instDataCD, numSg);
1531 if (layoutsA.empty() || layoutsB.empty() || layoutsCD.empty())
1532 return std::nullopt;
1533
1534 // Pick the best subgroup layout
1535 llvm::DenseSet<LayoutRepresentation> setA(layoutsA.begin(), layoutsA.end());
1536 llvm::DenseSet<LayoutRepresentation> setCD(layoutsCD.begin(),
1537 layoutsCD.end());
1538 std::optional<LayoutRepresentation> bestPick;
1539 auto checkAlignedSgDataAB = [&](LayoutRepresentation sgLayout) {
1540 return aTy.getShape().back() / sgLayout.second ==
1541 bTy.getShape().front() / sgLayout.first;
1542 };
1543 for (auto &sgLayout : layoutsB) {
1544 if (setA.contains(sgLayout) && setCD.contains(sgLayout)) {
1545 if (!checkAlignedSgDataAB(sgLayout))
1546 continue;
1547 // Is in (A and B and CD) and matches consumer -> best pick
1548 if (consumerSgLayout.has_value() && sgLayout == *consumerSgLayout) {
1549 bestPick = sgLayout;
1550 break;
1551 }
1552 // Is in (A and B and CD) layoutsB is ordered from most
1553 // balanced to least. So the first one we see is the most balanced one,
1554 // remember it and later only update if there is one that matches the
1555 // consumer.
1556 if (!bestPick)
1557 bestPick = sgLayout;
1558 }
1559 }
1560 if (!bestPick)
1561 return std::nullopt;
1562
1563 SmallVector<int> sgLayout = {static_cast<int>(bestPick->first),
1564 static_cast<int>(bestPick->second)};
1565 SmallVector<int> sgDataA = {static_cast<int>(aTy.getShape()[0] / sgLayout[0]),
1566 static_cast<int>(aTy.getShape()[1])};
1567 SmallVector<int> sgDataB = {
1568 static_cast<int>(bTy.getShape()[0]),
1569 static_cast<int>(bTy.getShape()[1] / sgLayout[1])};
1570 SmallVector<int> sgDataCD = {
1571 static_cast<int>(cdTy.getShape()[0] / sgLayout[0]),
1572 static_cast<int>(cdTy.getShape()[1] / sgLayout[1])};
1573
1574 auto dpasALayout =
1575 xegpu::LayoutAttr::get(context, DenseI32ArrayAttr::get(context, sgLayout),
1576 DenseI32ArrayAttr::get(context, sgDataA), nullptr,
1577 nullptr, nullptr, nullptr);
1578 auto dpasBLayout =
1579 xegpu::LayoutAttr::get(context, DenseI32ArrayAttr::get(context, sgLayout),
1580 DenseI32ArrayAttr::get(context, sgDataB), nullptr,
1581 nullptr, nullptr, nullptr);
1582 auto dpasCDLayout =
1583 xegpu::LayoutAttr::get(context, DenseI32ArrayAttr::get(context, sgLayout),
1584 DenseI32ArrayAttr::get(context, sgDataCD), nullptr,
1585 nullptr, nullptr, nullptr);
1586
1587 return std::make_tuple(dpasALayout, dpasBLayout, dpasCDLayout);
1588}
1589
1590/// Sets up the anchor layouts for dpas operands (A, B, and C/D).
1591/// The numSg and consumerLayout (optional) are only used by sg layout
1592/// creation.
1593std::optional<
1594 std::tuple<xegpu::DistributeLayoutAttr, xegpu::DistributeLayoutAttr,
1595 xegpu::DistributeLayoutAttr>>
1596xegpu::setupDpasLayout(xegpu::LayoutKind layoutKind, VectorType aTy,
1597 VectorType bTy, VectorType cdTy,
1598 xegpu::DistributeLayoutAttr consumerLayout, int numSg,
1599 const xegpu::uArch::uArch *uArch) {
1600 auto context = aTy.getContext();
1601 const auto *uArchInstruction =
1602 dyn_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc>(uArch->getInstruction(
1604
1605 if (layoutKind == xegpu::LayoutKind::Subgroup) {
1606 assert(numSg > 0 &&
1607 "Number of subgroups must be provided for sg layout creation.");
1608 return getupDpasSubgroupLayouts(context, aTy, bTy, cdTy, consumerLayout,
1609 numSg, uArch);
1610 } else if (layoutKind == xegpu::LayoutKind::InstData) {
1611 auto instDataVecs = getDpasInstDataVectors(aTy, bTy, cdTy, uArch);
1612 if (!instDataVecs)
1613 return std::nullopt;
1614 auto [instDataA, instDataB, instDataCD] = *instDataVecs;
1615 return std::make_tuple(
1616 xegpu::LayoutAttr::get(
1617 context, SmallVector<int>(instDataA.begin(), instDataA.end())),
1618 xegpu::LayoutAttr::get(
1619 context, SmallVector<int>(instDataB.begin(), instDataB.end())),
1620 xegpu::LayoutAttr::get(
1621 context, SmallVector<int>(instDataCD.begin(), instDataCD.end())));
1622 } else if (layoutKind == xegpu::LayoutKind::Lane) {
1623 auto aLayout = getDefaultLaneLayout2DBlockIo(
1624 aTy, uArch, uArchInstruction->getPackedFormatBitSizeA());
1625 auto bLayout = getDefaultLaneLayout2DBlockIo(
1626 bTy, uArch, uArchInstruction->getPackedFormatBitSizeB(), true);
1627 auto cdLayout = getDefaultLaneLayout2DBlockIo(
1628 cdTy, uArch /*, packingSize = std::nullopt */);
1629 return std::make_tuple(aLayout, bLayout, cdLayout);
1630 }
1631 return std::nullopt;
1632}
1633
1634/// Helper to create a scale layout derived from a matrix operand layout.
1635/// The scale layout is computed by mapping each dimension of the matrix layout
1636/// to the corresponding scale tensor dimension using the ratio between the
1637/// matrix and scale shapes.
1638static xegpu::DistributeLayoutAttr
1639createScaleLayout(mlir::MLIRContext *context, VectorType matrixTy,
1640 VectorType scaleTy, xegpu::DistributeLayoutAttr matrixLayout,
1641 bool isBScale, const xegpu::uArch::uArch *uArch) {
1642 if (!scaleTy || !matrixLayout)
1643 return nullptr;
1644
1645 // Calculate scaling factor by dividing matrix shape by scale shape
1646 ArrayRef<int64_t> matrixShape = matrixTy.getShape();
1647 ArrayRef<int64_t> scaleShape = scaleTy.getShape();
1648
1649 // Scale shapes can be 1D or 2D, handle both cases
1650 if (scaleShape.empty())
1651 return nullptr;
1652
1653 auto uArchInstruction =
1654 dyn_cast<xegpu::uArch::SubgroupScaledMatrixMultiplyAcc>(
1657
1658 int64_t rank = matrixLayout.getRank();
1659 assert(rank == 2 && "dpas layouts must be two dimensions");
1660
1661 SmallVector<int64_t> sgLayout = matrixLayout.getEffectiveSgLayoutAsInt();
1662 SmallVector<int64_t> sgData = matrixLayout.getEffectiveSgDataAsInt();
1663 SmallVector<int64_t> instData = matrixLayout.getEffectiveInstDataAsInt();
1664 SmallVector<int64_t> laneLayout = matrixLayout.getEffectiveLaneLayoutAsInt();
1665 SmallVector<int64_t> laneData = matrixLayout.getEffectiveLaneDataAsInt();
1666 auto order = matrixLayout.getOrder();
1667
1668 SmallVector<int> scaleSgLayout;
1669 SmallVector<int> scaleSgData;
1670 if (!sgLayout.empty() && !sgData.empty()) {
1671 scaleSgLayout.assign(sgLayout.begin(), sgLayout.end());
1672 scaleSgData.assign(sgData.begin(), sgData.end());
1673 scaleSgData[rank - 2] = std::max<int64_t>(
1674 scaleShape[rank - 2] / (matrixShape[rank - 2] / sgData[rank - 2]), 1);
1675 scaleSgData[rank - 1] = std::max<int64_t>(
1676 scaleShape[rank - 1] / (matrixShape[rank - 1] / sgData[rank - 1]), 1);
1677 }
1678
1679 // For DPAS_MX scales: if matrix has inst_data, scale needs adjusted
1680 // inst_data. Scale inst_data is derived from matrix inst_data divided by
1681 // scale factor.
1682 SmallVector<int> scaleInstData;
1683 if (!instData.empty()) {
1684 scaleInstData.assign(instData.begin(), instData.end());
1685 if (isBScale)
1686 scaleInstData[rank - 2] = std::max<int64_t>(
1687 scaleShape[rank - 2] / (matrixShape[rank - 2] / instData[rank - 2]),
1688 1);
1689 else
1690 scaleInstData[rank - 1] = std::max<int64_t>(
1691 scaleShape[rank - 1] / (matrixShape[rank - 1] / instData[rank - 1]),
1692 1);
1693 }
1694
1695 SmallVector<int> scaleLaneLayout;
1696 SmallVector<int> scaleLaneData;
1697 if (!laneLayout.empty() && !laneData.empty()) {
1698 scaleLaneLayout.assign(laneLayout.begin(), laneLayout.end());
1699 scaleLaneData.assign(laneData.begin(), laneData.end());
1700 bool isRowMajor = uArchInstruction->isLaneLayoutRowMajorOrder();
1701 if (isBScale ^ isRowMajor) {
1702 std::swap(scaleLaneLayout[rank - 2], scaleLaneLayout[rank - 1]);
1703 scaleLaneLayout[rank - 2] =
1704 std::min<int64_t>(scaleShape[rank - 2], scaleLaneLayout[rank - 2]);
1705 }
1706 scaleLaneData[rank - 2] =
1707 std::max<int64_t>(scaleShape[rank - 2] / scaleLaneLayout[rank - 2], 1);
1708 scaleLaneData[rank - 1] =
1709 std::max<int64_t>(scaleShape[rank - 1] / scaleLaneLayout[rank - 1], 1);
1710 }
1711 return xegpu::LayoutAttr::get(
1712 context,
1713 scaleSgLayout.empty() ? nullptr
1714 : DenseI32ArrayAttr::get(context, scaleSgLayout),
1715 scaleSgData.empty() ? nullptr
1716 : DenseI32ArrayAttr::get(context, scaleSgData),
1717 scaleInstData.empty() ? nullptr
1718 : DenseI32ArrayAttr::get(context, scaleInstData),
1719 scaleLaneLayout.empty()
1720 ? nullptr
1721 : DenseI32ArrayAttr::get(context, scaleLaneLayout),
1722 scaleLaneData.empty() ? nullptr
1723 : DenseI32ArrayAttr::get(context, scaleLaneData),
1724 order);
1725}
1726
1727/// Sets up the anchor layouts for dpas_mx operands (A, B, C/D, A_scale, and
1728/// B_scale). The numSg and consumerLayout (optional) are only used by sg layout
1729/// creation.
1730std::optional<
1731 std::tuple<xegpu::DistributeLayoutAttr, xegpu::DistributeLayoutAttr,
1732 xegpu::DistributeLayoutAttr, xegpu::DistributeLayoutAttr,
1733 xegpu::DistributeLayoutAttr>>
1734xegpu::setupDpasMxLayout(xegpu::LayoutKind layoutKind, VectorType aTy,
1735 VectorType bTy, VectorType cdTy, VectorType aScaleTy,
1736 VectorType bScaleTy,
1737 xegpu::DistributeLayoutAttr consumerLayout, int numSg,
1738 const xegpu::uArch::uArch *uArch) {
1739 auto context = aTy.getContext();
1740
1741 if (layoutKind == xegpu::LayoutKind::Subgroup) {
1742 assert(numSg > 0 &&
1743 "Number of subgroups must be provided for sg layout creation.");
1744 auto dpasLayouts = getupDpasSubgroupLayouts(context, aTy, bTy, cdTy,
1745 consumerLayout, numSg, uArch);
1746 if (!dpasLayouts)
1747 return std::nullopt;
1748
1749 auto [dpasALayout, dpasBLayout, dpasCDLayout] = *dpasLayouts;
1750
1751 // Create scale layouts
1752 auto aScaleLayout =
1753 createScaleLayout(context, aTy, aScaleTy, dpasALayout, false, uArch);
1754
1755 auto bScaleLayout =
1756 createScaleLayout(context, bTy, bScaleTy, dpasBLayout, true, uArch);
1757
1758 return std::make_tuple(dpasALayout, dpasBLayout, dpasCDLayout, aScaleLayout,
1759 bScaleLayout);
1760 } else if (layoutKind == xegpu::LayoutKind::InstData) {
1761 auto instDataVecs =
1762 getDpasInstDataVectors(aTy, bTy, cdTy, uArch, /*isDpasMx=*/true);
1763 if (!instDataVecs)
1764 return std::nullopt;
1765 auto [instDataA, instDataB, instDataCD] = *instDataVecs;
1766
1767 auto dpasALayout = xegpu::LayoutAttr::get(
1768 context, SmallVector<int>(instDataA.begin(), instDataA.end()));
1769 auto dpasBLayout = xegpu::LayoutAttr::get(
1770 context, SmallVector<int>(instDataB.begin(), instDataB.end()));
1771 auto dpasCDLayout = xegpu::LayoutAttr::get(
1772 context, SmallVector<int>(instDataCD.begin(), instDataCD.end()));
1773
1774 // Create scale layouts
1775 auto aScaleLayout =
1776 createScaleLayout(context, aTy, aScaleTy, dpasALayout, false, uArch);
1777 auto bScaleLayout =
1778 createScaleLayout(context, bTy, bScaleTy, dpasBLayout, true, uArch);
1779
1780 return std::make_tuple(dpasALayout, dpasBLayout, dpasCDLayout, aScaleLayout,
1781 bScaleLayout);
1782 } else if (layoutKind == xegpu::LayoutKind::Lane) {
1783 const auto *uArchInstruction =
1784 dyn_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc>(uArch->getInstruction(
1786 auto aLayout = getDefaultLaneLayout2DBlockIo(
1787 aTy, uArch, uArchInstruction->getPackedFormatBitSizeA());
1788 auto bLayout = getDefaultLaneLayout2DBlockIo(
1789 bTy, uArch, uArchInstruction->getPackedFormatBitSizeB(), true);
1790 auto cdLayout = getDefaultLaneLayout2DBlockIo(cdTy, uArch);
1791
1792 // Create scale layouts
1793 auto aScaleLayout =
1794 createScaleLayout(context, aTy, aScaleTy, aLayout, false, uArch);
1795 auto bScaleLayout =
1796 createScaleLayout(context, bTy, bScaleTy, bLayout, true, uArch);
1797
1798 return std::make_tuple(aLayout, bLayout, cdLayout, aScaleLayout,
1799 bScaleLayout);
1800 }
1801 return std::nullopt;
1802}
1803
1804xegpu::DistributeLayoutAttr xegpu::inferSourceLayoutFromResultForNonAnchorOp(
1805 OpOperand &operand, xegpu::DistributeLayoutAttr resLayout) {
1806 if (!resLayout)
1807 return nullptr;
1808 Operation *op = operand.getOwner();
1809 unsigned idx = operand.getOperandNumber();
1810
1811 // For vector::BroadcastOp, infer the source layout from the result layout.
1812 if (auto broadcast = dyn_cast<vector::BroadcastOp>(op)) {
1813 auto srcTy = dyn_cast<VectorType>(broadcast.getSourceType());
1814 if (!srcTy)
1815 return nullptr;
1817 resLayout, broadcast.getResultVectorType().getShape(),
1818 srcTy.getShape());
1819 }
1820
1821 // For vector::MultiDimReductionOp, infer source layout from result layout
1822 // using reduction dims. Acc operand is expected to have the same layout as
1823 // the result.
1824 if (auto reduction = dyn_cast<vector::MultiDimReductionOp>(op)) {
1825 if (idx == 0) {
1826 SmallVector<int64_t> reductionDims(reduction.getReductionDims());
1827 return xegpu::inferMultiReductionSourceLayout(resLayout, reductionDims);
1828 }
1829 if (idx == 1)
1830 return resLayout;
1831 }
1832
1833 if (auto reduction = dyn_cast<vector::ReductionOp>(op))
1834 return xegpu::inferReductionSourceLayout(resLayout);
1835
1836 // For vector::BitCastOp, infer source layout from result layout using
1837 // element type bitwidths.
1838 if (auto bitcast = dyn_cast<vector::BitCastOp>(op)) {
1839 int resElemBitWidth =
1840 bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
1841 int srcElemBitWidth =
1842 bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
1843 return xegpu::inferBitCastSourceLayout(resLayout, resElemBitWidth,
1844 srcElemBitWidth);
1845 }
1846
1847 // For vector::ShapeCastOp, infer source layout from result layout using
1848 // shapes.
1849 if (auto shapeCast = dyn_cast<vector::ShapeCastOp>(op)) {
1851 resLayout, shapeCast.getResultVectorType().getShape(),
1852 shapeCast.getSourceVectorType().getShape());
1853 }
1854
1855 // For vector::InsertStridedSliceOp, infer source layout from result layout.
1856 // Dest vector must have the same layout as the result.
1857 if (auto insertSlice = dyn_cast<vector::InsertStridedSliceOp>(op)) {
1858 if (idx == 0) {
1860 resLayout, insertSlice.getDestVectorType().getShape(),
1861 insertSlice.getSourceVectorType().getShape());
1862 }
1863 if (idx == 1)
1864 return resLayout;
1865 }
1866
1867 // For vector::Insert Op, infer source layout from result layout using
1868 // shapes.
1869 if (auto insert = dyn_cast<vector::InsertOp>(op)) {
1870 VectorType resVecTy = dyn_cast<VectorType>(insert.getResult().getType());
1871 VectorType valueToStoreTy =
1872 dyn_cast<VectorType>(insert.getValueToStore().getType());
1873
1874 if ((idx == 0) && valueToStoreTy) {
1875 return xegpu::inferInsertSourceLayout(resLayout, resVecTy.getShape(),
1876 valueToStoreTy.getShape());
1877 }
1878 if (idx == 1)
1879 return resLayout;
1880 }
1881
1882 // For vector::Extract Op, infer source layout from result layout using
1883 // shapes.
1884 if (auto extract = dyn_cast<vector::ExtractOp>(op)) {
1885 VectorType srcVecTy = dyn_cast<VectorType>(extract.getSource().getType());
1886 VectorType resVecTy = dyn_cast<VectorType>(extract.getResult().getType());
1887 if (!srcVecTy || !resVecTy)
1888 return nullptr;
1889 return xegpu::inferExtractSourceLayout(resLayout, resVecTy.getShape(),
1890 srcVecTy.getShape());
1891 }
1892
1893 // For vector::TransposeOp, infer source layout from result layout using
1894 // permutation.
1895 if (auto transpose = dyn_cast<vector::TransposeOp>(op)) {
1896 return xegpu::inferTransposeSourceLayout(resLayout,
1897 transpose.getPermutation());
1898 }
1899
1900 // For vector::BitCastOp, infer source layout from result layout using
1901 // element type bitwidths.
1902 if (auto bitcast = dyn_cast<vector::BitCastOp>(op)) {
1903 int resElemBitWidth =
1904 bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
1905 int srcElemBitWidth =
1906 bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
1907 return xegpu::inferBitCastSourceLayout(resLayout, resElemBitWidth,
1908 srcElemBitWidth);
1909 }
1910
1911 // for vector::interleave
1912 if (auto interleave = dyn_cast<vector::InterleaveOp>(op)) {
1913 return xegpu::inferInterleaveSourceLayout(resLayout);
1914 }
1915
1916 // for vector::deinterleave
1917 if (auto deinterleave = dyn_cast<vector::DeinterleaveOp>(op)) {
1918 return xegpu::inferDeinterleaveSourceLayout(resLayout);
1919 }
1920
1921 // For vector::ExtractStridedSliceOp, simply return result layout
1922 if (dyn_cast<vector::ExtractStridedSliceOp>(op))
1923 return resLayout;
1924
1925 // For elementwise operations, all operands must have the same layout as the
1926 // result.
1928 return resLayout;
1929
1930 return nullptr;
1931}
1932
1933xegpu::DistributeLayoutAttr xegpu::getConsumerLayoutAt(OpOperand &operand) {
1934 Operation *op = operand.getOwner();
1935 // Anchor ops declare the layout they
1936 // require on each operand. Trust that declaration directly so that
1937 // ResolveLayoutConflicts compares producer-vs-declared
1938 if (isa<xegpu::AnchorLayoutInterface>(op))
1939 return xegpu::getDistributeLayoutAttr(operand);
1940 // For non-anchor ops, derive the operand layout from the op's result
1941 // layout via op-specific semantics.
1942 xegpu::DistributeLayoutAttr resLayout;
1943 if (op->getNumResults() == 1 || isa<vector::DeinterleaveOp>(op))
1944 resLayout = xegpu::getDistributeLayoutAttr(op->getResult(0));
1945 return inferSourceLayoutFromResultForNonAnchorOp(operand, resLayout);
1946}
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
Definition PDL.cpp:62
lhs
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
std::pair< int64_t, int64_t > LayoutRepresentation
static xegpu::DistributeLayoutAttr createScaleLayout(mlir::MLIRContext *context, VectorType matrixTy, VectorType scaleTy, xegpu::DistributeLayoutAttr matrixLayout, bool isBScale, const xegpu::uArch::uArch *uArch)
Helper to create a scale layout derived from a matrix operand layout.
static std::optional< std::tuple< SmallVector< int64_t >, SmallVector< int64_t >, SmallVector< int64_t > > > getDpasInstDataVectors(VectorType aTy, VectorType bTy, VectorType cdTy, const xegpu::uArch::uArch *uArch, bool isDpasMx=false)
Helper function to compute inst_data vectors for DPAS operands A, B, and C/D.
static xegpu::DistributeLayoutAttr getLayoutFromUsePoints(Value result)
static xegpu::DistributeLayoutAttr setupGenericStoreAnchorLayout(xegpu::LayoutKind layoutKind, mlir::MLIRContext *context, bool isChunkedStore, int maxChunkSize, ArrayRef< int64_t > srcShape, int subgroupSize)
Sets up the anchor layout for store scatter and store matrix operation.
static std::optional< std::tuple< xegpu::DistributeLayoutAttr, xegpu::DistributeLayoutAttr, xegpu::DistributeLayoutAttr > > getupDpasSubgroupLayouts(mlir::MLIRContext *context, VectorType aTy, VectorType bTy, VectorType cdTy, xegpu::DistributeLayoutAttr consumerLayout, int numSg, const xegpu::uArch::uArch *uArch)
Helper function to set up subgroup layouts for DPAS operands A, B, and C/D.
static void propagateResultsToRegularOperands(Operation *op)
static void propagateRegionResultsToYieldOperands(mlir::RegionBranchTerminatorOpInterface yieldOp)
static SmallVector< LayoutRepresentation > getValidLayouts(ArrayRef< int64_t > wgShape, ArrayRef< int64_t > instData, int64_t sgCount)
static void propagateRegionArgsToInits(mlir::RegionBranchOpInterface regionOp)
static void setTensorDescLayout(Value val, xegpu::DistributeLayoutAttr layout)
static xegpu::LayoutAttr getDefaultLaneLayout2DBlockIo(RankedTy ty, const xegpu::uArch::uArch *uArch, std::optional< unsigned > packingSize=std::nullopt, bool vnni=false)
static void walkRegionBackward(Region &region, llvm::function_ref< void(Operation *)> visit)
static xegpu::DistributeLayoutAttr setupGenericLoadAnchorLayout(xegpu::LayoutKind layoutKind, mlir::MLIRContext *context, xegpu::DistributeLayoutAttr consumerLayout, bool isChunkedLoad, int maxChunkSize, ArrayRef< int64_t > resShape, int subgroupSize)
Sets up the anchor layout for load gather and load matrix operation.
Block represents an ordered list of Operations.
Definition Block.h:33
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class represents an operand of an operation.
Definition Value.h:254
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
This is a value defined by a result of an operation.
Definition Value.h:454
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:44
unsigned getBeginOperandIndex() const
Return the operand index of the first element of this range.
type_range getType() const
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
bool hasAttrOfType(NameT &&name)
Definition Operation.h:600
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:537
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:432
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:699
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:408
auto getDiscardableAttrs()
Return a range of all of discardable attributes on this operation.
Definition Operation.h:511
Attribute removeDiscardableAttr(StringAttr name)
Remove the discardable attribute with the specified name if it exists.
Definition Operation.h:497
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:403
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:822
Attribute removeAttr(StringAttr name)
Remove the attribute with the specified name if it exists.
Definition Operation.h:625
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:429
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
bool empty()
Definition Region.h:60
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Definition Value.h:116
Type getType() const
Return the type of this value.
Definition Value.h:105
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int64_t > content)
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
DistributeLayoutAttr inferShapeCastSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for a shape cast operation given the result layout attribute,...
DistributeLayoutAttr setupInterleaveResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the result layout for an interleave operation to ensure the source layout can be safely deriv...
DistributeLayoutAttr inferTransposeSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > permutation)
Infers the source layout attribute for a transpose operation given the result layout attribute and pe...
DistributeLayoutAttr inferInsertSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for an insert operation.
DistributeLayoutAttr inferInsertStridedSliceSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for an insert strided slice operation given the result layout attr...
void removeTemporaryLayoutAttrs(Operation *op)
Removes the temporary layout attributes for each OpOperand and OpResult of the given operation.
void setTemporaryLayout(const T &operandOrResult, const DistributeLayoutAttr layout)
LayoutKind
Specifies the level of a layout hierarchy for comparison or propagation.
Definition XeGPU.h:32
SmallVector< NamedAttribute > dropInstDataOnAttrs(ArrayRef< NamedAttribute > attrs)
Updates the NamedAttribute sequence by dropping inst-data information from any DistributeLayoutAttr f...
DistributeLayoutAttr inferSourceLayoutFromResultForNonAnchorOp(OpOperand &operand, DistributeLayoutAttr resLayout)
Infers the source layout attribute for an operand using result layout attribute.
DistributeLayoutAttr inferInterleaveSourceLayout(DistributeLayoutAttr resLayout)
Infers the source layout attribute for an interleave operation given the result layout attribute.
bool matchUnitDimExpansion(ArrayRef< int64_t > src, ArrayRef< int64_t > dst, SmallVector< int64_t > &expandedUnitDims)
DistributeLayoutAttr setupLoadMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the anchor layout for load matrix operation.
int getLargestDivisor(T dim, ArrayRef< T > candidates, ArrayRef< T > candidateMultiples={})
Helper Function to find a proper instruction multiple for the user-supplied sg-level data shape (dive...
bool recoverTemporaryLayouts(Operation *rootOp)
Attach layout attributes to all vector-type operands of operations within the given operation's neste...
DistributeLayoutAttr inferBroadcastSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for a broadcast operation given the result layout attribute,...
std::optional< std::tuple< DistributeLayoutAttr, DistributeLayoutAttr, DistributeLayoutAttr, DistributeLayoutAttr, DistributeLayoutAttr > > setupDpasMxLayout(LayoutKind layoutKind, VectorType aTy, VectorType bTy, VectorType cdTy, VectorType aScaleTy, VectorType bScaleTy, DistributeLayoutAttr consumerLayout, int numSg, const uArch::uArch *uArch)
Sets up the anchor layouts for dpas_mx operands (A, B, C/D, A_scale, and B_scale).
DistributeLayoutAttr setupStoreScatterAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, int chunkSize, const uArch::uArch *uArch)
Sets up the anchor layout for a store scatter operation.
SliceAttr setupMultiReductionResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, DistributeLayoutAttr consumerLayout, SmallVector< int64_t > reductionDims, int numSg, const uArch::uArch *uArch)
Sets up layout for Multi-Reduction operations by creating a SliceAttr for the result.
bool matchSplitDimExpansion(ArrayRef< int64_t > src, ArrayRef< int64_t > dst, SmallVector< SmallVector< int64_t > > &splitDimGroups)
DistributeLayoutAttr setupBitCastResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Setup the result layout attribute for a bitcast operation based on element type bitwidths.
void removeLayoutAttr(const T &operandOrResult)
Removes the LayoutAttr for a given OpOperand or OpResult if it exists.
DistributeLayoutAttr inferMaskOffsetLayoutForScatterIO(DistributeLayoutAttr payloadLayout, int chunkSize)
Infers the layout attribute for mask and offset operand for Chunked load and store,...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
SmallVector< NamedAttribute > dropSgLayoutAndDataOnAttrs(ArrayRef< NamedAttribute > attrs)
Updates the NamedAttribute sequence by dropping sg-layout and sg-data information from any Distribute...
DistributeLayoutAttr inferExtractSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for an extract operation.
std::string getTemporaryLayoutName(const OpOperand &operand)
Return the attribute name for the OpOperand to attach DistributeLayoutAttr.
DistributeLayoutAttr inferBitCastSourceLayout(DistributeLayoutAttr resLayout, int resElemTyBitWidth, int srcElemTyBitWidth)
Infers the source layout attribute for a bitcast operation given the result layout attribute,...
DistributeLayoutAttr setupInsertStridedSliceResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the result layout for an insert strided slice operation.
DistributeLayoutAttr inferReductionSourceLayout(DistributeLayoutAttr resLayout)
Infers the source layout attribute for a reduction operation given the result layout attribute and re...
DistributeLayoutAttr inferDeinterleaveSourceLayout(DistributeLayoutAttr resLayout)
Infers the source layout attribute for a deinterleave operation given the result layout attribute.
DistributeLayoutAttr getConsumerLayoutAt(OpOperand &operand)
Gets the expected layout for a given consumer operand.
void removeLayoutAttrs(Operation *op)
Removes the DistributeLayoutAttr for each OpOperand and OpResult of the given operation if they exist...
DistributeLayoutAttr inferMultiReductionSourceLayout(DistributeLayoutAttr resLayout, SmallVector< int64_t > reduceDims)
Infers the source layout attribute for a reduction operation given the result layout attribute and re...
bool isTriviallyRematerializable(Operation *op)
Returns true if op is safe and cheap to clone: it has no side effects, no regions,...
DistributeLayoutAttr setupLoadGatherAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, int chunkSize, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the anchor layout for a load gather operation.
std::optional< std::tuple< DistributeLayoutAttr, DistributeLayoutAttr, DistributeLayoutAttr > > setupDpasLayout(LayoutKind layoutKind, VectorType aTy, VectorType bTy, VectorType cdTy, DistributeLayoutAttr consumerLayout, int numSg, const uArch::uArch *uArch)
Sets up the anchor layouts for a dpas operands (A, B, and C/D).
SliceAttr setupReductionResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, const uArch::uArch *uArch)
Sets up layout for Reduction operations by creating a SliceAttr for the result.
DistributeLayoutAttr setupStoreMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, const uArch::uArch *uArch)
Sets up the anchor layout for a store matrix operation.
Include the generated interface declarations.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
virtual llvm::SmallVector< uint32_t, 8 > getSupportedN(Type type) const =0
virtual llvm::SmallVector< uint32_t, 8 > getSupportedK(Type type) const =0
virtual llvm::SmallVector< uint32_t, 8 > getSupportedM(Type type) const =0
virtual int getSubgroupSize() const =0
uArch(StringRef name, StringRef description, llvm::ArrayRef< const Instruction * > instructionRegistry)
Definition uArchBase.h:156
const Instruction * getInstruction(InstructionKind instKind) const
Definition uArchBase.h:168