MLIR 23.0.0git
XeGPULayoutImpl.cpp
Go to the documentation of this file.
1//===---- XeGPULayoutImpl.cpp - MLIR Utilities for XeGPUOps
2//------------------===//
3//
4// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
5// See https://llvm.org/LICENSE.txt for license information.
6// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7//
8//===----------------------------------------------------------------------===//
9//
10// This file implements layout utility functions for XeGPU dialect
11// transformation.
12//
13//===----------------------------------------------------------------------===//
14
23#include "mlir/IR/Builders.h"
24#include "mlir/IR/Operation.h"
25#include "mlir/IR/ValueRange.h"
30#include "llvm/ADT/PostOrderIterator.h"
31#include "llvm/Support/FormatVariadic.h"
32#include <cstdint>
33#include <numeric>
34
35using namespace mlir;
36
40 out.reserve(attrs.size());
41
42 for (auto attr : attrs) {
43 if (auto dist = dyn_cast<xegpu::DistributeLayoutAttr>(attr.getValue())) {
44 auto newLayout = dist.dropSgLayoutAndData();
45 if (newLayout)
46 out.emplace_back(attr.getName(), newLayout);
47 } else {
48 out.push_back(attr);
49 }
50 }
51
52 return out;
53}
54
58 out.reserve(attrs.size());
59
60 for (auto attr : attrs) {
61 if (auto dist = dyn_cast<xegpu::DistributeLayoutAttr>(attr.getValue())) {
62 auto newLayout = dist.dropInstData();
63 if (newLayout)
64 out.emplace_back(attr.getName(), newLayout);
65 } else {
66 out.push_back(attr);
67 }
68 }
69
70 return out;
71}
72
73// Sets the layout on a TensorDesc value by updating its type to include
74// the given layout, if the type does not already have a layout attached.
75static void setTensorDescLayout(Value val, xegpu::DistributeLayoutAttr layout) {
76 auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(val.getType());
77 if (!tensorDescTy || tensorDescTy.getLayoutAttr())
78 return;
79 auto typeWithLayout = xegpu::TensorDescType::get(
80 tensorDescTy.getContext(), tensorDescTy.getShape(),
81 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
82 val.setType(typeWithLayout);
83}
84
85// the walkRegionBackward() is a recursive function
86// the input rootOp is the function operation, which is also a region op.
87// it recursively processes the region op in reverse topological order.
88static void walkRegionBackward(Region &region,
90
91 // Use post-order traversal to process blocks in reverse topological order.
92 // This ensures that use blocks are visited before def blocks, which is
93 // required for backward layout propagation.
94 if (region.empty())
95 return;
96 llvm::ReversePostOrderTraversal<Region *> rpot(&region);
97 SmallVector<Block *> blocks(rpot.begin(), rpot.end());
98 for (Block *block : llvm::reverse(blocks)) {
99 // ops: back -> front
100 for (Operation &op : llvm::reverse(*block)) {
101 // make sure we first visit inside the region op (so yield op first)
102 // and then move to region op itself
103 // Regions are iterated in forward order so that for multi-region ops
104 // like scf.while, earlier regions (e.g., "before/cond") are processed
105 // first. This ensures that when a later region's terminator (e.g., "do"
106 // yield) needs the layout of an earlier region's block args, those
107 // layouts are already available from use points.
108 for (Region &nested : op.getRegions())
109 walkRegionBackward(nested, visit);
110
111 visit(&op);
112 }
113 }
114}
115
116static xegpu::DistributeLayoutAttr getLayoutFromUsePoints(Value result) {
117 xegpu::DistributeLayoutAttr layout = nullptr;
118 for (OpOperand &use : result.getUses()) {
119 if (auto tmpLayout = xegpu::getDistributeLayoutAttr(use)) {
120 if (!layout)
121 layout = tmpLayout;
122 break;
123 }
124 }
125 return layout;
126}
127
128// Returns true if `op` is safe and cheap to clone (no side effects, no
129// regions, and all operands are themselves trivially rematerializable, e.g.
130// block-arg-free pure value generators such as `vector.step`, splat
131// `arith.constant`, or `vector.create_mask` whose operands are constants).
133 if (!op || op->getNumRegions() != 0)
134 return false;
135 if (!isMemoryEffectFree(op))
136 return false;
137 for (Value v : op->getOperands()) {
138 Operation *defOp = v.getDefiningOp();
139 if (!defOp)
140 return false;
141 if (!isTriviallyRematerializable(defOp))
142 return false;
143 }
144 return true;
145}
146
147// For regular operations: First the result layouts are propagated from uses.
148// Then the result layouts are propagated to uses (operands).
150 if (op->getNumResults() == 0)
151 return;
152 if (op->getNumResults() > 1 && !isa<vector::DeinterleaveOp>(op))
153 return;
154 OpResult result = op->getResult(0);
155 xegpu::DistributeLayoutAttr resLayout = getLayoutFromUsePoints(result);
156 Type resultType = result.getType();
157
158 if (!resLayout)
159 return;
160
161 // Recover layout for TensorDesc type results by updating the type to include
162 // the layout. For vector type
163 if (isa<xegpu::TensorDescType>(resultType))
164 setTensorDescLayout(result, resLayout);
165
166 // Recover layout for vector type results, or for multi-reduction ops which
167 // may reduce to a scalar that still needs a layout.
168 if (isa<VectorType>(resultType) || isa<vector::MultiDimReductionOp>(op))
170
171 if (isa<vector::DeinterleaveOp>(op))
172 xegpu::setTemporaryLayout(op->getResult(1), resLayout);
173
174 for (OpOperand &opr : op->getOpOperands()) {
175 xegpu::DistributeLayoutAttr operandLayout =
177 if (isa<VectorType>(opr.get().getType()) && operandLayout)
178 xegpu::setTemporaryLayout(opr, operandLayout);
179 }
180}
181
182// Propagate layout from region op results and sibling region block args
183// to yield/condition operands. For each successor of this terminator:
184// - Parent successor: propagate from parent op's result layouts (use points).
185// - Region successor: propagate from target region's block arg layouts (use
186// points), e.g., scf.yield in "after/do" region propagates to "before/cond"
187// block args.
189 mlir::RegionBranchTerminatorOpInterface yieldOp) {
190 auto regionBranchOp =
191 dyn_cast<RegionBranchOpInterface>(yieldOp->getParentOp());
192 if (!regionBranchOp)
193 return;
194
196 SmallVector<Attribute> operandAttrs(yieldOp->getNumOperands(), nullptr);
197 yieldOp.getSuccessorRegions(operandAttrs, successors);
198
199 for (const RegionSuccessor &successor : successors) {
200 OperandRange succOps = yieldOp.getSuccessorOperands(successor);
201 if (succOps.empty())
202 continue;
203 unsigned beginIdx = succOps.getBeginOperandIndex();
204 ValueRange successorInputs = regionBranchOp.getSuccessorInputs(successor);
205 unsigned count = std::min<unsigned>(succOps.size(), successorInputs.size());
206
207 for (unsigned i = 0; i < count; ++i) {
208 xegpu::DistributeLayoutAttr layout;
209 if (successor.isOperation()) {
210 // For parent successor, get layout from external use points of the
211 // parent op's results.
212 auto regionResult = regionBranchOp->getResult(i);
213 layout = getLayoutFromUsePoints(regionResult);
214 if (layout) {
215 // set layout for the region op, like scf.loop
216 xegpu::setTemporaryLayout(regionResult, layout);
217 if (isa<xegpu::TensorDescType>(regionResult.getType()))
218 setTensorDescLayout(regionResult, layout);
219 }
220 } else {
221 // For region successor, get layout from the target region's block
222 // arg use points (e.g., "before/cond" region args for scf.while
223 // "after/do" yield).
224 layout = getLayoutFromUsePoints(successorInputs[i]);
225 }
226 if (!layout)
227 continue;
228 auto operandType = succOps[i].getType();
229 if (isa<VectorType>(operandType) ||
230 dyn_cast<xegpu::TensorDescType>(operandType))
231 // recover layout for yield op operands
232 xegpu::setTemporaryLayout(yieldOp->getOpOperand(beginIdx + i), layout);
233 }
234 }
235}
236
237// Propagate layout from region arguments to region op's init operands. This
238// sets the temporary layout for region arguments and init operands.
239LogicalResult
240xegpu::propagateRegionArgsToInits(mlir::RegionBranchOpInterface regionOp,
241 xegpu::GetLayoutFnTy getLayoutOfValue) {
242 // Iterate all regions of the region op. For each block argument that has a
243 // layout (obtained via `getLayoutOfValue`), trace back to find the
244 // corresponding init operand of the regionOp and set the layout on it.
245 // This works generically for scf.for, scf.while, and other
246 // RegionBranchOpInterface ops.
247 for (Region &region : regionOp->getRegions()) {
248 RegionSuccessor regionSuccessor(&region);
249 // Use getSuccessorInputs to get the block arguments that correspond to
250 // predecessor operands. This correctly handles ops like scf.for where
251 // the induction variable is a block arg but not a successor input.
252 ValueRange successorInputs = regionOp.getSuccessorInputs(regionSuccessor);
253 for (auto [inputIdx, regionArg] : llvm::enumerate(successorInputs)) {
254 auto layout = getLayoutOfValue(regionArg);
255 if (!layout)
256 continue;
257
258 // Recover layout for tensor_desc block args by updating the type.
259 if (isa<xegpu::TensorDescType>(regionArg.getType()))
260 setTensorDescLayout(regionArg, layout);
261
262 // Recover layout for region op operands, like scf.for's init operands.
263 // Find all predecessor values that flow into this block argument.
264 SmallVector<Value> predValues;
265 regionOp.getPredecessorValues(regionSuccessor, inputIdx, predValues);
266 for (Value predVal : predValues) {
267 // Match predecessor value to an operand of the regionOp.
268 for (OpOperand &operand : regionOp->getOpOperands()) {
269 if (operand.get() == predVal)
270 xegpu::setTemporaryLayout(operand, layout);
271 }
272 }
273 }
274 }
275 return success();
276}
277
278// Prerequisite for Layout Recovery
279// It relies on the following invariant:
280// 1. there is no layout conflict between different uses of the same definition.
281// 2. each definition has a well-defined layout requirement at its use point.
282// - Every definition must have at least one use that appears after it in
283// topological order.
284// - TODO: If a definition has no such use (e.g., a loop result or region
285// output), an explicit convert_layout operation is inserted to create a
286// use.
287// - Only the result of convert_layout is permitted to have no subsequent
288// use.
289//
290// The recovery proceeds by scanning the operation in reverse topological order
291// as follows:
292// For regular operations: First the result layouts are propagated from uses.
293// Then the result layouts are propagated to operands.
294//
295// For region operations (e.g., loops):
296// - When backward propagation reaches a region op, it sets the layout of
297// the region op’s results according to use points like regular ops.
298// - Then, the result layouts (such as a loop output) are propagated to
299// their corresponding operands in the yield.
300// - When backward propagation reaches the first operation inside the
301// region, the pass examines the region op’s initialization list,
302// propagating from region arguments to the corresponding initialization
303// operands.
304// - This ensures that layouts are consistently propagated
305// across region boundaries while preserving a single well-defined use for
306// each definition at the region-op level.
308 auto processFunc = [&](Region &body, StringRef funcName) {
309 walkRegionBackward(body, [&](Operation *op) {
310 if (auto regionOp = dyn_cast<mlir::RegionBranchOpInterface>(op)) {
313 } else if (auto yieldOp =
314 dyn_cast<mlir::RegionBranchTerminatorOpInterface>(op)) {
316 } else if (!dyn_cast<xegpu::AnchorLayoutInterface>(op)) {
318 }
319 });
320 };
322 rootOp->walk([&](func::FuncOp func) {
323 processFunc(func.getBody(), func.getSymName());
324 });
325 rootOp->walk([&](gpu::GPUFuncOp func) {
326 processFunc(func.getBody(), func.getName());
327 });
328
329 return true;
330}
331
332template <typename T, typename>
333void xegpu::removeLayoutAttr(const T &operandOrResult) {
334 Operation *owner = operandOrResult.getOwner();
335 std::string name = xegpu::getTemporaryLayoutName(operandOrResult);
336 if (owner->hasAttrOfType<DistributeLayoutAttr>(name))
337 owner->removeAttr(name);
338}
339
340// Explicit instantiation for OpResult
341template void
343
344// Explicit instantiation for OpOperand
345template void
347
349 op->walk([&](Operation *nestOp) {
350 // Remove all attributes of DistributeLayoutAttr type
351 SmallVector<StringAttr> attrsToRemove;
352 for (auto namedAttr : nestOp->getAttrs()) {
353 if (isa<DistributeLayoutAttr>(namedAttr.getValue()))
354 attrsToRemove.push_back(namedAttr.getName());
355 }
356 for (auto attrName : attrsToRemove)
357 nestOp->removeAttr(attrName);
358 });
359}
360
362 op->walk([&](Operation *nestOp) {
363 SmallVector<StringAttr> attrsToRemove;
364 for (auto namedAttr : nestOp->getDiscardableAttrs()) {
365 if (isa<xegpu::DistributeLayoutAttr>(namedAttr.getValue()))
366 attrsToRemove.push_back(namedAttr.getName());
367 }
368 for (auto attrName : attrsToRemove)
369 nestOp->removeDiscardableAttr(attrName);
370 });
371}
372
373/// Infers the source layout attribute for a broadcast operation given the
374/// result layout attribute, result shape, source shape.
375xegpu::DistributeLayoutAttr
376xegpu::inferBroadcastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
377 ArrayRef<int64_t> resShape,
378 ArrayRef<int64_t> srcShape) {
379
380 SmallVector<int64_t> bcastDims;
381 size_t dimDiff = resShape.size() - srcShape.size();
382 auto bcastSourceLayout = resLayout;
383 for (size_t i = dimDiff; i < resShape.size(); i++) {
384 if ((srcShape[i - dimDiff] == 1) && (resShape[i] != 1))
385 bcastDims.push_back(i);
386 }
387
388 // the sg_layout and lane_layout for unit dimensions are preserved so it can
389 // be propagate to producer op so potentially used by the multi-reduction op.
390 if (!bcastDims.empty())
391 bcastSourceLayout = bcastSourceLayout.setUnitDimData(bcastDims);
392
393 if (dimDiff > 0) {
394 SmallVector<int64_t> sliceDims;
395 for (size_t i = 0; i < dimDiff; i++)
396 sliceDims.push_back(i);
397 bcastSourceLayout = xegpu::SliceAttr::get(
398 resLayout.getContext(), bcastSourceLayout,
399 DenseI64ArrayAttr::get(resLayout.getContext(), sliceDims));
400 }
401 return bcastSourceLayout;
402}
403
404/// Infers the source layout attribute for a reduction operation given the
405/// result layout attribute and reduced dims.
406xegpu::DistributeLayoutAttr
407xegpu::inferMultiReductionSourceLayout(xegpu::DistributeLayoutAttr resLayout,
408 SmallVector<int64_t> reduceDims) {
409
410 assert(isa<xegpu::SliceAttr>(resLayout) &&
411 "reduction result layout must be slice layout");
412
413 xegpu::SliceAttr sliceLayout = dyn_cast<xegpu::SliceAttr>(resLayout);
414
415 assert((reduceDims == sliceLayout.getDims().asArrayRef()) &&
416 "reduction dims must match with slice dims");
417
418 return sliceLayout.getParent();
419}
420
421xegpu::DistributeLayoutAttr
422xegpu::inferReductionSourceLayout(xegpu::DistributeLayoutAttr resLayout) {
423 return xegpu::inferMultiReductionSourceLayout(resLayout, {0});
424}
425
426/// Infers the source layout attribute for a transpose operation given the
427/// result layout attribute and permutation.
428///
429/// vector.transpose semantics is `result[i] = source[permutation[i]]`, so
430/// `result_layout[i] = source_layout[permutation[i]]`. To recover the source
431/// layout from the result layout we must apply the inverse permutation.
432xegpu::DistributeLayoutAttr
433xegpu::inferTransposeSourceLayout(xegpu::DistributeLayoutAttr resLayout,
434 ArrayRef<int64_t> permutation) {
436 invertPermutationVector(permutation);
437 return resLayout.transposeDims(inversePermutation);
438}
439
440/// Infers the source layout attribute for a bitcast operation given the
441/// result layout attribute, result element type bitwidth, and source element
442/// type bitwidth.
443xegpu::DistributeLayoutAttr
444xegpu::inferBitCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
445 int resElemTyBitWidth, int srcElemTyBitWidth) {
446
447 SmallVector<int64_t> sgData = resLayout.getEffectiveSgDataAsInt();
448 SmallVector<int64_t> instData = resLayout.getEffectiveInstDataAsInt();
449 SmallVector<int64_t> laneData = resLayout.getEffectiveLaneDataAsInt();
450 size_t sgDataSize = sgData.size();
451 size_t instDataSize = instData.size();
452 size_t laneDataSize = laneData.size();
453 int64_t sgDataValue = -1;
454 int64_t instDataValue = -1;
455 int64_t laneDataValue = -1;
456 int64_t dim = resLayout.getRank() - 1;
457
458 if (srcElemTyBitWidth <= resElemTyBitWidth) {
459 int bitWidthRatio = resElemTyBitWidth / srcElemTyBitWidth;
460 if (sgDataSize)
461 sgDataValue = sgData.back() * bitWidthRatio;
462 if (instDataSize)
463 instDataValue = instData.back() * bitWidthRatio;
464 if (laneDataSize)
465 laneDataValue = laneData.back() * bitWidthRatio;
466 } else {
467 int bitWidthRatio = srcElemTyBitWidth / resElemTyBitWidth;
468 if (sgDataSize) {
469 assert((sgData.back() % bitWidthRatio) == 0 &&
470 "sgData not divisible by bitWidthRatio");
471 sgDataValue = sgData.back() / bitWidthRatio;
472 }
473 if (instDataSize) {
474 assert((instData.back() % bitWidthRatio) == 0 &&
475 "instData not divisible by bitWidthRatio");
476 instDataValue = instData.back() / bitWidthRatio;
477 }
478 if (laneDataSize) {
479 assert((laneData.back() % bitWidthRatio) == 0 &&
480 "laneData not divisible by bitWidthRatio");
481 laneDataValue = laneData.back() / bitWidthRatio;
482 }
483 }
484
485 xegpu::DistributeLayoutAttr finalSrcLayout;
486 finalSrcLayout =
487 resLayout.setDimData(dim, sgDataValue, instDataValue, laneDataValue);
488
489 return finalSrcLayout;
490}
491
492/// Infers the source layout attribute for an interleave operation given the
493/// result layout attribute. Interleave doubles the size of the innermost
494/// dimension, so the layout inference is similar to bitcast where the source
495/// element type is larger than the result element type (ratio = 2).
496xegpu::DistributeLayoutAttr
497xegpu::inferInterleaveSourceLayout(xegpu::DistributeLayoutAttr resLayout) {
498
499 SmallVector<int64_t> sgData = resLayout.getEffectiveSgDataAsInt();
500 SmallVector<int64_t> instData = resLayout.getEffectiveInstDataAsInt();
501 SmallVector<int64_t> laneData = resLayout.getEffectiveLaneDataAsInt();
502 size_t sgDataSize = sgData.size();
503 size_t instDataSize = instData.size();
504 size_t laneDataSize = laneData.size();
505 int64_t sgDataValue = -1;
506 int64_t instDataValue = -1;
507 int64_t laneDataValue = -1;
508 int64_t dim = resLayout.getRank() - 1;
509
510 // Interleave doubles the innermost dimension, so we need to halve the
511 // layout values (similar to bitcast with ratio = 2)
512 constexpr int ratio = 2;
513 if (sgDataSize) {
514 assert((sgData.back() % ratio) == 0 &&
515 "sgData not divisible by interleave ratio");
516 sgDataValue = sgData.back() / ratio;
517 }
518 if (instDataSize) {
519 assert((instData.back() % ratio) == 0 &&
520 "instData not divisible by interleave ratio");
521 instDataValue = instData.back() / ratio;
522 }
523 if (laneDataSize) {
524 assert((laneData.back() % ratio) == 0 &&
525 "laneData not divisible by interleave ratio");
526 laneDataValue = laneData.back() / ratio;
527 }
528
529 return resLayout.setDimData(dim, sgDataValue, instDataValue, laneDataValue);
530}
531
532/// Infers the source layout attribute for a deinterleave operation given the
533/// result layout attribute. Deinterleave halves the size of the innermost
534/// dimension, so the layout inference is similar to bitcast where the source
535/// element type is smaller than the result element type (ratio = 2).
536xegpu::DistributeLayoutAttr
537xegpu::inferDeinterleaveSourceLayout(xegpu::DistributeLayoutAttr resLayout) {
538
539 SmallVector<int64_t> sgData = resLayout.getEffectiveSgDataAsInt();
540 SmallVector<int64_t> instData = resLayout.getEffectiveInstDataAsInt();
541 SmallVector<int64_t> laneData = resLayout.getEffectiveLaneDataAsInt();
542 size_t sgDataSize = sgData.size();
543 size_t instDataSize = instData.size();
544 size_t laneDataSize = laneData.size();
545 int64_t sgDataValue = -1;
546 int64_t instDataValue = -1;
547 int64_t laneDataValue = -1;
548 int64_t dim = resLayout.getRank() - 1;
549
550 // Deinterleave halves the innermost dimension, so we need to double the
551 // layout values (similar to bitcast with ratio = 2)
552 constexpr int ratio = 2;
553 if (sgDataSize)
554 sgDataValue = sgData.back() * ratio;
555 if (instDataSize)
556 instDataValue = instData.back() * ratio;
557 if (laneDataSize)
558 laneDataValue = laneData.back() * ratio;
559
560 return resLayout.setDimData(dim, sgDataValue, instDataValue, laneDataValue);
561}
562
563/// Infers the source layout attribute for an insert strided slice operation
564/// given the result layout attribute, result shape, and source shape. Removes
565/// leading dimensions from the result layout to match the source shape size.
566xegpu::DistributeLayoutAttr xegpu::inferInsertStridedSliceSourceLayout(
567 xegpu::DistributeLayoutAttr resLayout, ArrayRef<int64_t> resShape,
568 ArrayRef<int64_t> srcShape) {
569
570 int srcShapeSize = srcShape.size();
571 int resShapeSize = resShape.size();
572 int dimDiff = resShapeSize - srcShapeSize;
573
574 if (dimDiff > 0) {
575 // assert that the leading dimensions being sliced off are not distributed
576 // (i.e. sg_layout and lane_layout for those dimensions are all 1)
577 auto resSgLayout = resLayout.getEffectiveSgLayoutAsInt();
578 auto resLaneLayout = resLayout.getEffectiveLaneLayoutAsInt();
579 for (int i = 0; i < dimDiff; i++) {
580 assert((resSgLayout.size() == 0 || resSgLayout[i] == 1) &&
581 (resLaneLayout.size() == 0 || resLaneLayout[i] == 1) &&
582 "Leading dimensions being sliced off must not be distributed");
583 }
584 return resLayout.dropDims(llvm::to_vector(llvm::seq<int64_t>(0, dimDiff)));
585 }
586 return resLayout;
587}
588
589/// Infers the source layout attribute for an insert operation
590/// given the result layout attribute, result shape, and source shape. Removes
591/// leading dimensions from the result layout to match the source shape size.
592// TODO: add propagation support for insert op
593xegpu::DistributeLayoutAttr
594xegpu::inferInsertSourceLayout(xegpu::DistributeLayoutAttr resLayout,
595 ArrayRef<int64_t> resShape,
596 ArrayRef<int64_t> srcShape) {
597
598 int srcShapeSize = srcShape.size();
599 int resShapeSize = resShape.size();
600 int dimDiff = resShapeSize - srcShapeSize;
601
602 if (dimDiff > 0) {
603 // assert that the leading dimensions being sliced off are not distributed
604 // (i.e. sg_layout and lane_layout for those dimensions are all 1)
605 auto resSgLayout = resLayout.getEffectiveSgLayoutAsInt();
606 auto resLaneLayout = resLayout.getEffectiveLaneLayoutAsInt();
607 for (int i = 0; i < dimDiff; i++) {
608 assert((resSgLayout.size() == 0 || resSgLayout[i] == 1) &&
609 (resLaneLayout.size() == 0 || resLaneLayout[i] == 1) &&
610 "Leading dimensions being sliced off must not be distributed");
611 }
612 return resLayout.dropDims(llvm::to_vector(llvm::seq<int64_t>(0, dimDiff)));
613 }
614 return resLayout;
615}
616
617/// Infers the source layout attribute for extract operation
618/// given the result layout attribute, result shape, and source shape. Adds
619/// leading dimensions to the source layout to match the source shape size.
620// TODO: add layout attribute interface: expandDim() and use it here.
621// TODO: add propagation support for extract op
622xegpu::DistributeLayoutAttr
623xegpu::inferExtractSourceLayout(xegpu::DistributeLayoutAttr resLayout,
624 ArrayRef<int64_t> resShape,
625 ArrayRef<int64_t> srcShape) {
626
627 int srcShapeSize = srcShape.size();
628 int resShapeSize = resShape.size();
629 int dimDiff = srcShapeSize - resShapeSize;
630 auto context = resLayout.getContext();
631 // construct the source layout by adding unit dimensions to the front of
632 // result layout
633 if (dimDiff > 0) {
634 auto sgLayout = resLayout.getEffectiveSgLayoutAsInt();
635 auto sgData = resLayout.getEffectiveSgDataAsInt();
636 auto instData = resLayout.getEffectiveInstDataAsInt();
637 auto laneLayout = resLayout.getEffectiveLaneLayoutAsInt();
638 auto laneData = resLayout.getEffectiveLaneDataAsInt();
639 auto order = resLayout.getEffectiveOrderAsInt();
640
641 // Example: result shape is 3D with order [1, 2, 0], source shape is 5D
642 // (adding 2 leading dimensions). Expected source order: [3, 4, 2, 1, 0]
643 // Step 1: shift existing order by dimDiff: [1, 2, 0] -> [3, 4, 2]
644 // Step 2: append new leading dims in reverse (slowest first): [3, 4, 2, 1,
645 // 0]
646
647 // Shift existing dimension indices in order by dimDiff to account for the
648 // new leading dimensions being added to the source shape
649 for (auto &o : order)
650 o += dimDiff;
651
652 // Add unit dimensions to the front of non-empty layout vectors and append
653 // the new dimension indices to the order array in reverse (slowest
654 // dimension has the lowest index and appears last in the order array)
655 for (int i = 0; i < dimDiff; i++) {
656 if (!sgLayout.empty())
657 sgLayout.insert(sgLayout.begin(), 1);
658 if (!sgData.empty())
659 sgData.insert(sgData.begin(), 1);
660 if (!instData.empty())
661 instData.insert(instData.begin(), 1);
662 if (!laneLayout.empty())
663 laneLayout.insert(laneLayout.begin(), 1);
664 if (!laneData.empty())
665 laneData.insert(laneData.begin(), 1);
666 order.push_back(dimDiff - 1 - i);
667 }
668
669 DenseI32ArrayAttr orderAttr = resLayout ? resLayout.getOrder() : nullptr;
670 auto toAttr = [&](ArrayRef<int64_t> v) -> DenseI32ArrayAttr {
671 if (v.empty())
672 return DenseI32ArrayAttr();
673 SmallVector<int32_t> v32(v.begin(), v.end());
674 return DenseI32ArrayAttr::get(context, v32);
675 };
676 auto srcLayout = xegpu::LayoutAttr::get(
677 context, sgLayout.empty() ? nullptr : toAttr(sgLayout),
678 sgData.empty() ? nullptr : toAttr(sgData),
679 instData.empty() ? nullptr : toAttr(instData),
680 laneLayout.empty() ? nullptr : toAttr(laneLayout),
681 laneData.empty() ? nullptr : toAttr(laneData),
682 (!orderAttr || orderAttr.empty()) ? nullptr : toAttr(order));
683 return srcLayout;
684 }
685 return resLayout;
686}
687
688/// Infers the source layout attribute for a shape cast operation given the
689/// result layout attribute, result shape, and source shape.
690xegpu::DistributeLayoutAttr
691xegpu::inferShapeCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
692 ArrayRef<int64_t> resShape,
693 ArrayRef<int64_t> srcShape) {
694
695 // There are three use cases:
696 // 1. expand dims of low-rank dimensions (e.g., 1D to 2D): to set up the
697 // tensor before broadcast
698 // 2. split dim of a high-rank dimension (e.g., 1D to 2D): to setup tensor
699 // for multi-stage reduction
700 // 3. combines all dims to a single dim and put in the innermost dim in 2d as
701 // [1, combinedData] or [combinedData]. Say, [2, 4, 8] -> [1, 64] or [64]
702 // Use cases are only supported after workgroup distribution,
703 // like cross-sg reduction saves multidimension data to
704 // 1D slm buffer, shapecast inserted by cse/canonicalization passes.
705
706 // Use case 1: Shapes only differ by expanding unit dimensions, for broadcast
707 SmallVector<int64_t> expandedUnitDims;
708
709 if (xegpu::matchUnitDimExpansion(srcShape, resShape, expandedUnitDims)) {
710 // create a slice layout for the source by removing the expanded unit dims
711 auto sliceDimsAttr = DenseI64ArrayAttr::get(
712 resLayout.getContext(), ArrayRef<int64_t>(expandedUnitDims));
713 auto srcLayout =
714 xegpu::SliceAttr::get(resLayout.getContext(), resLayout, sliceDimsAttr);
715 return srcLayout;
716 }
717
718 // Use case 2: Dim split from source to result, for multi-stage reduction
719 SmallVector<SmallVector<int64_t>> splitDimGroups;
720 if (xegpu::matchSplitDimExpansion(srcShape, resShape, splitDimGroups)) {
721 auto srcLayout = resLayout;
722 for (const auto &dimGroup : splitDimGroups)
723 srcLayout = srcLayout.collapseDims(dimGroup);
724
725 return srcLayout;
726 }
727
728 // Use case 3: General dim collapse, for cross-sg reduction to SLM and other
729 // shape casts where consecutive src dims fold into a single dst dim.
730 //
731 // Mirrors use case 2's elegant shape: walk the dst-side groups and call
732 // a single layout-attribute primitive per group. Here the primitive is
733 // `expandDim(dim, targetShape)`, the inverse of `collapseDims`. It applies
734 // the per-field distribution policy required for a no-data-movement collapse
735 // (sg_layout/lane_layout spread outer-to-inner; sg_data/lane_data/inst_data
736 // fill innermost-first; inst_data is seeded from lane_layout * lane_data).
737 // See LayoutAttr::expandDim for the full policy.
738 //
739 // Iteration goes innermost-first (reverse dst order) so that each
740 // expandDim/dropDims call only mutates dst positions whose indices are
741 // unaffected by earlier calls.
743 if (xegpu::matchDimCollapse(srcShape, resShape, collapseDims)) {
744 auto srcLayout = resLayout;
745 for (int64_t dstIdx = static_cast<int64_t>(collapseDims.size()) - 1;
746 dstIdx >= 0; --dstIdx) {
747 ArrayRef<int64_t> srcDims = collapseDims[dstIdx];
748 if (srcDims.empty()) {
749 // Unit dst dim with no backing src dim: drop it.
750 srcLayout = srcLayout.dropDims({dstIdx});
751 continue;
752 }
753 if (srcDims.size() == 1)
754 // 1:1 mapping, nothing to do for this dim.
755 continue;
756 SmallVector<int64_t> targetShape;
757 targetShape.reserve(srcDims.size());
758 for (int64_t d : srcDims)
759 targetShape.push_back(srcShape[d]);
760 srcLayout = srcLayout.expandDim(dstIdx, targetShape);
761 }
762 return srcLayout;
763 }
764 llvm_unreachable("running into unsupported shape cast scenarios");
765 return nullptr;
766}
767
768/// Infers the layout attribute for mask and offset operand for Chunked load
769/// and store, given the anchor layout attribute for the value being load/store.
770xegpu::DistributeLayoutAttr xegpu::inferMaskOffsetLayoutForScatterIO(
771 xegpu::DistributeLayoutAttr payloadLayout, int chunkSize) {
772 auto rank = payloadLayout.getRank();
773 if (chunkSize > 1)
774 return payloadLayout.dropDims(
775 llvm::to_vector(llvm::seq<int64_t>(rank - 1, rank)));
776 return payloadLayout;
777}
778
779/// Sets up layout for reduction operations by creating a SliceAttr for the
780/// result.
781///
782/// Algorithm Overview:
783/// This function attempts to construct a source layout that, when sliced along
784/// reduction dimensions, produces a result layout compatible with the
785/// consumer layout.
786///
787/// For subgroup layouts, it first tries to align the source layout's subgroup
788/// layout and data with the consumer's layout on non-reduction dimensions.
789/// Then, it distributes remaining subgroups across reduction dimensions. This
790/// avoids subgroup data redistribution overhead between the reduced result and
791/// its consumer. When the consumer layout is a slice layout, it attempts to
792/// reuse the slice layout's parent layout for the source to further minimize
793/// potential data redistribution.
794///
795/// InstData requries {1, ..., min(maxReduceVectorSize, srcShape),subgroupSize}
796/// Lane Layout requires {1, ..., 1, subgroupSize}
797/// Lane data requires {1, ..., min(maxReduceVectorSize, srcShape), 1}
798///
799/// Examples:
800/// 1. Subgroup layout - Row reduction on 2D tensor:
801/// srcShape=[32, 128], reductionDims=[1], resShape=[32], subgroupSize=16,
802/// NumSg=32
803/// * Consumer Layout:
804/// #xegpu.slice<#xegpu.layout<sg_layout=[4, 8], sg_data=[8, 8]>, dims =
805/// [1]>}
806//// * Result Layout:
807/// #xegpu.slice<#xegpu.layout<sg_layout=[4, 8],sg_data=[8, 16]>, dims =
808/// [1]>}
809/// Note that the sg_layout is reused but sg_data needs to be adjusted to
810/// evenly distribute the source tensor tile among the reduction dim.
811///
812/// 2. Subgroup layout - Same example above but consumer doesn't have a
813/// reusable slice layout.
814/// * Consumer Layout:
815/// #xegpu.layout<sgLayout=[32], sgData=[1]>
816/// * Result Layout:
817/// #xegpu.slice<#xegpu.layout<sgLayout=[32,1], sgData=[1, 64]>, dims =
818/// [1]>}
819/// * Consumer Layout:
820/// #xegpu.slice<#xegpu.layout<sgLayout=[8, 2, 4], sgData=[4, 64, 32]>,
821/// dims = [1, 2]>}
822/// * Result Layout:
823/// #xegpu.slice<#xegpu.layout<sgLayout=[8,4], sgData=[4, 32]>, dims =
824/// [1]>}
825/// Note that the consumer's layout can't be directly reused as is.
826/// So the algorithm distributes all subgroups on non reduction dimensions
827/// first and then distribute remaining subgroups on the reduction
828/// dimension.
829///
830/// 2. InstData layout - Column reduction:
831/// srcShape=[32, 64], reductionDims=[0], subgroupSize=16
832/// Result: instData=[1, 16] (maxReduceVectorSize=1, subgroupSize on
833/// innermost)
834///
835/// 3. Lane layout - Multi-dimensional reduction:
836/// srcShape=[16, 32, 64], reductionDims=[1], subgroupSize=16
837/// Result: laneLayout=[1, 1, 16], laneData=[1, 1, 1]
838/// (subgroupSize on innermost dim, max vector size on reduction dim)
839
841 xegpu::LayoutKind layoutKind, VectorType srcVecTy,
842 DistributeLayoutAttr consumerLayout, SmallVector<int64_t> reductionDims,
843 int numSg, const xegpu::uArch::uArch *uArch) {
844
845 auto srcShape = srcVecTy.getShape();
846 int srcRank = srcShape.size();
847 auto context = srcVecTy.getContext();
848
849 // Helper lambda to convert int64 vectors to int32 DenseArrayAttr
850 auto toInt32Attr = [&](ArrayRef<int64_t> vec) {
851 SmallVector<int32_t> vec32(vec.begin(), vec.end());
852 return DenseI32ArrayAttr::get(context, vec32);
853 };
854
855 const int subgroupSize = uArch->getSubgroupSize();
856 int64_t maxReduceVectorSize = 1; // could extend to spirv vector Size
857 xegpu::DistributeLayoutAttr srcLayout;
858 if (layoutKind == xegpu::LayoutKind::Subgroup) {
859 xegpu::SliceAttr consumerSliceLayout =
860 dyn_cast_if_present<xegpu::SliceAttr>(consumerLayout);
861 if (consumerSliceLayout &&
862 consumerSliceLayout.getDims().asArrayRef().equals(reductionDims)) {
863 srcLayout = consumerSliceLayout.getParent();
864 SmallVector<int64_t> sgLayoutFromConsumer =
865 srcLayout.getEffectiveSgLayoutAsInt();
866 auto srcSgData = computeShapeRatio(srcShape, sgLayoutFromConsumer);
867 if (srcSgData)
868 for (int dim = 0; dim < srcRank; dim++) {
869 if (llvm::is_contained(reductionDims, dim))
870 srcLayout =
871 srcLayout.setDimData(dim, srcSgData.value()[dim], -1, -1);
872 }
873 } else {
874 SmallVector<int64_t> consumerSgLayout =
875 consumerLayout ? consumerLayout.getEffectiveSgLayoutAsInt()
877 SmallVector<int64_t> consumerSgData =
878 consumerLayout ? consumerLayout.getEffectiveSgDataAsInt()
880 SmallVector<int64_t> consumerOrder =
881 consumerLayout ? consumerLayout.getEffectiveOrderAsInt()
883 DenseI32ArrayAttr orderAttr =
884 consumerLayout ? consumerLayout.getOrder() : nullptr;
885 SmallVector<int64_t> sgLayout(srcRank), sgData(srcRank), order(srcRank);
886 int remainingSgCount =
887 consumerLayout ? consumerLayout.getNumSubgroups() : numSg;
888 int consumerIdx = 0;
889
890 // First pass: Match consumer's layout on non-reduction dimensions
891 for (int i = 0; i < srcRank; i++) {
892 if (!llvm::is_contained(reductionDims, i) &&
893 consumerIdx < static_cast<int>(consumerSgLayout.size())) {
894 sgLayout[i] = consumerSgLayout[consumerIdx];
895 sgData[i] = consumerSgData[consumerIdx];
896 remainingSgCount /= sgLayout[i];
897 order[i] = consumerOrder[consumerIdx];
898 consumerIdx++;
899 }
900 }
901
902 // Second pass: Distribute remaining subgroups across reduction dimensions
903 // the reduction to scalar case is handled only by this loop
904 int64_t remainOrder = consumerSgLayout.size();
905 for (int i = 0; i < srcRank; i++) {
906 if (llvm::is_contained(reductionDims, i)) {
907 sgLayout[i] =
908 std::min(srcShape[i], static_cast<int64_t>(remainingSgCount));
909 assert((srcShape[i] % sgLayout[i] == 0) &&
910 "source shape not divisible by sg_layout");
911 sgData[i] = srcShape[i] / sgLayout[i];
912 remainingSgCount /= sgLayout[i];
913 order[i] = remainOrder++;
914 }
915 }
916
917 assert(remainingSgCount == 1 && "not all subgroups distributed");
918 srcLayout = xegpu::LayoutAttr::get(
919 context, toInt32Attr(sgLayout), toInt32Attr(sgData),
920 /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
921 /*lane_data =*/nullptr, /*order =*/
922 (!orderAttr || orderAttr.empty()) ? nullptr : toInt32Attr(order));
923 }
924 } else if (layoutKind == xegpu::LayoutKind::InstData) {
925
926 SmallVector<int64_t> instData(srcRank, 1);
927 if (srcRank >= 2)
928 instData[srcRank - 2] =
929 std::min(maxReduceVectorSize, srcShape[srcRank - 2]);
930 instData[srcRank - 1] =
931 std::min(static_cast<int64_t>(subgroupSize), srcShape[srcRank - 1]);
932 srcLayout = xegpu::LayoutAttr::get(context, toInt32Attr(instData));
933 } else if (layoutKind == xegpu::LayoutKind::Lane) {
934
935 SmallVector<int64_t> laneLayout(srcRank, 1), laneData(srcRank, 1);
936 laneLayout[srcRank - 1] =
937 std::min(static_cast<int64_t>(subgroupSize), srcShape[srcRank - 1]);
938 if (srcRank >= 2)
939 laneData[srcRank - 2] =
940 std::min(maxReduceVectorSize, srcShape[srcRank - 2]);
941 srcLayout = xegpu::LayoutAttr::get(context, toInt32Attr(laneLayout),
942 toInt32Attr(laneData));
943 }
944
945 return xegpu::SliceAttr::get(context, srcLayout,
946 DenseI64ArrayAttr::get(context, reductionDims));
947}
948
949/// Sets up layout for Reduction operations by creating a SliceAttr for the
950/// result.
951xegpu::SliceAttr
953 VectorType srcVecTy,
954 const xegpu::uArch::uArch *uArch) {
955
956 auto srcShape = srcVecTy.getShape();
957 auto context = srcVecTy.getContext();
958 auto subgroupSize = uArch->getSubgroupSize();
959 xegpu::LayoutAttr srcLayout;
960
961 if (layoutKind == xegpu::LayoutKind::Subgroup) {
962 assert(true && "subgroup layout assignment not supported for reduction (op "
963 "is not expected at this level).");
964 } else if (layoutKind == xegpu::LayoutKind::InstData) {
965 assert(true && "instData layout assignment not supported for reduction (op "
966 "is not expected at this level).");
967 } else if (layoutKind == xegpu::LayoutKind::Lane) {
968 SmallVector<int32_t> laneLayout(1), laneData(1);
969 laneLayout[0] = std::min(subgroupSize, static_cast<int32_t>(srcShape[0]));
970 laneData[0] = 1;
971 srcLayout = xegpu::LayoutAttr::get(
972 context, DenseI32ArrayAttr::get(context, laneLayout),
973 DenseI32ArrayAttr::get(context, laneData));
974 }
975
976 auto result = xegpu::SliceAttr::get(context, srcLayout,
977 DenseI64ArrayAttr::get(context, 0));
978 return result;
979}
980
981/// Sets up the result layout for a bitcast operation.
982/// When casting to a smaller bitwidth, adjusts the layout dimensions (sgData,
983/// instData, or laneData) by multiplying by the bitwidth ratio to ensure the
984/// result layout can be correctly divided back to the source layout during
985/// inference.
986///
987/// Examples:
988/// 1. Casting f32 -> f16 (32-bit to 16-bit, bitWidthRatio = 2):
989/// Consumer layout: instData=[1, 16], subgroupSize=16
990/// Source shape: [8, 32]
991/// Result layout: instData=[1, 32] (16 * 2)
992/// The innermost dimension is multiplied by 2 to maintain consistency.
993///
994/// 2. Casting f32 -> i8 (32-bit to 8-bit, bitWidthRatio = 4):
995/// Consumer instData=[1, 16], subgroupSize=16
996/// Source shape: [4, 128]
997/// adjust the instData from [1, 16] to [1, 16 * 4 = 64]
998///
999/// 3. Casting i8 -> i32 (8-bit to 32-bit, bitWidthRatio = 1/4):
1000/// Consumer layout: laneLayout=[1, 16], laneData=[1, 4]
1001/// No adjustment needed - returns consumer layout directly.
1002///
1003xegpu::DistributeLayoutAttr xegpu::setupBitCastResultLayout(
1004 xegpu::LayoutKind layoutKind, VectorType srcVecTy, VectorType resVecTy,
1005 DistributeLayoutAttr consumerLayout, const xegpu::uArch::uArch *uArch) {
1006
1007 int srcElemTyBitWidth = srcVecTy.getElementType().getIntOrFloatBitWidth();
1008 int resElemTyBitWidth = resVecTy.getElementType().getIntOrFloatBitWidth();
1009
1010 ArrayRef<int64_t> srcShape = srcVecTy.getShape();
1011 ArrayRef<int64_t> resShape = resVecTy.getShape();
1012 SmallVector<int64_t> sgData = consumerLayout.getEffectiveSgDataAsInt();
1013 SmallVector<int64_t> instData = consumerLayout.getEffectiveInstDataAsInt();
1014 SmallVector<int64_t> laneData = consumerLayout.getEffectiveLaneDataAsInt();
1015 SmallVector<int64_t> laneLayout =
1016 consumerLayout.getEffectiveLaneLayoutAsInt();
1017
1018 assert(consumerLayout.getRank() == static_cast<int64_t>(srcShape.size()) &&
1019 "laneData must be available for all dimensions");
1020 size_t innerMostDim = srcShape.size() - 1;
1021 int64_t sgDataValue = -1;
1022 int64_t instDataValue = -1;
1023 int64_t laneDataValue = -1;
1024 if (srcElemTyBitWidth > resElemTyBitWidth) {
1025 // When casting to a smaller bitwidth, multiply the result layout
1026 // accordingly to ensure it can be divided by the ratio back to the
1027 // source layout.
1028 int bitWidthRatio = srcElemTyBitWidth / resElemTyBitWidth;
1029 if (layoutKind == xegpu::LayoutKind::Subgroup) {
1030 sgDataValue = sgData[innerMostDim];
1031 while ((sgDataValue <= resShape[innerMostDim]) &&
1032 (sgDataValue % bitWidthRatio) != 0)
1033 sgDataValue *= 2;
1034 } else if (layoutKind == xegpu::LayoutKind::InstData) {
1035 instDataValue = instData[innerMostDim];
1036 const int innermostDimLaneLayout = laneLayout.empty()
1037 ? uArch->getSubgroupSize()
1038 : laneLayout[innerMostDim];
1039 // Adjust instDataValue so it still fits within an instruction after
1040 // dividing by bitWidthRatio
1041 while ((instDataValue <= resShape[innerMostDim]) &&
1042 (instDataValue % (innermostDimLaneLayout * bitWidthRatio) != 0))
1043 instDataValue *= 2;
1044 assert((resShape[innerMostDim] % instDataValue) == 0 &&
1045 "resShape, instData, and lanelayout for innermost must be 2^n !");
1046 } else if (layoutKind == xegpu::LayoutKind::Lane) {
1047 laneDataValue = laneData[innerMostDim];
1048 while ((laneDataValue <= resShape[innerMostDim]) &&
1049 (laneDataValue % bitWidthRatio != 0))
1050 laneDataValue *= 2;
1051 }
1052 // Now set only instData and laneData, preserving sgData
1053 xegpu::DistributeLayoutAttr resLayout;
1054 resLayout = consumerLayout.setDimData(innerMostDim, sgDataValue,
1055 instDataValue, laneDataValue);
1056 return resLayout;
1057 }
1058 return consumerLayout;
1059}
1060
1061/// Sets up the result layout for an interleave operation to ensure the source
1062/// layout can be safely derived. Interleave doubles the innermost dimension,
1063/// so the result layout must ensure that laneData is a multiple
1064/// of 2, and instData must be divisible by innermostDimLaneLayout * 2.
1065///
1066/// Example:
1067/// Interleave: vector<128x256xf4> -> vector<128x512xf4>
1068/// Consumer layout: laneLayout=[1, 16], laneData=[1, 4], instData=[1, 64]
1069/// Result layout adjustment to ensure source can be safely inferred:
1070/// - laneData must be >= 2 and multiple of 2 (so source = laneData/2 is
1071/// valid)
1072/// - instData must be divisible by (16 * 2 = 32) (so source = instData/2 is
1073/// valid)
1074/// - Adjusted instData: ensure (instData % 32 == 0)
1075///
1076xegpu::DistributeLayoutAttr xegpu::setupInterleaveResultLayout(
1077 xegpu::LayoutKind layoutKind, VectorType srcVecTy, VectorType resVecTy,
1078 DistributeLayoutAttr consumerLayout, const xegpu::uArch::uArch *uArch) {
1079
1080 ArrayRef<int64_t> srcShape = srcVecTy.getShape();
1081 SmallVector<int64_t> sgData = consumerLayout.getEffectiveSgDataAsInt();
1082 SmallVector<int64_t> instData = consumerLayout.getEffectiveInstDataAsInt();
1083 SmallVector<int64_t> laneData = consumerLayout.getEffectiveLaneDataAsInt();
1084 SmallVector<int64_t> laneLayout =
1085 consumerLayout.getEffectiveLaneLayoutAsInt();
1086
1087 assert(consumerLayout.getRank() == static_cast<int64_t>(srcShape.size()) &&
1088 "consumer layout rank must match source shape rank");
1089 const size_t innerMostDim = srcShape.size() - 1;
1090 int64_t sgDataValue = -1;
1091 int64_t instDataValue = -1;
1092 int64_t laneDataValue = -1;
1093
1094 // Interleave doubles the innermost dimension (ratio = 2)
1095 constexpr int ratio = 2;
1096
1097 if (layoutKind == xegpu::LayoutKind::Subgroup) {
1098 sgDataValue = sgData[innerMostDim];
1099 // Ensure sgDataValue is divisible by ratio so source sgData can be inferred
1100 while ((sgDataValue <= srcShape[innerMostDim]) &&
1101 (sgDataValue % ratio != 0))
1102 sgDataValue *= ratio;
1103 } else if (layoutKind == xegpu::LayoutKind::InstData) {
1104 instDataValue = instData[innerMostDim];
1105 const int innermostDimLaneLayout = laneLayout.empty()
1106 ? uArch->getSubgroupSize()
1107 : laneLayout[innerMostDim];
1108 // Adjust instDataValue so it can be divided by (innermostDimLaneLayout *
1109 // ratio) when inferring the source layout
1110 while ((instDataValue <= srcShape[innerMostDim]) &&
1111 (instDataValue % (innermostDimLaneLayout * ratio) != 0))
1112 instDataValue *= ratio;
1113 assert((srcShape[innerMostDim] % instDataValue) == 0 &&
1114 "srcShape, instData, and laneLayout for innermost must be 2^n!");
1115 } else if (layoutKind == xegpu::LayoutKind::Lane) {
1116 laneDataValue = laneData[innerMostDim];
1117 // Ensure laneDataValue is at least 2 and divisible by ratio
1118 // so that source laneData = laneDataValue/2 is valid
1119 while ((laneDataValue <= srcShape[innerMostDim]) &&
1120 (laneDataValue % ratio != 0))
1121 laneDataValue *= ratio;
1122 }
1123
1124 return consumerLayout.setDimData(innerMostDim, sgDataValue, instDataValue,
1125 laneDataValue);
1126}
1127
1128/// Sets up the result layout for an insert strided slice operation.
1129/// Creates a result layout based on the specified layout kind (InstData or
1130/// Lane).
1131xegpu::DistributeLayoutAttr xegpu::setupInsertStridedSliceResultLayout(
1132 xegpu::LayoutKind layoutKind, VectorType srcVectorTy,
1133 VectorType resVectorTy, xegpu::DistributeLayoutAttr consumerLayout,
1134 const xegpu::uArch::uArch *uArch) {
1135
1136 xegpu::DistributeLayoutAttr requiredResLayout;
1137 SmallVector<int64_t> consumerInstData =
1138 consumerLayout.getEffectiveInstDataAsInt();
1139 SmallVector<int64_t> consumerLaneData =
1140 consumerLayout.getEffectiveLaneDataAsInt();
1141 SmallVector<int64_t> consumerLaneLayout =
1142 consumerLayout.getEffectiveLaneLayoutAsInt();
1143 ArrayRef<int64_t> srcShape = srcVectorTy.getShape();
1144 int64_t instDataValue = -1;
1145 int64_t laneDataValue = -1;
1146
1147 requiredResLayout = consumerLayout;
1148 int srcRank = srcShape.size();
1149
1150 if (layoutKind == xegpu::LayoutKind::Subgroup) {
1151 assert(true &&
1152 "subgroup layout assignment not supported for insertStridedSlice.");
1153 } else if (layoutKind == xegpu::LayoutKind::InstData) {
1154 for (int dim = 0; dim < srcRank; dim++) {
1155 instDataValue = std::min(srcShape[dim], consumerInstData[dim]);
1156 requiredResLayout =
1157 requiredResLayout.setDimData(dim, -1, instDataValue, -1);
1158 }
1159 } else if (layoutKind == xegpu::LayoutKind::Lane) {
1160 for (int dim = 0; dim < srcRank; dim++) {
1161 assert(srcShape[dim] % consumerLaneLayout[dim] == 0 &&
1162 "srcShape must be divisible by laneLayout for all dimensions");
1163 laneDataValue = std::min(srcShape[dim] / consumerLaneLayout[dim],
1164 consumerLaneData[dim]);
1165 requiredResLayout =
1166 requiredResLayout.setDimData(dim, -1, -1, laneDataValue);
1167 }
1168 }
1169 return requiredResLayout;
1170}
1171
1172/// Sets up the anchor layout for load gather and load matrix operation.
1173/// load matrix lowers to load gather and 1d block load. All of them share the
1174/// same layout setup logic.
1175/// For Subgroup layout, uses the consumer layout directly.
1176/// non-chunked loads (1D or 2D):
1177/// InstData = {1, ..., min(consumer, maxLaneLoadSize * subgroupSize)}
1178/// LaneLayout = {1, ..., subgroupSize}
1179/// lane_data = {1, ..., min(consumer, maxLaneLoadSize)}
1180/// chunked loads (2D only):
1181/// InstData = {subgroupSize, min(consumer, maxLaneLoadSize)}
1182/// LaneLayout = {subgroupSize, 1}
1183/// lane_data={1,min(consumer, maxLaneLoadSize)}
1184static xegpu::DistributeLayoutAttr setupGenericLoadAnchorLayout(
1185 xegpu::LayoutKind layoutKind, mlir::MLIRContext *context,
1186 xegpu::DistributeLayoutAttr consumerLayout, bool isChunkedLoad,
1187 int maxChunkSize, ArrayRef<int64_t> resShape, int subgroupSize) {
1188
1189 if (layoutKind == xegpu::LayoutKind::Subgroup)
1190 return consumerLayout;
1191
1192 SmallVector<int64_t> consumerInstData =
1193 consumerLayout.getEffectiveInstDataAsInt();
1194 SmallVector<int64_t> consumerLaneData =
1195 consumerLayout.getEffectiveLaneDataAsInt();
1196
1197 SmallVector<int> instData(resShape.size(), 1);
1198 SmallVector<int> laneLayout(resShape.size(), 1);
1199 SmallVector<int> laneData(resShape.size(), 1);
1200
1201 if (!isChunkedLoad) {
1202 if (layoutKind == xegpu::LayoutKind::InstData) {
1203 instData.back() = std::min(static_cast<int>(consumerInstData.back()),
1204 maxChunkSize * subgroupSize);
1205 return xegpu::LayoutAttr::get(context, instData);
1206 } else if (layoutKind == xegpu::LayoutKind::Lane) {
1207 laneData.back() =
1208 std::min(static_cast<int>(consumerLaneData.back()), maxChunkSize);
1209 laneLayout.back() = std::min(static_cast<int64_t>(subgroupSize),
1210 resShape.back() / laneData.back());
1211 return xegpu::LayoutAttr::get(context, laneLayout, laneData);
1212 }
1213 } else {
1214 assert(resShape.size() == 2 && "Chunked Store must access 2D tensor tile.");
1215 if (layoutKind == xegpu::LayoutKind::InstData) {
1216 instData[0] = subgroupSize;
1217 instData[1] =
1218 std::min(static_cast<int>(consumerInstData[1]), maxChunkSize);
1219 return xegpu::LayoutAttr::get(context, instData);
1220 } else if (layoutKind == xegpu::LayoutKind::Lane) {
1221 laneLayout[0] = subgroupSize;
1222 laneData[1] =
1223 std::min(static_cast<int>(consumerLaneData[1]), maxChunkSize);
1224 return xegpu::LayoutAttr::get(context, laneLayout, laneData);
1225 }
1226 }
1227 return nullptr;
1228}
1229
1230/// Sets up the anchor layout for a load gather operation.
1231xegpu::DistributeLayoutAttr xegpu::setupLoadGatherAnchorLayout(
1232 xegpu::LayoutKind layoutKind, VectorType resVecTy, int chunkSize,
1233 xegpu::DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch) {
1234
1235 const int subgroupSize = uArch->getSubgroupSize();
1236 ArrayRef<int64_t> resShape = resVecTy.getShape();
1237 auto context = resVecTy.getContext();
1238 auto elemBitWidth = resVecTy.getElementType().getIntOrFloatBitWidth();
1239
1240 const auto *uArchInstruction =
1241 dyn_cast<xegpu::uArch::LoadGatherInstructionInterface>(
1243 int maxChunkSize = uArchInstruction->getMaxLaneLoadSize(elemBitWidth);
1244
1245 return setupGenericLoadAnchorLayout(layoutKind, context, consumerLayout,
1246 (chunkSize > 1), maxChunkSize, resShape,
1247 subgroupSize);
1248}
1249
1250/// Sets up the anchor layout for load matrix operation.
1251/// TODO: enhance load matrix to indicate lowering to chunked load or not.
1252xegpu::DistributeLayoutAttr
1254 VectorType resVecTy,
1255 xegpu::DistributeLayoutAttr consumerLayout,
1256 const xegpu::uArch::uArch *uArch) {
1257
1258 const int subgroupSize = uArch->getSubgroupSize();
1259 ArrayRef<int64_t> resShape = resVecTy.getShape();
1260 auto context = resVecTy.getContext();
1261 auto elemBitWidth = resVecTy.getElementType().getIntOrFloatBitWidth();
1262
1263 const auto *uArchInstruction =
1264 dyn_cast<xegpu::uArch::LoadGatherInstructionInterface>(
1266 int maxChunkSize = uArchInstruction->getMaxLaneLoadSize(elemBitWidth);
1267 return setupGenericLoadAnchorLayout(layoutKind, context, consumerLayout,
1268 false, maxChunkSize, resShape,
1269 subgroupSize);
1270}
1271
1272/// Sets up the anchor layout for store scatter and store matrix operation.
1273/// store matrix lowers to store scatter and 1d block store. All of them share
1274/// the same layout setup logic. For Subgroup layout, not supported yet.
1275/// non-chunked stores (1D or 2D):
1276/// InstData = {1, ..., subgroupSize}
1277/// LaneLayout = {1, ..., subgroupSize}
1278/// lane_data = {1, ..., 1}
1279/// chunked stores (2D only):
1280/// InstData = {subgroupSize, min(srcVec, maxLaneStoreSize)}
1281/// LaneLayout = {subgroupSize, 1}
1282/// lane_data={1,min(srcVec, maxLaneStoreSize)}
1283static xegpu::DistributeLayoutAttr
1285 mlir::MLIRContext *context, bool isChunkedStore,
1286 int maxChunkSize, ArrayRef<int64_t> srcShape,
1287 int subgroupSize) {
1288
1289 int srcShapeSize = srcShape.size();
1290 SmallVector<int> instData(srcShapeSize, 1);
1291 SmallVector<int> laneLayout(srcShapeSize, 1);
1292 SmallVector<int> laneData(srcShapeSize, 1);
1293
1294 if (layoutKind == xegpu::LayoutKind::Subgroup) {
1295 assert(true &&
1296 "subgroup layout assignment not supported for storeScatter.");
1297 return nullptr;
1298 }
1299
1300 if (!isChunkedStore) {
1301 if (layoutKind == xegpu::LayoutKind::InstData) {
1302 instData[srcShapeSize - 1] =
1303 std::min(subgroupSize, static_cast<int>(srcShape.back()));
1304 return xegpu::LayoutAttr::get(context, instData);
1305 } else if (layoutKind == xegpu::LayoutKind::Lane) {
1306 laneLayout[srcShapeSize - 1] =
1307 std::min(subgroupSize, static_cast<int>(srcShape.back()));
1308 return xegpu::LayoutAttr::get(context, laneLayout, laneData);
1309 }
1310 } else {
1311 assert(srcShapeSize == 2 && "Chunked Store must access 2D tensor tile.");
1312 if (layoutKind == xegpu::LayoutKind::InstData) {
1313 instData[0] = subgroupSize;
1314 instData[1] = std::min(static_cast<int>(srcShape[1]), maxChunkSize);
1315 return xegpu::LayoutAttr::get(context, instData);
1316 } else if (layoutKind == xegpu::LayoutKind::Lane) {
1317 laneLayout[0] = subgroupSize;
1318 laneData[1] = std::min(static_cast<int>(srcShape[1]), maxChunkSize);
1319 return xegpu::LayoutAttr::get(context, laneLayout, laneData);
1320 }
1321 }
1322 return nullptr;
1323}
1324
1325/// Sets up the anchor layout for a store scatter operation.
1326xegpu::DistributeLayoutAttr
1328 VectorType srcVecTy, int chunkSize,
1329 const uArch::uArch *uArch) {
1330
1331 const int subgroupSize = uArch->getSubgroupSize();
1332 ArrayRef<int64_t> srcShape = srcVecTy.getShape();
1333 auto context = srcVecTy.getContext();
1334 auto elemBitWidth = srcVecTy.getElementType().getIntOrFloatBitWidth();
1335
1336 const auto *uArchInstruction =
1337 dyn_cast<xegpu::uArch::StoreScatterInstructionInterface>(
1339 int maxChunkSize = uArchInstruction->getMaxLaneStoreSize(elemBitWidth);
1340 return setupGenericStoreAnchorLayout(layoutKind, context, (chunkSize > 1),
1341 maxChunkSize, srcShape, subgroupSize);
1342}
1343
1344/// Sets up the anchor layout for a store matrix operation.
1345xegpu::DistributeLayoutAttr
1347 VectorType srcVecTy,
1348 const xegpu::uArch::uArch *uArch) {
1349
1350 const int subgroupSize = uArch->getSubgroupSize();
1351 ArrayRef<int64_t> srcShape = srcVecTy.getShape();
1352 auto context = srcVecTy.getContext();
1353 auto elemBitWidth = srcVecTy.getElementType().getIntOrFloatBitWidth();
1354
1355 const auto *uArchInstruction =
1356 dyn_cast<xegpu::uArch::StoreScatterInstructionInterface>(
1358 int maxChunkSize = uArchInstruction->getMaxLaneStoreSize(elemBitWidth);
1359
1360 return setupGenericStoreAnchorLayout(layoutKind, context, false, maxChunkSize,
1361 srcShape, subgroupSize);
1362}
1363
1364// This function returns the default lane layout for a given vector type.
1365// - `packingSize` means multiple consecutive elements can be accessed
1366// together as a single unit.
1367// - `vnni` means data packing is column-wise (i.e., 2x1xf16 with vnni vs.
1368// 1x2xf16 w/o vnni).
1369template <typename RankedTy>
1370static xegpu::LayoutAttr getDefaultLaneLayout2DBlockIo(
1371 RankedTy ty, const xegpu::uArch::uArch *uArch,
1372 std::optional<unsigned> packingSize = std::nullopt, bool vnni = false) {
1373 // Expecting at least 1D vector. For rank > 2, leading dims are batch dims.
1374 assert(((ty.getRank() >= 1 && !vnni) || ty.getRank() >= 2) &&
1375 "Expected at least 1D non-vnni or 2D vector.");
1376 // Expecting int or float element type.
1377 assert(ty.getElementType().isIntOrFloat() &&
1378 "Expected int or float element type.");
1379
1380 auto context = ty.getContext();
1381 auto rank = ty.getRank();
1382 SmallVector<int> laneLayout(rank, 1);
1383 SmallVector<int> laneData(rank, 1);
1384 if (packingSize.has_value()) {
1385 unsigned bitwidth = ty.getElementType().getIntOrFloatBitWidth();
1386 int &laneDataPos = vnni ? laneData[rank - 2] : laneData.back();
1387 laneDataPos = bitwidth < *packingSize ? *packingSize / bitwidth : 1;
1388 }
1389 laneLayout.back() = uArch->getSubgroupSize();
1390 return xegpu::LayoutAttr::get(context, laneLayout, laneData);
1391}
1392
1393// This function returns all layouts for the given sgCount, whose sgData:
1394// 1. Evenly divides the wgShape.
1395// 2. Is a multiple of instData.
1396// Example:
1397// wgShape = [128, 64], instData = [8, 16], sgCount = 32
1398// Returns layouts:
1399// [(8,4), (16,2)], which correspond to sgData [16,16] and [8,32].
1400using LayoutRepresentation = std::pair<int64_t, int64_t>;
1403 int64_t sgCount) {
1405 for (int sgLayout0 = 1; sgLayout0 <= sgCount; ++sgLayout0) {
1406 if (sgCount % sgLayout0)
1407 continue;
1408 int64_t sgLayout1 = sgCount / sgLayout0;
1409 int64_t sgData0 = wgShape[0] / sgLayout0;
1410 int64_t sgData1 = wgShape[1] / sgLayout1;
1411 if ((wgShape[0] % sgLayout0 || wgShape[1] % sgLayout1) ||
1412 (sgData0 % instData[0] || sgData1 % instData[1]))
1413 continue;
1414 candidates.emplace_back(sgLayout0, sgLayout1);
1415 }
1416 // Sort primarily by how balanced they are
1417 // (i.e., minimize the absolute difference between the two dimensions), and
1418 // secondarily by the first dimension in ascending order.
1419 llvm::sort(candidates, [](const LayoutRepresentation &lhs,
1420 const LayoutRepresentation &rhs) {
1421 int diffLhs = std::abs(lhs.first - lhs.second);
1422 int diffRhs = std::abs(rhs.first - rhs.second);
1423 if (diffLhs != diffRhs)
1424 return diffLhs < diffRhs;
1425 return lhs.first < rhs.first;
1426 });
1427 return candidates;
1428}
1429
1430/// Helper function to compute inst_data vectors for DPAS operands A, B, and
1431/// C/D.
1432static std::optional<std::tuple<SmallVector<int64_t>, SmallVector<int64_t>,
1434getDpasInstDataVectors(VectorType aTy, VectorType bTy, VectorType cdTy,
1436 bool isDpasMx = false) {
1437 const int subgroupSize = uArch->getSubgroupSize();
1438
1439 const xegpu::uArch::MMAInstructionInterface *uArchInstruction;
1440 if (isDpasMx)
1441 uArchInstruction = dyn_cast<xegpu::uArch::SubgroupScaledMatrixMultiplyAcc>(
1444 else
1445 uArchInstruction =
1446 dyn_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc>(uArch->getInstruction(
1448
1449 // M dimension is the second-to-last dim of A (handles batch dims).
1450 const unsigned dataALen = aTy.getShape()[aTy.getRank() - 2];
1451 auto supportedALen = uArchInstruction->getSupportedM(aTy.getElementType());
1452 const int maxALen =
1453 xegpu::getLargestDivisor(dataALen, ArrayRef<unsigned>(supportedALen));
1454
1455 // N dimension is the last dim of B.
1456 const unsigned dataBLen = bTy.getShape().back();
1457 auto supportedBLen = uArchInstruction->getSupportedN(bTy.getElementType());
1458 const int maxBLen =
1459 xegpu::getLargestDivisor(dataBLen, ArrayRef<unsigned>(supportedBLen));
1460
1461 auto supportedCLen = uArchInstruction->getSupportedN(cdTy.getElementType());
1462 const int maxCLen =
1463 xegpu::getLargestDivisor(dataBLen, ArrayRef<unsigned>(supportedCLen));
1464 if (maxALen == -1 || maxBLen == -1 || maxCLen == -1)
1465 return std::nullopt;
1466
1467 // For DPAS_MX, use getSupportedK to get the scaled K dimension.
1468 // assume single element in the returned vector.
1469 int kDimSize = subgroupSize;
1470 if (isDpasMx) {
1471 auto supportedKLen = uArchInstruction->getSupportedK(aTy.getElementType());
1472 if (supportedKLen.empty())
1473 return std::nullopt;
1474 kDimSize = supportedKLen[0];
1475 }
1476
1477 SmallVector<int64_t> instDataA(aTy.getRank(), 1);
1478 instDataA[aTy.getRank() - 2] = maxALen;
1479 instDataA[aTy.getRank() - 1] = kDimSize;
1480 SmallVector<int64_t> instDataB(bTy.getRank(), 1);
1481 instDataB[bTy.getRank() - 2] = kDimSize;
1482 instDataB[bTy.getRank() - 1] = maxBLen;
1483 SmallVector<int64_t> instDataCD(cdTy.getRank(), 1);
1484 instDataCD[cdTy.getRank() - 2] = maxALen;
1485 instDataCD[cdTy.getRank() - 1] = maxCLen;
1486 return std::make_tuple(instDataA, instDataB, instDataCD);
1487}
1488
1489/// Helper function to set up subgroup layouts for DPAS operands A, B, and C/D.
1490/// Returns the three layouts if successful, nullopt otherwise.
1491static std::optional<
1492 std::tuple<xegpu::DistributeLayoutAttr, xegpu::DistributeLayoutAttr,
1493 xegpu::DistributeLayoutAttr>>
1495 VectorType bTy, VectorType cdTy,
1496 xegpu::DistributeLayoutAttr consumerLayout, int numSg,
1497 const xegpu::uArch::uArch *uArch) {
1498 auto instDataVecs = getDpasInstDataVectors(aTy, bTy, cdTy, uArch);
1499 if (!instDataVecs)
1500 return std::nullopt;
1501 auto [instDataA, instDataB, instDataCD] = *instDataVecs;
1502 assert(instDataA.size() == 2 && instDataB.size() == 2 &&
1503 instDataCD.size() == 2 &&
1504 "Sg layout creation expects valid 2D inst data");
1505
1506 std::optional<LayoutRepresentation> consumerSgLayout = std::nullopt;
1507 if (consumerLayout && consumerLayout.isForWorkgroup()) {
1508 SmallVector<int64_t> sgLayoutD = consumerLayout.getEffectiveSgLayoutAsInt();
1509 consumerSgLayout = std::make_pair(sgLayoutD[0], sgLayoutD[1]);
1510 }
1511
1512 // Get all valid layouts for A, B and C/D operands
1513 auto layoutsA = getValidLayouts(aTy.getShape(), instDataA, numSg);
1514 auto layoutsB = getValidLayouts(bTy.getShape(), instDataB, numSg);
1515 auto layoutsCD = getValidLayouts(cdTy.getShape(), instDataCD, numSg);
1516 if (layoutsA.empty() || layoutsB.empty() || layoutsCD.empty())
1517 return std::nullopt;
1518
1519 // Pick the best subgroup layout
1520 llvm::DenseSet<LayoutRepresentation> setA(layoutsA.begin(), layoutsA.end());
1521 llvm::DenseSet<LayoutRepresentation> setCD(layoutsCD.begin(),
1522 layoutsCD.end());
1523 std::optional<LayoutRepresentation> bestPick;
1524 auto checkAlignedSgDataAB = [&](LayoutRepresentation sgLayout) {
1525 return aTy.getShape().back() / sgLayout.second ==
1526 bTy.getShape().front() / sgLayout.first;
1527 };
1528 for (auto &sgLayout : layoutsB) {
1529 if (setA.contains(sgLayout) && setCD.contains(sgLayout)) {
1530 if (!checkAlignedSgDataAB(sgLayout))
1531 continue;
1532 // Is in (A and B and CD) and matches consumer -> best pick
1533 if (consumerSgLayout.has_value() && sgLayout == *consumerSgLayout) {
1534 bestPick = sgLayout;
1535 break;
1536 }
1537 // Is in (A and B and CD) layoutsB is ordered from most
1538 // balanced to least. So the first one we see is the most balanced one,
1539 // remember it and later only update if there is one that matches the
1540 // consumer.
1541 if (!bestPick)
1542 bestPick = sgLayout;
1543 }
1544 }
1545 if (!bestPick)
1546 return std::nullopt;
1547
1548 SmallVector<int> sgLayout = {static_cast<int>(bestPick->first),
1549 static_cast<int>(bestPick->second)};
1550 SmallVector<int> sgDataA = {static_cast<int>(aTy.getShape()[0] / sgLayout[0]),
1551 static_cast<int>(aTy.getShape()[1])};
1552 SmallVector<int> sgDataB = {
1553 static_cast<int>(bTy.getShape()[0]),
1554 static_cast<int>(bTy.getShape()[1] / sgLayout[1])};
1555 SmallVector<int> sgDataCD = {
1556 static_cast<int>(cdTy.getShape()[0] / sgLayout[0]),
1557 static_cast<int>(cdTy.getShape()[1] / sgLayout[1])};
1558
1559 auto dpasALayout =
1560 xegpu::LayoutAttr::get(context, DenseI32ArrayAttr::get(context, sgLayout),
1561 DenseI32ArrayAttr::get(context, sgDataA), nullptr,
1562 nullptr, nullptr, nullptr);
1563 auto dpasBLayout =
1564 xegpu::LayoutAttr::get(context, DenseI32ArrayAttr::get(context, sgLayout),
1565 DenseI32ArrayAttr::get(context, sgDataB), nullptr,
1566 nullptr, nullptr, nullptr);
1567 auto dpasCDLayout =
1568 xegpu::LayoutAttr::get(context, DenseI32ArrayAttr::get(context, sgLayout),
1569 DenseI32ArrayAttr::get(context, sgDataCD), nullptr,
1570 nullptr, nullptr, nullptr);
1571
1572 return std::make_tuple(dpasALayout, dpasBLayout, dpasCDLayout);
1573}
1574
1575/// Sets up the anchor layouts for dpas operands (A, B, and C/D).
1576/// The numSg and consumerLayout (optional) are only used by sg layout
1577/// creation.
1578std::optional<
1579 std::tuple<xegpu::DistributeLayoutAttr, xegpu::DistributeLayoutAttr,
1580 xegpu::DistributeLayoutAttr>>
1581xegpu::setupDpasLayout(xegpu::LayoutKind layoutKind, VectorType aTy,
1582 VectorType bTy, VectorType cdTy,
1583 xegpu::DistributeLayoutAttr consumerLayout, int numSg,
1584 const xegpu::uArch::uArch *uArch) {
1585 auto context = aTy.getContext();
1586 const auto *uArchInstruction =
1587 dyn_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc>(uArch->getInstruction(
1589
1590 if (layoutKind == xegpu::LayoutKind::Subgroup) {
1591 assert(numSg > 0 &&
1592 "Number of subgroups must be provided for sg layout creation.");
1593 return getupDpasSubgroupLayouts(context, aTy, bTy, cdTy, consumerLayout,
1594 numSg, uArch);
1595 } else if (layoutKind == xegpu::LayoutKind::InstData) {
1596 auto instDataVecs = getDpasInstDataVectors(aTy, bTy, cdTy, uArch);
1597 if (!instDataVecs)
1598 return std::nullopt;
1599 auto [instDataA, instDataB, instDataCD] = *instDataVecs;
1600 return std::make_tuple(
1601 xegpu::LayoutAttr::get(
1602 context, SmallVector<int>(instDataA.begin(), instDataA.end())),
1603 xegpu::LayoutAttr::get(
1604 context, SmallVector<int>(instDataB.begin(), instDataB.end())),
1605 xegpu::LayoutAttr::get(
1606 context, SmallVector<int>(instDataCD.begin(), instDataCD.end())));
1607 } else if (layoutKind == xegpu::LayoutKind::Lane) {
1608 auto aLayout = getDefaultLaneLayout2DBlockIo(
1609 aTy, uArch, uArchInstruction->getPackedFormatBitSizeA());
1610 auto bLayout = getDefaultLaneLayout2DBlockIo(
1611 bTy, uArch, uArchInstruction->getPackedFormatBitSizeB(), true);
1612 auto cdLayout = getDefaultLaneLayout2DBlockIo(
1613 cdTy, uArch /*, packingSize = std::nullopt */);
1614 return std::make_tuple(aLayout, bLayout, cdLayout);
1615 }
1616 return std::nullopt;
1617}
1618
1619/// Helper to create a scale layout derived from a matrix operand layout.
1620/// The scale layout is computed by mapping each dimension of the matrix layout
1621/// to the corresponding scale tensor dimension using the ratio between the
1622/// matrix and scale shapes.
1623static xegpu::DistributeLayoutAttr
1624createScaleLayout(mlir::MLIRContext *context, VectorType matrixTy,
1625 VectorType scaleTy, xegpu::DistributeLayoutAttr matrixLayout,
1626 bool isBScale, const xegpu::uArch::uArch *uArch) {
1627 if (!scaleTy || !matrixLayout)
1628 return nullptr;
1629
1630 // Calculate scaling factor by dividing matrix shape by scale shape
1631 ArrayRef<int64_t> matrixShape = matrixTy.getShape();
1632 ArrayRef<int64_t> scaleShape = scaleTy.getShape();
1633
1634 // Scale shapes can be 1D or 2D, handle both cases
1635 if (scaleShape.empty())
1636 return nullptr;
1637
1638 auto uArchInstruction =
1639 dyn_cast<xegpu::uArch::SubgroupScaledMatrixMultiplyAcc>(
1642
1643 int64_t rank = matrixLayout.getRank();
1644 assert(rank >= 2 && "dpas layouts must be at least two dimensions");
1645
1646 SmallVector<int64_t> sgLayout = matrixLayout.getEffectiveSgLayoutAsInt();
1647 SmallVector<int64_t> sgData = matrixLayout.getEffectiveSgDataAsInt();
1648 SmallVector<int64_t> instData = matrixLayout.getEffectiveInstDataAsInt();
1649 SmallVector<int64_t> laneLayout = matrixLayout.getEffectiveLaneLayoutAsInt();
1650 SmallVector<int64_t> laneData = matrixLayout.getEffectiveLaneDataAsInt();
1651 auto order = matrixLayout.getOrder();
1652
1653 SmallVector<int> scaleSgLayout;
1654 SmallVector<int> scaleSgData;
1655 if (!sgLayout.empty() && !sgData.empty()) {
1656 scaleSgLayout.assign(sgLayout.begin(), sgLayout.end());
1657 scaleSgData.assign(sgData.begin(), sgData.end());
1658 scaleSgData[rank - 2] = std::max<int64_t>(
1659 scaleShape[rank - 2] / (matrixShape[rank - 2] / sgData[rank - 2]), 1);
1660 scaleSgData[rank - 1] = std::max<int64_t>(
1661 scaleShape[rank - 1] / (matrixShape[rank - 1] / sgData[rank - 1]), 1);
1662 }
1663
1664 // For DPAS_MX scales: if matrix has inst_data, scale needs adjusted
1665 // inst_data. Scale inst_data is derived from matrix inst_data divided by
1666 // scale factor.
1667 SmallVector<int> scaleInstData;
1668 if (!instData.empty()) {
1669 scaleInstData.assign(instData.begin(), instData.end());
1670 if (isBScale)
1671 scaleInstData[rank - 2] = std::max<int64_t>(
1672 scaleShape[rank - 2] / (matrixShape[rank - 2] / instData[rank - 2]),
1673 1);
1674 else
1675 scaleInstData[rank - 1] = std::max<int64_t>(
1676 scaleShape[rank - 1] / (matrixShape[rank - 1] / instData[rank - 1]),
1677 1);
1678 }
1679
1680 SmallVector<int> scaleLaneLayout;
1681 SmallVector<int> scaleLaneData;
1682 if (!laneLayout.empty() && !laneData.empty()) {
1683 scaleLaneLayout.assign(laneLayout.begin(), laneLayout.end());
1684 scaleLaneData.assign(laneData.begin(), laneData.end());
1685 bool isRowMajor = uArchInstruction->isLaneLayoutRowMajorOrder();
1686 if (isBScale ^ isRowMajor) {
1687 std::swap(scaleLaneLayout[rank - 2], scaleLaneLayout[rank - 1]);
1688 scaleLaneLayout[rank - 2] =
1689 std::min<int64_t>(scaleShape[rank - 2], scaleLaneLayout[rank - 2]);
1690 }
1691 scaleLaneData[rank - 2] =
1692 std::max<int64_t>(scaleShape[rank - 2] / scaleLaneLayout[rank - 2], 1);
1693 scaleLaneData[rank - 1] =
1694 std::max<int64_t>(scaleShape[rank - 1] / scaleLaneLayout[rank - 1], 1);
1695 }
1696 return xegpu::LayoutAttr::get(
1697 context,
1698 scaleSgLayout.empty() ? nullptr
1699 : DenseI32ArrayAttr::get(context, scaleSgLayout),
1700 scaleSgData.empty() ? nullptr
1701 : DenseI32ArrayAttr::get(context, scaleSgData),
1702 scaleInstData.empty() ? nullptr
1703 : DenseI32ArrayAttr::get(context, scaleInstData),
1704 scaleLaneLayout.empty()
1705 ? nullptr
1706 : DenseI32ArrayAttr::get(context, scaleLaneLayout),
1707 scaleLaneData.empty() ? nullptr
1708 : DenseI32ArrayAttr::get(context, scaleLaneData),
1709 order);
1710}
1711
1712/// Sets up the anchor layouts for dpas_mx operands (A, B, C/D, A_scale, and
1713/// B_scale). The numSg and consumerLayout (optional) are only used by sg layout
1714/// creation.
1715std::optional<
1716 std::tuple<xegpu::DistributeLayoutAttr, xegpu::DistributeLayoutAttr,
1717 xegpu::DistributeLayoutAttr, xegpu::DistributeLayoutAttr,
1718 xegpu::DistributeLayoutAttr>>
1719xegpu::setupDpasMxLayout(xegpu::LayoutKind layoutKind, VectorType aTy,
1720 VectorType bTy, VectorType cdTy, VectorType aScaleTy,
1721 VectorType bScaleTy,
1722 xegpu::DistributeLayoutAttr consumerLayout, int numSg,
1723 const xegpu::uArch::uArch *uArch) {
1724 auto context = aTy.getContext();
1725
1726 if (layoutKind == xegpu::LayoutKind::Subgroup) {
1727 assert(numSg > 0 &&
1728 "Number of subgroups must be provided for sg layout creation.");
1729 auto dpasLayouts = getupDpasSubgroupLayouts(context, aTy, bTy, cdTy,
1730 consumerLayout, numSg, uArch);
1731 if (!dpasLayouts)
1732 return std::nullopt;
1733
1734 auto [dpasALayout, dpasBLayout, dpasCDLayout] = *dpasLayouts;
1735
1736 // Create scale layouts
1737 auto aScaleLayout =
1738 createScaleLayout(context, aTy, aScaleTy, dpasALayout, false, uArch);
1739
1740 auto bScaleLayout =
1741 createScaleLayout(context, bTy, bScaleTy, dpasBLayout, true, uArch);
1742
1743 return std::make_tuple(dpasALayout, dpasBLayout, dpasCDLayout, aScaleLayout,
1744 bScaleLayout);
1745 } else if (layoutKind == xegpu::LayoutKind::InstData) {
1746 auto instDataVecs =
1747 getDpasInstDataVectors(aTy, bTy, cdTy, uArch, /*isDpasMx=*/true);
1748 if (!instDataVecs)
1749 return std::nullopt;
1750 auto [instDataA, instDataB, instDataCD] = *instDataVecs;
1751
1752 auto dpasALayout = xegpu::LayoutAttr::get(
1753 context, SmallVector<int>(instDataA.begin(), instDataA.end()));
1754 auto dpasBLayout = xegpu::LayoutAttr::get(
1755 context, SmallVector<int>(instDataB.begin(), instDataB.end()));
1756 auto dpasCDLayout = xegpu::LayoutAttr::get(
1757 context, SmallVector<int>(instDataCD.begin(), instDataCD.end()));
1758
1759 // Create scale layouts
1760 auto aScaleLayout =
1761 createScaleLayout(context, aTy, aScaleTy, dpasALayout, false, uArch);
1762 auto bScaleLayout =
1763 createScaleLayout(context, bTy, bScaleTy, dpasBLayout, true, uArch);
1764
1765 return std::make_tuple(dpasALayout, dpasBLayout, dpasCDLayout, aScaleLayout,
1766 bScaleLayout);
1767 } else if (layoutKind == xegpu::LayoutKind::Lane) {
1768 const auto *uArchInstruction =
1769 dyn_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc>(uArch->getInstruction(
1771 auto aLayout = getDefaultLaneLayout2DBlockIo(
1772 aTy, uArch, uArchInstruction->getPackedFormatBitSizeA());
1773 auto bLayout = getDefaultLaneLayout2DBlockIo(
1774 bTy, uArch, uArchInstruction->getPackedFormatBitSizeB(), true);
1775 auto cdLayout = getDefaultLaneLayout2DBlockIo(cdTy, uArch);
1776
1777 // Create scale layouts
1778 auto aScaleLayout =
1779 createScaleLayout(context, aTy, aScaleTy, aLayout, false, uArch);
1780 auto bScaleLayout =
1781 createScaleLayout(context, bTy, bScaleTy, bLayout, true, uArch);
1782
1783 return std::make_tuple(aLayout, bLayout, cdLayout, aScaleLayout,
1784 bScaleLayout);
1785 }
1786 return std::nullopt;
1787}
1788
1789xegpu::DistributeLayoutAttr xegpu::inferSourceLayoutFromResultForNonAnchorOp(
1790 OpOperand &operand, xegpu::DistributeLayoutAttr resLayout) {
1791 if (!resLayout)
1792 return nullptr;
1793 Operation *op = operand.getOwner();
1794 unsigned idx = operand.getOperandNumber();
1795
1796 // For vector::BroadcastOp, infer the source layout from the result layout.
1797 if (auto broadcast = dyn_cast<vector::BroadcastOp>(op)) {
1798 auto srcTy = dyn_cast<VectorType>(broadcast.getSourceType());
1799 if (!srcTy)
1800 return nullptr;
1802 resLayout, broadcast.getResultVectorType().getShape(),
1803 srcTy.getShape());
1804 }
1805
1806 // For vector::MultiDimReductionOp, infer source layout from result layout
1807 // using reduction dims. Acc operand is expected to have the same layout as
1808 // the result.
1809 if (auto reduction = dyn_cast<vector::MultiDimReductionOp>(op)) {
1810 if (idx == 0) {
1811 SmallVector<int64_t> reductionDims(reduction.getReductionDims());
1812 return xegpu::inferMultiReductionSourceLayout(resLayout, reductionDims);
1813 }
1814 if (idx == 1)
1815 return resLayout;
1816 }
1817
1818 if (auto reduction = dyn_cast<vector::ReductionOp>(op))
1819 return xegpu::inferReductionSourceLayout(resLayout);
1820
1821 // For vector::BitCastOp, infer source layout from result layout using
1822 // element type bitwidths.
1823 if (auto bitcast = dyn_cast<vector::BitCastOp>(op)) {
1824 int resElemBitWidth =
1825 bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
1826 int srcElemBitWidth =
1827 bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
1828 return xegpu::inferBitCastSourceLayout(resLayout, resElemBitWidth,
1829 srcElemBitWidth);
1830 }
1831
1832 // For vector::ShapeCastOp, infer source layout from result layout using
1833 // shapes.
1834 if (auto shapeCast = dyn_cast<vector::ShapeCastOp>(op)) {
1836 resLayout, shapeCast.getResultVectorType().getShape(),
1837 shapeCast.getSourceVectorType().getShape());
1838 }
1839
1840 // For vector::InsertStridedSliceOp, infer source layout from result layout.
1841 // Dest vector must have the same layout as the result.
1842 if (auto insertSlice = dyn_cast<vector::InsertStridedSliceOp>(op)) {
1843 if (idx == 0) {
1845 resLayout, insertSlice.getDestVectorType().getShape(),
1846 insertSlice.getSourceVectorType().getShape());
1847 }
1848 if (idx == 1)
1849 return resLayout;
1850 }
1851
1852 // For vector::Insert Op, infer source layout from result layout using
1853 // shapes.
1854 if (auto insert = dyn_cast<vector::InsertOp>(op)) {
1855 VectorType resVecTy = dyn_cast<VectorType>(insert.getResult().getType());
1856 VectorType valueToStoreTy =
1857 dyn_cast<VectorType>(insert.getValueToStore().getType());
1858
1859 if ((idx == 0) && valueToStoreTy) {
1860 return xegpu::inferInsertSourceLayout(resLayout, resVecTy.getShape(),
1861 valueToStoreTy.getShape());
1862 }
1863 if (idx == 1)
1864 return resLayout;
1865 }
1866
1867 // For vector::Extract Op, infer source layout from result layout using
1868 // shapes.
1869 if (auto extract = dyn_cast<vector::ExtractOp>(op)) {
1870 VectorType srcVecTy = dyn_cast<VectorType>(extract.getSource().getType());
1871 VectorType resVecTy = dyn_cast<VectorType>(extract.getResult().getType());
1872 if (!srcVecTy || !resVecTy)
1873 return nullptr;
1874 return xegpu::inferExtractSourceLayout(resLayout, resVecTy.getShape(),
1875 srcVecTy.getShape());
1876 }
1877
1878 // For vector::TransposeOp, infer source layout from result layout using
1879 // permutation.
1880 if (auto transpose = dyn_cast<vector::TransposeOp>(op)) {
1881 return xegpu::inferTransposeSourceLayout(resLayout,
1882 transpose.getPermutation());
1883 }
1884
1885 // For vector::BitCastOp, infer source layout from result layout using
1886 // element type bitwidths.
1887 if (auto bitcast = dyn_cast<vector::BitCastOp>(op)) {
1888 int resElemBitWidth =
1889 bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
1890 int srcElemBitWidth =
1891 bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
1892 return xegpu::inferBitCastSourceLayout(resLayout, resElemBitWidth,
1893 srcElemBitWidth);
1894 }
1895
1896 // for vector::interleave
1897 if (auto interleave = dyn_cast<vector::InterleaveOp>(op)) {
1898 return xegpu::inferInterleaveSourceLayout(resLayout);
1899 }
1900
1901 // for vector::deinterleave
1902 if (auto deinterleave = dyn_cast<vector::DeinterleaveOp>(op)) {
1903 return xegpu::inferDeinterleaveSourceLayout(resLayout);
1904 }
1905
1906 // For vector::ExtractStridedSliceOp, simply return result layout
1907 if (dyn_cast<vector::ExtractStridedSliceOp>(op))
1908 return resLayout;
1909
1910 // For elementwise operations, all operands must have the same layout as the
1911 // result.
1913 return resLayout;
1914
1915 return nullptr;
1916}
1917
1918xegpu::DistributeLayoutAttr xegpu::getConsumerLayoutAt(OpOperand &operand) {
1919 Operation *op = operand.getOwner();
1920 // Anchor ops declare the layout they
1921 // require on each operand. Trust that declaration directly so that
1922 // ResolveLayoutConflicts compares producer-vs-declared
1923 if (isa<xegpu::AnchorLayoutInterface>(op))
1924 return xegpu::getDistributeLayoutAttr(operand);
1925 // For non-anchor ops, derive the operand layout from the op's result
1926 // layout via op-specific semantics.
1927 xegpu::DistributeLayoutAttr resLayout;
1928 if (op->getNumResults() == 1 || isa<vector::DeinterleaveOp>(op))
1929 resLayout = xegpu::getDistributeLayoutAttr(op->getResult(0));
1930 return inferSourceLayoutFromResultForNonAnchorOp(operand, resLayout);
1931}
return success()
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
Definition PDL.cpp:62
lhs
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
std::pair< int64_t, int64_t > LayoutRepresentation
static xegpu::DistributeLayoutAttr createScaleLayout(mlir::MLIRContext *context, VectorType matrixTy, VectorType scaleTy, xegpu::DistributeLayoutAttr matrixLayout, bool isBScale, const xegpu::uArch::uArch *uArch)
Helper to create a scale layout derived from a matrix operand layout.
static std::optional< std::tuple< SmallVector< int64_t >, SmallVector< int64_t >, SmallVector< int64_t > > > getDpasInstDataVectors(VectorType aTy, VectorType bTy, VectorType cdTy, const xegpu::uArch::uArch *uArch, bool isDpasMx=false)
Helper function to compute inst_data vectors for DPAS operands A, B, and C/D.
static xegpu::DistributeLayoutAttr getLayoutFromUsePoints(Value result)
static xegpu::DistributeLayoutAttr setupGenericStoreAnchorLayout(xegpu::LayoutKind layoutKind, mlir::MLIRContext *context, bool isChunkedStore, int maxChunkSize, ArrayRef< int64_t > srcShape, int subgroupSize)
Sets up the anchor layout for store scatter and store matrix operation.
static std::optional< std::tuple< xegpu::DistributeLayoutAttr, xegpu::DistributeLayoutAttr, xegpu::DistributeLayoutAttr > > getupDpasSubgroupLayouts(mlir::MLIRContext *context, VectorType aTy, VectorType bTy, VectorType cdTy, xegpu::DistributeLayoutAttr consumerLayout, int numSg, const xegpu::uArch::uArch *uArch)
Helper function to set up subgroup layouts for DPAS operands A, B, and C/D.
static void propagateResultsToRegularOperands(Operation *op)
static void propagateRegionResultsToYieldOperands(mlir::RegionBranchTerminatorOpInterface yieldOp)
static SmallVector< LayoutRepresentation > getValidLayouts(ArrayRef< int64_t > wgShape, ArrayRef< int64_t > instData, int64_t sgCount)
static void setTensorDescLayout(Value val, xegpu::DistributeLayoutAttr layout)
static xegpu::LayoutAttr getDefaultLaneLayout2DBlockIo(RankedTy ty, const xegpu::uArch::uArch *uArch, std::optional< unsigned > packingSize=std::nullopt, bool vnni=false)
static void walkRegionBackward(Region &region, llvm::function_ref< void(Operation *)> visit)
static xegpu::DistributeLayoutAttr setupGenericLoadAnchorLayout(xegpu::LayoutKind layoutKind, mlir::MLIRContext *context, xegpu::DistributeLayoutAttr consumerLayout, bool isChunkedLoad, int maxChunkSize, ArrayRef< int64_t > resShape, int subgroupSize)
Sets up the anchor layout for load gather and load matrix operation.
Block represents an ordered list of Operations.
Definition Block.h:33
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class represents an operand of an operation.
Definition Value.h:254
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
This is a value defined by a result of an operation.
Definition Value.h:454
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:44
unsigned getBeginOperandIndex() const
Return the operand index of the first element of this range.
type_range getType() const
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
bool hasAttrOfType(NameT &&name)
Definition Operation.h:600
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:537
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:432
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:699
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:408
auto getDiscardableAttrs()
Return a range of all of discardable attributes on this operation.
Definition Operation.h:511
Attribute removeDiscardableAttr(StringAttr name)
Remove the discardable attribute with the specified name if it exists.
Definition Operation.h:497
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:403
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:822
Attribute removeAttr(StringAttr name)
Remove the attribute with the specified name if it exists.
Definition Operation.h:625
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:429
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
bool empty()
Definition Region.h:60
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Definition Value.h:116
Type getType() const
Return the type of this value.
Definition Value.h:105
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int64_t > content)
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
DistributeLayoutAttr inferShapeCastSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for a shape cast operation given the result layout attribute,...
bool matchDimCollapse(ArrayRef< int64_t > src, ArrayRef< int64_t > dst, SmallVector< SmallVector< int64_t > > &collapseDims)
DistributeLayoutAttr setupInterleaveResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the result layout for an interleave operation to ensure the source layout can be safely deriv...
DistributeLayoutAttr inferTransposeSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > permutation)
Infers the source layout attribute for a transpose operation given the result layout attribute and pe...
DistributeLayoutAttr inferInsertSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for an insert operation.
DistributeLayoutAttr inferInsertStridedSliceSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for an insert strided slice operation given the result layout attr...
void removeTemporaryLayoutAttrs(Operation *op)
Removes the temporary layout attributes for each OpOperand and OpResult of the given operation.
void setTemporaryLayout(const T &operandOrResult, const DistributeLayoutAttr layout)
LayoutKind
Specifies the level of a layout hierarchy for comparison or propagation.
Definition XeGPU.h:32
SmallVector< NamedAttribute > dropInstDataOnAttrs(ArrayRef< NamedAttribute > attrs)
Updates the NamedAttribute sequence by dropping inst-data information from any DistributeLayoutAttr f...
DistributeLayoutAttr inferSourceLayoutFromResultForNonAnchorOp(OpOperand &operand, DistributeLayoutAttr resLayout)
Infers the source layout attribute for an operand using result layout attribute.
DistributeLayoutAttr inferInterleaveSourceLayout(DistributeLayoutAttr resLayout)
Infers the source layout attribute for an interleave operation given the result layout attribute.
bool matchUnitDimExpansion(ArrayRef< int64_t > src, ArrayRef< int64_t > dst, SmallVector< int64_t > &expandedUnitDims)
DistributeLayoutAttr setupLoadMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the anchor layout for load matrix operation.
int getLargestDivisor(T dim, ArrayRef< T > candidates, ArrayRef< T > candidateMultiples={})
Helper Function to find a proper instruction multiple for the user-supplied sg-level data shape (dive...
bool recoverTemporaryLayouts(Operation *rootOp)
Attach layout attributes to all vector-type operands of operations within the given operation's neste...
DistributeLayoutAttr inferBroadcastSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for a broadcast operation given the result layout attribute,...
std::optional< std::tuple< DistributeLayoutAttr, DistributeLayoutAttr, DistributeLayoutAttr, DistributeLayoutAttr, DistributeLayoutAttr > > setupDpasMxLayout(LayoutKind layoutKind, VectorType aTy, VectorType bTy, VectorType cdTy, VectorType aScaleTy, VectorType bScaleTy, DistributeLayoutAttr consumerLayout, int numSg, const uArch::uArch *uArch)
Sets up the anchor layouts for dpas_mx operands (A, B, C/D, A_scale, and B_scale).
DistributeLayoutAttr setupStoreScatterAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, int chunkSize, const uArch::uArch *uArch)
Sets up the anchor layout for a store scatter operation.
SliceAttr setupMultiReductionResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, DistributeLayoutAttr consumerLayout, SmallVector< int64_t > reductionDims, int numSg, const uArch::uArch *uArch)
Sets up layout for Multi-Reduction operations by creating a SliceAttr for the result.
llvm::function_ref< DistributeLayoutAttr(Value)> GetLayoutFnTy
Callable returning the propagated layout for a given Value, used by the layout-propagation helpers be...
bool matchSplitDimExpansion(ArrayRef< int64_t > src, ArrayRef< int64_t > dst, SmallVector< SmallVector< int64_t > > &splitDimGroups)
DistributeLayoutAttr setupBitCastResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Setup the result layout attribute for a bitcast operation based on element type bitwidths.
void removeLayoutAttr(const T &operandOrResult)
Removes the LayoutAttr for a given OpOperand or OpResult if it exists.
DistributeLayoutAttr inferMaskOffsetLayoutForScatterIO(DistributeLayoutAttr payloadLayout, int chunkSize)
Infers the layout attribute for mask and offset operand for Chunked load and store,...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
SmallVector< NamedAttribute > dropSgLayoutAndDataOnAttrs(ArrayRef< NamedAttribute > attrs)
Updates the NamedAttribute sequence by dropping sg-layout and sg-data information from any Distribute...
DistributeLayoutAttr inferExtractSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for an extract operation.
std::string getTemporaryLayoutName(const OpOperand &operand)
Return the attribute name for the OpOperand to attach DistributeLayoutAttr.
DistributeLayoutAttr inferBitCastSourceLayout(DistributeLayoutAttr resLayout, int resElemTyBitWidth, int srcElemTyBitWidth)
Infers the source layout attribute for a bitcast operation given the result layout attribute,...
DistributeLayoutAttr setupInsertStridedSliceResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the result layout for an insert strided slice operation.
DistributeLayoutAttr inferReductionSourceLayout(DistributeLayoutAttr resLayout)
Infers the source layout attribute for a reduction operation given the result layout attribute and re...
DistributeLayoutAttr inferDeinterleaveSourceLayout(DistributeLayoutAttr resLayout)
Infers the source layout attribute for a deinterleave operation given the result layout attribute.
DistributeLayoutAttr getConsumerLayoutAt(OpOperand &operand)
Gets the expected layout for a given consumer operand.
void removeLayoutAttrs(Operation *op)
Removes the DistributeLayoutAttr for each OpOperand and OpResult of the given operation if they exist...
DistributeLayoutAttr inferMultiReductionSourceLayout(DistributeLayoutAttr resLayout, SmallVector< int64_t > reduceDims)
Infers the source layout attribute for a reduction operation given the result layout attribute and re...
bool isTriviallyRematerializable(Operation *op)
Returns true if op is safe and cheap to clone: it has no side effects, no regions,...
DistributeLayoutAttr setupLoadGatherAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, int chunkSize, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the anchor layout for a load gather operation.
LogicalResult propagateRegionArgsToInits(RegionBranchOpInterface regionOp, GetLayoutFnTy getLayoutOfValue)
Propagate layouts from a region branch op's region entry block arguments back to its init operands.
std::optional< std::tuple< DistributeLayoutAttr, DistributeLayoutAttr, DistributeLayoutAttr > > setupDpasLayout(LayoutKind layoutKind, VectorType aTy, VectorType bTy, VectorType cdTy, DistributeLayoutAttr consumerLayout, int numSg, const uArch::uArch *uArch)
Sets up the anchor layouts for a dpas operands (A, B, and C/D).
SliceAttr setupReductionResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, const uArch::uArch *uArch)
Sets up layout for Reduction operations by creating a SliceAttr for the result.
DistributeLayoutAttr setupStoreMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, const uArch::uArch *uArch)
Sets up the anchor layout for a store matrix operation.
Include the generated interface declarations.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
virtual llvm::SmallVector< uint32_t, 8 > getSupportedN(Type type) const =0
virtual llvm::SmallVector< uint32_t, 8 > getSupportedK(Type type) const =0
virtual llvm::SmallVector< uint32_t, 8 > getSupportedM(Type type) const =0
virtual int getSubgroupSize() const =0
uArch(StringRef name, StringRef description, llvm::ArrayRef< const Instruction * > instructionRegistry)
Definition uArchBase.h:156
const Instruction * getInstruction(InstructionKind instKind) const
Definition uArchBase.h:168