MLIR 23.0.0git
XeGPUPropagateLayout.cpp
Go to the documentation of this file.
1//===- XeGPUPropagateLayout.cpp - XeGPU Layout Propagation ------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
21#include "mlir/IR/Attributes.h"
22#include "mlir/IR/Builders.h"
25#include "mlir/IR/Operation.h"
26#include "mlir/IR/Value.h"
27#include "mlir/IR/Visitors.h"
31#include "mlir/Support/LLVM.h"
32#include "llvm/ADT/ArrayRef.h"
33#include "llvm/ADT/STLExtras.h"
34#include "llvm/ADT/SmallSet.h"
35#include "llvm/ADT/SmallVector.h"
36#include "llvm/ADT/TypeSwitch.h"
37#include "llvm/Support/Casting.h"
38#include "llvm/Support/Debug.h"
39#include "llvm/Support/LogicalResult.h"
40#include "llvm/Support/raw_ostream.h"
41
42namespace mlir {
43namespace xegpu {
44#define GEN_PASS_DEF_XEGPUPROPAGATELAYOUT
45#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
46} // namespace xegpu
47} // namespace mlir
48
49#define DEBUG_TYPE "xegpu-propagate-layout"
50#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
51
52using namespace mlir;
53using namespace mlir::dataflow;
54
55namespace {
56
57//===----------------------------------------------------------------------===//
58// LayoutInfo
59//===----------------------------------------------------------------------===//
60
61/// Helper class for tracking the analysis state of an mlir value. For layout
62/// propagation, the analysis state is simply the distribution layout of
63/// each value. The distribution layout information is encapsulated using
64/// xegpu::DistributeLayoutAttr class which can hold information about any type
65/// of distribution layout that XeGPU dialect supports. Purpose of this analysis
66/// to propagate some unique distribution layout for each value in the program
67/// starting from a set of anchor operations (like DPAS, StoreNd, etc.). Note
68/// that analysis will reach a fixed point when all values are reached some
69/// layout and, analysis does not try to modify any already assigned layouts.
70///
71/// Given this, LayoutInfo satisifies the following properties:
72/// 1) A LayoutInfo value can be in one of two states - `assigned` or `not
73/// assigned`.
74/// 2) Two LayoutInfo values are equal if they are both assigned or
75/// both not assigned. The concrete value of assigned state does not matter.
76/// 3) The meet operator works as follows:
77/// - If current state is assigned, return the current state. (already
78/// a unique layout is assigned. don't change it)
79/// - Otherwise, return the other state.
80
81struct LayoutInfo {
82private:
83 xegpu::DistributeLayoutAttr storage = nullptr;
84
85public:
86 LayoutInfo() = default;
87 LayoutInfo(const xegpu::DistributeLayoutAttr &layout) : storage(layout) {}
88
89 // Two lattice values are equal if they have `some` layout. The actual
90 // content of the layout does not matter.
91 bool operator==(const LayoutInfo &other) const {
92 return this->isAssigned() == other.isAssigned();
93 }
94
95 static LayoutInfo meet(const LayoutInfo &lhs, const LayoutInfo &rhs);
96
97 static LayoutInfo join(const LayoutInfo &lhs, const LayoutInfo &rhs);
98
99 void print(raw_ostream &os) const;
100
101 bool isAssigned() const { return storage != nullptr; }
102
103 LayoutInfo transpose(ArrayRef<int64_t> permutation) const;
104
105 SmallVector<int> getLaneLayout() const;
106
107 SmallVector<int> getLaneData() const;
108
109 SmallVector<int> getInstData() const;
110
111 SmallVector<int> getSgLayout() const;
112
113 SmallVector<int> getSgData() const;
114
115 SmallVector<int> getOrder() const;
116
117 bool isSliceLayout() const {
118 if (!isAssigned())
119 return false;
120 return isa<xegpu::SliceAttr>(storage);
121 }
122
123 int64_t getRank() const {
124 if (!isAssigned())
125 return -1;
126 return storage.getRank();
127 }
128
129 Attribute get() { return storage; }
130 void set(const xegpu::DistributeLayoutAttr &layout) { storage = layout; }
131};
132
133SmallVector<int> LayoutInfo::getLaneLayout() const {
134 if (!isAssigned())
135 return {};
136 return llvm::map_to_vector(storage.getEffectiveLaneLayoutAsInt(),
137 [](int64_t val) { return static_cast<int>(val); });
138}
139
140SmallVector<int> LayoutInfo::getLaneData() const {
141 if (!isAssigned())
142 return {};
143 return llvm::map_to_vector(storage.getEffectiveLaneDataAsInt(),
144 [](int64_t val) { return static_cast<int>(val); });
145}
146
147SmallVector<int> LayoutInfo::getInstData() const {
148 if (!isAssigned())
149 return {};
150 return llvm::map_to_vector(storage.getEffectiveInstDataAsInt(),
151 [](int64_t val) { return static_cast<int>(val); });
152}
153
154SmallVector<int> LayoutInfo::getSgLayout() const {
155 if (!isAssigned())
156 return {};
157 return llvm::map_to_vector(storage.getEffectiveSgLayoutAsInt(),
158 [](int64_t val) { return static_cast<int>(val); });
159}
160
161SmallVector<int> LayoutInfo::getSgData() const {
162 if (!isAssigned())
163 return {};
164 return llvm::map_to_vector(storage.getEffectiveSgDataAsInt(),
165 [](int64_t val) { return static_cast<int>(val); });
166}
167
168SmallVector<int> LayoutInfo::getOrder() const {
169 if (!isAssigned() || !storage.getOrder())
170 return {};
171 return llvm::map_to_vector(storage.getOrder().asArrayRef(),
172 [](int64_t val) { return static_cast<int>(val); });
173}
174
175void LayoutInfo::print(raw_ostream &os) const {
176 if (isAssigned()) {
177 os << storage;
178 } else {
179 os << "Not assigned.";
180 }
181}
182
183LayoutInfo LayoutInfo::meet(const LayoutInfo &lhs, const LayoutInfo &rhs) {
184 if (!lhs.isAssigned())
185 return rhs;
186 return lhs;
187}
188
189/// Since this is a backward analysis, join method is not used.
190LayoutInfo LayoutInfo::join(const LayoutInfo &lhs, const LayoutInfo &rhs) {
191 llvm_unreachable("Join should not be triggered by layout propagation.");
192}
193
194/// Construct a new layout with the transposed inst_data or lane_layout,
195/// lane_data.
196LayoutInfo LayoutInfo::transpose(ArrayRef<int64_t> permutation) const {
197 if (!isAssigned())
198 return {};
199 // Check if the permutation is valid.
200 llvm::SmallSet<int64_t, 4> seen(permutation.begin(), permutation.end());
201 bool hasDuplicates = seen.size() != permutation.size();
202 bool withinRange = llvm::all_of(permutation, [&](int64_t idx) {
203 return idx >= 0 && idx < static_cast<int64_t>(permutation.size());
204 });
205
206 if (!withinRange || hasDuplicates) {
207 assert(false && "Invalid permutation for transpose.");
208 return {};
209 }
210
211 SmallVector<int32_t> laneLayout;
212 SmallVector<int32_t> laneData;
213 SmallVector<int32_t> instData;
214 SmallVector<int32_t> sgLayout;
217
218 for (int64_t idx : permutation) {
219 if (getLaneLayout().size()) {
220 laneLayout.push_back(static_cast<int32_t>(getLaneLayout()[idx]));
221 laneData.push_back(static_cast<int32_t>(getLaneData()[idx]));
222 }
223 if (getInstData().size())
224 instData.push_back(static_cast<int32_t>(getInstData()[idx]));
225 if (getSgData().size()) {
226 sgLayout.push_back(static_cast<int32_t>(getSgLayout()[idx]));
227 sgData.push_back(static_cast<int32_t>(getSgData()[idx]));
228 }
229 if (getOrder().size()) {
230 order.push_back(static_cast<int32_t>(getOrder()[idx]));
231 }
232 }
233 auto orderAttr = order.size()
234 ? DenseI32ArrayAttr::get(storage.getContext(), order)
235 : nullptr;
236 xegpu::LayoutAttr layoutAttr;
237 if (getLaneLayout().size())
238 layoutAttr =
239 xegpu::LayoutAttr::get(storage.getContext(), laneLayout, laneData);
240 if (getInstData().size())
241 layoutAttr = xegpu::LayoutAttr::get(storage.getContext(), instData);
242 if (getSgData().size())
243 layoutAttr = xegpu::LayoutAttr::get(
244 storage.getContext(),
245 DenseI32ArrayAttr::get(storage.getContext(), sgLayout),
246 DenseI32ArrayAttr::get(storage.getContext(), sgData),
247 /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
248 /*lane_data =*/nullptr, orderAttr);
249 return LayoutInfo(layoutAttr);
250}
251
252//===----------------------------------------------------------------------===//
253// LayoutInfoLattice
254//===----------------------------------------------------------------------===//
255
256/// Lattice holding the LayoutInfo for each value.
257struct LayoutInfoLattice : public Lattice<LayoutInfo> {
259 using Lattice::Lattice;
260};
261
262/// Helper Functions to get default layouts. A `default layout` is a layout that
263/// is assigned to a value when the layout is not fixed by some anchor operation
264/// (like DPAS).
265
266/// Helper Function to get the default layout for uniform values like constants.
267/// For 1D vector, lane_layout is [subgroupSize] and lane_data is [1].
268/// For 2D vector, lane_layout is [1, subgroupSize] and lane_data is [1, 1].
269/// For ND vector (N>2), leading dims get unit lane_layout and lane_data.
270static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
271 unsigned rank,
272 const xegpu::uArch::uArch *uArch) {
273 assert(rank >= 1 && "Expected at least 1D vector.");
274 if (rank == 1) {
275 return LayoutInfo(
276 xegpu::LayoutAttr::get(ctx, {uArch->getSubgroupSize()}, {1}));
277 }
278 // For rank >= 2, lane_layout is [1, ..., 1, subgroupSize] and
279 // lane_data is [1, ..., 1, 1].
280 SmallVector<int32_t> laneLayout(rank, 1);
281 SmallVector<int32_t> laneData(rank, 1);
282 laneLayout[rank - 1] = uArch->getSubgroupSize();
283 return LayoutInfo(xegpu::LayoutAttr::get(ctx, laneLayout, laneData));
284}
285
286/// Helper to get the default layout for 2D block operations.
287/// For ND (N>2) types, leading dimensions get unit layout/data values.
288template <typename Ty>
289static LayoutInfo getSIMTLayoutInfoBlockIO(Ty ty,
291 unsigned packingSize) {
292 // Expecting at least 1D.
293 assert(ty.getRank() >= 1 && "Expected at least 1D vector.");
294 // Expecting int or float element type.
295 assert(ty.getElementType().isIntOrFloat() &&
296 "Expected int or float element type.");
297 // If the rank is 1, then return default layout for 1D vector.
298 if (ty.getRank() == 1)
299 return getDefaultSIMTLayoutInfo(ty.getContext(), 1, uArch);
300 // Packing factor is determined by the element type bitwidth.
301 unsigned bitwidth = ty.getElementType().getIntOrFloatBitWidth();
302 int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
303 // For rank >= 2, distribute along the last dimension with leading units.
304 unsigned rank = ty.getRank();
305 SmallVector<int32_t> laneLayout(rank, 1);
306 SmallVector<int32_t> laneData(rank, 1);
307 laneLayout[rank - 1] = uArch->getSubgroupSize();
308 laneData[rank - 1] = packingFactor;
309 return LayoutInfo(
310 xegpu::LayoutAttr::get(ty.getContext(), laneLayout, laneData));
311}
312
313//===----------------------------------------------------------------------===//
314// LayoutInfoPropagation
315//===----------------------------------------------------------------------===//
316
317/// Backward data flow analysis to propagate the lane_layout and lane_data of
318/// each value in the program. Currently, the layouts for operands DPAS,
319/// StoreNd, and StoreScatter are fixed (known before propagation). Purpose of
320/// this analysis is to propagate those known layouts to all their producers and
321/// (other) consumers.
322class LayoutInfoPropagation
323 : public SparseBackwardDataFlowAnalysis<LayoutInfoLattice> {
324public:
326
327private:
328 xegpu::LayoutKind layoutKind;
329 unsigned indexBitWidth;
330 void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
332
333 void visitDpasMxOp(xegpu::DpasMxOp dpasMx,
336
337 void visitStoreNdOp(xegpu::StoreNdOp store,
340
341 void visitStoreScatterOp(xegpu::StoreScatterOp storeScatter,
344
345 void visitLoadNdOp(xegpu::LoadNdOp load,
348
349 void visitLoadGatherOp(xegpu::LoadGatherOp load,
352
353 void visitTransposeOp(vector::TransposeOp transpose,
356
357 void visitVectorBitcastOp(vector::BitCastOp bitcast,
360
361 void visitVectorInterleaveOp(vector::InterleaveOp interleave,
364
365 void visitVectorDeinterleaveOp(vector::DeinterleaveOp deinterleave,
368
369 void visitPrefetchNdOp(xegpu::PrefetchNdOp prefetch,
372
373 void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction,
376
377 void visitVectorReductionOp(vector::ReductionOp reduction,
380
381 void visitVectorBroadCastOp(vector::BroadcastOp broadcast,
384 void visitShapeCastOp(vector::ShapeCastOp shapeCast,
387 void
388 visitInsertStridedSliceOp(vector::InsertStridedSliceOp insertStridedSlice,
391
392 void visitLoadMatrixOp(xegpu::LoadMatrixOp load,
395
396 void visitStoreMatrixOp(xegpu::StoreMatrixOp store,
399
400 void visitLoadGatherOp(xegpu::LoadMatrixOp load,
403
404 void visitStoreScatterOp(xegpu::StoreMatrixOp store,
407
408 void visitConvertLayoutOp(xegpu::ConvertLayoutOp convertLayout,
411
412 bool hasParamsOfLayoutKind(xegpu::DistributeLayoutAttr anchorLayout);
413
414public:
415 LayoutInfoPropagation(DataFlowSolver &solver,
416 SymbolTableCollection &symbolTable,
417 xegpu::LayoutKind layoutKind, unsigned indexBitWidth)
418 : SparseBackwardDataFlowAnalysis(solver, symbolTable),
419 layoutKind(layoutKind), indexBitWidth(indexBitWidth) {}
421
422 LogicalResult
423 visitOperation(Operation *op, ArrayRef<LayoutInfoLattice *> operands,
424 ArrayRef<const LayoutInfoLattice *> results) override;
425
426 void visitBranchOperand(OpOperand &operand) override {};
427
428 void visitCallOperand(OpOperand &operand) override {};
429
430 void
431 visitNonControlFlowArguments(RegionSuccessor &successor,
432 ArrayRef<BlockArgument> arguments) override {};
433
434 void visitExternalCall(CallOpInterface call,
436 ArrayRef<const LayoutInfoLattice *> results) override {
437 };
438
439 void setToExitState(LayoutInfoLattice *lattice) override {
440 (void)lattice->meet(LayoutInfo());
441 }
442};
443} // namespace
444
445LogicalResult LayoutInfoPropagation::visitOperation(
446 Operation *op, ArrayRef<LayoutInfoLattice *> operands,
447 ArrayRef<const LayoutInfoLattice *> results) {
449 .Case(
450 [&](xegpu::DpasOp dpasOp) { visitDpasOp(dpasOp, operands, results); })
451 .Case([&](xegpu::DpasMxOp dpasMxOp) {
452 visitDpasMxOp(dpasMxOp, operands, results);
453 })
454 .Case([&](xegpu::StoreNdOp storeNdOp) {
455 visitStoreNdOp(storeNdOp, operands, results);
456 })
457 .Case([&](xegpu::StoreScatterOp storeScatterOp) {
458 visitStoreScatterOp(storeScatterOp, operands, results);
459 })
460 .Case([&](xegpu::LoadNdOp loadNdOp) {
461 visitLoadNdOp(loadNdOp, operands, results);
462 })
463 .Case([&](xegpu::LoadGatherOp loadGatherOp) {
464 visitLoadGatherOp(loadGatherOp, operands, results);
465 })
466 .Case([&](xegpu::PrefetchNdOp prefetchNdOp) {
467 visitPrefetchNdOp(prefetchNdOp, operands, results);
468 })
469 .Case([&](vector::TransposeOp transposeOp) {
470 visitTransposeOp(transposeOp, operands, results);
471 })
472 .Case([&](vector::BitCastOp bitcastOp) {
473 visitVectorBitcastOp(bitcastOp, operands, results);
474 })
475 .Case([&](vector::InterleaveOp interleaveOp) {
476 visitVectorInterleaveOp(interleaveOp, operands, results);
477 })
478 .Case([&](vector::DeinterleaveOp deinterleaveOp) {
479 visitVectorDeinterleaveOp(deinterleaveOp, operands, results);
480 })
481 .Case([&](vector::MultiDimReductionOp reductionOp) {
482 visitVectorMultiReductionOp(reductionOp, operands, results);
483 })
484 .Case([&](vector::ReductionOp reductionOp) {
485 visitVectorReductionOp(reductionOp, operands, results);
486 })
487 .Case([&](vector::BroadcastOp broadcastOp) {
488 visitVectorBroadCastOp(broadcastOp, operands, results);
489 })
490 .Case([&](vector::ShapeCastOp shapeCastOp) {
491 visitShapeCastOp(shapeCastOp, operands, results);
492 })
493 .Case([&](vector::InsertStridedSliceOp insertStridedSliceOp) {
494 visitInsertStridedSliceOp(insertStridedSliceOp, operands, results);
495 })
496 .Case([&](xegpu::LoadMatrixOp loadMatrixOp) {
497 visitLoadMatrixOp(loadMatrixOp, operands, results);
498 })
499 .Case([&](xegpu::StoreMatrixOp storeMatrixOp) {
500 visitStoreMatrixOp(storeMatrixOp, operands, results);
501 })
502 .Case([&](xegpu::ConvertLayoutOp convertLayoutOp) {
503 visitConvertLayoutOp(convertLayoutOp, operands, results);
504 })
505 // All other ops.
506 .Default([&](Operation *op) {
507 for (const LayoutInfoLattice *resultInfo : results) {
508 if (!resultInfo->getValue().isAssigned())
509 continue;
510 for (auto [operandInfo, operand] :
511 llvm::zip(operands, op->getOpOperands())) {
512 // If the operand type is not a vector or tensor descriptor, skip
513 // it.
514 if (!isa<xegpu::TensorDescType, VectorType>(
515 operand.get().getType()))
516 continue;
517 // Propagate the result layout to the operand.
518 meet(operandInfo, *resultInfo);
519 }
520 }
521 });
522
523 return success();
524}
525
526bool LayoutInfoPropagation::hasParamsOfLayoutKind(
527 xegpu::DistributeLayoutAttr anchorLayout) {
528 if (anchorLayout == nullptr) {
529 return false;
530 }
531 if (layoutKind == xegpu::LayoutKind::InstData) {
532 return !(anchorLayout.getEffectiveInstDataAsInt().empty());
533 }
534 if (layoutKind == xegpu::LayoutKind::Lane) {
535 return !(anchorLayout.getEffectiveLaneLayoutAsInt().empty() ||
536 anchorLayout.getEffectiveLaneDataAsInt().empty());
537 }
538 if (layoutKind == xegpu::LayoutKind::Subgroup) {
539 return !(anchorLayout.getEffectiveSgLayoutAsInt().empty() ||
540 anchorLayout.getEffectiveSgDataAsInt().empty());
541 }
542 return false;
543}
544
545// This function returns all layouts for the given sgCount, whose sgData:
546// 1. Evenly divides the wgShape.
547// 2. Is a multiple of instData.
548// Example:
549// wgShape = [128, 64], instData = [8, 16], sgCount = 32
550// Returns layouts:
551// [(8,4), (16,2)], which correspond to sgData [16,16] and [8,32].
553 ArrayRef<int> instData,
554 int64_t sgCount) {
556 for (int sgLayout0 = 1; sgLayout0 <= sgCount; ++sgLayout0) {
557 if (sgCount % sgLayout0)
558 continue;
559 int sgLayout1 = sgCount / sgLayout0;
560 int sgData0 = wgShape[0] / sgLayout0;
561 int sgData1 = wgShape[1] / sgLayout1;
562 if ((wgShape[0] % sgLayout0 || wgShape[1] % sgLayout1) ||
563 (sgData0 % instData[0] || sgData1 % instData[1]))
564 continue;
565 candidates.emplace_back(sgLayout0, sgLayout1);
566 }
567 // Sort primarily by how balanced they are
568 // (i.e., minimize the absolute difference between the two dimensions), and
569 // secondarily by the first dimension in ascending order.
570 llvm::sort(candidates, [](const std::pair<int, int> &lhs,
571 const std::pair<int, int> &rhs) {
572 int diffLhs = std::abs(lhs.first - lhs.second);
573 int diffRhs = std::abs(rhs.first - rhs.second);
574 if (diffLhs != diffRhs)
575 return diffLhs < diffRhs;
576 return lhs.first < rhs.first;
577 });
578 return candidates;
579}
580
581FailureOr<int64_t> getNumSg(Operation *op, const int sgSize) {
582 // Oblivious to workitem layout, the total count matters.
583 auto gpuFunc = op->getParentOfType<gpu::GPUFuncOp>();
584 if (!gpuFunc)
585 return failure();
586 auto knownBlockSize = gpuFunc.getKnownBlockSize();
587 if (!knownBlockSize.has_value())
588 return failure();
589 const int flatBlockSize = llvm::product_of(knownBlockSize.value());
590 return flatBlockSize / sgSize;
591}
592
593void LayoutInfoPropagation::visitPrefetchNdOp(
594 xegpu::PrefetchNdOp prefetch, ArrayRef<LayoutInfoLattice *> operands,
595 ArrayRef<const LayoutInfoLattice *> results) {
596
597 LayoutInfo prefetchLayout;
598 xegpu::DistributeLayoutAttr anchorLayout = prefetch.getLayoutAttr();
599 if (hasParamsOfLayoutKind(anchorLayout)) {
600 prefetchLayout = LayoutInfo(anchorLayout);
601 } else {
602 // Here we assign the default layout to the tensor descriptor operand of
603 // prefetch.
604 auto tdescTy = prefetch.getTensorDescType();
605
606 const uArch *uArch = getUArch(getChipStr(prefetch).value_or(""));
607 if (!uArch)
608 return;
609 const auto *uArchInstruction =
610 dyn_cast<xegpu::uArch::Subgroup2DBlockPrefetchInstruction>(
611 uArch->getInstruction(
612 xegpu::uArch::InstructionKind::Subgroup2DBlockPrefetch));
613
614 auto blockWHC =
615 uArchInstruction->getBlockWidthHeightCount(tdescTy.getElementType());
616 if (!blockWHC)
617 prefetch.emitWarning("No known block params found for the element type.");
618 auto [bWidth, bHeight, bCount] = blockWHC.value();
619 SmallVector<int> instData;
620 int instWidth = xegpu::getLargestDivisor(
621 static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 1)), bWidth);
622 if (instWidth == -1)
623 prefetch.emitWarning(
624 "No suitable instruction multiple found for the given shape.");
625 if (tdescTy.getRank() == 1)
626 instData = {instWidth};
627 else {
628 int instHeight = xegpu::getLargestDivisor(
629 static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 2)), bHeight);
630 if (instHeight == -1)
631 prefetch.emitWarning(
632 "No suitable instruction multiple found for the given shape.");
633 instData = {instHeight, instWidth};
634 }
635
636 if (layoutKind == xegpu::LayoutKind::InstData)
637 prefetchLayout =
638 LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(), instData));
639 else
640 prefetchLayout = getSIMTLayoutInfoBlockIO(
641 tdescTy, uArch, uArchInstruction->getPackedFormatBitSize());
642
643 prefetch.setLayoutAttr(
644 dyn_cast<xegpu::DistributeLayoutAttr>(prefetchLayout.get()));
645 }
646 // Propagate the layout to the source tensor descriptor.
647 propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout));
648}
649
650void LayoutInfoPropagation::visitVectorMultiReductionOp(
651 vector::MultiDimReductionOp reduction,
652 ArrayRef<LayoutInfoLattice *> operands,
653 ArrayRef<const LayoutInfoLattice *> results) {
654 Type resultTy = reduction.getDestType();
655 // The layout of the result must be present.
656 LayoutInfo resLayoutInfo = results[0]->getValue();
657
658 xegpu::DistributeLayoutAttr consumerLayoutAttr;
659 if (!resultTy.isIntOrFloat()) {
660 if (!resLayoutInfo.isAssigned())
661 return;
662 consumerLayoutAttr =
663 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
664 }
665
666 VectorType sourceTy = reduction.getSourceVectorType();
667 SmallVector<int64_t> reductionDims(reduction.getReductionDims());
668
669 const uArch *uArch = getUArch(xegpu::getChipStr(reduction).value_or(""));
670 if (!uArch)
671 return;
672 int numSg = 0;
673 if (layoutKind == xegpu::LayoutKind::Subgroup) {
674 auto numSgOrErr = getNumSg(reduction, uArch->getSubgroupSize());
675 if (succeeded(numSgOrErr))
676 numSg = numSgOrErr.value();
677 }
678
679 // The result layout represents the layout requirements of the operation.
680 // it is recorded to anchor layout or temporary layout.
681 // it must be honored for current op and may conflict with the layout
682 // propagated from consumer op, the conflict is resolved in later phase by
683 // converting the required result layout to the consumer layout
684 auto requiredResLayoutAttr = xegpu::setupMultiReductionResultLayout(
685 layoutKind, sourceTy, consumerLayoutAttr, reductionDims, numSg, uArch);
686
687 xegpu::setTemporaryLayout(reduction->getResult(0), requiredResLayoutAttr);
688
689 // derive the source layout from the dominant layout and reduction dims
690 auto srcLayoutAttr = xegpu::inferMultiReductionSourceLayout(
691 requiredResLayoutAttr, reductionDims);
692
693 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
694 // Accumulator should have the same layout as the result.
695 propagateIfChanged(operands[1],
696 operands[1]->meet(LayoutInfo(requiredResLayoutAttr)));
697}
698
699void LayoutInfoPropagation::visitVectorReductionOp(
700 vector::ReductionOp reduction, ArrayRef<LayoutInfoLattice *> operands,
701 ArrayRef<const LayoutInfoLattice *> results) {
702
703 VectorType sourceTy = reduction.getSourceVectorType();
704 const uArch *uArch = getUArch(xegpu::getChipStr(reduction).value_or(""));
705 if (!uArch)
706 return;
707
708 auto requiredResLayoutAttr =
709 xegpu::setupReductionResultLayout(layoutKind, sourceTy, uArch);
710 xegpu::setTemporaryLayout(reduction->getResult(0), requiredResLayoutAttr);
711
712 auto srcLayoutAttr = xegpu::inferReductionSourceLayout(requiredResLayoutAttr);
713 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
714 if (reduction.getAcc())
715 propagateIfChanged(operands[1],
716 operands[1]->meet(LayoutInfo(requiredResLayoutAttr)));
717}
718
719void LayoutInfoPropagation::visitVectorBroadCastOp(
720 vector::BroadcastOp broadcast, ArrayRef<LayoutInfoLattice *> operands,
721 ArrayRef<const LayoutInfoLattice *> results) {
722 // The layout of the result must be present.
723 LayoutInfo resLayoutInfo = results[0]->getValue();
724 if (!resLayoutInfo.isAssigned())
725 return;
726
727 // Only consider vector to vector broadcasts for now.
728 VectorType resultTy = broadcast.getResultVectorType();
729 VectorType sourceTy = dyn_cast<VectorType>(broadcast.getSourceType());
730 // skip layout propagation for non-vector source operand.
731 if (!sourceTy)
732 return;
733
734 auto srcShape = sourceTy.getShape();
735 auto resShape = resultTy.getShape();
736
737 auto resultLayoutAttr =
738 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
739
740 xegpu::DistributeLayoutAttr srcLayoutAttr =
741 xegpu::inferBroadcastSourceLayout(resultLayoutAttr, resShape, srcShape);
742
743 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
744}
745
746void LayoutInfoPropagation::visitShapeCastOp(
747 vector::ShapeCastOp shapeCast, ArrayRef<LayoutInfoLattice *> operands,
748 ArrayRef<const LayoutInfoLattice *> results) {
749 // The layout of the result must be present.
750 LayoutInfo resLayoutInfo = results[0]->getValue();
751 if (!resLayoutInfo.isAssigned())
752 return;
753 ArrayRef<int64_t> resShape = shapeCast.getResultVectorType().getShape();
754 ArrayRef<int64_t> srcShape = shapeCast.getSourceVectorType().getShape();
755 auto resultLayoutAttr =
756 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
757
758 xegpu::DistributeLayoutAttr srcLayoutAttr =
759 xegpu::inferShapeCastSourceLayout(resultLayoutAttr, resShape, srcShape);
760
761 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
762}
763
764/// Set the layouts for DPAS A, B, and C operands.
765void LayoutInfoPropagation::visitDpasOp(
766 xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
767 ArrayRef<const LayoutInfoLattice *> results) {
768 LayoutInfo dpasALayout;
769 LayoutInfo dpasBLayout;
770 LayoutInfo dpasCDLayout;
771
772 xegpu::DistributeLayoutAttr anchorLayoutCD = dpas.getLayoutCdAttr();
773 if (hasParamsOfLayoutKind(anchorLayoutCD)) {
774 xegpu::DistributeLayoutAttr anchorLayoutA = dpas.getLayoutAAttr();
775 xegpu::DistributeLayoutAttr anchorLayoutB = dpas.getLayoutBAttr();
776 assert(hasParamsOfLayoutKind(anchorLayoutA) &&
777 "Expected anchor layout for DPAS A operand.");
778 assert(hasParamsOfLayoutKind(anchorLayoutB) &&
779 "Expected anchor layout for DPAS B operand.");
780 dpasALayout = LayoutInfo(anchorLayoutA);
781 dpasBLayout = LayoutInfo(anchorLayoutB);
782 dpasCDLayout = LayoutInfo(anchorLayoutCD);
783 } else {
784 const uArch *uArch = getUArch(getChipStr(dpas).value_or(""));
785 if (!uArch)
786 return;
787 VectorType aTy = dpas.getLhsType();
788 VectorType bTy = dpas.getRhsType();
789 VectorType cdTy = dpas.getResultType();
790
791 xegpu::DistributeLayoutAttr consumerLayoutAttr = nullptr;
792 xegpu::DistributeLayoutAttr requiredCDLayoutAttr, requiredALayout,
793 requiredBLayout;
794
795 int numSg = 0;
796 if (layoutKind == xegpu::LayoutKind::Subgroup) {
797 LayoutInfo consumerLayout = results[0]->getValue();
798 if (!consumerLayout.isAssigned())
799 return;
800 consumerLayoutAttr =
801 dyn_cast<xegpu::DistributeLayoutAttr>(consumerLayout.get());
802 auto numSgOrErr = getNumSg(dpas, uArch->getSubgroupSize());
803 if (failed(numSgOrErr)) {
804 dpas.emitWarning(
805 "Unable to determine the number of subgroups for the operation.");
806 return;
807 }
808 numSg = numSgOrErr.value();
809 }
810 auto layouts = xegpu::setupDpasLayout(layoutKind, aTy, bTy, cdTy,
811 consumerLayoutAttr, numSg, uArch);
812 if (!layouts.has_value()) {
813 dpas.emitWarning(
814 "Failed to determine required layouts for DPAS operands.");
815 return;
816 }
817
818 std::tie(requiredALayout, requiredBLayout, requiredCDLayoutAttr) = *layouts;
819
820 dpas.setLayoutAAttr(requiredALayout);
821 dpas.setLayoutBAttr(requiredBLayout);
822 dpas.setLayoutCdAttr(requiredCDLayoutAttr);
823 dpasALayout = LayoutInfo(requiredALayout);
824 dpasBLayout = LayoutInfo(requiredBLayout);
825 dpasCDLayout = LayoutInfo(requiredCDLayoutAttr);
826 }
827 propagateIfChanged(operands[0], operands[0]->meet(dpasALayout));
828 propagateIfChanged(operands[1], operands[1]->meet(dpasBLayout));
829 if (operands.size() > 2)
830 propagateIfChanged(operands[2], operands[2]->meet(dpasCDLayout));
831}
832
833/// Propagate layout for DpasMxOp operands using the layout attributes.
834/// DpasMxOp has operands: a, b, acc (optional), scale_a (optional), scale_b
835/// (optional)
836void LayoutInfoPropagation::visitDpasMxOp(
837 xegpu::DpasMxOp dpasMx, ArrayRef<LayoutInfoLattice *> operands,
838 ArrayRef<const LayoutInfoLattice *> results) {
839
840 // Initialize layout variables
841 LayoutInfo dpasMxALayout, dpasMxBLayout, dpasMxCDLayout;
842 LayoutInfo dpasMxAScaleLayout, dpasMxBScaleLayout;
843
844 // Get existing layout attributes from the operation
845 xegpu::DistributeLayoutAttr anchorLayoutA = dpasMx.getLayoutAAttr();
846 xegpu::DistributeLayoutAttr anchorLayoutB = dpasMx.getLayoutBAttr();
847 xegpu::DistributeLayoutAttr anchorLayoutCD = dpasMx.getLayoutCdAttr();
848
849 // Check if all layouts are already set
850 if (anchorLayoutA && anchorLayoutB && anchorLayoutCD &&
851 hasParamsOfLayoutKind(anchorLayoutA) &&
852 hasParamsOfLayoutKind(anchorLayoutB) &&
853 hasParamsOfLayoutKind(anchorLayoutCD)) {
854 dpasMxALayout = LayoutInfo(anchorLayoutA);
855 dpasMxBLayout = LayoutInfo(anchorLayoutB);
856 dpasMxCDLayout = LayoutInfo(anchorLayoutCD);
857
858 // Get scale layouts if available
859 xegpu::DistributeLayoutAttr anchorLayoutAScale =
860 dpasMx.getLayoutAScaleAttr();
861 xegpu::DistributeLayoutAttr anchorLayoutBScale =
862 dpasMx.getLayoutBScaleAttr();
863 if (anchorLayoutAScale)
864 dpasMxAScaleLayout = LayoutInfo(anchorLayoutAScale);
865 if (anchorLayoutBScale)
866 dpasMxBScaleLayout = LayoutInfo(anchorLayoutBScale);
867 } else {
868 // Need to compute layouts
869 const uArch *uArch = getUArch(getChipStr(dpasMx).value_or(""));
870 if (!uArch)
871 return;
872
873 VectorType aTy = dpasMx.getAType();
874 VectorType bTy = dpasMx.getBType();
875 VectorType cdTy = dpasMx.getResultType();
876
877 // Get scale types if present
878 VectorType aScaleTy;
879 VectorType bScaleTy;
880 Value scaleA = dpasMx.getScaleA();
881 Value scaleB = dpasMx.getScaleB();
882 if (scaleA)
883 aScaleTy = dyn_cast<VectorType>(scaleA.getType());
884 if (scaleB)
885 bScaleTy = dyn_cast<VectorType>(scaleB.getType());
886
887 xegpu::DistributeLayoutAttr consumerLayoutAttr = nullptr;
888 xegpu::DistributeLayoutAttr requiredCDLayoutAttr, requiredALayout,
889 requiredBLayout, requiredAScaleLayout, requiredBScaleLayout;
890
891 int numSg = 0;
892 if (layoutKind == xegpu::LayoutKind::Subgroup) {
893 LayoutInfo consumerLayout = results[0]->getValue();
894 if (!consumerLayout.isAssigned())
895 return;
896 consumerLayoutAttr =
897 dyn_cast<xegpu::DistributeLayoutAttr>(consumerLayout.get());
898 auto numSgOrErr = getNumSg(dpasMx, uArch->getSubgroupSize());
899 if (failed(numSgOrErr)) {
900 dpasMx.emitWarning(
901 "Unable to determine the number of subgroups for the operation.");
902 return;
903 }
904 numSg = numSgOrErr.value();
905 }
906
907 auto layouts =
908 xegpu::setupDpasMxLayout(layoutKind, aTy, bTy, cdTy, aScaleTy, bScaleTy,
909 consumerLayoutAttr, numSg, uArch);
910 if (!layouts.has_value()) {
911 dpasMx.emitWarning(
912 "Failed to determine required layouts for DPAS_MX operands.");
913 return;
914 }
915
916 std::tie(requiredALayout, requiredBLayout, requiredCDLayoutAttr,
917 requiredAScaleLayout, requiredBScaleLayout) = *layouts;
918
919 dpasMx.setLayoutAAttr(requiredALayout);
920 dpasMx.setLayoutBAttr(requiredBLayout);
921 dpasMx.setLayoutCdAttr(requiredCDLayoutAttr);
922 if (requiredAScaleLayout)
923 dpasMx.setLayoutAScaleAttr(requiredAScaleLayout);
924 if (requiredBScaleLayout)
925 dpasMx.setLayoutBScaleAttr(requiredBScaleLayout);
926
927 dpasMxALayout = LayoutInfo(requiredALayout);
928 dpasMxBLayout = LayoutInfo(requiredBLayout);
929 dpasMxCDLayout = LayoutInfo(requiredCDLayoutAttr);
930 if (requiredAScaleLayout)
931 dpasMxAScaleLayout = LayoutInfo(requiredAScaleLayout);
932 if (requiredBScaleLayout)
933 dpasMxBScaleLayout = LayoutInfo(requiredBScaleLayout);
934 }
935
936 // Propagate layouts to operands. Because acc, scale_a, scale_b are all
937 // optional (AttrSizedOperandSegments), the index of each present operand in
938 // `operands` depends on which optionals are actually supplied. Use the
939 // op's accessors to determine the correct positional index.
940 propagateIfChanged(operands[0], operands[0]->meet(dpasMxALayout));
941 propagateIfChanged(operands[1], operands[1]->meet(dpasMxBLayout));
942 unsigned idx = 2;
943 if (dpasMx.getAcc()) {
944 propagateIfChanged(operands[idx], operands[idx]->meet(dpasMxCDLayout));
945 ++idx;
946 }
947 if (dpasMx.getScaleA()) {
948 if (dpasMxAScaleLayout.isAssigned())
949 propagateIfChanged(operands[idx],
950 operands[idx]->meet(dpasMxAScaleLayout));
951 ++idx;
952 }
953 if (dpasMx.getScaleB()) {
954 if (dpasMxBScaleLayout.isAssigned())
955 propagateIfChanged(operands[idx],
956 operands[idx]->meet(dpasMxBScaleLayout));
957 ++idx;
958 }
959}
960
961/// Set the layout for the value and tensor descriptor operands in StoreNdOp.
962void LayoutInfoPropagation::visitStoreNdOp(
963 xegpu::StoreNdOp store, ArrayRef<LayoutInfoLattice *> operands,
964 ArrayRef<const LayoutInfoLattice *> results) {
965 LayoutInfo storeLayout;
966 xegpu::DistributeLayoutAttr anchorLayout = store.getLayoutAttr();
967 if (hasParamsOfLayoutKind(anchorLayout)) {
968 storeLayout = LayoutInfo(anchorLayout);
969 } else {
970 const uArch *uArch = getUArch(getChipStr(store).value_or(""));
971 if (!uArch)
972 return;
973 const auto *uArchInstruction =
974 dyn_cast<xegpu::uArch::Subgroup2DBlockStoreInstruction>(
975 uArch->getInstruction(
976 xegpu::uArch::InstructionKind::Subgroup2DBlockStore));
977 VectorType dataTy = store.getValueType();
978 auto blockWHC = uArchInstruction->getBlockWidthHeightCount(
979 store.getValueType().getElementType());
980 if (!blockWHC)
981 store.emitWarning("No known block params found for the element type.");
982 auto [bWidth, bHeight, bCount] = blockWHC.value();
983 // Default to 1 for any leading batch dims; rank-1 and rank>=2 cases
984 // overwrite the trailing entries below.
985 SmallVector<int> instData(dataTy.getRank(), 1);
986 int instWidth = xegpu::getLargestDivisor(
987 static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 1)), bWidth);
988 if (instWidth == -1)
989 store.emitWarning(
990 "No suitable instruction multiple found for the given shape.");
991 if (dataTy.getRank() == 1) {
992 instData = {instWidth};
993 } else {
994 int instHeight = xegpu::getLargestDivisor(
995 static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 2)), bHeight);
996 if (instHeight == -1)
997 store.emitWarning(
998 "No suitable instruction multiple found for the given shape.");
999 instData[dataTy.getRank() - 2] = instHeight;
1000 instData[dataTy.getRank() - 1] = instWidth;
1001 }
1002
1003 if (layoutKind == xegpu::LayoutKind::InstData)
1004 storeLayout =
1005 LayoutInfo(xegpu::LayoutAttr::get(dataTy.getContext(), instData));
1006 else if (layoutKind == xegpu::LayoutKind::Lane)
1007 storeLayout =
1008 getSIMTLayoutInfoBlockIO(store.getValueType(), uArch,
1009 uArchInstruction->getPackedFormatBitSize());
1010 else { // xegpu::LayoutKind::Subgroup
1011 auto sgSize = uArch->getSubgroupSize();
1012 auto numSgOrErr = getNumSg(store, sgSize);
1013 if (failed(numSgOrErr)) {
1014 store.emitWarning(
1015 "Unable to determine the number of subgroups for the operation.");
1016 return;
1017 }
1018 auto sgLayouts = getValidLayouts(store.getValueType().getShape(),
1019 instData, numSgOrErr.value());
1020 if (sgLayouts.empty()) {
1021 store.emitWarning(
1022 "Unable to determine suitable subgroup layout for store value.");
1023 return;
1024 }
1025 SmallVector<int> sgLayout = {sgLayouts[0].first, sgLayouts[0].second};
1026 SmallVector<int> sgData = {
1027 static_cast<int>(dataTy.getShape()[0]) / sgLayout[0],
1028 static_cast<int>(dataTy.getShape()[1]) / sgLayout[1]};
1029 storeLayout = LayoutInfo(xegpu::LayoutAttr::get(
1030 dataTy.getContext(),
1031 DenseI32ArrayAttr::get(dataTy.getContext(), sgLayout),
1032 DenseI32ArrayAttr::get(dataTy.getContext(), sgData),
1033 /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
1034 /*lane_data =*/nullptr, /*order =*/nullptr));
1035 }
1036 store.setLayoutAttr(
1037 dyn_cast<xegpu::DistributeLayoutAttr>(storeLayout.get()));
1038 }
1039 // Propagate the layout to the value operand.
1040 // Both operands should have the same layout
1041 for (LayoutInfoLattice *operand : operands)
1042 propagateIfChanged(operand, operand->meet(storeLayout));
1043}
1044
1045/// Propagate the layout of the value to the tensor descriptor operand in
1046/// LoadNdOp.
1047void LayoutInfoPropagation::visitLoadNdOp(
1048 xegpu::LoadNdOp load, ArrayRef<LayoutInfoLattice *> operands,
1049 ArrayRef<const LayoutInfoLattice *> results) {
1050 LayoutInfo loadLayout;
1051 xegpu::DistributeLayoutAttr anchorLayout = load.getLayoutAttr();
1052 if (hasParamsOfLayoutKind(anchorLayout)) {
1053 loadLayout = LayoutInfo(anchorLayout);
1054 } else {
1055
1056 LayoutInfo valueLayout = results[0]->getValue();
1057 // Need the layout of the value to propagate to the tensor descriptor.
1058 if (!valueLayout.isAssigned())
1059 return;
1060 loadLayout = valueLayout;
1061 // LoadNdOp has the transpose effect. However, at the stage of this analysis
1062 // this effect is not expected and should be abstracted away. Emit a
1063 // warning.
1064 if (auto transpose = load.getTranspose()) {
1065 load.emitWarning("Transpose effect is not expected for LoadNdOp at "
1066 "LayoutInfoPropagation stage.");
1067 loadLayout = valueLayout.transpose(transpose.value());
1068 }
1069 load.setLayoutAttr(dyn_cast<xegpu::DistributeLayoutAttr>(loadLayout.get()));
1070 }
1071 // Propagate the new layout to the tensor descriptor operand.
1072 propagateIfChanged(operands[0], operands[0]->meet(loadLayout));
1073}
1074
1075/// Propagate the layout of the value to the tensor descriptor operand in
1076/// ConvertLayoutOp.
1077void LayoutInfoPropagation::visitConvertLayoutOp(
1078 xegpu::ConvertLayoutOp convert, ArrayRef<LayoutInfoLattice *> operands,
1079 ArrayRef<const LayoutInfoLattice *> results) {
1080 xegpu::DistributeLayoutAttr anchorLayout = convert.getInputLayoutAttr();
1081 LayoutInfo convertLayout(anchorLayout);
1082 // Propagate the new layout to the tensor descriptor operand.
1083 propagateIfChanged(operands[0], operands[0]->meet(convertLayout));
1084}
1085
1086/// For vector::TransposeOp, the layout of the result is transposed and
1087/// propagated to the operand.
1088void LayoutInfoPropagation::visitTransposeOp(
1089 vector::TransposeOp transpose, ArrayRef<LayoutInfoLattice *> operands,
1090 ArrayRef<const LayoutInfoLattice *> results) {
1091 // Need the layout of transpose result to propagate to the operands.
1092 LayoutInfo resultLayout = results[0]->getValue();
1093 if (!resultLayout.isAssigned())
1094 return;
1095
1096 auto consumerLayoutAttr =
1097 dyn_cast<xegpu::DistributeLayoutAttr>(resultLayout.get());
1098 auto srcLayoutAttr = xegpu::inferTransposeSourceLayout(
1099 consumerLayoutAttr, transpose.getPermutation());
1100
1101 // Propagate the new layout to the vector operand.
1102 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
1103}
1104
1105/// For vector::BitCastOp, the lane_data of the source layout is changed based
1106/// on the bit width of the source and result types.
1107void LayoutInfoPropagation::visitVectorBitcastOp(
1108 vector::BitCastOp bitcast, ArrayRef<LayoutInfoLattice *> operands,
1109 ArrayRef<const LayoutInfoLattice *> results) {
1110 // Need the layout of bitcast result to propagate to the operands.
1111 LayoutInfo resLayoutInfo = results[0]->getValue();
1112 if (!resLayoutInfo.isAssigned())
1113 return;
1114
1115 auto srcVecType = bitcast.getSourceVectorType();
1116 auto resVecType = bitcast.getResultVectorType();
1117
1118 auto consumerLayoutAttr =
1119 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
1120 const uArch *uArch = getUArch(xegpu::getChipStr(bitcast).value_or(""));
1121 if (!uArch)
1122 return;
1123 auto requiredResLayoutAttr = setupBitCastResultLayout(
1124 layoutKind, srcVecType, resVecType, consumerLayoutAttr, uArch);
1125
1126 xegpu::setTemporaryLayout(bitcast->getResult(0), requiredResLayoutAttr);
1127
1128 int inElemTyBitWidth = srcVecType.getElementType().getIntOrFloatBitWidth();
1129 int outElemTyBitWidth = resVecType.getElementType().getIntOrFloatBitWidth();
1130
1131 // derive the source layout from the dominant layout and reduction dims
1132 auto srcLayoutAttr = xegpu::inferBitCastSourceLayout(
1133 requiredResLayoutAttr, outElemTyBitWidth, inElemTyBitWidth);
1134
1135 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
1136}
1137
1138/// For vector::InterleaveOp, the result has double the innermost dimension size
1139/// compared to each source operand. The layout is propagated from result to
1140/// sources, adjusting for the 2x size increase.
1141void LayoutInfoPropagation::visitVectorInterleaveOp(
1142 vector::InterleaveOp interleave, ArrayRef<LayoutInfoLattice *> operands,
1143 ArrayRef<const LayoutInfoLattice *> results) {
1144 // Need the layout of interleave result to propagate to the operands.
1145 LayoutInfo resLayoutInfo = results[0]->getValue();
1146 if (!resLayoutInfo.isAssigned())
1147 return;
1148
1149 auto srcVecType = interleave.getSourceVectorType();
1150 auto resVecType = interleave.getResultVectorType();
1151
1152 auto consumerLayoutAttr =
1153 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
1154 const uArch *uArch = getUArch(xegpu::getChipStr(interleave).value_or(""));
1155 if (!uArch)
1156 return;
1157
1158 // Setup the result layout to ensure the source layout can be safely derived
1159 auto requiredResLayoutAttr = setupInterleaveResultLayout(
1160 layoutKind, srcVecType, resVecType, consumerLayoutAttr, uArch);
1161
1162 xegpu::setTemporaryLayout(interleave->getResult(0), requiredResLayoutAttr);
1163
1164 // Derive the source layout from the result layout (halve the innermost dim)
1165 auto srcLayoutAttr =
1166 xegpu::inferInterleaveSourceLayout(requiredResLayoutAttr);
1167
1168 // Both operands (lhs and rhs) get the same source layout
1169 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
1170 propagateIfChanged(operands[1], operands[1]->meet(LayoutInfo(srcLayoutAttr)));
1171}
1172
1173/// For vector::DeinterleaveOp, the source has double the innermost dimension
1174/// size compared to each result. The layout is propagated from results to
1175/// source, adjusting for the 2x size decrease in results.
1176void LayoutInfoPropagation::visitVectorDeinterleaveOp(
1177 vector::DeinterleaveOp deinterleave, ArrayRef<LayoutInfoLattice *> operands,
1178 ArrayRef<const LayoutInfoLattice *> results) {
1179 // Need the layout of deinterleave results to propagate to the operand.
1180 // Use the first result's layout (both results should have the same layout)
1181 LayoutInfo resLayoutInfo = results[0]->getValue();
1182 if (!resLayoutInfo.isAssigned())
1183 return;
1184
1185 auto consumerLayoutAttr =
1186 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
1187
1188 // Derive the source layout from the result layout (double the innermost dim)
1189 // No setup function needed - just infer directly
1190 auto srcLayoutAttr = xegpu::inferDeinterleaveSourceLayout(consumerLayoutAttr);
1191
1192 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
1193}
1194
1195void LayoutInfoPropagation::visitInsertStridedSliceOp(
1196 vector::InsertStridedSliceOp insertStridedSlice,
1197 ArrayRef<LayoutInfoLattice *> operands,
1198 ArrayRef<const LayoutInfoLattice *> results) {
1199 // The layout of the result must be present.
1200 LayoutInfo resLayoutInfo = results[0]->getValue();
1201 if (!resLayoutInfo.isAssigned())
1202 return;
1203
1204 auto srcVecType = insertStridedSlice.getSourceVectorType();
1205 auto resVecType = insertStridedSlice.getDestVectorType();
1206
1207 auto consumerLayoutAttr =
1208 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
1209 const uArch *uArch =
1210 getUArch(xegpu::getChipStr(insertStridedSlice).value_or(""));
1211 if (!uArch)
1212 return;
1213
1214 auto requiredResLayoutAttr = xegpu::setupInsertStridedSliceResultLayout(
1215 layoutKind, srcVecType, resVecType, consumerLayoutAttr, uArch);
1216 xegpu::setTemporaryLayout(insertStridedSlice->getResult(0),
1217 requiredResLayoutAttr);
1218
1219 auto srcLayoutAttr = xegpu::inferInsertStridedSliceSourceLayout(
1220 requiredResLayoutAttr, resVecType.getShape(), srcVecType.getShape());
1221 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
1222 propagateIfChanged(operands[1],
1223 operands[1]->meet(LayoutInfo(requiredResLayoutAttr)));
1224}
1225
1226/// Propagate the layout of the result to the tensor descriptor, mask and offset
1227/// operands in LoadGatherOp.
1228void LayoutInfoPropagation::visitLoadGatherOp(
1229 xegpu::LoadGatherOp load, ArrayRef<LayoutInfoLattice *> operands,
1230 ArrayRef<const LayoutInfoLattice *> results) {
1231 xegpu::DistributeLayoutAttr requiredAnchorLayoutAttr;
1232 xegpu::DistributeLayoutAttr anchorLayoutAttr = load.getLayoutAttr();
1233 const uArch *uArch = getUArch(getChipStr(load).value_or(""));
1234 if (!uArch)
1235 return;
1236 VectorType resVecTy = load.getValueType();
1237 int chunkSize = load.getChunkSize().value_or(1);
1238
1239 LayoutInfo resLayoutInfo = results[0]->getValue();
1240 if (!resLayoutInfo.isAssigned())
1241 return;
1242 auto consumerLayoutAttr =
1243 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
1244
1245 if (hasParamsOfLayoutKind(anchorLayoutAttr)) {
1246 requiredAnchorLayoutAttr = anchorLayoutAttr;
1247 } else {
1248 if (!resVecTy) {
1249 load.emitWarning("Not propagating, non-vector payload supplied.");
1250 return;
1251 }
1252 requiredAnchorLayoutAttr = xegpu::setupLoadGatherAnchorLayout(
1253 layoutKind, resVecTy, chunkSize, consumerLayoutAttr, uArch);
1254 load.setLayoutAttr(requiredAnchorLayoutAttr);
1255 }
1256
1257 assert((chunkSize <= 1) || (layoutKind != xegpu::LayoutKind::Subgroup));
1258 auto maskLayoutAttr = xegpu::inferMaskOffsetLayoutForScatterIO(
1259 requiredAnchorLayoutAttr, chunkSize);
1260 LayoutInfo maskLayoutInfo = LayoutInfo(maskLayoutAttr);
1261 auto loadLayoutInfo = LayoutInfo(requiredAnchorLayoutAttr);
1262
1263 // Propagate the new layout to the tensor descriptor operand.
1264 if (isa<xegpu::TensorDescType>(load.getSourceType()))
1265 propagateIfChanged(operands[0], operands[0]->meet(loadLayoutInfo));
1266 // Propagate the new layout to the offset and mask operands.
1267 propagateIfChanged(operands[1], operands[1]->meet(maskLayoutInfo));
1268 propagateIfChanged(operands[2], operands[2]->meet(maskLayoutInfo));
1269}
1270
1271/// Set the layout for the value, tensor descriptor, offset and mask operands in
1272/// the StoreScatterOp.
1273void LayoutInfoPropagation::visitStoreScatterOp(
1274 xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands,
1275 ArrayRef<const LayoutInfoLattice *> results) {
1276
1277 xegpu::DistributeLayoutAttr requiredAnchorLayoutAttr;
1278 xegpu::DistributeLayoutAttr anchorLayoutAttr = storeScatter.getLayoutAttr();
1279 const uArch *uArch = getUArch(getChipStr(storeScatter).value_or(""));
1280 if (!uArch)
1281 return;
1282 VectorType srcVecTy = storeScatter.getValueType();
1283 int chunkSize = storeScatter.getChunkSize().value_or(1);
1284
1285 if (hasParamsOfLayoutKind(anchorLayoutAttr)) {
1286 requiredAnchorLayoutAttr = anchorLayoutAttr;
1287 } else {
1288 if (!srcVecTy) {
1289 storeScatter.emitWarning("Not propagating, non-vector payload supplied.");
1290 return;
1291 }
1292 requiredAnchorLayoutAttr = xegpu::setupStoreScatterAnchorLayout(
1293 layoutKind, srcVecTy, chunkSize, uArch);
1294 storeScatter.setLayoutAttr(requiredAnchorLayoutAttr);
1295 }
1296
1297 LayoutInfo srcLayoutInfo = LayoutInfo(requiredAnchorLayoutAttr);
1298 assert((chunkSize <= 1) || (layoutKind != xegpu::LayoutKind::Subgroup));
1299 auto maskLayoutAttr = xegpu::inferMaskOffsetLayoutForScatterIO(
1300 requiredAnchorLayoutAttr, chunkSize);
1301 LayoutInfo maskLayoutInfo = LayoutInfo(maskLayoutAttr);
1302
1303 // Propagate the payload operand layout
1304 propagateIfChanged(operands[0], operands[0]->meet(srcLayoutInfo));
1305 // Propagate the destination (if tdesc) operand layout
1306 if (isa<xegpu::TensorDescType>(storeScatter.getDestType()))
1307 propagateIfChanged(operands[1], operands[1]->meet(srcLayoutInfo));
1308 // Propagate the new layout to the offset and mask operands.
1309 propagateIfChanged(operands[2], operands[2]->meet(maskLayoutInfo));
1310 propagateIfChanged(operands[3], operands[3]->meet(maskLayoutInfo));
1311}
1312
1313void LayoutInfoPropagation::visitLoadMatrixOp(
1314 xegpu::LoadMatrixOp loadMatrixOp, ArrayRef<LayoutInfoLattice *> operands,
1315 ArrayRef<const LayoutInfoLattice *> results) {
1316
1317 LayoutInfo resLayoutInfo = results[0]->getValue();
1318 if (!resLayoutInfo.isAssigned())
1319 return;
1320
1321 auto consumerLayoutAttr =
1322 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
1323
1324 xegpu::DistributeLayoutAttr anchorLayout = loadMatrixOp.getLayoutAttr();
1325
1326 // only need to set anchor layout, no need to porpagate to memdesc and
1327 // offset
1328 if (!hasParamsOfLayoutKind(anchorLayout)) {
1329 VectorType resVecTy =
1330 llvm::cast<VectorType>(loadMatrixOp.getRes().getType());
1331 const uArch *uArch = getUArch(getChipStr(loadMatrixOp).value_or(""));
1332 if (!uArch)
1333 return;
1334 auto requiredAnchorLayoutAttr = xegpu::setupLoadMatrixAnchorLayout(
1335 layoutKind, resVecTy, consumerLayoutAttr, uArch);
1336 loadMatrixOp.setLayoutAttr(requiredAnchorLayoutAttr);
1337 }
1338}
1339
1340void LayoutInfoPropagation::visitStoreMatrixOp(
1341 xegpu::StoreMatrixOp storeMatrix, ArrayRef<LayoutInfoLattice *> operands,
1342 ArrayRef<const LayoutInfoLattice *> results) {
1343 xegpu::DistributeLayoutAttr anchorLayout = storeMatrix.getLayoutAttr();
1344 LayoutInfo layout;
1345 if (hasParamsOfLayoutKind(anchorLayout)) {
1346 layout = LayoutInfo(anchorLayout);
1347 } else {
1348 VectorType srcVecTy =
1349 llvm::cast<VectorType>(storeMatrix.getData().getType());
1350 const uArch *uArch = getUArch(getChipStr(storeMatrix).value_or(""));
1351 if (!uArch)
1352 return;
1353 auto requiredAnchorLayoutAttr =
1354 xegpu::setupStoreMatrixAnchorLayout(layoutKind, srcVecTy, uArch);
1355 storeMatrix.setLayoutAttr(requiredAnchorLayoutAttr);
1356 layout = LayoutInfo(requiredAnchorLayoutAttr);
1357 }
1358
1359 propagateIfChanged(operands[0], operands[0]->meet(layout));
1360}
1361
1362namespace {
1363//===----------------------------------------------------------------------===//
1364// RunLayoutInfoPropagation
1365//===----------------------------------------------------------------------===//
1366
1367/// Driver class for running the LayoutInfoPropagation analysis.
1368class RunLayoutInfoPropagation {
1369public:
1370 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RunLayoutInfoPropagation)
1371
1372 RunLayoutInfoPropagation(Operation *op, xegpu::LayoutKind layoutKind,
1373 unsigned indexBitWidth)
1374 : target(op) {
1375 SymbolTableCollection symbolTable;
1376 loadBaselineAnalyses(solver);
1377 solver.load<LayoutInfoPropagation>(symbolTable, layoutKind, indexBitWidth);
1378 (void)solver.initializeAndRun(op);
1379 }
1380
1381 LayoutInfo getLayoutInfo(Value val);
1382
1383 void printAnalysisResult(llvm::raw_ostream &os);
1384
1385private:
1386 DataFlowSolver solver;
1387 const Operation *target;
1388};
1389} // namespace
1390
1391LayoutInfo RunLayoutInfoPropagation::getLayoutInfo(Value val) {
1392 auto *state = solver.lookupState<LayoutInfoLattice>(val);
1393 if (!state)
1394 return {};
1395 return state->getValue();
1396}
1397
1398// Print the analysis result for debugging purposes.
1399void RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) {
1400 auto printFunctionResult = [&](FunctionOpInterface funcOp) {
1401 os << "function: " << funcOp.getName() << ":\n";
1402 // Function arguments
1403 for (BlockArgument arg : funcOp.getArguments()) {
1404 LayoutInfo layout = getLayoutInfo(arg);
1405 os << "argument: " << arg << "\n";
1406 os << "layout : ";
1407 layout.print(os);
1408 os << "\n";
1409 }
1410 // Function ops
1411 funcOp.walk([&](Operation *op) {
1412 // Skip ops that do not have results
1413 if (op->getResults().empty())
1414 return;
1415 os << "op : ";
1416 // For control-flow ops, print the op name only.
1417 if (isa<BranchOpInterface>(op) || isa<RegionBranchOpInterface>(op))
1418 os << op->getName();
1419 else
1420 op->print(os);
1421 os << "\n";
1422 // Print the layout for each result.
1423 for (auto [i, r] : llvm::enumerate(op->getResults())) {
1424 LayoutInfo layout = getLayoutInfo(r);
1425 os << "layout for result #" << i << ": ";
1426 layout.print(os);
1427 os << "\n";
1428 }
1429 });
1430 };
1431
1432 SmallVector<FunctionOpInterface> funcOps;
1433 if (auto modOp = dyn_cast<ModuleOp>(target)) {
1434 for (auto funcOp : modOp.getOps<FunctionOpInterface>())
1435 funcOps.push_back(funcOp);
1436
1437 // Collect all GpuFuncOps in the module.
1438 for (auto gpuModOp : modOp.getOps<gpu::GPUModuleOp>()) {
1439 for (auto gpuFuncOp : gpuModOp.getOps<FunctionOpInterface>())
1440 funcOps.push_back(gpuFuncOp);
1441 }
1442 }
1443 // Print the analysis result for each function.
1444 for (FunctionOpInterface funcOp : funcOps)
1445 printFunctionResult(funcOp);
1446}
1447
1448namespace {
1449
1450//===----------------------------------------------------------------------===//
1451// ResolveLayoutConflicts
1452//===----------------------------------------------------------------------===//
1453
1454/// Helper to get the defining CreateNdDescOp of a tensor descriptor value. This
1455/// function tries to find the defining CreateNdDescOp recursively accross
1456/// control-flow boundaries.
1457static xegpu::CreateNdDescOp getDefiningCreateNdDescOp(Value tdescValue) {
1458 // Try to get the defining CreateNdDescOp of the tensor descriptor.
1459 auto definingOp = tdescValue.getDefiningOp<xegpu::CreateNdDescOp>();
1460 if (definingOp)
1461 return definingOp;
1462 // If tdescValue is an argument, try to get the tied init value from the
1463 // parent loop-like op.
1464 if (auto arg = dyn_cast<BlockArgument>(tdescValue)) {
1465 auto *parentOp = arg.getOwner()->getParentOp();
1466 if (auto loop = dyn_cast<LoopLikeOpInterface>(parentOp)) {
1467 OpOperand *tiedInit = loop.getTiedLoopInit(arg);
1468 if (tiedInit)
1469 return getDefiningCreateNdDescOp(tiedInit->get());
1470 }
1471 }
1472 // If not found, return null.
1473 return nullptr;
1474}
1475
1476struct ResolveLayoutConflicts {
1477 ResolveLayoutConflicts(Operation *parentOp)
1478 : parentOp(parentOp), builder(parentOp->getContext()) {}
1479 LogicalResult run();
1480
1481private:
1482 Operation *parentOp;
1483 OpBuilder builder;
1484 LogicalResult resolveTensorDescConsumer(OpOperand &operand);
1485 LogicalResult resolveVectorConsumer(OpOperand &operand);
1486 LogicalResult assignResultLayout(OpResult &result);
1487};
1488
1489} // namespace
1490
1491LogicalResult ResolveLayoutConflicts::run() {
1492 // Scan all operations in the parent op and resolve layout conflicts at
1493 // tensor descriptor and vector use points.
1494 auto r = parentOp->walk([&](Operation *op) -> WalkResult {
1495 // if the operation inputs vector and output scalar, like multi-reduction we
1496 // need to check if the result has layout and add a convert_layout to serve
1497 // as anchor op for the reduction op's layout.
1498 if (isa<vector::MultiDimReductionOp>(op) || isa<vector::ReductionOp>(op)) {
1499 for (OpResult result : op->getResults()) {
1500 if (result.getType().isIntOrFloat()) {
1501 auto res = assignResultLayout(result);
1502 if (failed(res)) {
1503 DBGS() << "Failed to resolve vector consumer for multi-reduction "
1504 << *op << "\n";
1505 return WalkResult::interrupt();
1506 }
1507 }
1508 }
1509 }
1510 for (OpOperand &operand : op->getOpOperands()) {
1511 // Handle conflicts in tensor descriptor operands.
1512 Type operandType = operand.get().getType();
1513 if (isa<xegpu::AnchorLayoutInterface>(op) &&
1514 isa<xegpu::TensorDescType>(operandType)) {
1515 auto res = resolveTensorDescConsumer(operand);
1516 if (failed(res)) {
1517 DBGS() << "Failed to resolve tensor descriptor consumer: " << *op
1518 << "\n";
1519 return WalkResult::interrupt();
1520 }
1521 }
1522 // Handle conflicts in vector operands.
1523 if (isa<VectorType>(operandType)) {
1524 auto res = resolveVectorConsumer(operand);
1525 if (failed(res)) {
1526 DBGS() << "Failed to resolve vector consumer: " << *op << "\n";
1527 return WalkResult::interrupt();
1528 }
1529 }
1530 }
1531 return WalkResult::advance();
1532 });
1533
1534 LLVM_DEBUG({
1535 DBGS() << "IR after resolving layout conflicts:\n";
1536 parentOp->dump();
1537 });
1538
1539 return r.wasInterrupted() ? failure() : success();
1540}
1541
1542LogicalResult ResolveLayoutConflicts::assignResultLayout(OpResult &result) {
1543 Operation *producerOp = result.getDefiningOp();
1544 auto producerLayout = xegpu::getDistributeLayoutAttr(result);
1545 // Insert a convert_layout op to assign the layout.
1547 auto convertOp = xegpu::ConvertLayoutOp::create(
1548 builder, producerOp->getLoc(), result.getType(), result, producerLayout,
1549 producerLayout);
1550 result.replaceAllUsesExcept(convertOp.getResult(), convertOp);
1551 return success();
1552}
1553
1554LogicalResult
1555ResolveLayoutConflicts::resolveVectorConsumer(OpOperand &operand) {
1556 Value vectorValue = operand.get();
1557 Operation *consumerOp = operand.getOwner();
1558 // Get the current layout of the vector value.
1559 auto producerLayout = xegpu::getDistributeLayoutAttr(vectorValue);
1560 if (!producerLayout) {
1561 if (auto vectorTy = dyn_cast<VectorType>(vectorValue.getType());
1562 vectorTy && vectorTy.getRank() > 1)
1563 consumerOp->emitWarning("Expected layout for non-1D vectors.");
1564 return success(); // uniform non-tensor-data vector does not require layout
1565 }
1566 // Region branch ops (e.g. scf.for) and their terminators (e.g. scf.yield)
1567 // forward their operands to successor region inputs / parent op results;
1568 // their consumer layout is resolved through that forwarding, not at this
1569 // use point.
1570 if (isa<RegionBranchOpInterface, RegionBranchTerminatorOpInterface>(
1571 consumerOp))
1572 return success();
1573
1574 auto consumerLayout = xegpu::getConsumerLayoutAt(operand);
1575 if (!consumerLayout)
1576 return consumerOp->emitError(
1577 "No consumer layout found for vector operand.");
1578
1579 // If layouts are same, no conflict exists, return success.
1580 if (consumerLayout.isEqualTo(producerLayout))
1581 return success();
1582
1583 // If the producer is trivially rematerializable (e.g. `vector.step`, splat
1584 // `arith.constant`), clone it and stamp the consumer's expected layout on
1585 // the clone instead of inserting a `xegpu.convert_layout`. The convert
1586 // would otherwise lower to a cross-subgroup data movement through SLM at
1587 // WG-to-SG distribution time, which is more expensive than
1588 // recomputing a pure value generator.
1589 if (auto *producerOp = vectorValue.getDefiningOp();
1590 producerOp && producerOp->getNumResults() == 1 &&
1591 isa<OpResult>(vectorValue) &&
1593 builder.setInsertionPointAfter(producerOp);
1594 Operation *clone = builder.clone(*producerOp);
1595 OpResult cloneResult = clone->getResult(0);
1596 // Drop the inherited producer layout so the new layout takes effect
1597 xegpu::removeLayoutAttr(cloneResult);
1598 xegpu::setDistributeLayoutAttr(cloneResult, consumerLayout);
1599 operand.set(cloneResult);
1600 return success();
1601 }
1602
1603 // Insert a convert_layout op to resolve the conflict.
1604 builder.setInsertionPointAfterValue(vectorValue);
1605 auto convertOp = xegpu::ConvertLayoutOp::create(
1606 builder, consumerOp->getLoc(), vectorValue.getType(), vectorValue,
1607 producerLayout, consumerLayout);
1608
1609 // Update the operand to use the converted value.
1610 operand.set(convertOp.getResult());
1611 return success();
1612}
1613
1614LogicalResult
1615ResolveLayoutConflicts::resolveTensorDescConsumer(OpOperand &operand) {
1616 Operation *consumerOp = operand.getOwner();
1617 Value tdescValue = operand.get();
1618 auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(consumerOp);
1619 auto currTDescType = dyn_cast<xegpu::TensorDescType>(tdescValue.getType());
1620 assert(anchorOp && currTDescType &&
1621 "Expected anchor layout op and tensor descriptor consumer.");
1622 Attribute currLayout = currTDescType.getLayout();
1623 Attribute expectedLayout = anchorOp.getAnchorLayout();
1624 // A conflict exists in tensor descriptor operand if tensor descriptor's
1625 // layout is different from the anchor layout expected by the consumer.
1626 if (expectedLayout && currLayout && expectedLayout != currLayout) {
1627 // Try to get the defining CreateNdDescOp of the tensor descriptor.
1628 auto conflictingCreateNdOp = getDefiningCreateNdDescOp(tdescValue);
1629 if (!conflictingCreateNdOp) {
1630 DBGS() << "Unable to find defining CreateNdDescOp for tensor descriptor: "
1631 << tdescValue << "\n";
1632 return failure();
1633 }
1634 // Duplicate the CreateNdDescOp with the expected layout.
1635 builder.setInsertionPointAfter(conflictingCreateNdOp);
1636 auto newTensorDescType = xegpu::TensorDescType::get(
1637 conflictingCreateNdOp.getContext(), currTDescType.getShape(),
1638 currTDescType.getElementType(), currTDescType.getEncoding(),
1639 expectedLayout);
1640 xegpu::CreateNdDescOp newOp = xegpu::CreateNdDescOp::create(
1641 builder, consumerOp->getLoc(), newTensorDescType,
1642 conflictingCreateNdOp->getOperands(),
1643 conflictingCreateNdOp->getAttrs());
1644 // Replace the tensor descriptor operand in the consumer op with the new
1645 // tensor descriptor.
1646 consumerOp->replaceUsesOfWith(tdescValue, newOp.getResult());
1647 }
1648 return success();
1649}
1650
1651using GetLayoutFnTy = function_ref<xegpu::DistributeLayoutAttr(Value)>;
1652/// Update an operation with the layout of its results. If the result type is
1653/// a vector type, a temporary layout attribute is added to the operation. If
1654/// the result type is a tensor descriptor type, the type is updated with the
1655/// layout attribute. The users of the result are also updated with the layout
1656/// attribute.
1657static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
1658 GetLayoutFnTy getLayoutOfValue) {
1659 // Region ops (like scf.for) are already handled by the
1660 // updateControlFlowOps.
1661 if (mlir::isa<mlir::RegionBranchOpInterface>(op))
1662 return success();
1663
1664 // Iterate over all the results.
1665 for (OpResult result : op->getResults()) {
1666 Type resultType = result.getType();
1667 // Layouts are needed only for vector and tensor descriptor types.
1668 if (!isa<VectorType, xegpu::TensorDescType>(resultType))
1669 continue;
1670 // If the result has no layout but has users, emit a warning and continue.
1671 xegpu::DistributeLayoutAttr layout = getLayoutOfValue(result);
1672 if (!layout && result.getNumUses() > 0) {
1673 op->emitWarning("op has users but no layout assigned for its result");
1674 continue;
1675 }
1676 // If the result is a tensor descriptor type, update the tensor desc type
1677 // with layout.
1678 if (auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType)) {
1679 auto typeWithLayout = xegpu::TensorDescType::get(
1680 tensorDescTy.getContext(), tensorDescTy.getShape(),
1681 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
1682 result.setType(typeWithLayout);
1683 continue;
1684 }
1685 // If the result is a vector type, add a temporary layout attribute to the
1686 // op.
1688 }
1689 return success();
1690}
1691
1692/// Region ops like scf.for need special handling because they have blocks
1693/// inside. If the blocks have tensor descriptor type as block arguments,
1694/// thier types must be updated. Also region op can have results that may not
1695/// have any users (e.g. A and B tiles). They are not assigned a layout by
1696/// layout analysis because they have no users. However inside the region op
1697/// corresponding block arguments for these results do have layouts.
1698/// Therefore, in this case we still need to update the result types with the
1699/// layout attribute. This function function updates the internal block
1700/// arguments and the result types of the region op with the assigned layouts.
1701/// clang-format off
1702/// Example: scf.for ... iter_args(...) -> (out types) {
1703/// ^bb0(block types):
1704/// ...
1705/// scf.yield ... : (yield types)
1706/// }
1707/// clang-format on
1708/// In this example, at scf.yield, control-flow can transfer to two successor
1709/// regions. One is the ^bb0 (for loop body) and the other is the scf.for op
1710/// itself (yield the results). So we update both the block arguments of the
1711/// successor region (i.e. block types) and the result types of the scf.for op
1712/// (i.e. out types). Note that yield types are updated by respective
1713/// producers inside bb0.
1714static LogicalResult
1716 mlir::RegionBranchTerminatorOpInterface terminator,
1717 GetLayoutFnTy getLayoutOfValue) {
1718 // Only process if the terminator is inside a region branch op.
1719 auto branchOp = dyn_cast<RegionBranchOpInterface>(terminator->getParentOp());
1720 if (!branchOp)
1721 return success();
1722
1724 branchOp.getSuccessorOperandInputMapping(mapping,
1725 RegionBranchPoint(terminator));
1726 for (const auto &[successorOperand, successorInputs] : mapping) {
1727 for (Value successorInput : successorInputs) {
1728 Type inputType = successorInput.getType();
1729 // We only need to operate on tensor descriptor or vector types.
1730 if (!isa<xegpu::TensorDescType, VectorType>(inputType))
1731 continue;
1732 xegpu::DistributeLayoutAttr successorOperandLayout =
1733 getLayoutOfValue(successorOperand->get());
1734
1735 // If either of the layouts is not assigned, we cannot proceed.
1736 if (!successorOperandLayout) {
1737 LLVM_DEBUG(DBGS() << "No layout assigned for forwarded operand in "
1738 "branch terminator: "
1739 << successorOperand->get() << "\n");
1740 return failure();
1741 }
1742 // Get tensor descriptor type with the layout.
1743 if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(inputType)) {
1744 auto newTdescTy = xegpu::TensorDescType::get(
1745 tdescTy.getContext(), tdescTy.getShape(), tdescTy.getElementType(),
1746 tdescTy.getEncoding(), successorOperandLayout);
1747 successorInput.setType(newTdescTy);
1748 continue;
1749 }
1750 // If the type is a vector type and this region argument is an OpResult,
1751 // set the layout attribute on the OpResult.
1752 if (auto result = dyn_cast<OpResult>(successorInput))
1753 xegpu::setDistributeLayoutAttr(result, successorOperandLayout);
1754 }
1755 }
1756 return success();
1757}
1758
1759/// Update the function arguments and results with the layouts.
1760static LogicalResult updateFunctionOpInterface(mlir::OpBuilder &builder,
1761 mlir::FunctionOpInterface funcOp,
1762 GetLayoutFnTy getLayoutOfValue) {
1763 // Only process functions whose type is a standard MLIR FunctionType.
1764 // Functions using a different type representation (e.g. llvm.func with
1765 // LLVMFunctionType) are not targets for XeGPU layout propagation, and
1766 // calling setType(FunctionType{}) on them would corrupt their type.
1767 if (!isa<FunctionType>(funcOp.getFunctionType()))
1768 return success();
1769 SmallVector<Type> newArgTypes;
1770 // Update the function arguments.
1771 for (BlockArgument arg : funcOp.getArguments()) {
1772 Type argType = arg.getType();
1773 newArgTypes.push_back(argType);
1774 if (!isa<VectorType, xegpu::TensorDescType>(argType))
1775 continue;
1776 xegpu::DistributeLayoutAttr layout = getLayoutOfValue(arg);
1777 if (!layout) {
1778 LLVM_DEBUG(DBGS() << "Expecting layout for function argument: " << arg
1779 << " but got none.\n");
1780 return failure();
1781 }
1782 if (auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(argType)) {
1783 auto newTdescTy = xegpu::TensorDescType::get(
1784 tensorDescTy.getContext(), tensorDescTy.getShape(),
1785 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
1786 arg.setType(newTdescTy);
1787 newArgTypes.back() = newTdescTy;
1788 }
1789 }
1790 // Update the function type with the new argument types.
1791 // NOTE: We assume that function results are not expected to have layouts.
1792 funcOp.setType(FunctionType::get(funcOp.getContext(), newArgTypes,
1793 funcOp.getResultTypes()));
1794 return success();
1795}
1796
1797namespace {
1798struct XeGPUPropagateLayoutPass final
1799 : public xegpu::impl::XeGPUPropagateLayoutBase<XeGPUPropagateLayoutPass> {
1800 XeGPUPropagateLayoutPass() = default;
1801 XeGPUPropagateLayoutPass(const XeGPUPropagateLayoutPass &other) = default;
1802 XeGPUPropagateLayoutPass(xegpu::XeGPUPropagateLayoutOptions options)
1803 : XeGPUPropagateLayoutBase(std::move(options)) {}
1804 void runOnOperation() override;
1805};
1806
1807} // namespace
1808
1810 LayoutKind layoutKind,
1811 unsigned indexBitWidth, bool printOnly) {
1812 RunLayoutInfoPropagation analysis(target, layoutKind, indexBitWidth);
1813 // Print the analysis result and exit. (for debugging purposes)
1814 if (printOnly) {
1815 auto &os = llvm::outs();
1816 analysis.printAnalysisResult(os);
1817 return success();
1818 }
1819 // Helper to convert LayoutInfo to xegpu::LayoutAttr.
1820 auto getLayoutFromPropagation =
1821 [&](Value val) -> xegpu::DistributeLayoutAttr {
1822 LayoutInfo layout = analysis.getLayoutInfo(val);
1823 if (auto opResult = dyn_cast<OpResult>(val)) {
1824 Operation *defOp = opResult.getDefiningOp();
1825 if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(defOp)) {
1826 auto anchorLayout = anchorOp.getAnchorLayout();
1827 if (anchorLayout != nullptr)
1828 return anchorLayout;
1829 }
1830 xegpu::DistributeLayoutAttr requiredResLayoutAttr =
1831 xegpu::getTemporaryLayout(opResult);
1832 if (requiredResLayoutAttr != nullptr)
1833 return requiredResLayoutAttr;
1834 }
1835 if (!layout.isAssigned())
1836 return {};
1837 xegpu::DistributeLayoutAttr layoutAttr =
1838 cast<xegpu::DistributeLayoutAttr>(layout.get());
1839 if (layout.isSliceLayout())
1840 return cast<xegpu::SliceAttr>(layoutAttr);
1841
1842 return cast<xegpu::LayoutAttr>(layoutAttr);
1843 };
1844
1845 Operation *op = target;
1846 auto walkResult = op->walk([&](mlir::Block *block) -> WalkResult {
1847 for (mlir::Operation &op : llvm::reverse(block->getOperations())) {
1848 LogicalResult r = success();
1850 .Case([&](mlir::RegionBranchTerminatorOpInterface branchTermOp) {
1851 r = updateControlFlowOps(builder, branchTermOp,
1852 getLayoutFromPropagation);
1853 })
1854 .Case([&](mlir::RegionBranchOpInterface branchOp) {
1856 getLayoutFromPropagation);
1857 })
1858 .Case([&](mlir::FunctionOpInterface funcOp) {
1859 r = updateFunctionOpInterface(builder, funcOp,
1860 getLayoutFromPropagation);
1861 })
1862 .Default([&](Operation *op) {
1863 r = updateOp(builder, op, getLayoutFromPropagation);
1864 });
1865 if (failed(r)) {
1866 op.emitError("Failed to update operation with the layout.");
1867 return WalkResult::interrupt();
1868 }
1869 }
1870 return WalkResult::advance();
1871 });
1872 if (walkResult.wasInterrupted())
1873 return failure();
1874
1875 return success();
1876}
1877
1879 ResolveLayoutConflicts resolver(target);
1880 return resolver.run();
1881}
1882
1883void XeGPUPropagateLayoutPass::runOnOperation() {
1884
1885 xegpu::removeTemporaryLayoutAttrs(getOperation());
1886
1887 xegpu::LayoutKind layoutKind;
1888 if (this->layoutKind == "lane") {
1889 layoutKind = xegpu::LayoutKind::Lane;
1890 } else if (this->layoutKind == "inst") {
1891 layoutKind = xegpu::LayoutKind::InstData;
1892 } else if (this->layoutKind == "subgroup") {
1893 layoutKind = xegpu::LayoutKind::Subgroup;
1894 } else {
1895 getOperation()->emitError("Unsupported layout kind option: " +
1896 this->layoutKind);
1897 signalPassFailure();
1898 return;
1899 }
1900 OpBuilder builder(&getContext());
1901 if (failed(xegpu::propagateLayouts(builder, getOperation(), layoutKind,
1902 this->indexBitWidth, this->printOnly))) {
1903 signalPassFailure();
1904 return;
1905 }
1906 // Resolve layout conflicts if any.
1907 if (failed(xegpu::resolveLayoutConflicts(getOperation()))) {
1908 signalPassFailure();
1909 return;
1910 }
1911}
return success()
#define DBGS()
Definition Hoisting.cpp:32
std::string join(const Ts &...args)
Helper function to concatenate arguments into a std::string.
lhs
b getContext())
auto load
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition TypeID.h:331
static SmallVector< LayoutRepresentation > getValidLayouts(ArrayRef< int64_t > wgShape, ArrayRef< int64_t > instData, int64_t sgCount)
static LogicalResult updateControlFlowOps(mlir::OpBuilder &builder, mlir::RegionBranchTerminatorOpInterface terminator, GetLayoutFnTy getLayoutOfValue)
Region ops like scf.for need special handling because they have blocks inside.
function_ref< xegpu::DistributeLayoutAttr(Value)> GetLayoutFnTy
FailureOr< int64_t > getNumSg(Operation *op, const int sgSize)
static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op, GetLayoutFnTy getLayoutOfValue)
Update an operation with the layout of its results.
static LogicalResult updateFunctionOpInterface(mlir::OpBuilder &builder, mlir::FunctionOpInterface funcOp, GetLayoutFnTy getLayoutOfValue)
Update the function arguments and results with the layouts.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:306
Block represents an ordered list of Operations.
Definition Block.h:33
OpListType & getOperations()
Definition Block.h:147
The general data-flow analysis solver.
LogicalResult initializeAndRun(Operation *top, llvm::function_ref< bool(DataFlowAnalysis &)> analysisFilter=nullptr)
Initialize analyses starting from the provided top-level operation and run the analysis until fixpoin...
const StateT * lookupState(AnchorT anchor) const
Lookup an analysis state for the given lattice anchor.
AnalysisT * load(Args &&...args)
Load an analysis into the solver. Return the analysis instance.
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:209
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:567
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition Builders.h:423
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
This class represents an operand of an operation.
Definition Value.h:254
This is a value defined by a result of an operation.
Definition Value.h:454
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:432
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:240
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:408
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:255
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:115
void print(raw_ostream &os, const OpPrintingFlags &flags={})
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:822
result_range getResults()
Definition Operation.h:440
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:429
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
This class represents a successor of a region.
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
This class represents a lattice holding a specific value of type ValueT.
A sparse (backward) data-flow analysis for propagating SSA value lattices backwards across the IR by ...
SparseBackwardDataFlowAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
void loadBaselineAnalyses(DataFlowSolver &solver)
Populates a DataFlowSolver with analyses that are required to ensure user-defined analyses are run pr...
Definition Utils.h:29
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
const uArch * getUArch(llvm::StringRef archName)
DistributeLayoutAttr inferShapeCastSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for a shape cast operation given the result layout attribute,...
DistributeLayoutAttr setupInterleaveResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the result layout for an interleave operation to ensure the source layout can be safely deriv...
DistributeLayoutAttr inferTransposeSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > permutation)
Infers the source layout attribute for a transpose operation given the result layout attribute and pe...
DistributeLayoutAttr inferInsertStridedSliceSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for an insert strided slice operation given the result layout attr...
void removeTemporaryLayoutAttrs(Operation *op)
Removes the temporary layout attributes for each OpOperand and OpResult of the given operation.
void setTemporaryLayout(const T &operandOrResult, const DistributeLayoutAttr layout)
LayoutKind
Specifies the level of a layout hierarchy for comparison or propagation.
Definition XeGPU.h:32
void setDistributeLayoutAttr(const OpResult &Result, const DistributeLayoutAttr layout)
[to-be-deprecated] Sets the DistributeLayoutAttr for a given OpResult user should use setAnchorLayout...
DistributeLayoutAttr inferInterleaveSourceLayout(DistributeLayoutAttr resLayout)
Infers the source layout attribute for an interleave operation given the result layout attribute.
DistributeLayoutAttr setupLoadMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the anchor layout for load matrix operation.
int getLargestDivisor(T dim, ArrayRef< T > candidates, ArrayRef< T > candidateMultiples={})
Helper Function to find a proper instruction multiple for the user-supplied sg-level data shape (dive...
DistributeLayoutAttr inferBroadcastSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for a broadcast operation given the result layout attribute,...
std::optional< std::tuple< DistributeLayoutAttr, DistributeLayoutAttr, DistributeLayoutAttr, DistributeLayoutAttr, DistributeLayoutAttr > > setupDpasMxLayout(LayoutKind layoutKind, VectorType aTy, VectorType bTy, VectorType cdTy, VectorType aScaleTy, VectorType bScaleTy, DistributeLayoutAttr consumerLayout, int numSg, const uArch::uArch *uArch)
Sets up the anchor layouts for dpas_mx operands (A, B, C/D, A_scale, and B_scale).
DistributeLayoutAttr setupStoreScatterAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, int chunkSize, const uArch::uArch *uArch)
Sets up the anchor layout for a store scatter operation.
SliceAttr setupMultiReductionResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, DistributeLayoutAttr consumerLayout, SmallVector< int64_t > reductionDims, int numSg, const uArch::uArch *uArch)
Sets up layout for Multi-Reduction operations by creating a SliceAttr for the result.
DistributeLayoutAttr setupBitCastResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Setup the result layout attribute for a bitcast operation based on element type bitwidths.
void removeLayoutAttr(const T &operandOrResult)
Removes the LayoutAttr for a given OpOperand or OpResult if it exists.
DistributeLayoutAttr inferMaskOffsetLayoutForScatterIO(DistributeLayoutAttr payloadLayout, int chunkSize)
Infers the layout attribute for mask and offset operand for Chunked load and store,...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
LogicalResult resolveLayoutConflicts(Operation *target)
DistributeLayoutAttr inferBitCastSourceLayout(DistributeLayoutAttr resLayout, int resElemTyBitWidth, int srcElemTyBitWidth)
Infers the source layout attribute for a bitcast operation given the result layout attribute,...
DistributeLayoutAttr setupInsertStridedSliceResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the result layout for an insert strided slice operation.
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
DistributeLayoutAttr inferReductionSourceLayout(DistributeLayoutAttr resLayout)
Infers the source layout attribute for a reduction operation given the result layout attribute and re...
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
DistributeLayoutAttr inferDeinterleaveSourceLayout(DistributeLayoutAttr resLayout)
Infers the source layout attribute for a deinterleave operation given the result layout attribute.
DistributeLayoutAttr getConsumerLayoutAt(OpOperand &operand)
Gets the expected layout for a given consumer operand.
DistributeLayoutAttr inferMultiReductionSourceLayout(DistributeLayoutAttr resLayout, SmallVector< int64_t > reduceDims)
Infers the source layout attribute for a reduction operation given the result layout attribute and re...
bool isTriviallyRematerializable(Operation *op)
Returns true if op is safe and cheap to clone: it has no side effects, no regions,...
DistributeLayoutAttr setupLoadGatherAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, int chunkSize, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the anchor layout for a load gather operation.
LogicalResult propagateLayouts(OpBuilder &builder, Operation *target, LayoutKind layoutKind, unsigned indexBitWidth, bool printOnly=false)
LogicalResult propagateRegionArgsToInits(RegionBranchOpInterface regionOp, GetLayoutFnTy getLayoutOfValue)
Propagate layouts from a region branch op's region entry block arguments back to its init operands.
std::optional< std::tuple< DistributeLayoutAttr, DistributeLayoutAttr, DistributeLayoutAttr > > setupDpasLayout(LayoutKind layoutKind, VectorType aTy, VectorType bTy, VectorType cdTy, DistributeLayoutAttr consumerLayout, int numSg, const uArch::uArch *uArch)
Sets up the anchor layouts for a dpas operands (A, B, and C/D).
SliceAttr setupReductionResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, const uArch::uArch *uArch)
Sets up layout for Reduction operations by creating a SliceAttr for the result.
DistributeLayoutAttr setupStoreMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, const uArch::uArch *uArch)
Sets up the anchor layout for a store matrix operation.
Include the generated interface declarations.
DenseMap< OpOperand *, SmallVector< Value > > RegionBranchSuccessorMapping
A mapping from successor operands to successor inputs.
bool operator==(StringAttr lhs, std::nullptr_t)
Define comparisons for StringAttr against nullptr and itself to avoid the StringRef overloads from be...
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:139
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
virtual int getSubgroupSize() const =0
const Instruction * getInstruction(InstructionKind instKind) const
Definition uArchBase.h:168