MLIR 23.0.0git
XeGPULayoutImpl.cpp
Go to the documentation of this file.
1//===---- XeGPULayoutImpl.cpp - MLIR Utilities for XeGPUOps
2//------------------===//
3//
4// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
5// See https://llvm.org/LICENSE.txt for license information.
6// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7//
8//===----------------------------------------------------------------------===//
9//
10// This file implements layout utility functions for XeGPU dialect
11// transformation.
12//
13//===----------------------------------------------------------------------===//
14
23#include "mlir/IR/Builders.h"
24#include "mlir/IR/Operation.h"
25#include "mlir/IR/ValueRange.h"
29#include "llvm/ADT/PostOrderIterator.h"
30#include "llvm/Support/FormatVariadic.h"
31#include <cstdint>
32#include <numeric>
33
34using namespace mlir;
35
39 out.reserve(attrs.size());
40
41 for (auto attr : attrs) {
42 if (auto dist = dyn_cast<xegpu::DistributeLayoutAttr>(attr.getValue())) {
43 auto newLayout = dist.dropSgLayoutAndData();
44 if (newLayout)
45 out.emplace_back(attr.getName(), newLayout);
46 } else {
47 out.push_back(attr);
48 }
49 }
50
51 return out;
52}
53
57 out.reserve(attrs.size());
58
59 for (auto attr : attrs) {
60 if (auto dist = dyn_cast<xegpu::DistributeLayoutAttr>(attr.getValue())) {
61 auto newLayout = dist.dropInstData();
62 if (newLayout)
63 out.emplace_back(attr.getName(), newLayout);
64 } else {
65 out.push_back(attr);
66 }
67 }
68
69 return out;
70}
71
72// Sets the layout on a TensorDesc value by updating its type to include
73// the given layout, if the type does not already have a layout attached.
74static void setTensorDescLayout(Value val, xegpu::DistributeLayoutAttr layout) {
75 auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(val.getType());
76 if (!tensorDescTy || tensorDescTy.getLayoutAttr())
77 return;
78 auto typeWithLayout = xegpu::TensorDescType::get(
79 tensorDescTy.getContext(), tensorDescTy.getShape(),
80 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
81 val.setType(typeWithLayout);
82}
83
84// the walkRegionBackward() is a recursive function
85// the input rootOp is the function operation, which is also a region op.
86// it recursively processes the region op in reverse topological order.
87static void walkRegionBackward(Region &region,
89
90 // Use post-order traversal to process blocks in reverse topological order.
91 // This ensures that use blocks are visited before def blocks, which is
92 // required for backward layout propagation.
93 if (region.empty())
94 return;
95 llvm::ReversePostOrderTraversal<Region *> rpot(&region);
96 SmallVector<Block *> blocks(rpot.begin(), rpot.end());
97 for (Block *block : llvm::reverse(blocks)) {
98 // ops: back -> front
99 for (Operation &op : llvm::reverse(*block)) {
100 // make sure we first visit inside the region op (so yield op first)
101 // and then move to region op itself
102 // Regions are iterated in forward order so that for multi-region ops
103 // like scf.while, earlier regions (e.g., "before/cond") are processed
104 // first. This ensures that when a later region's terminator (e.g., "do"
105 // yield) needs the layout of an earlier region's block args, those
106 // layouts are already available from use points.
107 for (Region &nested : op.getRegions())
108 walkRegionBackward(nested, visit);
109
110 visit(&op);
111 }
112 }
113}
114
115static xegpu::DistributeLayoutAttr getLayoutFromUsePoints(Value result) {
116 xegpu::DistributeLayoutAttr layout = nullptr;
117 for (OpOperand &use : result.getUses()) {
118 if (auto tmpLayout = xegpu::getDistributeLayoutAttr(use)) {
119 if (!layout)
120 layout = tmpLayout;
121 break;
122 }
123 }
124 return layout;
125}
126
127// For regular operations: First the result layouts are propagated from uses.
128// Then the result layouts are propagated to uses (operands).
130 if (op->getNumResults() == 0)
131 return;
132 if (op->getNumResults() > 1 && !isa<vector::DeinterleaveOp>(op))
133 return;
134 OpResult result = op->getResult(0);
135 xegpu::DistributeLayoutAttr resLayout = getLayoutFromUsePoints(result);
136 Type resultType = result.getType();
137
138 if (!resLayout)
139 return;
140
141 // Recover layout for TensorDesc type results by updating the type to include
142 // the layout. For vector type
143 if (isa<xegpu::TensorDescType>(resultType))
144 setTensorDescLayout(result, resLayout);
145
146 // Recover layout for vector type results, or for multi-reduction ops which
147 // may reduce to a scalar that still needs a layout.
148 if (isa<VectorType>(resultType) || isa<vector::MultiDimReductionOp>(op))
150
151 for (OpOperand &opr : op->getOpOperands()) {
152 xegpu::DistributeLayoutAttr operandLayout =
154 // Recover layout for vector operands
155 if (isa<VectorType>(opr.get().getType()) && operandLayout)
156 xegpu::setTemporaryLayout(opr, operandLayout);
157 }
158}
159
160// Propagate layout from region op results and sibling region block args
161// to yield/condition operands. For each successor of this terminator:
162// - Parent successor: propagate from parent op's result layouts (use points).
163// - Region successor: propagate from target region's block arg layouts (use
164// points), e.g., scf.yield in "after/do" region propagates to "before/cond"
165// block args.
167 mlir::RegionBranchTerminatorOpInterface yieldOp) {
168 auto regionBranchOp =
169 dyn_cast<RegionBranchOpInterface>(yieldOp->getParentOp());
170 if (!regionBranchOp)
171 return;
172
174 SmallVector<Attribute> operandAttrs(yieldOp->getNumOperands(), nullptr);
175 yieldOp.getSuccessorRegions(operandAttrs, successors);
176
177 for (const RegionSuccessor &successor : successors) {
178 OperandRange succOps = yieldOp.getSuccessorOperands(successor);
179 if (succOps.empty())
180 continue;
181 unsigned beginIdx = succOps.getBeginOperandIndex();
182 ValueRange successorInputs = regionBranchOp.getSuccessorInputs(successor);
183 unsigned count = std::min<unsigned>(succOps.size(), successorInputs.size());
184
185 for (unsigned i = 0; i < count; ++i) {
186 xegpu::DistributeLayoutAttr layout;
187 if (successor.isParent()) {
188 // For parent successor, get layout from external use points of the
189 // parent op's results.
190 auto regionResult = regionBranchOp->getResult(i);
191 layout = getLayoutFromUsePoints(regionResult);
192 if (layout) {
193 // set layout for the region op, like scf.loop
194 xegpu::setTemporaryLayout(regionResult, layout);
195 if (isa<xegpu::TensorDescType>(regionResult.getType()))
196 setTensorDescLayout(regionResult, layout);
197 }
198 } else {
199 // For region successor, get layout from the target region's block
200 // arg use points (e.g., "before/cond" region args for scf.while
201 // "after/do" yield).
202 layout = getLayoutFromUsePoints(successorInputs[i]);
203 }
204 if (!layout)
205 continue;
206 auto operandType = succOps[i].getType();
207 if (isa<VectorType>(operandType) ||
208 dyn_cast<xegpu::TensorDescType>(operandType))
209 // recover layout for yield op operands
210 xegpu::setTemporaryLayout(yieldOp->getOpOperand(beginIdx + i), layout);
211 }
212 }
213}
214
215// Propagate layout from region arguments to region op's init operands. This
216// sets the temporary layout for region arguments and init operands.
217static void propagateRegionArgsToInits(mlir::RegionBranchOpInterface regionOp) {
218 // Iterate all regions of the region op. For each block argument that has a
219 // layout (determined from its use points), trace back to find the
220 // corresponding init operand of the regionOp and set the layout on it.
221 // This works generically for scf.for, scf.while, and other
222 // RegionBranchOpInterface ops.
223 for (Region &region : regionOp->getRegions()) {
224 RegionSuccessor regionSuccessor(&region);
225 // Use getSuccessorInputs to get the block arguments that correspond to
226 // predecessor operands. This correctly handles ops like scf.for where
227 // the induction variable is a block arg but not a successor input.
228 ValueRange successorInputs = regionOp.getSuccessorInputs(regionSuccessor);
229 for (auto [inputIdx, regionArg] : llvm::enumerate(successorInputs)) {
230 auto layout = getLayoutFromUsePoints(regionArg);
231 if (!layout)
232 continue;
233
234 // Recover layout for tensor_desc block args by updating the type.
235 if (isa<xegpu::TensorDescType>(regionArg.getType()))
236 setTensorDescLayout(regionArg, layout);
237
238 // Recover layout for region op operands, like scf.for's init operands.
239 // Find all predecessor values that flow into this block argument.
240 SmallVector<Value> predValues;
241 regionOp.getPredecessorValues(regionSuccessor, inputIdx, predValues);
242 for (Value predVal : predValues) {
243 // Match predecessor value to an operand of the regionOp.
244 for (OpOperand &operand : regionOp->getOpOperands()) {
245 if (operand.get() == predVal)
246 xegpu::setTemporaryLayout(operand, layout);
247 }
248 }
249 }
250 }
251}
252
253// Prerequisite for Layout Recovery
254// It relies on the following invariant:
255// 1. there is no layout conflict between different uses of the same definition.
256// 2. each definition has a well-defined layout requirement at its use point.
257// - Every definition must have at least one use that appears after it in
258// topological order.
259// - TODO: If a definition has no such use (e.g., a loop result or region
260// output), an explicit convert_layout operation is inserted to create a
261// use.
262// - Only the result of convert_layout is permitted to have no subsequent
263// use.
264//
265// The recovery proceeds by scanning the operation in reverse topological order
266// as follows:
267// For regular operations: First the result layouts are propagated from uses.
268// Then the result layouts are propagated to operands.
269//
270// For region operations (e.g., loops):
271// - When backward propagation reaches a region op, it sets the layout of
272// the region op’s results according to use points like regular ops.
273// - Then, the result layouts (such as a loop output) are propagated to
274// their corresponding operands in the yield.
275// - When backward propagation reaches the first operation inside the
276// region, the pass examines the region op’s initialization list,
277// propagating from region arguments to the corresponding initialization
278// operands.
279// - This ensures that layouts are consistently propagated
280// across region boundaries while preserving a single well-defined use for
281// each definition at the region-op level.
283 auto processFunc = [&](Region &body, StringRef funcName) {
284 walkRegionBackward(body, [&](Operation *op) {
285 if (auto regionOp = dyn_cast<mlir::RegionBranchOpInterface>(op)) {
287 } else if (auto yieldOp =
288 dyn_cast<mlir::RegionBranchTerminatorOpInterface>(op)) {
290 } else if (!dyn_cast<xegpu::AnchorLayoutInterface>(op)) {
292 }
293 });
294 };
296 rootOp->walk([&](func::FuncOp func) {
297 processFunc(func.getBody(), func.getSymName());
298 });
299 rootOp->walk([&](gpu::GPUFuncOp func) {
300 processFunc(func.getBody(), func.getName());
301 });
302
303 return true;
304}
305
306template <typename T, typename>
307void xegpu::removeLayoutAttr(const T &operandOrResult) {
308 Operation *owner = operandOrResult.getOwner();
309 std::string name = xegpu::getTemporaryLayoutName(operandOrResult);
310 if (owner->hasAttrOfType<DistributeLayoutAttr>(name))
311 owner->removeAttr(name);
312}
313
314// Explicit instantiation for OpResult
315template void
317
318// Explicit instantiation for OpOperand
319template void
321
323 op->walk([&](Operation *nestOp) {
324 // Remove all attributes of DistributeLayoutAttr type
325 SmallVector<StringAttr> attrsToRemove;
326 for (auto namedAttr : nestOp->getAttrs()) {
327 if (isa<DistributeLayoutAttr>(namedAttr.getValue()))
328 attrsToRemove.push_back(namedAttr.getName());
329 }
330 for (auto attrName : attrsToRemove)
331 nestOp->removeAttr(attrName);
332 });
333}
334
336 op->walk([&](Operation *nestOp) {
337 SmallVector<StringAttr> attrsToRemove;
338 for (auto namedAttr : nestOp->getDiscardableAttrs()) {
339 if (isa<xegpu::DistributeLayoutAttr>(namedAttr.getValue()))
340 attrsToRemove.push_back(namedAttr.getName());
341 }
342 for (auto attrName : attrsToRemove)
343 nestOp->removeDiscardableAttr(attrName);
344 });
345}
346
347/// Infers the source layout attribute for a broadcast operation given the
348/// result layout attribute, result shape, source shape.
349xegpu::DistributeLayoutAttr
350xegpu::inferBroadcastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
351 ArrayRef<int64_t> resShape,
352 ArrayRef<int64_t> srcShape) {
353
354 SmallVector<int64_t> bcastDims;
355 size_t dimDiff = resShape.size() - srcShape.size();
356 auto bcastSourceLayout = resLayout;
357 for (size_t i = dimDiff; i < resShape.size(); i++) {
358 if ((srcShape[i - dimDiff] == 1) && (resShape[i] != 1))
359 bcastDims.push_back(i);
360 }
361
362 // the sg_layout and lane_layout for unit dimensions are preserved so it can
363 // be propagate to producer op so potentially used by the multi-reduction op.
364 if (!bcastDims.empty())
365 bcastSourceLayout = bcastSourceLayout.setUnitDimData(bcastDims);
366
367 if (dimDiff > 0) {
368 SmallVector<int64_t> sliceDims;
369 for (size_t i = 0; i < dimDiff; i++)
370 sliceDims.push_back(i);
371 bcastSourceLayout = xegpu::SliceAttr::get(
372 resLayout.getContext(), bcastSourceLayout,
373 DenseI64ArrayAttr::get(resLayout.getContext(), sliceDims));
374 }
375 return bcastSourceLayout;
376}
377
378/// Infers the source layout attribute for a reduction operation given the
379/// result layout attribute and reduced dims.
380xegpu::DistributeLayoutAttr
381xegpu::inferMultiReductionSourceLayout(xegpu::DistributeLayoutAttr resLayout,
382 SmallVector<int64_t> reduceDims) {
383
384 assert(isa<xegpu::SliceAttr>(resLayout) &&
385 "reduction result layout must be slice layout");
386
387 xegpu::SliceAttr sliceLayout = dyn_cast<xegpu::SliceAttr>(resLayout);
388
389 assert((reduceDims == sliceLayout.getDims().asArrayRef()) &&
390 "reduction dims must match with slice dims");
391
392 return sliceLayout.getParent();
393}
394
395xegpu::DistributeLayoutAttr
396xegpu::inferReductionSourceLayout(xegpu::DistributeLayoutAttr resLayout) {
397 return xegpu::inferMultiReductionSourceLayout(resLayout, {0});
398}
399
400/// Infers the source layout attribute for a transpose operation given the
401/// result layout attribute and permutation.
402xegpu::DistributeLayoutAttr
403xegpu::inferTransposeSourceLayout(xegpu::DistributeLayoutAttr resLayout,
404 ArrayRef<int64_t> permutation) {
405 return resLayout.transposeDims(permutation);
406}
407
408/// Infers the source layout attribute for a bitcast operation given the
409/// result layout attribute, result element type bitwidth, and source element
410/// type bitwidth.
411xegpu::DistributeLayoutAttr
412xegpu::inferBitCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
413 int resElemTyBitWidth, int srcElemTyBitWidth) {
414
415 SmallVector<int64_t> sgData = resLayout.getEffectiveSgDataAsInt();
416 SmallVector<int64_t> instData = resLayout.getEffectiveInstDataAsInt();
417 SmallVector<int64_t> laneData = resLayout.getEffectiveLaneDataAsInt();
418 size_t sgDataSize = sgData.size();
419 size_t instDataSize = instData.size();
420 size_t laneDataSize = laneData.size();
421 int64_t sgDataValue = -1;
422 int64_t instDataValue = -1;
423 int64_t laneDataValue = -1;
424 int64_t dim = resLayout.getRank() - 1;
425
426 if (srcElemTyBitWidth <= resElemTyBitWidth) {
427 int bitWidthRatio = resElemTyBitWidth / srcElemTyBitWidth;
428 if (sgDataSize)
429 sgDataValue = sgData.back() * bitWidthRatio;
430 if (instDataSize)
431 instDataValue = instData.back() * bitWidthRatio;
432 if (laneDataSize)
433 laneDataValue = laneData.back() * bitWidthRatio;
434 } else {
435 int bitWidthRatio = srcElemTyBitWidth / resElemTyBitWidth;
436 if (sgDataSize) {
437 assert((sgData.back() % bitWidthRatio) == 0 &&
438 "sgData not divisible by bitWidthRatio");
439 sgDataValue = sgData.back() / bitWidthRatio;
440 }
441 if (instDataSize) {
442 assert((instData.back() % bitWidthRatio) == 0 &&
443 "instData not divisible by bitWidthRatio");
444 instDataValue = instData.back() / bitWidthRatio;
445 }
446 if (laneDataSize) {
447 assert((laneData.back() % bitWidthRatio) == 0 &&
448 "laneData not divisible by bitWidthRatio");
449 laneDataValue = laneData.back() / bitWidthRatio;
450 }
451 }
452
453 xegpu::DistributeLayoutAttr finalSrcLayout;
454 finalSrcLayout =
455 resLayout.setDimData(dim, sgDataValue, instDataValue, laneDataValue);
456
457 return finalSrcLayout;
458}
459
460/// Infers the source layout attribute for an interleave operation given the
461/// result layout attribute. Interleave doubles the size of the innermost
462/// dimension, so the layout inference is similar to bitcast where the source
463/// element type is larger than the result element type (ratio = 2).
464xegpu::DistributeLayoutAttr
465xegpu::inferInterleaveSourceLayout(xegpu::DistributeLayoutAttr resLayout) {
466
467 SmallVector<int64_t> sgData = resLayout.getEffectiveSgDataAsInt();
468 SmallVector<int64_t> instData = resLayout.getEffectiveInstDataAsInt();
469 SmallVector<int64_t> laneData = resLayout.getEffectiveLaneDataAsInt();
470 size_t sgDataSize = sgData.size();
471 size_t instDataSize = instData.size();
472 size_t laneDataSize = laneData.size();
473 int64_t sgDataValue = -1;
474 int64_t instDataValue = -1;
475 int64_t laneDataValue = -1;
476 int64_t dim = resLayout.getRank() - 1;
477
478 // Interleave doubles the innermost dimension, so we need to halve the
479 // layout values (similar to bitcast with ratio = 2)
480 constexpr int ratio = 2;
481 if (sgDataSize) {
482 assert((sgData.back() % ratio) == 0 &&
483 "sgData not divisible by interleave ratio");
484 sgDataValue = sgData.back() / ratio;
485 }
486 if (instDataSize) {
487 assert((instData.back() % ratio) == 0 &&
488 "instData not divisible by interleave ratio");
489 instDataValue = instData.back() / ratio;
490 }
491 if (laneDataSize) {
492 assert((laneData.back() % ratio) == 0 &&
493 "laneData not divisible by interleave ratio");
494 laneDataValue = laneData.back() / ratio;
495 }
496
497 return resLayout.setDimData(dim, sgDataValue, instDataValue, laneDataValue);
498}
499
500/// Infers the source layout attribute for a deinterleave operation given the
501/// result layout attribute. Deinterleave halves the size of the innermost
502/// dimension, so the layout inference is similar to bitcast where the source
503/// element type is smaller than the result element type (ratio = 2).
504xegpu::DistributeLayoutAttr
505xegpu::inferDeinterleaveSourceLayout(xegpu::DistributeLayoutAttr resLayout) {
506
507 SmallVector<int64_t> sgData = resLayout.getEffectiveSgDataAsInt();
508 SmallVector<int64_t> instData = resLayout.getEffectiveInstDataAsInt();
509 SmallVector<int64_t> laneData = resLayout.getEffectiveLaneDataAsInt();
510 size_t sgDataSize = sgData.size();
511 size_t instDataSize = instData.size();
512 size_t laneDataSize = laneData.size();
513 int64_t sgDataValue = -1;
514 int64_t instDataValue = -1;
515 int64_t laneDataValue = -1;
516 int64_t dim = resLayout.getRank() - 1;
517
518 // Deinterleave halves the innermost dimension, so we need to double the
519 // layout values (similar to bitcast with ratio = 2)
520 constexpr int ratio = 2;
521 if (sgDataSize)
522 sgDataValue = sgData.back() * ratio;
523 if (instDataSize)
524 instDataValue = instData.back() * ratio;
525 if (laneDataSize)
526 laneDataValue = laneData.back() * ratio;
527
528 return resLayout.setDimData(dim, sgDataValue, instDataValue, laneDataValue);
529}
530
531/// Infers the source layout attribute for an insert strided slice operation
532/// given the result layout attribute, result shape, and source shape. Removes
533/// leading dimensions from the result layout to match the source shape size.
534xegpu::DistributeLayoutAttr xegpu::inferInsertStridedSliceSourceLayout(
535 xegpu::DistributeLayoutAttr resLayout, ArrayRef<int64_t> resShape,
536 ArrayRef<int64_t> srcShape) {
537
538 int srcShapeSize = srcShape.size();
539 int resShapeSize = resShape.size();
540 int dimDiff = resShapeSize - srcShapeSize;
541
542 if (dimDiff > 0) {
543 // assert that the leading dimensions being sliced off are not distributed
544 // (i.e. sg_layout and lane_layout for those dimensions are all 1)
545 auto resSgLayout = resLayout.getEffectiveSgLayoutAsInt();
546 auto resLaneLayout = resLayout.getEffectiveLaneLayoutAsInt();
547 for (int i = 0; i < dimDiff; i++) {
548 assert((resSgLayout.size() == 0 || resSgLayout[i] == 1) &&
549 (resLaneLayout.size() == 0 || resLaneLayout[i] == 1) &&
550 "Leading dimensions being sliced off must not be distributed");
551 }
552 return resLayout.dropDims(llvm::to_vector(llvm::seq<int64_t>(0, dimDiff)));
553 }
554 return resLayout;
555}
556
557/// Infers the source layout attribute for an insert operation
558/// given the result layout attribute, result shape, and source shape. Removes
559/// leading dimensions from the result layout to match the source shape size.
560// TODO: add propagation support for insert op
561xegpu::DistributeLayoutAttr
562xegpu::inferInsertSourceLayout(xegpu::DistributeLayoutAttr resLayout,
563 ArrayRef<int64_t> resShape,
564 ArrayRef<int64_t> srcShape) {
565
566 int srcShapeSize = srcShape.size();
567 int resShapeSize = resShape.size();
568 int dimDiff = resShapeSize - srcShapeSize;
569
570 if (dimDiff > 0) {
571 // assert that the leading dimensions being sliced off are not distributed
572 // (i.e. sg_layout and lane_layout for those dimensions are all 1)
573 auto resSgLayout = resLayout.getEffectiveSgLayoutAsInt();
574 auto resLaneLayout = resLayout.getEffectiveLaneLayoutAsInt();
575 for (int i = 0; i < dimDiff; i++) {
576 assert((resSgLayout.size() == 0 || resSgLayout[i] == 1) &&
577 (resLaneLayout.size() == 0 || resLaneLayout[i] == 1) &&
578 "Leading dimensions being sliced off must not be distributed");
579 }
580 return resLayout.dropDims(llvm::to_vector(llvm::seq<int64_t>(0, dimDiff)));
581 }
582 return resLayout;
583}
584
585/// Infers the source layout attribute for extract operation
586/// given the result layout attribute, result shape, and source shape. Adds
587/// leading dimensions to the source layout to match the source shape size.
588// TODO: add layout attribute interface: expandDims() and use it here.
589// TODO: add propagation support for extract op
590xegpu::DistributeLayoutAttr
591xegpu::inferExtractSourceLayout(xegpu::DistributeLayoutAttr resLayout,
592 ArrayRef<int64_t> resShape,
593 ArrayRef<int64_t> srcShape) {
594
595 int srcShapeSize = srcShape.size();
596 int resShapeSize = resShape.size();
597 int dimDiff = srcShapeSize - resShapeSize;
598 auto context = resLayout.getContext();
599 // construct the source layout by adding unit dimensions to the front of
600 // result layout
601 if (dimDiff > 0) {
602 auto sgLayout = resLayout.getEffectiveSgLayoutAsInt();
603 auto sgData = resLayout.getEffectiveSgDataAsInt();
604 auto instData = resLayout.getEffectiveInstDataAsInt();
605 auto laneLayout = resLayout.getEffectiveLaneLayoutAsInt();
606 auto laneData = resLayout.getEffectiveLaneDataAsInt();
607 auto order = resLayout.getEffectiveOrderAsInt();
608
609 // Example: result shape is 3D with order [1, 2, 0], source shape is 5D
610 // (adding 2 leading dimensions). Expected source order: [3, 4, 2, 1, 0]
611 // Step 1: shift existing order by dimDiff: [1, 2, 0] -> [3, 4, 2]
612 // Step 2: append new leading dims in reverse (slowest first): [3, 4, 2, 1,
613 // 0]
614
615 // Shift existing dimension indices in order by dimDiff to account for the
616 // new leading dimensions being added to the source shape
617 for (auto &o : order)
618 o += dimDiff;
619
620 // Add unit dimensions to the front of non-empty layout vectors and append
621 // the new dimension indices to the order array in reverse (slowest
622 // dimension has the lowest index and appears last in the order array)
623 for (int i = 0; i < dimDiff; i++) {
624 if (!sgLayout.empty())
625 sgLayout.insert(sgLayout.begin(), 1);
626 if (!sgData.empty())
627 sgData.insert(sgData.begin(), 1);
628 if (!instData.empty())
629 instData.insert(instData.begin(), 1);
630 if (!laneLayout.empty())
631 laneLayout.insert(laneLayout.begin(), 1);
632 if (!laneData.empty())
633 laneData.insert(laneData.begin(), 1);
634 order.push_back(dimDiff - 1 - i);
635 }
636
637 DenseI32ArrayAttr orderAttr = resLayout ? resLayout.getOrder() : nullptr;
638 auto toAttr = [&](ArrayRef<int64_t> v) -> DenseI32ArrayAttr {
639 if (v.empty())
640 return DenseI32ArrayAttr();
641 SmallVector<int32_t> v32(v.begin(), v.end());
642 return DenseI32ArrayAttr::get(context, v32);
643 };
644 auto srcLayout = xegpu::LayoutAttr::get(
645 context, sgLayout.empty() ? nullptr : toAttr(sgLayout),
646 sgData.empty() ? nullptr : toAttr(sgData),
647 instData.empty() ? nullptr : toAttr(instData),
648 laneLayout.empty() ? nullptr : toAttr(laneLayout),
649 laneData.empty() ? nullptr : toAttr(laneData),
650 (!orderAttr || orderAttr.empty()) ? nullptr : toAttr(order));
651 return srcLayout;
652 }
653 return resLayout;
654}
655
656/// Infers the source layout attribute for a shape cast operation given the
657/// result layout attribute, result shape, and source shape.
658xegpu::DistributeLayoutAttr
659xegpu::inferShapeCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
660 ArrayRef<int64_t> resShape,
661 ArrayRef<int64_t> srcShape) {
662
663 // There are three use cases:
664 // 1. expand dims of low-rank dimensions (e.g., 1D to 2D): to set up the
665 // tensor before broadcast
666 // 2. split dim of a high-rank dimension (e.g., 1D to 2D): to setup tensor
667 // for multi-stage reduction
668 // 3. combines all dims to a single dim and put in the innermost dim in 2d as
669 // [1, combinedData] or [combinedData]. Say, [2, 4, 8] -> [1, 64] or [64]
670 // Use cases are only supported after workgroup distribution,
671 // like cross-sg reduction saves multidimension data to
672 // 1D slm buffer, shapecast inserted by cse/canonicalization passes.
673
674 // Use case 1: Shapes only differ by expanding unit dimensions, for broadcast
675 SmallVector<int64_t> expandedUnitDims;
676
677 if (xegpu::matchUnitDimExpansion(srcShape, resShape, expandedUnitDims)) {
678 // create a slice layout for the source by removing the expanded unit dims
679 auto sliceDimsAttr = DenseI64ArrayAttr::get(
680 resLayout.getContext(), ArrayRef<int64_t>(expandedUnitDims));
681 auto srcLayout =
682 xegpu::SliceAttr::get(resLayout.getContext(), resLayout, sliceDimsAttr);
683 return srcLayout;
684 }
685
686 // Use case 2: Dim split from source to result, for multi-stage reduction
687 SmallVector<SmallVector<int64_t>> splitDimGroups;
688 if (xegpu::matchSplitDimExpansion(srcShape, resShape, splitDimGroups)) {
689 auto srcLayout = resLayout;
690 for (const auto &dimGroup : splitDimGroups)
691 srcLayout = srcLayout.collapseDims(dimGroup);
692
693 return srcLayout;
694 }
695
696 // Use case 3: Collaspse to innermost dim, for cross-sg reduction to SLM
697 auto matchCollapseToInnermostDim = [&](ArrayRef<int64_t> src,
698 ArrayRef<int64_t> dst) -> bool {
699 // only one non-unit dim in dst which is the innermost dim
700 if ((dst.size() != 2) && (dst.size() != 1))
701 return false;
702 int64_t srcSize = std::accumulate(src.begin(), src.end(), 1LL,
703 std::multiplies<int64_t>());
704 if (dst.size() == 1)
705 return (dst[0] == srcSize);
706 return (dst[0] == 1) && (dst[1] == srcSize);
707 };
708
709 if (matchCollapseToInnermostDim(srcShape, resShape)) {
710 int srcShapeSize = srcShape.size();
711 int resShapeSize = resShape.size();
712 auto context = resLayout.getContext();
713 auto resInstData = resLayout.getEffectiveInstDataAsInt();
714 auto resLaneLayout = resLayout.getEffectiveLaneLayoutAsInt();
715 auto resLaneData = resLayout.getEffectiveLaneDataAsInt();
716
717 // Extract layout info from result's innermost dimension and apply to
718 // source's innermost dimension while setting all other dimensions to 1.
719 // The inferred layout is restricted by srcShape to ensure it fits within
720 // the source dimensions.
721 // Examples 1:
722 // srcShape=[8, 16, 32], resShape=[1, 4096]
723 // resInstData=[1, 16]
724 // -> inferredInstData=[1, 1, min(16, 32)]=[1, 1, 16]
725 // Examples 2:
726 // srcShape=[4, 8, 64], resShape=[2048]
727 // resLaneLayout=[16], resLaneData=[2]
728 // -> inferredLaneLayout=[1, 1, 16]
729 // -> inferredLaneData=[1, 1, min(2, 64/16)]=[1, 1, 2]
730
731 if (resInstData.size() != 0) {
732 // assert resInstData must be 1 for all but the innermost dim
733 for (int i = 0; i < resShapeSize - 1; i++) {
734 assert(resInstData[i] == 1 &&
735 "only innermost dim can have non-unit instData");
736 }
737 SmallVector<int> inferredInstData(srcShapeSize, 1);
738 inferredInstData[srcShapeSize - 1] =
739 std::min(resInstData[resShapeSize - 1], srcShape[srcShapeSize - 1]);
740 return xegpu::LayoutAttr::get(context, inferredInstData);
741 }
742
743 if (resLaneLayout.size() != 0) {
744 for (int i = 0; i < resShapeSize - 1; i++) {
745 assert(resLaneData[i] == 1 &&
746 "only innermost dim can have non-unit instData");
747 }
748 assert(srcShape.back() % resLaneLayout.back() == 0 &&
749 "source innermost dim must be >= result lane layout");
750 SmallVector<int> inferredLaneLayout(srcShapeSize, 1);
751 SmallVector<int> inferredLaneData(srcShapeSize, 1);
752 inferredLaneLayout.back() = resLaneLayout.back();
753 inferredLaneData.back() = std::min(
754 resLaneData.back(), srcShape.back() / inferredLaneLayout.back());
755 return xegpu::LayoutAttr::get(context, inferredLaneLayout,
756 inferredLaneData);
757 }
758 }
759 llvm_unreachable("running into unsupported shape cast scenarios");
760 return nullptr;
761}
762
763/// Infers the layout attribute for mask and offset operand for Chunked load
764/// and store, given the anchor layout attribute for the value being load/store.
765xegpu::DistributeLayoutAttr xegpu::inferMaskOffsetLayoutForScatterIO(
766 xegpu::DistributeLayoutAttr payloadLayout, int chunkSize) {
767 auto rank = payloadLayout.getRank();
768 if (chunkSize > 1)
769 return payloadLayout.dropDims(
770 llvm::to_vector(llvm::seq<int64_t>(rank - 1, rank)));
771 return payloadLayout;
772}
773
774/// Sets up layout for reduction operations by creating a SliceAttr for the
775/// result.
776///
777/// Algorithm Overview:
778/// This function attempts to construct a source layout that, when sliced along
779/// reduction dimensions, produces a result layout compatible with the
780/// consumer layout.
781///
782/// For subgroup layouts, it first tries to align the source layout's subgroup
783/// layout and data with the consumer's layout on non-reduction dimensions.
784/// Then, it distributes remaining subgroups across reduction dimensions. This
785/// avoids subgroup data redistribution overhead between the reduced result and
786/// its consumer. When the consumer layout is a slice layout, it attempts to
787/// reuse the slice layout's parent layout for the source to further minimize
788/// potential data redistribution.
789///
790/// InstData requries {1, ..., min(maxReduceVectorSize, srcShape),subgroupSize}
791/// Lane Layout requires {1, ..., 1, subgroupSize}
792/// Lane data requires {1, ..., min(maxReduceVectorSize, srcShape), 1}
793///
794/// Examples:
795/// 1. Subgroup layout - Row reduction on 2D tensor:
796/// srcShape=[32, 128], reductionDims=[1], resShape=[32], subgroupSize=16,
797/// NumSg=32
798/// * Consumer Layout:
799/// #xegpu.slice<#xegpu.layout<sg_layout=[4, 8], sg_data=[8, 8]>, dims =
800/// [1]>}
801//// * Result Layout:
802/// #xegpu.slice<#xegpu.layout<sg_layout=[4, 8],sg_data=[8, 16]>, dims =
803/// [1]>}
804/// Note that the sg_layout is reused but sg_data needs to be adjusted to
805/// evenly distribute the source tensor tile among the reduction dim.
806///
807/// 2. Subgroup layout - Same example above but consumer doesn't have a
808/// reusable slice layout.
809/// * Consumer Layout:
810/// #xegpu.layout<sgLayout=[32], sgData=[1]>
811/// * Result Layout:
812/// #xegpu.slice<#xegpu.layout<sgLayout=[32,1], sgData=[1, 64]>, dims =
813/// [1]>}
814/// * Consumer Layout:
815/// #xegpu.slice<#xegpu.layout<sgLayout=[8, 2, 4], sgData=[4, 64, 32]>,
816/// dims = [1, 2]>}
817/// * Result Layout:
818/// #xegpu.slice<#xegpu.layout<sgLayout=[8,4], sgData=[4, 32]>, dims =
819/// [1]>}
820/// Note that the consumer's layout can't be directly reused as is.
821/// So the algorithm distributes all subgroups on non reduction dimensions
822/// first and then distribute remaining subgroups on the reduction
823/// dimension.
824///
825/// 2. InstData layout - Column reduction:
826/// srcShape=[32, 64], reductionDims=[0], subgroupSize=16
827/// Result: instData=[1, 16] (maxReduceVectorSize=1, subgroupSize on
828/// innermost)
829///
830/// 3. Lane layout - Multi-dimensional reduction:
831/// srcShape=[16, 32, 64], reductionDims=[1], subgroupSize=16
832/// Result: laneLayout=[1, 1, 16], laneData=[1, 1, 1]
833/// (subgroupSize on innermost dim, max vector size on reduction dim)
834
836 xegpu::LayoutKind layoutKind, VectorType srcVecTy,
837 DistributeLayoutAttr consumerLayout, SmallVector<int64_t> reductionDims,
838 int numSg, const xegpu::uArch::uArch *uArch) {
839
840 auto srcShape = srcVecTy.getShape();
841 int srcRank = srcShape.size();
842 auto context = srcVecTy.getContext();
843
844 // Helper lambda to convert int64 vectors to int32 DenseArrayAttr
845 auto toInt32Attr = [&](ArrayRef<int64_t> vec) {
846 SmallVector<int32_t> vec32(vec.begin(), vec.end());
847 return DenseI32ArrayAttr::get(context, vec32);
848 };
849
850 const int subgroupSize = uArch->getSubgroupSize();
851 int64_t maxReduceVectorSize = 1; // could extend to spirv vector Size
852 xegpu::DistributeLayoutAttr srcLayout;
853 if (layoutKind == xegpu::LayoutKind::Subgroup) {
854 xegpu::SliceAttr consumerSliceLayout =
855 dyn_cast_if_present<xegpu::SliceAttr>(consumerLayout);
856 if (consumerSliceLayout &&
857 consumerSliceLayout.getDims().asArrayRef().equals(reductionDims)) {
858 srcLayout = consumerSliceLayout.getParent();
859 SmallVector<int64_t> sgLayoutFromConsumer =
860 srcLayout.getEffectiveSgLayoutAsInt();
861 auto srcSgData = computeShapeRatio(srcShape, sgLayoutFromConsumer);
862 if (srcSgData)
863 for (int dim = 0; dim < srcRank; dim++) {
864 if (llvm::is_contained(reductionDims, dim))
865 srcLayout =
866 srcLayout.setDimData(dim, srcSgData.value()[dim], -1, -1);
867 }
868 } else {
869 SmallVector<int64_t> consumerSgLayout =
870 consumerLayout ? consumerLayout.getEffectiveSgLayoutAsInt()
872 SmallVector<int64_t> consumerSgData =
873 consumerLayout ? consumerLayout.getEffectiveSgDataAsInt()
875 SmallVector<int64_t> consumerOrder =
876 consumerLayout ? consumerLayout.getEffectiveOrderAsInt()
878 DenseI32ArrayAttr orderAttr =
879 consumerLayout ? consumerLayout.getOrder() : nullptr;
880 SmallVector<int64_t> sgLayout(srcRank), sgData(srcRank), order(srcRank);
881 int remainingSgCount =
882 consumerLayout ? consumerLayout.getNumSubgroups() : numSg;
883 int consumerIdx = 0;
884
885 // First pass: Match consumer's layout on non-reduction dimensions
886 for (int i = 0; i < srcRank; i++) {
887 if (!llvm::is_contained(reductionDims, i) &&
888 consumerIdx < static_cast<int>(consumerSgLayout.size())) {
889 sgLayout[i] = consumerSgLayout[consumerIdx];
890 sgData[i] = consumerSgData[consumerIdx];
891 remainingSgCount /= sgLayout[i];
892 order[i] = consumerOrder[consumerIdx];
893 consumerIdx++;
894 }
895 }
896
897 // Second pass: Distribute remaining subgroups across reduction dimensions
898 // the reduction to scalar case is handled only by this loop
899 int64_t remainOrder = consumerSgLayout.size();
900 for (int i = 0; i < srcRank; i++) {
901 if (llvm::is_contained(reductionDims, i)) {
902 sgLayout[i] =
903 std::min(srcShape[i], static_cast<int64_t>(remainingSgCount));
904 assert((srcShape[i] % sgLayout[i] == 0) &&
905 "source shape not divisible by sg_layout");
906 sgData[i] = srcShape[i] / sgLayout[i];
907 remainingSgCount /= sgLayout[i];
908 order[i] = remainOrder++;
909 }
910 }
911
912 assert(remainingSgCount == 1 && "not all subgroups distributed");
913 srcLayout = xegpu::LayoutAttr::get(
914 context, toInt32Attr(sgLayout), toInt32Attr(sgData),
915 /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
916 /*lane_data =*/nullptr, /*order =*/
917 (!orderAttr || orderAttr.empty()) ? nullptr : toInt32Attr(order));
918 }
919 } else if (layoutKind == xegpu::LayoutKind::InstData) {
920
921 SmallVector<int64_t> instData(srcRank, 1);
922 if (srcRank >= 2)
923 instData[srcRank - 2] =
924 std::min(maxReduceVectorSize, srcShape[srcRank - 2]);
925 instData[srcRank - 1] =
926 std::min(static_cast<int64_t>(subgroupSize), srcShape[srcRank - 1]);
927 srcLayout = xegpu::LayoutAttr::get(context, toInt32Attr(instData));
928 } else if (layoutKind == xegpu::LayoutKind::Lane) {
929
930 SmallVector<int64_t> laneLayout(srcRank, 1), laneData(srcRank, 1);
931 laneLayout[srcRank - 1] =
932 std::min(static_cast<int64_t>(subgroupSize), srcShape[srcRank - 1]);
933 if (srcRank >= 2)
934 laneData[srcRank - 2] =
935 std::min(maxReduceVectorSize, srcShape[srcRank - 2]);
936 srcLayout = xegpu::LayoutAttr::get(context, toInt32Attr(laneLayout),
937 toInt32Attr(laneData));
938 }
939
940 return xegpu::SliceAttr::get(context, srcLayout,
941 DenseI64ArrayAttr::get(context, reductionDims));
942}
943
944/// Sets up layout for Reduction operations by creating a SliceAttr for the
945/// result.
946xegpu::SliceAttr
948 VectorType srcVecTy,
949 const xegpu::uArch::uArch *uArch) {
950
951 auto srcShape = srcVecTy.getShape();
952 auto context = srcVecTy.getContext();
953 auto subgroupSize = uArch->getSubgroupSize();
954 xegpu::LayoutAttr srcLayout;
955
956 if (layoutKind == xegpu::LayoutKind::Subgroup) {
957 assert(true && "subgroup layout assignment not supported for reduction (op "
958 "is not expected at this level).");
959 } else if (layoutKind == xegpu::LayoutKind::InstData) {
960 assert(true && "instData layout assignment not supported for reduction (op "
961 "is not expected at this level).");
962 } else if (layoutKind == xegpu::LayoutKind::Lane) {
963 SmallVector<int32_t> laneLayout(1), laneData(1);
964 laneLayout[0] = std::min(subgroupSize, static_cast<int32_t>(srcShape[0]));
965 laneData[0] = 1;
966 srcLayout = xegpu::LayoutAttr::get(
967 context, DenseI32ArrayAttr::get(context, laneLayout),
968 DenseI32ArrayAttr::get(context, laneData));
969 }
970
971 auto result = xegpu::SliceAttr::get(context, srcLayout,
972 DenseI64ArrayAttr::get(context, 0));
973 return result;
974}
975
976/// Sets up the result layout for a bitcast operation.
977/// When casting to a smaller bitwidth, adjusts the layout dimensions (sgData,
978/// instData, or laneData) by multiplying by the bitwidth ratio to ensure the
979/// result layout can be correctly divided back to the source layout during
980/// inference.
981///
982/// Examples:
983/// 1. Casting f32 -> f16 (32-bit to 16-bit, bitWidthRatio = 2):
984/// Consumer layout: instData=[1, 16], subgroupSize=16
985/// Source shape: [8, 32]
986/// Result layout: instData=[1, 32] (16 * 2)
987/// The innermost dimension is multiplied by 2 to maintain consistency.
988///
989/// 2. Casting f32 -> i8 (32-bit to 8-bit, bitWidthRatio = 4):
990/// Consumer instData=[1, 16], subgroupSize=16
991/// Source shape: [4, 128]
992/// adjust the instData from [1, 16] to [1, 16 * 4 = 64]
993///
994/// 3. Casting i8 -> i32 (8-bit to 32-bit, bitWidthRatio = 1/4):
995/// Consumer layout: laneLayout=[1, 16], laneData=[1, 4]
996/// No adjustment needed - returns consumer layout directly.
997///
998xegpu::DistributeLayoutAttr xegpu::setupBitCastResultLayout(
999 xegpu::LayoutKind layoutKind, VectorType srcVecTy, VectorType resVecTy,
1000 DistributeLayoutAttr consumerLayout, const xegpu::uArch::uArch *uArch) {
1001
1002 int srcElemTyBitWidth = srcVecTy.getElementType().getIntOrFloatBitWidth();
1003 int resElemTyBitWidth = resVecTy.getElementType().getIntOrFloatBitWidth();
1004
1005 ArrayRef<int64_t> srcShape = srcVecTy.getShape();
1006 SmallVector<int64_t> sgData = consumerLayout.getEffectiveSgDataAsInt();
1007 SmallVector<int64_t> instData = consumerLayout.getEffectiveInstDataAsInt();
1008 SmallVector<int64_t> laneData = consumerLayout.getEffectiveLaneDataAsInt();
1009 assert(consumerLayout.getRank() == static_cast<int64_t>(srcShape.size()) &&
1010 "laneData must be available for all dimensions");
1011 size_t dim = srcShape.size() - 1;
1012 int64_t sgDataValue = -1;
1013 int64_t instDataValue = -1;
1014 int64_t laneDataValue = -1;
1015 const int subgroupSize = uArch->getSubgroupSize();
1016
1017 if (srcElemTyBitWidth > resElemTyBitWidth) {
1018 // When casting to a smaller bitwidth, multiply the result layout
1019 // accordingly to ensure it can be divided by the ratio back to the
1020 // source layout.
1021 int bitWidthRatio = srcElemTyBitWidth / resElemTyBitWidth;
1022 int innermostDimLaneLayout = subgroupSize;
1023 if (layoutKind == xegpu::LayoutKind::Subgroup) {
1024 sgDataValue = sgData[dim];
1025 } else if (layoutKind == xegpu::LayoutKind::InstData) {
1026 instDataValue = instData[dim];
1027 // Adjust instDataValue so it still fits within an instruction after
1028 // dividing by bitWidthRatio
1029 while ((instDataValue <= srcShape[dim]) &&
1030 (instDataValue % (innermostDimLaneLayout * bitWidthRatio) != 0))
1031 instDataValue *= 2;
1032 assert((srcShape[dim] % instDataValue) == 0 &&
1033 "srcShape, instData, and lanelayout for innermost must be 2^n !");
1034 } else if (layoutKind == xegpu::LayoutKind::Lane) {
1035 laneDataValue = laneData[dim];
1036 while ((laneDataValue <= srcShape[dim]) &&
1037 (laneDataValue % bitWidthRatio != 0))
1038 laneDataValue *= 2;
1039 }
1040 // Now set only instData and laneData, preserving sgData
1041 xegpu::DistributeLayoutAttr resLayout;
1042 resLayout = consumerLayout.setDimData(dim, sgDataValue, instDataValue,
1043 laneDataValue);
1044 return resLayout;
1045 }
1046 return consumerLayout;
1047}
1048
1049/// Sets up the result layout for an interleave operation to ensure the source
1050/// layout can be safely derived. Interleave doubles the innermost dimension,
1051/// so the result layout must ensure that laneData is a multiple
1052/// of 2, and instData must be divisible by innermostDimLaneLayout * 2.
1053///
1054/// Example:
1055/// Interleave: vector<128x256xf4> -> vector<128x512xf4>
1056/// Consumer layout: laneLayout=[1, 16], laneData=[1, 4], instData=[1, 64]
1057/// Result layout adjustment to ensure source can be safely inferred:
1058/// - laneData must be >= 2 and multiple of 2 (so source = laneData/2 is
1059/// valid)
1060/// - instData must be divisible by (16 * 2 = 32) (so source = instData/2 is
1061/// valid)
1062/// - Adjusted instData: ensure (instData % 32 == 0)
1063///
1064xegpu::DistributeLayoutAttr xegpu::setupInterleaveResultLayout(
1065 xegpu::LayoutKind layoutKind, VectorType srcVecTy, VectorType resVecTy,
1066 DistributeLayoutAttr consumerLayout, const xegpu::uArch::uArch *uArch) {
1067
1068 ArrayRef<int64_t> srcShape = srcVecTy.getShape();
1069 SmallVector<int64_t> sgData = consumerLayout.getEffectiveSgDataAsInt();
1070 SmallVector<int64_t> instData = consumerLayout.getEffectiveInstDataAsInt();
1071 SmallVector<int64_t> laneData = consumerLayout.getEffectiveLaneDataAsInt();
1072
1073 assert(consumerLayout.getRank() == static_cast<int64_t>(srcShape.size()) &&
1074 "consumer layout rank must match source shape rank");
1075 const size_t innerMostDim = srcShape.size() - 1;
1076 int64_t sgDataValue = -1;
1077 int64_t instDataValue = -1;
1078 int64_t laneDataValue = -1;
1079
1080 // Interleave doubles the innermost dimension (ratio = 2)
1081 constexpr int ratio = 2;
1082 int innermostDimLaneLayout = uArch->getSubgroupSize();
1083
1084 if (layoutKind == xegpu::LayoutKind::Subgroup) {
1085 sgDataValue = sgData[innerMostDim];
1086 // Ensure sgDataValue is divisible by ratio so source sgData can be inferred
1087 while ((sgDataValue <= srcShape[innerMostDim]) &&
1088 (sgDataValue % ratio != 0))
1089 sgDataValue *= ratio;
1090 } else if (layoutKind == xegpu::LayoutKind::InstData) {
1091 instDataValue = instData[innerMostDim];
1092 // Adjust instDataValue so it can be divided by (innermostDimLaneLayout *
1093 // ratio) when inferring the source layout
1094 while ((instDataValue <= srcShape[innerMostDim]) &&
1095 (instDataValue % (innermostDimLaneLayout * ratio) != 0))
1096 instDataValue *= ratio;
1097 assert((srcShape[innerMostDim] % instDataValue) == 0 &&
1098 "srcShape, instData, and laneLayout for innermost must be 2^n!");
1099 } else if (layoutKind == xegpu::LayoutKind::Lane) {
1100 laneDataValue = laneData[innerMostDim];
1101 // Ensure laneDataValue is at least 2 and divisible by ratio
1102 // so that source laneData = laneDataValue/2 is valid
1103 while ((laneDataValue <= srcShape[innerMostDim]) &&
1104 (laneDataValue % ratio != 0))
1105 laneDataValue *= ratio;
1106 }
1107
1108 return consumerLayout.setDimData(innerMostDim, sgDataValue, instDataValue,
1109 laneDataValue);
1110}
1111
1112/// Sets up the result layout for an insert strided slice operation.
1113/// Creates a result layout based on the specified layout kind (InstData or
1114/// Lane).
1115xegpu::DistributeLayoutAttr xegpu::setupInsertStridedSliceResultLayout(
1116 xegpu::LayoutKind layoutKind, VectorType srcVectorTy,
1117 VectorType resVectorTy, xegpu::DistributeLayoutAttr consumerLayout,
1118 const xegpu::uArch::uArch *uArch) {
1119
1120 xegpu::DistributeLayoutAttr requiredResLayout;
1121 SmallVector<int64_t> consumerInstData =
1122 consumerLayout.getEffectiveInstDataAsInt();
1123 SmallVector<int64_t> consumerLaneData =
1124 consumerLayout.getEffectiveLaneDataAsInt();
1125 SmallVector<int64_t> consumerLaneLayout =
1126 consumerLayout.getEffectiveLaneLayoutAsInt();
1127 ArrayRef<int64_t> srcShape = srcVectorTy.getShape();
1128 int64_t instDataValue = -1;
1129 int64_t laneDataValue = -1;
1130
1131 requiredResLayout = consumerLayout;
1132 int srcRank = srcShape.size();
1133
1134 if (layoutKind == xegpu::LayoutKind::Subgroup) {
1135 assert(true &&
1136 "subgroup layout assignment not supported for insertStridedSlice.");
1137 } else if (layoutKind == xegpu::LayoutKind::InstData) {
1138 for (int dim = 0; dim < srcRank; dim++) {
1139 instDataValue = std::min(srcShape[dim], consumerInstData[dim]);
1140 requiredResLayout =
1141 requiredResLayout.setDimData(dim, -1, instDataValue, -1);
1142 }
1143 } else if (layoutKind == xegpu::LayoutKind::Lane) {
1144 for (int dim = 0; dim < srcRank; dim++) {
1145 assert(srcShape[dim] % consumerLaneLayout[dim] == 0 &&
1146 "srcShape must be divisible by laneLayout for all dimensions");
1147 laneDataValue = std::min(srcShape[dim] / consumerLaneLayout[dim],
1148 consumerLaneData[dim]);
1149 requiredResLayout =
1150 requiredResLayout.setDimData(dim, -1, -1, laneDataValue);
1151 }
1152 }
1153 return requiredResLayout;
1154}
1155
1156/// Sets up the anchor layout for load gather and load matrix operation.
1157/// load matrix lowers to load gather and 1d block load. All of them share the
1158/// same layout setup logic.
1159/// For Subgroup layout, uses the consumer layout directly.
1160/// non-chunked loads:
1161/// InstData = {1, ..., min(consumer, maxLaneLoadSize * subgroupSize)}
1162/// LaneLayout = {1, ..., subgroupSize}
1163/// lane_data = {1, ..., min(consumer, maxLaneLoadSize)}
1164/// chunked loads:
1165/// InstData = {subgroupSize, min(consumer, maxLaneLoadSize)}
1166/// LaneLayout = {subgroupSize, 1}
1167/// lane_data={1,min(consumer, maxLaneLoadSize)}
1168static xegpu::DistributeLayoutAttr setupGenericLoadAnchorLayout(
1169 xegpu::LayoutKind layoutKind, mlir::MLIRContext *context,
1170 xegpu::DistributeLayoutAttr consumerLayout, bool isChunkedLoad,
1171 int maxChunkSize, ArrayRef<int64_t> resShape, int subgroupSize) {
1172
1173 if (layoutKind == xegpu::LayoutKind::Subgroup)
1174 return consumerLayout;
1175
1176 SmallVector<int64_t> consumerInstData =
1177 consumerLayout.getEffectiveInstDataAsInt();
1178 SmallVector<int64_t> consumerLaneData =
1179 consumerLayout.getEffectiveLaneDataAsInt();
1180
1181 SmallVector<int> instData(resShape.size(), 1);
1182 SmallVector<int> laneLayout(resShape.size(), 1);
1183 SmallVector<int> laneData(resShape.size(), 1);
1184
1185 if (!isChunkedLoad) {
1186 if (layoutKind == xegpu::LayoutKind::InstData) {
1187 instData.back() = std::min(static_cast<int>(consumerInstData.back()),
1188 maxChunkSize * subgroupSize);
1189 return xegpu::LayoutAttr::get(context, instData);
1190 } else if (layoutKind == xegpu::LayoutKind::Lane) {
1191 laneData.back() =
1192 std::min(static_cast<int>(consumerLaneData.back()), maxChunkSize);
1193 laneLayout.back() = std::min(static_cast<int64_t>(subgroupSize),
1194 resShape.back() / laneData.back());
1195 return xegpu::LayoutAttr::get(context, laneLayout, laneData);
1196 }
1197 } else {
1198 assert(resShape.size() == 2 && "Chunked Store must access 2D tensor tile.");
1199 if (layoutKind == xegpu::LayoutKind::InstData) {
1200 instData[0] = subgroupSize;
1201 instData[1] =
1202 std::min(static_cast<int>(consumerInstData[1]), maxChunkSize);
1203 return xegpu::LayoutAttr::get(context, instData);
1204 } else if (layoutKind == xegpu::LayoutKind::Lane) {
1205 laneLayout[0] = subgroupSize;
1206 laneData[1] =
1207 std::min(static_cast<int>(consumerLaneData[1]), maxChunkSize);
1208 return xegpu::LayoutAttr::get(context, laneLayout, laneData);
1209 }
1210 }
1211 return nullptr;
1212}
1213
1214/// Sets up the anchor layout for a load gather operation.
1215xegpu::DistributeLayoutAttr xegpu::setupLoadGatherAnchorLayout(
1216 xegpu::LayoutKind layoutKind, VectorType resVecTy, int chunkSize,
1217 xegpu::DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch) {
1218
1219 const int subgroupSize = uArch->getSubgroupSize();
1220 ArrayRef<int64_t> resShape = resVecTy.getShape();
1221 auto context = resVecTy.getContext();
1222 auto elemBitWidth = resVecTy.getElementType().getIntOrFloatBitWidth();
1223
1224 const auto *uArchInstruction =
1225 dyn_cast<xegpu::uArch::LoadGatherInstructionInterface>(
1227 int maxChunkSize = uArchInstruction->getMaxLaneLoadSize(elemBitWidth);
1228
1229 return setupGenericLoadAnchorLayout(layoutKind, context, consumerLayout,
1230 (chunkSize > 1), maxChunkSize, resShape,
1231 subgroupSize);
1232}
1233
1234/// Sets up the anchor layout for load matrix operation.
1235/// TODO: enhance load matrix to indicate lowering to chunked load or not.
1236xegpu::DistributeLayoutAttr
1238 VectorType resVecTy,
1239 xegpu::DistributeLayoutAttr consumerLayout,
1240 const xegpu::uArch::uArch *uArch) {
1241
1242 const int subgroupSize = uArch->getSubgroupSize();
1243 ArrayRef<int64_t> resShape = resVecTy.getShape();
1244 auto context = resVecTy.getContext();
1245 auto elemBitWidth = resVecTy.getElementType().getIntOrFloatBitWidth();
1246
1247 const auto *uArchInstruction =
1248 dyn_cast<xegpu::uArch::LoadGatherInstructionInterface>(
1250 int maxChunkSize = uArchInstruction->getMaxLaneLoadSize(elemBitWidth);
1251 return setupGenericLoadAnchorLayout(layoutKind, context, consumerLayout,
1252 false, maxChunkSize, resShape,
1253 subgroupSize);
1254}
1255
1256/// Sets up the anchor layout for store scatter and store matrix operation.
1257/// store matrix lowers to store scatter and 1d block store. All of them share
1258/// the same layout setup logic. For Subgroup layout, not support yet.
1259/// non-chunked stores:
1260/// InstData = {1, ..., subgroupSize}
1261/// LaneLayout = {1, ..., subgroupSize}
1262/// lane_data = {1, ..., 1}
1263/// chunked stores:
1264/// InstData = {subgroupSize, min(srcVec, maxLaneStoreSize)}
1265/// LaneLayout = {subgroupSize, 1}
1266/// lane_data={1,min(srcVec, maxLaneStoreSize)}
1267static xegpu::DistributeLayoutAttr
1269 mlir::MLIRContext *context, bool isChunkedStore,
1270 int maxChunkSize, ArrayRef<int64_t> srcShape,
1271 int subgroupSize) {
1272
1273 int srcShapeSize = srcShape.size();
1274 SmallVector<int> instData(srcShapeSize, 1);
1275 SmallVector<int> laneLayout(srcShapeSize, 1);
1276 SmallVector<int> laneData(srcShapeSize, 1);
1277
1278 if (layoutKind == xegpu::LayoutKind::Subgroup) {
1279 assert(true &&
1280 "subgroup layout assignment not supported for storeScatter.");
1281 return nullptr;
1282 }
1283
1284 if (!isChunkedStore) {
1285 if (layoutKind == xegpu::LayoutKind::InstData) {
1286 instData[srcShapeSize - 1] =
1287 std::min(subgroupSize, static_cast<int>(srcShape.back()));
1288 return xegpu::LayoutAttr::get(context, instData);
1289 } else if (layoutKind == xegpu::LayoutKind::Lane) {
1290 laneLayout[srcShapeSize - 1] =
1291 std::min(subgroupSize, static_cast<int>(srcShape.back()));
1292 return xegpu::LayoutAttr::get(context, laneLayout, laneData);
1293 }
1294 } else {
1295 assert(srcShapeSize == 2 && "Chunked Store must access 2D tensor tile.");
1296 if (layoutKind == xegpu::LayoutKind::InstData) {
1297 instData[0] = subgroupSize;
1298 instData[1] = std::min(static_cast<int>(srcShape[1]), maxChunkSize);
1299 return xegpu::LayoutAttr::get(context, instData);
1300 } else if (layoutKind == xegpu::LayoutKind::Lane) {
1301 laneLayout[0] = subgroupSize;
1302 laneData[1] = std::min(static_cast<int>(srcShape[1]), maxChunkSize);
1303 return xegpu::LayoutAttr::get(context, laneLayout, laneData);
1304 }
1305 }
1306 return nullptr;
1307}
1308
1309/// Sets up the anchor layout for a store scatter operation.
1310xegpu::DistributeLayoutAttr
1312 VectorType srcVecTy, int chunkSize,
1313 const uArch::uArch *uArch) {
1314
1315 const int subgroupSize = uArch->getSubgroupSize();
1316 ArrayRef<int64_t> srcShape = srcVecTy.getShape();
1317 auto context = srcVecTy.getContext();
1318 auto elemBitWidth = srcVecTy.getElementType().getIntOrFloatBitWidth();
1319
1320 const auto *uArchInstruction =
1321 dyn_cast<xegpu::uArch::StoreScatterInstructionInterface>(
1323 int maxChunkSize = uArchInstruction->getMaxLaneStoreSize(elemBitWidth);
1324 return setupGenericStoreAnchorLayout(layoutKind, context, (chunkSize > 1),
1325 maxChunkSize, srcShape, subgroupSize);
1326}
1327
1328/// Sets up the anchor layout for a store matrix operation.
1329xegpu::DistributeLayoutAttr
1331 VectorType srcVecTy,
1332 const xegpu::uArch::uArch *uArch) {
1333
1334 const int subgroupSize = uArch->getSubgroupSize();
1335 ArrayRef<int64_t> srcShape = srcVecTy.getShape();
1336 auto context = srcVecTy.getContext();
1337 auto elemBitWidth = srcVecTy.getElementType().getIntOrFloatBitWidth();
1338
1339 const auto *uArchInstruction =
1340 dyn_cast<xegpu::uArch::StoreScatterInstructionInterface>(
1342 int maxChunkSize = uArchInstruction->getMaxLaneStoreSize(elemBitWidth);
1343
1344 return setupGenericStoreAnchorLayout(layoutKind, context, false, maxChunkSize,
1345 srcShape, subgroupSize);
1346}
1347
1348// This function returns the default lane layout for a given vector type.
1349// - `packingSize` means multiple consecutive elements can be accessed
1350// together as a single unit.
1351// - `vnni` means data packing is column-wise (i.e., 2x1xf16 with vnni vs.
1352// 1x2xf16 w/o vnni).
1353template <typename RankedTy>
1354static xegpu::LayoutAttr getDefaultLaneLayout2DBlockIo(
1355 RankedTy ty, const xegpu::uArch::uArch *uArch,
1356 std::optional<unsigned> packingSize = std::nullopt, bool vnni = false) {
1357 // Expecting a 1D or 2D vector.
1358 assert(((ty.getRank() == 1 && !vnni) || ty.getRank() == 2) &&
1359 "Expected 1D non-vnni or 2D vector.");
1360 // Expecting int or float element type.
1361 assert(ty.getElementType().isIntOrFloat() &&
1362 "Expected int or float element type.");
1363
1364 auto context = ty.getContext();
1365 auto rank = ty.getRank();
1366 SmallVector<int> laneLayout(rank, 1);
1367 SmallVector<int> laneData(rank, 1);
1368 if (packingSize.has_value()) {
1369 unsigned bitwidth = ty.getElementType().getIntOrFloatBitWidth();
1370 int &laneDataPos = vnni ? laneData[rank - 2] : laneData.back();
1371 laneDataPos = bitwidth < *packingSize ? *packingSize / bitwidth : 1;
1372 }
1373 laneLayout.back() = uArch->getSubgroupSize();
1374 return xegpu::LayoutAttr::get(context, laneLayout, laneData);
1375}
1376
1377// This function returns all layouts for the given sgCount, whose sgData:
1378// 1. Evenly divides the wgShape.
1379// 2. Is a multiple of instData.
1380// Example:
1381// wgShape = [128, 64], instData = [8, 16], sgCount = 32
1382// Returns layouts:
1383// [(8,4), (16,2)], which correspond to sgData [16,16] and [8,32].
1384using LayoutRepresentation = std::pair<int64_t, int64_t>;
1387 int64_t sgCount) {
1389 for (int sgLayout0 = 1; sgLayout0 <= sgCount; ++sgLayout0) {
1390 if (sgCount % sgLayout0)
1391 continue;
1392 int64_t sgLayout1 = sgCount / sgLayout0;
1393 int64_t sgData0 = wgShape[0] / sgLayout0;
1394 int64_t sgData1 = wgShape[1] / sgLayout1;
1395 if ((wgShape[0] % sgLayout0 || wgShape[1] % sgLayout1) ||
1396 (sgData0 % instData[0] || sgData1 % instData[1]))
1397 continue;
1398 candidates.emplace_back(sgLayout0, sgLayout1);
1399 }
1400 // Sort primarily by how balanced they are
1401 // (i.e., minimize the absolute difference between the two dimensions), and
1402 // secondarily by the first dimension in ascending order.
1403 llvm::sort(candidates, [](const LayoutRepresentation &lhs,
1404 const LayoutRepresentation &rhs) {
1405 int diffLhs = std::abs(lhs.first - lhs.second);
1406 int diffRhs = std::abs(rhs.first - rhs.second);
1407 if (diffLhs != diffRhs)
1408 return diffLhs < diffRhs;
1409 return lhs.first < rhs.first;
1410 });
1411 return candidates;
1412}
1413
1414/// Helper function to compute inst_data vectors for DPAS operands A, B, and
1415/// C/D.
1416static std::optional<std::tuple<SmallVector<int64_t>, SmallVector<int64_t>,
1418getDpasInstDataVectors(VectorType aTy, VectorType bTy, VectorType cdTy,
1420 bool isDpasMx = false) {
1421 const int subgroupSize = uArch->getSubgroupSize();
1422
1423 const xegpu::uArch::MMAInstructionInterface *uArchInstruction;
1424 if (isDpasMx)
1425 uArchInstruction = dyn_cast<xegpu::uArch::SubgroupScaledMatrixMultiplyAcc>(
1428 else
1429 uArchInstruction =
1430 dyn_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc>(uArch->getInstruction(
1432
1433 const unsigned dataALen = aTy.getShape().front();
1434 auto supportedALen = uArchInstruction->getSupportedM(aTy.getElementType());
1435 const int maxALen =
1436 xegpu::getLargestDivisor(dataALen, ArrayRef<unsigned>(supportedALen));
1437
1438 const unsigned dataBLen = bTy.getShape().back();
1439 auto supportedBLen = uArchInstruction->getSupportedN(bTy.getElementType());
1440 const int maxBLen =
1441 xegpu::getLargestDivisor(dataBLen, ArrayRef<unsigned>(supportedBLen));
1442
1443 auto supportedCLen = uArchInstruction->getSupportedN(cdTy.getElementType());
1444 const int maxCLen =
1445 xegpu::getLargestDivisor(dataBLen, ArrayRef<unsigned>(supportedCLen));
1446 if (maxALen == -1 || maxBLen == -1 || maxCLen == -1)
1447 return std::nullopt;
1448
1449 // For DPAS_MX, use getSupportedK to get the scaled K dimension.
1450 // assume single element in the returned vector.
1451 int kDimSize = subgroupSize;
1452 if (isDpasMx) {
1453 auto supportedKLen = uArchInstruction->getSupportedK(aTy.getElementType());
1454 kDimSize = supportedKLen[0];
1455 }
1456
1457 SmallVector<int64_t> instDataA(aTy.getRank(), 1);
1458 instDataA[aTy.getRank() - 2] = maxALen;
1459 instDataA[aTy.getRank() - 1] = kDimSize;
1460 SmallVector<int64_t> instDataB(bTy.getRank(), 1);
1461 instDataB[bTy.getRank() - 2] = kDimSize;
1462 instDataB[bTy.getRank() - 1] = maxBLen;
1463 SmallVector<int64_t> instDataCD(cdTy.getRank(), 1);
1464 instDataCD[cdTy.getRank() - 2] = maxALen;
1465 instDataCD[cdTy.getRank() - 1] = maxCLen;
1466 return std::make_tuple(instDataA, instDataB, instDataCD);
1467}
1468
1469/// Helper function to set up subgroup layouts for DPAS operands A, B, and C/D.
1470/// Returns the three layouts if successful, nullopt otherwise.
1471static std::optional<
1472 std::tuple<xegpu::DistributeLayoutAttr, xegpu::DistributeLayoutAttr,
1473 xegpu::DistributeLayoutAttr>>
1475 VectorType bTy, VectorType cdTy,
1476 xegpu::DistributeLayoutAttr consumerLayout, int numSg,
1477 const xegpu::uArch::uArch *uArch) {
1478 auto instDataVecs = getDpasInstDataVectors(aTy, bTy, cdTy, uArch);
1479 if (!instDataVecs)
1480 return std::nullopt;
1481 auto [instDataA, instDataB, instDataCD] = *instDataVecs;
1482 assert(instDataA.size() == 2 && instDataB.size() == 2 &&
1483 instDataCD.size() == 2 &&
1484 "Sg layout creation expects valid 2D inst data");
1485
1486 std::optional<LayoutRepresentation> consumerSgLayout = std::nullopt;
1487 if (consumerLayout && consumerLayout.isForWorkgroup()) {
1488 SmallVector<int64_t> sgLayoutD = consumerLayout.getEffectiveSgLayoutAsInt();
1489 consumerSgLayout = std::make_pair(sgLayoutD[0], sgLayoutD[1]);
1490 }
1491
1492 // Get all valid layouts for A, B and C/D operands
1493 auto layoutsA = getValidLayouts(aTy.getShape(), instDataA, numSg);
1494 auto layoutsB = getValidLayouts(bTy.getShape(), instDataB, numSg);
1495 auto layoutsCD = getValidLayouts(cdTy.getShape(), instDataCD, numSg);
1496 if (layoutsA.empty() || layoutsB.empty() || layoutsCD.empty())
1497 return std::nullopt;
1498
1499 // Pick the best subgroup layout
1500 llvm::DenseSet<LayoutRepresentation> setA(layoutsA.begin(), layoutsA.end());
1501 llvm::DenseSet<LayoutRepresentation> setCD(layoutsCD.begin(),
1502 layoutsCD.end());
1503 std::optional<LayoutRepresentation> bestPick;
1504 auto checkAlignedSgDataAB = [&](LayoutRepresentation sgLayout) {
1505 return aTy.getShape().back() / sgLayout.second ==
1506 bTy.getShape().front() / sgLayout.first;
1507 };
1508 for (auto &sgLayout : layoutsB) {
1509 if (setA.contains(sgLayout) && setCD.contains(sgLayout)) {
1510 if (!checkAlignedSgDataAB(sgLayout))
1511 continue;
1512 // Is in (A and B and CD) and matches consumer -> best pick
1513 if (consumerSgLayout.has_value() && sgLayout == *consumerSgLayout) {
1514 bestPick = sgLayout;
1515 break;
1516 }
1517 // Is in (A and B and CD) layoutsB is ordered from most
1518 // balanced to least. So the first one we see is the most balanced one,
1519 // remember it and later only update if there is one that matches the
1520 // consumer.
1521 if (!bestPick)
1522 bestPick = sgLayout;
1523 }
1524 }
1525 if (!bestPick)
1526 return std::nullopt;
1527
1528 SmallVector<int> sgLayout = {static_cast<int>(bestPick->first),
1529 static_cast<int>(bestPick->second)};
1530 SmallVector<int> sgDataA = {static_cast<int>(aTy.getShape()[0] / sgLayout[0]),
1531 static_cast<int>(aTy.getShape()[1])};
1532 SmallVector<int> sgDataB = {
1533 static_cast<int>(bTy.getShape()[0]),
1534 static_cast<int>(bTy.getShape()[1] / sgLayout[1])};
1535 SmallVector<int> sgDataCD = {
1536 static_cast<int>(cdTy.getShape()[0] / sgLayout[0]),
1537 static_cast<int>(cdTy.getShape()[1] / sgLayout[1])};
1538
1539 auto dpasALayout =
1540 xegpu::LayoutAttr::get(context, DenseI32ArrayAttr::get(context, sgLayout),
1541 DenseI32ArrayAttr::get(context, sgDataA), nullptr,
1542 nullptr, nullptr, nullptr);
1543 auto dpasBLayout =
1544 xegpu::LayoutAttr::get(context, DenseI32ArrayAttr::get(context, sgLayout),
1545 DenseI32ArrayAttr::get(context, sgDataB), nullptr,
1546 nullptr, nullptr, nullptr);
1547 auto dpasCDLayout =
1548 xegpu::LayoutAttr::get(context, DenseI32ArrayAttr::get(context, sgLayout),
1549 DenseI32ArrayAttr::get(context, sgDataCD), nullptr,
1550 nullptr, nullptr, nullptr);
1551
1552 return std::make_tuple(dpasALayout, dpasBLayout, dpasCDLayout);
1553}
1554
1555/// Sets up the anchor layouts for dpas operands (A, B, and C/D).
1556/// The numSg and consumerLayout (optional) are only used by sg layout
1557/// creation.
1558std::optional<
1559 std::tuple<xegpu::DistributeLayoutAttr, xegpu::DistributeLayoutAttr,
1560 xegpu::DistributeLayoutAttr>>
1561xegpu::setupDpasLayout(xegpu::LayoutKind layoutKind, VectorType aTy,
1562 VectorType bTy, VectorType cdTy,
1563 xegpu::DistributeLayoutAttr consumerLayout, int numSg,
1564 const xegpu::uArch::uArch *uArch) {
1565 auto context = aTy.getContext();
1566 const auto *uArchInstruction =
1567 dyn_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc>(uArch->getInstruction(
1569
1570 if (layoutKind == xegpu::LayoutKind::Subgroup) {
1571 assert(numSg > 0 &&
1572 "Number of subgroups must be provided for sg layout creation.");
1573 return getupDpasSubgroupLayouts(context, aTy, bTy, cdTy, consumerLayout,
1574 numSg, uArch);
1575 } else if (layoutKind == xegpu::LayoutKind::InstData) {
1576 auto instDataVecs = getDpasInstDataVectors(aTy, bTy, cdTy, uArch);
1577 if (!instDataVecs)
1578 return std::nullopt;
1579 auto [instDataA, instDataB, instDataCD] = *instDataVecs;
1580 return std::make_tuple(
1581 xegpu::LayoutAttr::get(
1582 context, SmallVector<int>(instDataA.begin(), instDataA.end())),
1583 xegpu::LayoutAttr::get(
1584 context, SmallVector<int>(instDataB.begin(), instDataB.end())),
1585 xegpu::LayoutAttr::get(
1586 context, SmallVector<int>(instDataCD.begin(), instDataCD.end())));
1587 } else if (layoutKind == xegpu::LayoutKind::Lane) {
1588 auto aLayout = getDefaultLaneLayout2DBlockIo(
1589 aTy, uArch, uArchInstruction->getPackedFormatBitSizeA());
1590 auto bLayout = getDefaultLaneLayout2DBlockIo(
1591 bTy, uArch, uArchInstruction->getPackedFormatBitSizeB(), true);
1592 auto cdLayout = getDefaultLaneLayout2DBlockIo(
1593 cdTy, uArch /*, packingSize = std::nullopt */);
1594 return std::make_tuple(aLayout, bLayout, cdLayout);
1595 }
1596 return std::nullopt;
1597}
1598
1599/// Helper to create a scale layout derived from a matrix operand layout.
1600/// The scale layout is computed by mapping each dimension of the matrix layout
1601/// to the corresponding scale tensor dimension using the ratio between the
1602/// matrix and scale shapes.
1603static xegpu::DistributeLayoutAttr
1604createScaleLayout(mlir::MLIRContext *context, VectorType matrixTy,
1605 VectorType scaleTy, xegpu::DistributeLayoutAttr matrixLayout,
1606 bool isBScale, const xegpu::uArch::uArch *uArch) {
1607 if (!scaleTy || !matrixLayout)
1608 return nullptr;
1609
1610 // Calculate scaling factor by dividing matrix shape by scale shape
1611 ArrayRef<int64_t> matrixShape = matrixTy.getShape();
1612 ArrayRef<int64_t> scaleShape = scaleTy.getShape();
1613
1614 // Scale shapes can be 1D or 2D, handle both cases
1615 if (scaleShape.empty())
1616 return nullptr;
1617
1618 auto uArchInstruction =
1619 dyn_cast<xegpu::uArch::SubgroupScaledMatrixMultiplyAcc>(
1622
1623 int64_t rank = matrixLayout.getRank();
1624 assert(rank == 2 && "dpas layouts must be two dimensions");
1625
1626 SmallVector<int64_t> sgLayout = matrixLayout.getEffectiveSgLayoutAsInt();
1627 SmallVector<int64_t> sgData = matrixLayout.getEffectiveSgDataAsInt();
1628 SmallVector<int64_t> instData = matrixLayout.getEffectiveInstDataAsInt();
1629 SmallVector<int64_t> laneLayout = matrixLayout.getEffectiveLaneLayoutAsInt();
1630 SmallVector<int64_t> laneData = matrixLayout.getEffectiveLaneDataAsInt();
1631 auto order = matrixLayout.getOrder();
1632
1633 SmallVector<int> scaleSgLayout;
1634 SmallVector<int> scaleSgData;
1635 if (!sgLayout.empty() && !sgData.empty()) {
1636 scaleSgLayout.assign(sgLayout.begin(), sgLayout.end());
1637 scaleSgData.assign(sgData.begin(), sgData.end());
1638 scaleSgData[rank - 2] = std::max<int64_t>(
1639 scaleShape[rank - 2] / (matrixShape[rank - 2] / sgData[rank - 2]), 1);
1640 scaleSgData[rank - 1] = std::max<int64_t>(
1641 scaleShape[rank - 1] / (matrixShape[rank - 1] / sgData[rank - 1]), 1);
1642 }
1643
1644 // For DPAS_MX scales: if matrix has inst_data, scale needs adjusted
1645 // inst_data. Scale inst_data is derived from matrix inst_data divided by
1646 // scale factor.
1647 SmallVector<int> scaleInstData;
1648 if (!instData.empty()) {
1649 scaleInstData.assign(instData.begin(), instData.end());
1650 if (isBScale)
1651 scaleInstData[rank - 2] = std::max<int64_t>(
1652 scaleShape[rank - 2] / (matrixShape[rank - 2] / instData[rank - 2]),
1653 1);
1654 else
1655 scaleInstData[rank - 1] = std::max<int64_t>(
1656 scaleShape[rank - 1] / (matrixShape[rank - 1] / instData[rank - 1]),
1657 1);
1658 }
1659
1660 SmallVector<int> scaleLaneLayout;
1661 SmallVector<int> scaleLaneData;
1662 if (!laneLayout.empty() && !laneData.empty()) {
1663 scaleLaneLayout.assign(laneLayout.begin(), laneLayout.end());
1664 scaleLaneData.assign(laneData.begin(), laneData.end());
1665 bool isRowMajor = uArchInstruction->isLaneLayoutRowMajorOrder();
1666 if (isBScale ^ isRowMajor) {
1667 std::swap(scaleLaneLayout[rank - 2], scaleLaneLayout[rank - 1]);
1668 scaleLaneLayout[rank - 2] =
1669 std::min<int64_t>(scaleShape[rank - 2], scaleLaneLayout[rank - 2]);
1670 }
1671 scaleLaneData[rank - 2] =
1672 std::max<int64_t>(scaleShape[rank - 2] / scaleLaneLayout[rank - 2], 1);
1673 scaleLaneData[rank - 1] =
1674 std::max<int64_t>(scaleShape[rank - 1] / scaleLaneLayout[rank - 1], 1);
1675 }
1676 return xegpu::LayoutAttr::get(
1677 context,
1678 scaleSgLayout.empty() ? nullptr
1679 : DenseI32ArrayAttr::get(context, scaleSgLayout),
1680 scaleSgData.empty() ? nullptr
1681 : DenseI32ArrayAttr::get(context, scaleSgData),
1682 scaleInstData.empty() ? nullptr
1683 : DenseI32ArrayAttr::get(context, scaleInstData),
1684 scaleLaneLayout.empty()
1685 ? nullptr
1686 : DenseI32ArrayAttr::get(context, scaleLaneLayout),
1687 scaleLaneData.empty() ? nullptr
1688 : DenseI32ArrayAttr::get(context, scaleLaneData),
1689 order);
1690}
1691
1692/// Sets up the anchor layouts for dpas_mx operands (A, B, C/D, A_scale, and
1693/// B_scale). The numSg and consumerLayout (optional) are only used by sg layout
1694/// creation.
1695std::optional<
1696 std::tuple<xegpu::DistributeLayoutAttr, xegpu::DistributeLayoutAttr,
1697 xegpu::DistributeLayoutAttr, xegpu::DistributeLayoutAttr,
1698 xegpu::DistributeLayoutAttr>>
1699xegpu::setupDpasMxLayout(xegpu::LayoutKind layoutKind, VectorType aTy,
1700 VectorType bTy, VectorType cdTy, VectorType aScaleTy,
1701 VectorType bScaleTy,
1702 xegpu::DistributeLayoutAttr consumerLayout, int numSg,
1703 const xegpu::uArch::uArch *uArch) {
1704 auto context = aTy.getContext();
1705
1706 if (layoutKind == xegpu::LayoutKind::Subgroup) {
1707 assert(numSg > 0 &&
1708 "Number of subgroups must be provided for sg layout creation.");
1709 auto dpasLayouts = getupDpasSubgroupLayouts(context, aTy, bTy, cdTy,
1710 consumerLayout, numSg, uArch);
1711 if (!dpasLayouts)
1712 return std::nullopt;
1713
1714 auto [dpasALayout, dpasBLayout, dpasCDLayout] = *dpasLayouts;
1715
1716 // Create scale layouts
1717 auto aScaleLayout =
1718 createScaleLayout(context, aTy, aScaleTy, dpasALayout, false, uArch);
1719
1720 auto bScaleLayout =
1721 createScaleLayout(context, bTy, bScaleTy, dpasBLayout, true, uArch);
1722
1723 return std::make_tuple(dpasALayout, dpasBLayout, dpasCDLayout, aScaleLayout,
1724 bScaleLayout);
1725 } else if (layoutKind == xegpu::LayoutKind::InstData) {
1726 auto instDataVecs =
1727 getDpasInstDataVectors(aTy, bTy, cdTy, uArch, /*isDpasMx=*/true);
1728 if (!instDataVecs)
1729 return std::nullopt;
1730 auto [instDataA, instDataB, instDataCD] = *instDataVecs;
1731
1732 auto dpasALayout = xegpu::LayoutAttr::get(
1733 context, SmallVector<int>(instDataA.begin(), instDataA.end()));
1734 auto dpasBLayout = xegpu::LayoutAttr::get(
1735 context, SmallVector<int>(instDataB.begin(), instDataB.end()));
1736 auto dpasCDLayout = xegpu::LayoutAttr::get(
1737 context, SmallVector<int>(instDataCD.begin(), instDataCD.end()));
1738
1739 // Create scale layouts
1740 auto aScaleLayout =
1741 createScaleLayout(context, aTy, aScaleTy, dpasALayout, false, uArch);
1742 auto bScaleLayout =
1743 createScaleLayout(context, bTy, bScaleTy, dpasBLayout, true, uArch);
1744
1745 return std::make_tuple(dpasALayout, dpasBLayout, dpasCDLayout, aScaleLayout,
1746 bScaleLayout);
1747 } else if (layoutKind == xegpu::LayoutKind::Lane) {
1748 const auto *uArchInstruction =
1749 dyn_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc>(uArch->getInstruction(
1751 auto aLayout = getDefaultLaneLayout2DBlockIo(
1752 aTy, uArch, uArchInstruction->getPackedFormatBitSizeA());
1753 auto bLayout = getDefaultLaneLayout2DBlockIo(
1754 bTy, uArch, uArchInstruction->getPackedFormatBitSizeB(), true);
1755 auto cdLayout = getDefaultLaneLayout2DBlockIo(cdTy, uArch);
1756
1757 // Create scale layouts
1758 auto aScaleLayout =
1759 createScaleLayout(context, aTy, aScaleTy, aLayout, false, uArch);
1760 auto bScaleLayout =
1761 createScaleLayout(context, bTy, bScaleTy, bLayout, true, uArch);
1762
1763 return std::make_tuple(aLayout, bLayout, cdLayout, aScaleLayout,
1764 bScaleLayout);
1765 }
1766 return std::nullopt;
1767}
1768
1769xegpu::DistributeLayoutAttr
1771 xegpu::DistributeLayoutAttr resLayout) {
1772 if (!resLayout)
1773 return nullptr;
1774 Operation *op = operand.getOwner();
1775 unsigned idx = operand.getOperandNumber();
1776
1777 // For vector::BroadcastOp, infer the source layout from the result layout.
1778 if (auto broadcast = dyn_cast<vector::BroadcastOp>(op)) {
1779 auto srcTy = dyn_cast<VectorType>(broadcast.getSourceType());
1780 if (!srcTy)
1781 return nullptr;
1783 resLayout, broadcast.getResultVectorType().getShape(),
1784 srcTy.getShape());
1785 }
1786
1787 // For vector::MultiDimReductionOp, infer source layout from result layout
1788 // using reduction dims. Acc operand is expected to have the same layout as
1789 // the result.
1790 if (auto reduction = dyn_cast<vector::MultiDimReductionOp>(op)) {
1791 if (idx == 0) {
1792 SmallVector<int64_t> reductionDims(reduction.getReductionDims());
1793 return xegpu::inferMultiReductionSourceLayout(resLayout, reductionDims);
1794 }
1795 if (idx == 1)
1796 return resLayout;
1797 }
1798
1799 if (auto reduction = dyn_cast<vector::ReductionOp>(op))
1800 return xegpu::inferReductionSourceLayout(resLayout);
1801
1802 // For vector::BitCastOp, infer source layout from result layout using
1803 // element type bitwidths.
1804 if (auto bitcast = dyn_cast<vector::BitCastOp>(op)) {
1805 int resElemBitWidth =
1806 bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
1807 int srcElemBitWidth =
1808 bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
1809 return xegpu::inferBitCastSourceLayout(resLayout, resElemBitWidth,
1810 srcElemBitWidth);
1811 }
1812
1813 // For vector::ShapeCastOp, infer source layout from result layout using
1814 // shapes.
1815 if (auto shapeCast = dyn_cast<vector::ShapeCastOp>(op)) {
1817 resLayout, shapeCast.getResultVectorType().getShape(),
1818 shapeCast.getSourceVectorType().getShape());
1819 }
1820
1821 // For vector::InsertStridedSliceOp, infer source layout from result layout.
1822 // Dest vector must have the same layout as the result.
1823 if (auto insertSlice = dyn_cast<vector::InsertStridedSliceOp>(op)) {
1824 if (idx == 0) {
1826 resLayout, insertSlice.getDestVectorType().getShape(),
1827 insertSlice.getSourceVectorType().getShape());
1828 }
1829 if (idx == 1)
1830 return resLayout;
1831 }
1832
1833 // For vector::Insert Op, infer source layout from result layout using
1834 // shapes.
1835 if (auto insert = dyn_cast<vector::InsertOp>(op)) {
1836 VectorType resVecTy = dyn_cast<VectorType>(insert.getResult().getType());
1837 VectorType valueToStoreTy =
1838 dyn_cast<VectorType>(insert.getValueToStore().getType());
1839
1840 if ((idx == 0) && valueToStoreTy) {
1841 return xegpu::inferInsertSourceLayout(resLayout, resVecTy.getShape(),
1842 valueToStoreTy.getShape());
1843 }
1844 if (idx == 1)
1845 return resLayout;
1846 }
1847
1848 // For vector::Extract Op, infer source layout from result layout using
1849 // shapes.
1850 if (auto extract = dyn_cast<vector::ExtractOp>(op)) {
1851 VectorType srcVecTy = dyn_cast<VectorType>(extract.getSource().getType());
1852 VectorType resVecTy = dyn_cast<VectorType>(extract.getResult().getType());
1853 if (!srcVecTy || !resVecTy)
1854 return nullptr;
1855 return xegpu::inferExtractSourceLayout(resLayout, resVecTy.getShape(),
1856 srcVecTy.getShape());
1857 }
1858
1859 // For vector::TransposeOp, infer source layout from result layout using
1860 // permutation.
1861 if (auto transpose = dyn_cast<vector::TransposeOp>(op)) {
1862 return xegpu::inferTransposeSourceLayout(resLayout,
1863 transpose.getPermutation());
1864 }
1865
1866 // For vector::BitCastOp, infer source layout from result layout using
1867 // element type bitwidths.
1868 if (auto bitcast = dyn_cast<vector::BitCastOp>(op)) {
1869 int resElemBitWidth =
1870 bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
1871 int srcElemBitWidth =
1872 bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
1873 return xegpu::inferBitCastSourceLayout(resLayout, resElemBitWidth,
1874 srcElemBitWidth);
1875 }
1876
1877 // for vector::interleave
1878 if (auto interleave = dyn_cast<vector::InterleaveOp>(op)) {
1879 return xegpu::inferInterleaveSourceLayout(resLayout);
1880 }
1881
1882 // for vector::deinterleave
1883 if (auto deinterleave = dyn_cast<vector::DeinterleaveOp>(op)) {
1884 return xegpu::inferDeinterleaveSourceLayout(resLayout);
1885 }
1886
1887 // For vector::ExtractStridedSliceOp, simply return result layout
1888 if (dyn_cast<vector::ExtractStridedSliceOp>(op))
1889 return resLayout;
1890 // For elementwise operations, all operands must have the same layout as the
1891 // result.
1893 return resLayout;
1894
1895 return nullptr;
1896}
1897
1898xegpu::DistributeLayoutAttr xegpu::getConsumerLayoutAt(OpOperand &operand) {
1899 Operation *op = operand.getOwner();
1900 xegpu::DistributeLayoutAttr resLayout;
1901 if (op->getNumResults() == 1)
1902 resLayout = xegpu::getDistributeLayoutAttr(op->getResult(0));
1903 auto inferredOperandLayout = inferSourceLayoutFromResult(operand, resLayout);
1904 if (inferredOperandLayout)
1905 return inferredOperandLayout;
1906 // By default, assume no layout conflict and return the current layout of
1907 // the operand.
1908 return xegpu::getDistributeLayoutAttr(operand.get());
1909}
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
Definition PDL.cpp:62
lhs
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
std::pair< int64_t, int64_t > LayoutRepresentation
static xegpu::DistributeLayoutAttr createScaleLayout(mlir::MLIRContext *context, VectorType matrixTy, VectorType scaleTy, xegpu::DistributeLayoutAttr matrixLayout, bool isBScale, const xegpu::uArch::uArch *uArch)
Helper to create a scale layout derived from a matrix operand layout.
static std::optional< std::tuple< SmallVector< int64_t >, SmallVector< int64_t >, SmallVector< int64_t > > > getDpasInstDataVectors(VectorType aTy, VectorType bTy, VectorType cdTy, const xegpu::uArch::uArch *uArch, bool isDpasMx=false)
Helper function to compute inst_data vectors for DPAS operands A, B, and C/D.
static xegpu::DistributeLayoutAttr getLayoutFromUsePoints(Value result)
static xegpu::DistributeLayoutAttr setupGenericStoreAnchorLayout(xegpu::LayoutKind layoutKind, mlir::MLIRContext *context, bool isChunkedStore, int maxChunkSize, ArrayRef< int64_t > srcShape, int subgroupSize)
Sets up the anchor layout for store scatter and store matrix operation.
static std::optional< std::tuple< xegpu::DistributeLayoutAttr, xegpu::DistributeLayoutAttr, xegpu::DistributeLayoutAttr > > getupDpasSubgroupLayouts(mlir::MLIRContext *context, VectorType aTy, VectorType bTy, VectorType cdTy, xegpu::DistributeLayoutAttr consumerLayout, int numSg, const xegpu::uArch::uArch *uArch)
Helper function to set up subgroup layouts for DPAS operands A, B, and C/D.
static void propagateResultsToRegularOperands(Operation *op)
static void propagateRegionResultsToYieldOperands(mlir::RegionBranchTerminatorOpInterface yieldOp)
static SmallVector< LayoutRepresentation > getValidLayouts(ArrayRef< int64_t > wgShape, ArrayRef< int64_t > instData, int64_t sgCount)
static void propagateRegionArgsToInits(mlir::RegionBranchOpInterface regionOp)
static void setTensorDescLayout(Value val, xegpu::DistributeLayoutAttr layout)
static xegpu::LayoutAttr getDefaultLaneLayout2DBlockIo(RankedTy ty, const xegpu::uArch::uArch *uArch, std::optional< unsigned > packingSize=std::nullopt, bool vnni=false)
static void walkRegionBackward(Region &region, llvm::function_ref< void(Operation *)> visit)
static xegpu::DistributeLayoutAttr setupGenericLoadAnchorLayout(xegpu::LayoutKind layoutKind, mlir::MLIRContext *context, xegpu::DistributeLayoutAttr consumerLayout, bool isChunkedLoad, int maxChunkSize, ArrayRef< int64_t > resShape, int subgroupSize)
Sets up the anchor layout for load gather and load matrix operation.
Block represents an ordered list of Operations.
Definition Block.h:33
IRValueT get() const
Return the current value being used by this operand.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class represents an operand of an operation.
Definition Value.h:254
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
This is a value defined by a result of an operation.
Definition Value.h:454
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:44
unsigned getBeginOperandIndex() const
Return the operand index of the first element of this range.
type_range getType() const
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool hasAttrOfType(NameT &&name)
Definition Operation.h:601
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:538
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:433
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:409
auto getDiscardableAttrs()
Return a range of all of discardable attributes on this operation.
Definition Operation.h:512
Attribute removeDiscardableAttr(StringAttr name)
Remove the discardable attribute with the specified name if it exists.
Definition Operation.h:498
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:823
Attribute removeAttr(StringAttr name)
Remove the attribute with the specified name if it exists.
Definition Operation.h:626
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:430
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
bool empty()
Definition Region.h:60
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Definition Value.h:116
Type getType() const
Return the type of this value.
Definition Value.h:105
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int64_t > content)
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
DistributeLayoutAttr inferShapeCastSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for a shape cast operation given the result layout attribute,...
DistributeLayoutAttr setupInterleaveResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the result layout for an interleave operation to ensure the source layout can be safely deriv...
DistributeLayoutAttr inferTransposeSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > permutation)
Infers the source layout attribute for a transpose operation given the result layout attribute and pe...
DistributeLayoutAttr inferInsertSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for an insert operation.
DistributeLayoutAttr inferInsertStridedSliceSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for an insert strided slice operation given the result layout attr...
void removeTemporaryLayoutAttrs(Operation *op)
Removes the temporary layout attributes for each OpOperand and OpResult of the given operation.
void setTemporaryLayout(const T &operandOrResult, const DistributeLayoutAttr layout)
LayoutKind
Specifies the level of a layout hierarchy for comparison or propagation.
Definition XeGPU.h:32
SmallVector< NamedAttribute > dropInstDataOnAttrs(ArrayRef< NamedAttribute > attrs)
Updates the NamedAttribute sequence by dropping inst-data information from any DistributeLayoutAttr f...
DistributeLayoutAttr inferInterleaveSourceLayout(DistributeLayoutAttr resLayout)
Infers the source layout attribute for an interleave operation given the result layout attribute.
bool matchUnitDimExpansion(ArrayRef< int64_t > src, ArrayRef< int64_t > dst, SmallVector< int64_t > &expandedUnitDims)
DistributeLayoutAttr setupLoadMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the anchor layout for load matrix operation.
int getLargestDivisor(T dim, ArrayRef< T > candidates, ArrayRef< T > candidateMultiples={})
Helper Function to find a proper instruction multiple for the user-supplied sg-level data shape (dive...
bool recoverTemporaryLayouts(Operation *rootOp)
Attach layout attributes to all vector-type operands of operations within the given operation's neste...
DistributeLayoutAttr inferBroadcastSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for a broadcast operation given the result layout attribute,...
std::optional< std::tuple< DistributeLayoutAttr, DistributeLayoutAttr, DistributeLayoutAttr, DistributeLayoutAttr, DistributeLayoutAttr > > setupDpasMxLayout(LayoutKind layoutKind, VectorType aTy, VectorType bTy, VectorType cdTy, VectorType aScaleTy, VectorType bScaleTy, DistributeLayoutAttr consumerLayout, int numSg, const uArch::uArch *uArch)
Sets up the anchor layouts for dpas_mx operands (A, B, C/D, A_scale, and B_scale).
DistributeLayoutAttr setupStoreScatterAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, int chunkSize, const uArch::uArch *uArch)
Sets up the anchor layout for a store scatter operation.
SliceAttr setupMultiReductionResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, DistributeLayoutAttr consumerLayout, SmallVector< int64_t > reductionDims, int numSg, const uArch::uArch *uArch)
Sets up layout for Multi-Reduction operations by creating a SliceAttr for the result.
DistributeLayoutAttr inferSourceLayoutFromResult(OpOperand &operand, DistributeLayoutAttr resLayout)
Infers the source layout attribute for an operand using result layout attribute.
bool matchSplitDimExpansion(ArrayRef< int64_t > src, ArrayRef< int64_t > dst, SmallVector< SmallVector< int64_t > > &splitDimGroups)
DistributeLayoutAttr setupBitCastResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Setup the result layout attribute for a bitcast operation based on element type bitwidths.
void removeLayoutAttr(const T &operandOrResult)
Removes the LayoutAttr for a given OpOperand or OpResult if it exists.
DistributeLayoutAttr inferMaskOffsetLayoutForScatterIO(DistributeLayoutAttr payloadLayout, int chunkSize)
Infers the layout attribute for mask and offset operand for Chunked load and store,...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
SmallVector< NamedAttribute > dropSgLayoutAndDataOnAttrs(ArrayRef< NamedAttribute > attrs)
Updates the NamedAttribute sequence by dropping sg-layout and sg-data information from any Distribute...
DistributeLayoutAttr inferExtractSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for an extract operation.
std::string getTemporaryLayoutName(const OpOperand &operand)
Return the attribute name for the OpOperand to attach DistributeLayoutAttr.
DistributeLayoutAttr inferBitCastSourceLayout(DistributeLayoutAttr resLayout, int resElemTyBitWidth, int srcElemTyBitWidth)
Infers the source layout attribute for a bitcast operation given the result layout attribute,...
DistributeLayoutAttr setupInsertStridedSliceResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the result layout for an insert strided slice operation.
DistributeLayoutAttr inferReductionSourceLayout(DistributeLayoutAttr resLayout)
Infers the source layout attribute for a reduction operation given the result layout attribute and re...
DistributeLayoutAttr inferDeinterleaveSourceLayout(DistributeLayoutAttr resLayout)
Infers the source layout attribute for a deinterleave operation given the result layout attribute.
DistributeLayoutAttr getConsumerLayoutAt(OpOperand &operand)
Gets the expected layout for a given consumer operand.
void removeLayoutAttrs(Operation *op)
Removes the DistributeLayoutAttr for each OpOperand and OpResult of the given operation if they exist...
DistributeLayoutAttr inferMultiReductionSourceLayout(DistributeLayoutAttr resLayout, SmallVector< int64_t > reduceDims)
Infers the source layout attribute for a reduction operation given the result layout attribute and re...
DistributeLayoutAttr setupLoadGatherAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, int chunkSize, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the anchor layout for a load gather operation.
std::optional< std::tuple< DistributeLayoutAttr, DistributeLayoutAttr, DistributeLayoutAttr > > setupDpasLayout(LayoutKind layoutKind, VectorType aTy, VectorType bTy, VectorType cdTy, DistributeLayoutAttr consumerLayout, int numSg, const uArch::uArch *uArch)
Sets up the anchor layouts for a dpas operands (A, B, and C/D).
SliceAttr setupReductionResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, const uArch::uArch *uArch)
Sets up layout for Reduction operations by creating a SliceAttr for the result.
DistributeLayoutAttr setupStoreMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, const uArch::uArch *uArch)
Sets up the anchor layout for a store matrix operation.
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
virtual llvm::SmallVector< uint32_t, 8 > getSupportedN(Type type) const =0
virtual llvm::SmallVector< uint32_t, 8 > getSupportedK(Type type) const =0
virtual llvm::SmallVector< uint32_t, 8 > getSupportedM(Type type) const =0
virtual int getSubgroupSize() const =0
uArch(StringRef name, StringRef description, llvm::ArrayRef< const Instruction * > instructionRegistry)
Definition uArchBase.h:156
const Instruction * getInstruction(InstructionKind instKind) const
Definition uArchBase.h:168