MLIR 23.0.0git
XeGPUPropagateLayout.cpp
Go to the documentation of this file.
1//===- XeGPUPropagateLayout.cpp - XeGPU Layout Propagation ------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
21#include "mlir/IR/Attributes.h"
22#include "mlir/IR/Builders.h"
25#include "mlir/IR/Operation.h"
26#include "mlir/IR/Value.h"
27#include "mlir/IR/Visitors.h"
31#include "mlir/Support/LLVM.h"
32#include "llvm/ADT/ArrayRef.h"
33#include "llvm/ADT/STLExtras.h"
34#include "llvm/ADT/SmallSet.h"
35#include "llvm/ADT/SmallVector.h"
36#include "llvm/ADT/TypeSwitch.h"
37#include "llvm/Support/Casting.h"
38#include "llvm/Support/Debug.h"
39#include "llvm/Support/LogicalResult.h"
40#include "llvm/Support/raw_ostream.h"
41
42namespace mlir {
43namespace xegpu {
44#define GEN_PASS_DEF_XEGPUPROPAGATELAYOUT
45#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
46} // namespace xegpu
47} // namespace mlir
48
49#define DEBUG_TYPE "xegpu-propagate-layout"
50#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
51
52using namespace mlir;
53using namespace mlir::dataflow;
54
55namespace {
56
57//===----------------------------------------------------------------------===//
58// LayoutInfo
59//===----------------------------------------------------------------------===//
60
61/// Helper class for tracking the analysis state of an mlir value. For layout
62/// propagation, the analysis state is simply the distribution layout of
63/// each value. The distribution layout information is encapsulated using
64/// xegpu::DistributeLayoutAttr class which can hold information about any type
65/// of distribution layout that XeGPU dialect supports. Purpose of this analysis
66/// to propagate some unique distribution layout for each value in the program
67/// starting from a set of anchor operations (like DPAS, StoreNd, etc.). Note
68/// that analysis will reach a fixed point when all values are reached some
69/// layout and, analysis does not try to modify any already assigned layouts.
70///
71/// Given this, LayoutInfo satisifies the following properties:
72/// 1) A LayoutInfo value can be in one of two states - `assigned` or `not
73/// assigned`.
74/// 2) Two LayoutInfo values are equal if they are both assigned or
75/// both not assigned. The concrete value of assigned state does not matter.
76/// 3) The meet operator works as follows:
77/// - If current state is assigned, return the current state. (already
78/// a unique layout is assigned. don't change it)
79/// - Otherwise, return the other state.
80
81struct LayoutInfo {
82private:
83 xegpu::DistributeLayoutAttr storage = nullptr;
84
85public:
86 LayoutInfo() = default;
87 LayoutInfo(const xegpu::DistributeLayoutAttr &layout) : storage(layout) {}
88
89 // Two lattice values are equal if they have `some` layout. The actual
90 // content of the layout does not matter.
91 bool operator==(const LayoutInfo &other) const {
92 return this->isAssigned() == other.isAssigned();
93 }
94
95 static LayoutInfo meet(const LayoutInfo &lhs, const LayoutInfo &rhs);
96
97 static LayoutInfo join(const LayoutInfo &lhs, const LayoutInfo &rhs);
98
99 void print(raw_ostream &os) const;
100
101 bool isAssigned() const { return storage != nullptr; }
102
103 LayoutInfo transpose(ArrayRef<int64_t> permutation) const;
104
105 SmallVector<int> getLaneLayout() const;
106
107 SmallVector<int> getLaneData() const;
108
109 SmallVector<int> getInstData() const;
110
111 SmallVector<int> getSgLayout() const;
112
113 SmallVector<int> getSgData() const;
114
115 SmallVector<int> getOrder() const;
116
117 bool isSliceLayout() const {
118 if (!isAssigned())
119 return false;
120 return isa<xegpu::SliceAttr>(storage);
121 }
122
123 int64_t getRank() const {
124 if (!isAssigned())
125 return -1;
126 return storage.getRank();
127 }
128
129 Attribute get() { return storage; }
130 void set(const xegpu::DistributeLayoutAttr &layout) { storage = layout; }
131};
132
133SmallVector<int> LayoutInfo::getLaneLayout() const {
134 if (!isAssigned())
135 return {};
136 return llvm::map_to_vector(storage.getEffectiveLaneLayoutAsInt(),
137 [](int64_t val) { return static_cast<int>(val); });
138}
139
140SmallVector<int> LayoutInfo::getLaneData() const {
141 if (!isAssigned())
142 return {};
143 return llvm::map_to_vector(storage.getEffectiveLaneDataAsInt(),
144 [](int64_t val) { return static_cast<int>(val); });
145}
146
147SmallVector<int> LayoutInfo::getInstData() const {
148 if (!isAssigned())
149 return {};
150 return llvm::map_to_vector(storage.getEffectiveInstDataAsInt(),
151 [](int64_t val) { return static_cast<int>(val); });
152}
153
154SmallVector<int> LayoutInfo::getSgLayout() const {
155 if (!isAssigned())
156 return {};
157 return llvm::map_to_vector(storage.getEffectiveSgLayoutAsInt(),
158 [](int64_t val) { return static_cast<int>(val); });
159}
160
161SmallVector<int> LayoutInfo::getSgData() const {
162 if (!isAssigned())
163 return {};
164 return llvm::map_to_vector(storage.getEffectiveSgDataAsInt(),
165 [](int64_t val) { return static_cast<int>(val); });
166}
167
168SmallVector<int> LayoutInfo::getOrder() const {
169 if (!isAssigned() || !storage.getOrder())
170 return {};
171 return llvm::map_to_vector(storage.getOrder().asArrayRef(),
172 [](int64_t val) { return static_cast<int>(val); });
173}
174
175void LayoutInfo::print(raw_ostream &os) const {
176 if (isAssigned()) {
177 os << storage;
178 } else {
179 os << "Not assigned.";
180 }
181}
182
183LayoutInfo LayoutInfo::meet(const LayoutInfo &lhs, const LayoutInfo &rhs) {
184 if (!lhs.isAssigned())
185 return rhs;
186 return lhs;
187}
188
189/// Since this is a backward analysis, join method is not used.
190LayoutInfo LayoutInfo::join(const LayoutInfo &lhs, const LayoutInfo &rhs) {
191 llvm_unreachable("Join should not be triggered by layout propagation.");
192}
193
194/// Construct a new layout with the transposed inst_data or lane_layout,
195/// lane_data.
196LayoutInfo LayoutInfo::transpose(ArrayRef<int64_t> permutation) const {
197 if (!isAssigned())
198 return {};
199 // Check if the permutation is valid.
200 llvm::SmallSet<int64_t, 4> seen(permutation.begin(), permutation.end());
201 bool hasDuplicates = seen.size() != permutation.size();
202 bool withinRange = llvm::all_of(permutation, [&](int64_t idx) {
203 return idx >= 0 && idx < static_cast<int64_t>(permutation.size());
204 });
205
206 if (!withinRange || hasDuplicates) {
207 assert(false && "Invalid permutation for transpose.");
208 return {};
209 }
210
211 SmallVector<int32_t> laneLayout;
212 SmallVector<int32_t> laneData;
213 SmallVector<int32_t> instData;
214 SmallVector<int32_t> sgLayout;
217
218 for (int64_t idx : permutation) {
219 if (getLaneLayout().size()) {
220 laneLayout.push_back(static_cast<int32_t>(getLaneLayout()[idx]));
221 laneData.push_back(static_cast<int32_t>(getLaneData()[idx]));
222 }
223 if (getInstData().size())
224 instData.push_back(static_cast<int32_t>(getInstData()[idx]));
225 if (getSgData().size()) {
226 sgLayout.push_back(static_cast<int32_t>(getSgLayout()[idx]));
227 sgData.push_back(static_cast<int32_t>(getSgData()[idx]));
228 }
229 if (getOrder().size()) {
230 order.push_back(static_cast<int32_t>(getOrder()[idx]));
231 }
232 }
233 auto orderAttr = order.size()
234 ? DenseI32ArrayAttr::get(storage.getContext(), order)
235 : nullptr;
236 xegpu::LayoutAttr layoutAttr;
237 if (getLaneLayout().size())
238 layoutAttr =
239 xegpu::LayoutAttr::get(storage.getContext(), laneLayout, laneData);
240 if (getInstData().size())
241 layoutAttr = xegpu::LayoutAttr::get(storage.getContext(), instData);
242 if (getSgData().size())
243 layoutAttr = xegpu::LayoutAttr::get(
244 storage.getContext(),
245 DenseI32ArrayAttr::get(storage.getContext(), sgLayout),
246 DenseI32ArrayAttr::get(storage.getContext(), sgData),
247 /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
248 /*lane_data =*/nullptr, orderAttr);
249 return LayoutInfo(layoutAttr);
250}
251
252//===----------------------------------------------------------------------===//
253// LayoutInfoLattice
254//===----------------------------------------------------------------------===//
255
256/// Lattice holding the LayoutInfo for each value.
257struct LayoutInfoLattice : public Lattice<LayoutInfo> {
259 using Lattice::Lattice;
260};
261
262/// Helper Functions to get default layouts. A `default layout` is a layout that
263/// is assigned to a value when the layout is not fixed by some anchor operation
264/// (like DPAS).
265
266/// Helper Function to get the default layout for uniform values like constants.
267/// For 1D vector, lane_layout is [subgroupSize] and lane_data is [1].
268/// For 2D vector, lane_layout is [1, subgroupSize] and lane_data is [1, 1].
269static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
270 unsigned rank,
271 const xegpu::uArch::uArch *uArch) {
272 assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
273 if (rank == 1) {
274 return LayoutInfo(
275 xegpu::LayoutAttr::get(ctx, {uArch->getSubgroupSize()}, {1}));
276 }
277 return LayoutInfo(
278 xegpu::LayoutAttr::get(ctx, {1, uArch->getSubgroupSize()}, {1, 1}));
279}
280
281/// Helper to get the default layout for 2D block operations.
282template <typename Ty>
283static LayoutInfo getSIMTLayoutInfoBlockIO(Ty ty,
285 unsigned packingSize) {
286 // Expecting a 1D or 2D vector.
287 assert((ty.getRank() == 1 || ty.getRank() == 2) &&
288 "Expected 1D or 2D vector.");
289 // Expecting int or float element type.
290 assert(ty.getElementType().isIntOrFloat() &&
291 "Expected int or float element type.");
292 // If the rank is 1, then return default layout for 1D vector.
293 if (ty.getRank() == 1)
294 return getDefaultSIMTLayoutInfo(ty.getContext(), 1, uArch);
295 // Packing factor is determined by the element type bitwidth.
296 unsigned bitwidth = ty.getElementType().getIntOrFloatBitWidth();
297 int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
298 return LayoutInfo(xegpu::LayoutAttr::get(
299 ty.getContext(), {1, uArch->getSubgroupSize()}, {1, packingFactor}));
300}
301
302//===----------------------------------------------------------------------===//
303// LayoutInfoPropagation
304//===----------------------------------------------------------------------===//
305
306/// Backward data flow analysis to propagate the lane_layout and lane_data of
307/// each value in the program. Currently, the layouts for operands DPAS,
308/// StoreNd, and StoreScatter are fixed (known before propagation). Purpose of
309/// this analysis is to propagate those known layouts to all their producers and
310/// (other) consumers.
311class LayoutInfoPropagation
312 : public SparseBackwardDataFlowAnalysis<LayoutInfoLattice> {
313private:
314 xegpu::LayoutKind layoutKind;
315 unsigned indexBitWidth;
316 void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
318
319 void visitStoreNdOp(xegpu::StoreNdOp store,
322
323 void visitStoreScatterOp(xegpu::StoreScatterOp storeScatter,
326
327 void visitLoadNdOp(xegpu::LoadNdOp load,
330
331 void visitLoadGatherOp(xegpu::LoadGatherOp load,
334
335 void visitTransposeOp(vector::TransposeOp transpose,
338
339 void visitVectorBitcastOp(vector::BitCastOp bitcast,
342
343 void visitUpdateNdOffsetOp(xegpu::UpdateNdOffsetOp updateNdOffset,
346
347 void visitPrefetchNdOp(xegpu::PrefetchNdOp prefetch,
350
351 void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction,
354
355 void visitVectorReductionOp(vector::ReductionOp reduction,
358
359 void visitVectorBroadCastOp(vector::BroadcastOp broadcast,
362 void visitShapeCastOp(vector::ShapeCastOp shapeCast,
365 void
366 visitInsertStridedSliceOp(vector::InsertStridedSliceOp insertStridedSlice,
369
370 void visitLoadMatrixOp(xegpu::LoadMatrixOp load,
373
374 void visitStoreMatrixOp(xegpu::StoreMatrixOp store,
377
378 void visitLoadGatherOp(xegpu::LoadMatrixOp load,
381
382 void visitStoreScatterOp(xegpu::StoreMatrixOp store,
385
386 void visitConvertLayoutOp(xegpu::ConvertLayoutOp convertLayout,
389
390 bool hasParamsOfLayoutKind(xegpu::DistributeLayoutAttr anchorLayout);
391
392public:
393 LayoutInfoPropagation(DataFlowSolver &solver,
394 SymbolTableCollection &symbolTable,
395 xegpu::LayoutKind layoutKind, unsigned indexBitWidth)
396 : SparseBackwardDataFlowAnalysis(solver, symbolTable),
397 layoutKind(layoutKind), indexBitWidth(indexBitWidth) {}
399
400 LogicalResult
401 visitOperation(Operation *op, ArrayRef<LayoutInfoLattice *> operands,
402 ArrayRef<const LayoutInfoLattice *> results) override;
403
404 void visitBranchOperand(OpOperand &operand) override {};
405
406 void visitCallOperand(OpOperand &operand) override {};
407
408 void
409 visitNonControlFlowArguments(RegionSuccessor &successor,
410 ArrayRef<BlockArgument> arguments) override {};
411
412 void visitExternalCall(CallOpInterface call,
414 ArrayRef<const LayoutInfoLattice *> results) override {
415 };
416
417 void setToExitState(LayoutInfoLattice *lattice) override {
418 (void)lattice->meet(LayoutInfo());
419 }
420};
421} // namespace
422
423LogicalResult LayoutInfoPropagation::visitOperation(
424 Operation *op, ArrayRef<LayoutInfoLattice *> operands,
425 ArrayRef<const LayoutInfoLattice *> results) {
427 .Case(
428 [&](xegpu::DpasOp dpasOp) { visitDpasOp(dpasOp, operands, results); })
429 .Case([&](xegpu::StoreNdOp storeNdOp) {
430 visitStoreNdOp(storeNdOp, operands, results);
431 })
432 .Case([&](xegpu::StoreScatterOp storeScatterOp) {
433 visitStoreScatterOp(storeScatterOp, operands, results);
434 })
435 .Case([&](xegpu::LoadNdOp loadNdOp) {
436 visitLoadNdOp(loadNdOp, operands, results);
437 })
438 .Case([&](xegpu::LoadGatherOp loadGatherOp) {
439 visitLoadGatherOp(loadGatherOp, operands, results);
440 })
441 .Case([&](xegpu::UpdateNdOffsetOp updateNdOffsetOp) {
442 visitUpdateNdOffsetOp(updateNdOffsetOp, operands, results);
443 })
444 .Case([&](xegpu::PrefetchNdOp prefetchNdOp) {
445 visitPrefetchNdOp(prefetchNdOp, operands, results);
446 })
447 .Case([&](vector::TransposeOp transposeOp) {
448 visitTransposeOp(transposeOp, operands, results);
449 })
450 .Case([&](vector::BitCastOp bitcastOp) {
451 visitVectorBitcastOp(bitcastOp, operands, results);
452 })
453 .Case([&](vector::MultiDimReductionOp reductionOp) {
454 visitVectorMultiReductionOp(reductionOp, operands, results);
455 })
456 .Case([&](vector::ReductionOp reductionOp) {
457 visitVectorReductionOp(reductionOp, operands, results);
458 })
459 .Case([&](vector::BroadcastOp broadcastOp) {
460 visitVectorBroadCastOp(broadcastOp, operands, results);
461 })
462 .Case([&](vector::ShapeCastOp shapeCastOp) {
463 visitShapeCastOp(shapeCastOp, operands, results);
464 })
465 .Case([&](vector::InsertStridedSliceOp insertStridedSliceOp) {
466 visitInsertStridedSliceOp(insertStridedSliceOp, operands, results);
467 })
468 .Case([&](xegpu::LoadMatrixOp loadMatrixOp) {
469 visitLoadMatrixOp(loadMatrixOp, operands, results);
470 })
471 .Case([&](xegpu::StoreMatrixOp storeMatrixOp) {
472 visitStoreMatrixOp(storeMatrixOp, operands, results);
473 })
474 .Case([&](xegpu::ConvertLayoutOp convertLayoutOp) {
475 visitConvertLayoutOp(convertLayoutOp, operands, results);
476 })
477 // All other ops.
478 .Default([&](Operation *op) {
479 for (const LayoutInfoLattice *resultInfo : results) {
480 if (!resultInfo->getValue().isAssigned())
481 continue;
482 for (auto [operandInfo, operand] :
483 llvm::zip(operands, op->getOpOperands())) {
484 // If the operand type is not a vector or tensor descriptor, skip
485 // it.
486 if (!isa<xegpu::TensorDescType, VectorType>(
487 operand.get().getType()))
488 continue;
489 // Propagate the result layout to the operand.
490 meet(operandInfo, *resultInfo);
491 }
492 }
493 });
494
495 return success();
496}
497
498bool LayoutInfoPropagation::hasParamsOfLayoutKind(
499 xegpu::DistributeLayoutAttr anchorLayout) {
500 if (anchorLayout == nullptr) {
501 return false;
502 }
503 if (layoutKind == xegpu::LayoutKind::InstData) {
504 return !(anchorLayout.getEffectiveInstDataAsInt().empty());
505 }
506 if (layoutKind == xegpu::LayoutKind::Lane) {
507 return !(anchorLayout.getEffectiveLaneLayoutAsInt().empty() ||
508 anchorLayout.getEffectiveLaneDataAsInt().empty());
509 }
510 if (layoutKind == xegpu::LayoutKind::Subgroup) {
511 return !(anchorLayout.getEffectiveSgLayoutAsInt().empty() ||
512 anchorLayout.getEffectiveSgDataAsInt().empty());
513 }
514 return false;
515}
516
517// This function returns all layouts for the given sgCount, whose sgData:
518// 1. Evenly divides the wgShape.
519// 2. Is a multiple of instData.
520// Example:
521// wgShape = [128, 64], instData = [8, 16], sgCount = 32
522// Returns layouts:
523// [(8,4), (16,2)], which correspond to sgData [16,16] and [8,32].
525 ArrayRef<int> instData,
526 int64_t sgCount) {
528 for (int sgLayout0 = 1; sgLayout0 <= sgCount; ++sgLayout0) {
529 if (sgCount % sgLayout0)
530 continue;
531 int sgLayout1 = sgCount / sgLayout0;
532 int sgData0 = wgShape[0] / sgLayout0;
533 int sgData1 = wgShape[1] / sgLayout1;
534 if ((wgShape[0] % sgLayout0 || wgShape[1] % sgLayout1) ||
535 (sgData0 % instData[0] || sgData1 % instData[1]))
536 continue;
537 candidates.emplace_back(sgLayout0, sgLayout1);
538 }
539 // Sort primarily by how balanced they are
540 // (i.e., minimize the absolute difference between the two dimensions), and
541 // secondarily by the first dimension in ascending order.
542 llvm::sort(candidates, [](const std::pair<int, int> &lhs,
543 const std::pair<int, int> &rhs) {
544 int diffLhs = std::abs(lhs.first - lhs.second);
545 int diffRhs = std::abs(rhs.first - rhs.second);
546 if (diffLhs != diffRhs)
547 return diffLhs < diffRhs;
548 return lhs.first < rhs.first;
549 });
550 return candidates;
551}
552
553FailureOr<int64_t> getNumSg(Operation *op, const int sgSize) {
554 // Oblivious to workitem layout, the total count matters.
555 auto gpuFunc = op->getParentOfType<gpu::GPUFuncOp>();
556 if (!gpuFunc)
557 return failure();
558 auto knownBlockSize = gpuFunc.getKnownBlockSize();
559 if (!knownBlockSize.has_value())
560 return failure();
561 const int flatBlockSize = llvm::product_of(knownBlockSize.value());
562 return flatBlockSize / sgSize;
563}
564
565void LayoutInfoPropagation::visitPrefetchNdOp(
566 xegpu::PrefetchNdOp prefetch, ArrayRef<LayoutInfoLattice *> operands,
567 ArrayRef<const LayoutInfoLattice *> results) {
568
569 LayoutInfo prefetchLayout;
570 xegpu::DistributeLayoutAttr anchorLayout = prefetch.getLayoutAttr();
571 if (hasParamsOfLayoutKind(anchorLayout)) {
572 prefetchLayout = LayoutInfo(anchorLayout);
573 } else {
574 // Here we assign the default layout to the tensor descriptor operand of
575 // prefetch.
576 auto tdescTy = prefetch.getTensorDescType();
577
578 const uArch *uArch = getUArch(getChipStr(prefetch).value_or(""));
579 if (!uArch)
580 return;
581 const auto *uArchInstruction =
582 dyn_cast<xegpu::uArch::Subgroup2DBlockPrefetchInstruction>(
583 uArch->getInstruction(
584 xegpu::uArch::InstructionKind::Subgroup2DBlockPrefetch));
585
586 auto blockWHC =
587 uArchInstruction->getBlockWidthHeightCount(tdescTy.getElementType());
588 if (!blockWHC)
589 prefetch.emitWarning("No known block params found for the element type.");
590 auto [bWidth, bHeight, bCount] = blockWHC.value();
591 SmallVector<int> instData;
592 int instWidth = xegpu::getLargestDivisor(
593 static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 1)), bWidth);
594 if (instWidth == -1)
595 prefetch.emitWarning(
596 "No suitable instruction multiple found for the given shape.");
597 if (tdescTy.getRank() == 1)
598 instData = {instWidth};
599 else {
600 int instHeight = xegpu::getLargestDivisor(
601 static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 2)), bHeight);
602 if (instHeight == -1)
603 prefetch.emitWarning(
604 "No suitable instruction multiple found for the given shape.");
605 instData = {instHeight, instWidth};
606 }
607
608 if (layoutKind == xegpu::LayoutKind::InstData)
609 prefetchLayout =
610 LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(), instData));
611 else
612 prefetchLayout = getSIMTLayoutInfoBlockIO(
613 tdescTy, uArch, uArchInstruction->getPackedFormatBitSize());
614
615 prefetch.setLayoutAttr(
616 dyn_cast<xegpu::DistributeLayoutAttr>(prefetchLayout.get()));
617 }
618 // Propagate the layout to the source tensor descriptor.
619 propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout));
620}
621
622void LayoutInfoPropagation::visitVectorMultiReductionOp(
623 vector::MultiDimReductionOp reduction,
624 ArrayRef<LayoutInfoLattice *> operands,
625 ArrayRef<const LayoutInfoLattice *> results) {
626 Type resultTy = reduction.getDestType();
627 // The layout of the result must be present.
628 LayoutInfo resLayoutInfo = results[0]->getValue();
629
630 xegpu::DistributeLayoutAttr consumerLayoutAttr;
631 if (!resultTy.isIntOrFloat()) {
632 if (!resLayoutInfo.isAssigned())
633 return;
634 consumerLayoutAttr =
635 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
636 }
637
638 VectorType sourceTy = reduction.getSourceVectorType();
639 SmallVector<int64_t> reductionDims(reduction.getReductionDims());
640
641 const uArch *uArch = getUArch(xegpu::getChipStr(reduction).value_or(""));
642 if (!uArch)
643 return;
644 int numSg = 0;
645 if (layoutKind == xegpu::LayoutKind::Subgroup) {
646 auto numSgOrErr = getNumSg(reduction, uArch->getSubgroupSize());
647 if (succeeded(numSgOrErr))
648 numSg = numSgOrErr.value();
649 }
650
651 // The result layout represents the layout requirements of the operation.
652 // it is recorded to anchor layout or temporary layout.
653 // it must be honored for current op and may conflict with the layout
654 // propagated from consumer op, the conflict is resolved in later phase by
655 // converting the required result layout to the consumer layout
656 auto requiredResLayoutAttr = xegpu::setupMultiReductionResultLayout(
657 layoutKind, sourceTy, consumerLayoutAttr, reductionDims, numSg, uArch);
658
659 xegpu::setTemporaryLayout(reduction->getResult(0), requiredResLayoutAttr);
660
661 // derive the source layout from the dominant layout and reduction dims
662 auto srcLayoutAttr = xegpu::inferMultiReductionSourceLayout(
663 requiredResLayoutAttr, reductionDims);
664
665 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
666 // Accumulator should have the same layout as the result.
667 propagateIfChanged(operands[1],
668 operands[1]->meet(LayoutInfo(requiredResLayoutAttr)));
669}
670
671void LayoutInfoPropagation::visitVectorReductionOp(
672 vector::ReductionOp reduction, ArrayRef<LayoutInfoLattice *> operands,
673 ArrayRef<const LayoutInfoLattice *> results) {
674
675 VectorType sourceTy = reduction.getSourceVectorType();
676 const uArch *uArch = getUArch(xegpu::getChipStr(reduction).value_or(""));
677 if (!uArch)
678 return;
679
680 auto requiredResLayoutAttr =
681 xegpu::setupReductionResultLayout(layoutKind, sourceTy, uArch);
682 xegpu::setTemporaryLayout(reduction->getResult(0), requiredResLayoutAttr);
683
684 auto srcLayoutAttr = xegpu::inferReductionSourceLayout(requiredResLayoutAttr);
685 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
686 if (reduction.getAcc())
687 propagateIfChanged(operands[1],
688 operands[1]->meet(LayoutInfo(requiredResLayoutAttr)));
689}
690
691void LayoutInfoPropagation::visitVectorBroadCastOp(
692 vector::BroadcastOp broadcast, ArrayRef<LayoutInfoLattice *> operands,
693 ArrayRef<const LayoutInfoLattice *> results) {
694 // The layout of the result must be present.
695 LayoutInfo resLayoutInfo = results[0]->getValue();
696 if (!resLayoutInfo.isAssigned())
697 return;
698
699 // Only consider vector to vector broadcasts for now.
700 VectorType resultTy = broadcast.getResultVectorType();
701 VectorType sourceTy = dyn_cast<VectorType>(broadcast.getSourceType());
702 // skip layout propagation for non-vector source operand.
703 if (!sourceTy)
704 return;
705
706 auto srcShape = sourceTy.getShape();
707 auto resShape = resultTy.getShape();
708
709 auto resultLayoutAttr =
710 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
711
712 xegpu::DistributeLayoutAttr srcLayoutAttr =
713 xegpu::inferBroadcastSourceLayout(resultLayoutAttr, resShape, srcShape);
714
715 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
716}
717
718void LayoutInfoPropagation::visitShapeCastOp(
719 vector::ShapeCastOp shapeCast, ArrayRef<LayoutInfoLattice *> operands,
720 ArrayRef<const LayoutInfoLattice *> results) {
721 // The layout of the result must be present.
722 LayoutInfo resLayoutInfo = results[0]->getValue();
723 if (!resLayoutInfo.isAssigned())
724 return;
725 ArrayRef<int64_t> resShape = shapeCast.getResultVectorType().getShape();
726 ArrayRef<int64_t> srcShape = shapeCast.getSourceVectorType().getShape();
727 auto resultLayoutAttr =
728 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
729
730 xegpu::DistributeLayoutAttr srcLayoutAttr =
731 xegpu::inferShapeCastSourceLayout(resultLayoutAttr, resShape, srcShape);
732
733 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
734}
735
736/// Propagate the layout of the result tensor to the source tensor descriptor
737/// in UpdateNdOffsetOp.
738void LayoutInfoPropagation::visitUpdateNdOffsetOp(
739 xegpu::UpdateNdOffsetOp updateNdOffset,
740 ArrayRef<LayoutInfoLattice *> operands,
741 ArrayRef<const LayoutInfoLattice *> results) {
742 // The layout of the result must be present.
743 LayoutInfo resultLayout = results[0]->getValue();
744 if (!resultLayout.isAssigned())
745 return;
746 // Propagate the layout to the source operand.
747 propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
748}
749
750/// Set the layouts for DPAS A, B, and C operands.
751void LayoutInfoPropagation::visitDpasOp(
752 xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
753 ArrayRef<const LayoutInfoLattice *> results) {
754 LayoutInfo dpasALayout;
755 LayoutInfo dpasBLayout;
756 LayoutInfo dpasCDLayout;
757
758 xegpu::DistributeLayoutAttr anchorLayoutCD = dpas.getLayoutCdAttr();
759 if (hasParamsOfLayoutKind(anchorLayoutCD)) {
760 xegpu::DistributeLayoutAttr anchorLayoutA = dpas.getLayoutAAttr();
761 xegpu::DistributeLayoutAttr anchorLayoutB = dpas.getLayoutBAttr();
762 assert(hasParamsOfLayoutKind(anchorLayoutA) &&
763 "Expected anchor layout for DPAS A operand.");
764 assert(hasParamsOfLayoutKind(anchorLayoutB) &&
765 "Expected anchor layout for DPAS B operand.");
766 dpasALayout = LayoutInfo(anchorLayoutA);
767 dpasBLayout = LayoutInfo(anchorLayoutB);
768 dpasCDLayout = LayoutInfo(anchorLayoutCD);
769 } else {
770 const uArch *uArch = getUArch(getChipStr(dpas).value_or(""));
771 if (!uArch)
772 return;
773 VectorType aTy = dpas.getLhsType();
774 VectorType bTy = dpas.getRhsType();
775 VectorType cdTy = dpas.getResultType();
776
777 xegpu::DistributeLayoutAttr consumerLayoutAttr = nullptr;
778 xegpu::DistributeLayoutAttr requiredCDLayoutAttr, requiredALayout,
779 requiredBLayout;
780
781 int numSg = 0;
782 if (layoutKind == xegpu::LayoutKind::Subgroup) {
783 LayoutInfo consumerLayout = results[0]->getValue();
784 if (!consumerLayout.isAssigned())
785 return;
786 consumerLayoutAttr =
787 dyn_cast<xegpu::DistributeLayoutAttr>(consumerLayout.get());
788 auto numSgOrErr = getNumSg(dpas, uArch->getSubgroupSize());
789 if (failed(numSgOrErr)) {
790 dpas.emitWarning(
791 "Unable to determine the number of subgroups for the operation.");
792 return;
793 }
794 numSg = numSgOrErr.value();
795 }
796 auto layouts = xegpu::setupDpasLayout(layoutKind, aTy, bTy, cdTy,
797 consumerLayoutAttr, numSg, uArch);
798 if (!layouts.has_value()) {
799 dpas.emitWarning(
800 "Failed to determine required layouts for DPAS operands.");
801 return;
802 }
803
804 std::tie(requiredALayout, requiredBLayout, requiredCDLayoutAttr) = *layouts;
805
806 dpas.setLayoutAAttr(requiredALayout);
807 dpas.setLayoutBAttr(requiredBLayout);
808 dpas.setLayoutCdAttr(requiredCDLayoutAttr);
809 dpasALayout = LayoutInfo(requiredALayout);
810 dpasBLayout = LayoutInfo(requiredBLayout);
811 dpasCDLayout = LayoutInfo(requiredCDLayoutAttr);
812 }
813 propagateIfChanged(operands[0], operands[0]->meet(dpasALayout));
814 propagateIfChanged(operands[1], operands[1]->meet(dpasBLayout));
815 if (operands.size() > 2)
816 propagateIfChanged(operands[2], operands[2]->meet(dpasCDLayout));
817}
818
819/// Set the layout for the value and tensor descriptor operands in StoreNdOp.
820void LayoutInfoPropagation::visitStoreNdOp(
821 xegpu::StoreNdOp store, ArrayRef<LayoutInfoLattice *> operands,
822 ArrayRef<const LayoutInfoLattice *> results) {
823 LayoutInfo storeLayout;
824 xegpu::DistributeLayoutAttr anchorLayout = store.getLayoutAttr();
825 if (hasParamsOfLayoutKind(anchorLayout)) {
826 storeLayout = LayoutInfo(anchorLayout);
827 } else {
828 const uArch *uArch = getUArch(getChipStr(store).value_or(""));
829 if (!uArch)
830 return;
831 const auto *uArchInstruction =
832 dyn_cast<xegpu::uArch::Subgroup2DBlockStoreInstruction>(
833 uArch->getInstruction(
834 xegpu::uArch::InstructionKind::Subgroup2DBlockStore));
835 VectorType dataTy = store.getValueType();
836 auto blockWHC = uArchInstruction->getBlockWidthHeightCount(
837 store.getValueType().getElementType());
838 if (!blockWHC)
839 store.emitWarning("No known block params found for the element type.");
840 auto [bWidth, bHeight, bCount] = blockWHC.value();
841 SmallVector<int> instData;
842 int instWidth = xegpu::getLargestDivisor(
843 static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 1)), bWidth);
844 if (instWidth == -1)
845 store.emitWarning(
846 "No suitable instruction multiple found for the given shape.");
847 if (dataTy.getRank() == 1)
848 instData = {instWidth};
849 else {
850 int instHeight = xegpu::getLargestDivisor(
851 static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 2)), bHeight);
852 if (instHeight == -1)
853 store.emitWarning(
854 "No suitable instruction multiple found for the given shape.");
855 instData = {instHeight, instWidth};
856 }
857
858 if (layoutKind == xegpu::LayoutKind::InstData)
859 storeLayout =
860 LayoutInfo(xegpu::LayoutAttr::get(dataTy.getContext(), instData));
861 else if (layoutKind == xegpu::LayoutKind::Lane)
862 storeLayout =
863 getSIMTLayoutInfoBlockIO(store.getValueType(), uArch,
864 uArchInstruction->getPackedFormatBitSize());
865 else { // xegpu::LayoutKind::Subgroup
866 auto sgSize = uArch->getSubgroupSize();
867 auto numSgOrErr = getNumSg(store, sgSize);
868 if (failed(numSgOrErr)) {
869 store.emitWarning(
870 "Unable to determine the number of subgroups for the operation.");
871 return;
872 }
873 auto sgLayouts = getValidLayouts(store.getValueType().getShape(),
874 instData, numSgOrErr.value());
875 if (sgLayouts.empty()) {
876 store.emitWarning(
877 "Unable to determine suitable subgroup layout for store value.");
878 return;
879 }
880 SmallVector<int> sgLayout = {sgLayouts[0].first, sgLayouts[0].second};
881 SmallVector<int> sgData = {
882 static_cast<int>(dataTy.getShape()[0]) / sgLayout[0],
883 static_cast<int>(dataTy.getShape()[1]) / sgLayout[1]};
884 storeLayout = LayoutInfo(xegpu::LayoutAttr::get(
885 dataTy.getContext(),
886 DenseI32ArrayAttr::get(dataTy.getContext(), sgLayout),
887 DenseI32ArrayAttr::get(dataTy.getContext(), sgData),
888 /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
889 /*lane_data =*/nullptr, /*order =*/nullptr));
890 }
891 store.setLayoutAttr(
892 dyn_cast<xegpu::DistributeLayoutAttr>(storeLayout.get()));
893 }
894 // Propagate the layout to the value operand.
895 // Both operands should have the same layout
896 for (LayoutInfoLattice *operand : operands)
897 propagateIfChanged(operand, operand->meet(storeLayout));
898}
899
900/// Propagate the layout of the value to the tensor descriptor operand in
901/// LoadNdOp.
902void LayoutInfoPropagation::visitLoadNdOp(
903 xegpu::LoadNdOp load, ArrayRef<LayoutInfoLattice *> operands,
904 ArrayRef<const LayoutInfoLattice *> results) {
905 LayoutInfo loadLayout;
906 xegpu::DistributeLayoutAttr anchorLayout = load.getLayoutAttr();
907 if (hasParamsOfLayoutKind(anchorLayout)) {
908 loadLayout = LayoutInfo(anchorLayout);
909 } else {
910
911 LayoutInfo valueLayout = results[0]->getValue();
912 // Need the layout of the value to propagate to the tensor descriptor.
913 if (!valueLayout.isAssigned())
914 return;
915 loadLayout = valueLayout;
916 // LoadNdOp has the transpose effect. However, at the stage of this analysis
917 // this effect is not expected and should be abstracted away. Emit a
918 // warning.
919 if (auto transpose = load.getTranspose()) {
920 load.emitWarning("Transpose effect is not expected for LoadNdOp at "
921 "LayoutInfoPropagation stage.");
922 loadLayout = valueLayout.transpose(transpose.value());
923 }
924 load.setLayoutAttr(dyn_cast<xegpu::DistributeLayoutAttr>(loadLayout.get()));
925 }
926 // Propagate the new layout to the tensor descriptor operand.
927 propagateIfChanged(operands[0], operands[0]->meet(loadLayout));
928}
929
930/// Propagate the layout of the value to the tensor descriptor operand in
931/// ConvertLayoutOp.
932void LayoutInfoPropagation::visitConvertLayoutOp(
933 xegpu::ConvertLayoutOp convert, ArrayRef<LayoutInfoLattice *> operands,
934 ArrayRef<const LayoutInfoLattice *> results) {
935 xegpu::DistributeLayoutAttr anchorLayout = convert.getInputLayoutAttr();
936 LayoutInfo convertLayout(anchorLayout);
937 // Propagate the new layout to the tensor descriptor operand.
938 propagateIfChanged(operands[0], operands[0]->meet(convertLayout));
939}
940
941/// For vector::TransposeOp, the layout of the result is transposed and
942/// propagated to the operand.
943void LayoutInfoPropagation::visitTransposeOp(
944 vector::TransposeOp transpose, ArrayRef<LayoutInfoLattice *> operands,
945 ArrayRef<const LayoutInfoLattice *> results) {
946 // Need the layout of transpose result to propagate to the operands.
947 LayoutInfo resultLayout = results[0]->getValue();
948 if (!resultLayout.isAssigned())
949 return;
950 auto consumerLayoutAttr =
951 dyn_cast<xegpu::DistributeLayoutAttr>(resultLayout.get());
952 auto srcLayoutAttr = xegpu::inferTransposeSourceLayout(
953 consumerLayoutAttr, transpose.getPermutation());
954 // Propagate the new layout to the vector operand.
955 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
956}
957
958/// For vector::BitCastOp, the lane_data of the source layout is changed based
959/// on the bit width of the source and result types.
960void LayoutInfoPropagation::visitVectorBitcastOp(
961 vector::BitCastOp bitcast, ArrayRef<LayoutInfoLattice *> operands,
962 ArrayRef<const LayoutInfoLattice *> results) {
963 // Need the layout of bitcast result to propagate to the operands.
964 LayoutInfo resLayoutInfo = results[0]->getValue();
965 if (!resLayoutInfo.isAssigned())
966 return;
967
968 auto srcVecType = bitcast.getSourceVectorType();
969 auto resVecType = bitcast.getResultVectorType();
970
971 auto consumerLayoutAttr =
972 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
973 const uArch *uArch = getUArch(xegpu::getChipStr(bitcast).value_or(""));
974 if (!uArch)
975 return;
976 auto requiredResLayoutAttr = setupBitCastResultLayout(
977 layoutKind, srcVecType, resVecType, consumerLayoutAttr, uArch);
978
979 xegpu::setTemporaryLayout(bitcast->getResult(0), requiredResLayoutAttr);
980
981 int inElemTyBitWidth = srcVecType.getElementType().getIntOrFloatBitWidth();
982 int outElemTyBitWidth = resVecType.getElementType().getIntOrFloatBitWidth();
983
984 // derive the source layout from the dominant layout and reduction dims
985 auto srcLayoutAttr = xegpu::inferBitCastSourceLayout(
986 requiredResLayoutAttr, outElemTyBitWidth, inElemTyBitWidth);
987
988 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
989}
990
991void LayoutInfoPropagation::visitInsertStridedSliceOp(
992 vector::InsertStridedSliceOp insertStridedSlice,
993 ArrayRef<LayoutInfoLattice *> operands,
994 ArrayRef<const LayoutInfoLattice *> results) {
995 // The layout of the result must be present.
996 LayoutInfo resLayoutInfo = results[0]->getValue();
997 if (!resLayoutInfo.isAssigned())
998 return;
999
1000 auto srcVecType = insertStridedSlice.getSourceVectorType();
1001 auto resVecType = insertStridedSlice.getDestVectorType();
1002
1003 auto consumerLayoutAttr =
1004 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
1005 const uArch *uArch =
1006 getUArch(xegpu::getChipStr(insertStridedSlice).value_or(""));
1007 if (!uArch)
1008 return;
1009
1010 auto requiredResLayoutAttr = xegpu::setupInsertStridedSliceResultLayout(
1011 layoutKind, srcVecType, resVecType, consumerLayoutAttr, uArch);
1012 xegpu::setTemporaryLayout(insertStridedSlice->getResult(0),
1013 requiredResLayoutAttr);
1014
1015 auto srcLayoutAttr = xegpu::inferInsertStridedSliceSourceLayout(
1016 requiredResLayoutAttr, resVecType.getShape(), srcVecType.getShape());
1017 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
1018 propagateIfChanged(operands[1],
1019 operands[1]->meet(LayoutInfo(requiredResLayoutAttr)));
1020}
1021
1022/// Propagate the layout of the result to the tensor descriptor, mask and offset
1023/// operands in LoadGatherOp.
1024void LayoutInfoPropagation::visitLoadGatherOp(
1025 xegpu::LoadGatherOp load, ArrayRef<LayoutInfoLattice *> operands,
1026 ArrayRef<const LayoutInfoLattice *> results) {
1027 xegpu::DistributeLayoutAttr requiredAnchorLayoutAttr;
1028 xegpu::DistributeLayoutAttr anchorLayoutAttr = load.getLayoutAttr();
1029 const uArch *uArch = getUArch(getChipStr(load).value_or(""));
1030 if (!uArch)
1031 return;
1032 VectorType resVecTy = load.getValueType();
1033 int chunkSize = load.getChunkSize().value_or(1);
1034
1035 LayoutInfo resLayoutInfo = results[0]->getValue();
1036 if (!resLayoutInfo.isAssigned())
1037 return;
1038 auto consumerLayoutAttr =
1039 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
1040
1041 if (hasParamsOfLayoutKind(anchorLayoutAttr)) {
1042 requiredAnchorLayoutAttr = anchorLayoutAttr;
1043 } else {
1044 if (!resVecTy) {
1045 load.emitWarning("Not propagating, non-vector payload supplied.");
1046 return;
1047 }
1048 requiredAnchorLayoutAttr = xegpu::setupLoadGatherAnchorLayout(
1049 layoutKind, resVecTy, chunkSize, consumerLayoutAttr, uArch);
1050 load.setLayoutAttr(requiredAnchorLayoutAttr);
1051 }
1052
1053 assert((chunkSize <= 1) || (layoutKind != xegpu::LayoutKind::Subgroup));
1054 auto maskLayoutAttr = xegpu::inferMaskOffsetLayoutForScatterIO(
1055 requiredAnchorLayoutAttr, chunkSize);
1056 LayoutInfo maskLayoutInfo = LayoutInfo(maskLayoutAttr);
1057 auto loadLayoutInfo = LayoutInfo(requiredAnchorLayoutAttr);
1058
1059 // Propagate the new layout to the tensor descriptor operand.
1060 if (isa<xegpu::TensorDescType>(load.getSourceType()))
1061 propagateIfChanged(operands[0], operands[0]->meet(loadLayoutInfo));
1062 // Propagate the new layout to the offset and mask operands.
1063 propagateIfChanged(operands[1], operands[1]->meet(maskLayoutInfo));
1064 propagateIfChanged(operands[2], operands[2]->meet(maskLayoutInfo));
1065}
1066
1067/// Set the layout for the value, tensor descriptor, offset and mask operands in
1068/// the StoreScatterOp.
1069void LayoutInfoPropagation::visitStoreScatterOp(
1070 xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands,
1071 ArrayRef<const LayoutInfoLattice *> results) {
1072
1073 xegpu::DistributeLayoutAttr requiredAnchorLayoutAttr;
1074 xegpu::DistributeLayoutAttr anchorLayoutAttr = storeScatter.getLayoutAttr();
1075 const uArch *uArch = getUArch(getChipStr(storeScatter).value_or(""));
1076 if (!uArch)
1077 return;
1078 VectorType srcVecTy = storeScatter.getValueType();
1079 int chunkSize = storeScatter.getChunkSize().value_or(1);
1080
1081 if (hasParamsOfLayoutKind(anchorLayoutAttr)) {
1082 requiredAnchorLayoutAttr = anchorLayoutAttr;
1083 } else {
1084 if (!srcVecTy) {
1085 storeScatter.emitWarning("Not propagating, non-vector payload supplied.");
1086 return;
1087 }
1088 requiredAnchorLayoutAttr = xegpu::setupStoreScatterAnchorLayout(
1089 layoutKind, srcVecTy, chunkSize, uArch);
1090 storeScatter.setLayoutAttr(requiredAnchorLayoutAttr);
1091 }
1092
1093 LayoutInfo srcLayoutInfo = LayoutInfo(requiredAnchorLayoutAttr);
1094 assert((chunkSize <= 1) || (layoutKind != xegpu::LayoutKind::Subgroup));
1095 auto maskLayoutAttr = xegpu::inferMaskOffsetLayoutForScatterIO(
1096 requiredAnchorLayoutAttr, chunkSize);
1097 LayoutInfo maskLayoutInfo = LayoutInfo(maskLayoutAttr);
1098
1099 // Propagate the payload operand layout
1100 propagateIfChanged(operands[0], operands[0]->meet(srcLayoutInfo));
1101 // Propagate the destination (if tdesc) operand layout
1102 if (isa<xegpu::TensorDescType>(storeScatter.getDestType()))
1103 propagateIfChanged(operands[1], operands[1]->meet(srcLayoutInfo));
1104 // Propagate the new layout to the offset and mask operands.
1105 propagateIfChanged(operands[2], operands[2]->meet(maskLayoutInfo));
1106 propagateIfChanged(operands[3], operands[3]->meet(maskLayoutInfo));
1107}
1108
1109void LayoutInfoPropagation::visitLoadMatrixOp(
1110 xegpu::LoadMatrixOp loadMatrixOp, ArrayRef<LayoutInfoLattice *> operands,
1111 ArrayRef<const LayoutInfoLattice *> results) {
1112
1113 LayoutInfo resLayoutInfo = results[0]->getValue();
1114 if (!resLayoutInfo.isAssigned())
1115 return;
1116
1117 auto consumerLayoutAttr =
1118 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
1119
1120 xegpu::DistributeLayoutAttr anchorLayout = loadMatrixOp.getLayoutAttr();
1121
1122 // only need to set anchor layout, no need to porpagate to memdesc and
1123 // offset
1124 if (!hasParamsOfLayoutKind(anchorLayout)) {
1125 VectorType resVecTy =
1126 llvm::cast<VectorType>(loadMatrixOp.getRes().getType());
1127 const uArch *uArch = getUArch(getChipStr(loadMatrixOp).value_or(""));
1128 if (!uArch)
1129 return;
1130 auto requiredAnchorLayoutAttr = xegpu::setupLoadMatrixAnchorLayout(
1131 layoutKind, resVecTy, consumerLayoutAttr, uArch);
1132 loadMatrixOp.setLayoutAttr(requiredAnchorLayoutAttr);
1133 }
1134}
1135
1136// Store matrix is a flavor of scattered store for 2D shapes.
1137void LayoutInfoPropagation::visitStoreMatrixOp(
1138 xegpu::StoreMatrixOp storeMatrix, ArrayRef<LayoutInfoLattice *> operands,
1139 ArrayRef<const LayoutInfoLattice *> results) {
1140 xegpu::DistributeLayoutAttr anchorLayout = storeMatrix.getLayoutAttr();
1141 LayoutInfo layout;
1142 if (hasParamsOfLayoutKind(anchorLayout)) {
1143 layout = LayoutInfo(anchorLayout);
1144 } else {
1145 VectorType srcVecTy =
1146 llvm::cast<VectorType>(storeMatrix.getData().getType());
1147 const uArch *uArch = getUArch(getChipStr(storeMatrix).value_or(""));
1148 if (!uArch)
1149 return;
1150 auto requiredAnchorLayoutAttr =
1151 xegpu::setupStoreMatrixAnchorLayout(layoutKind, srcVecTy, uArch);
1152 storeMatrix.setLayoutAttr(requiredAnchorLayoutAttr);
1153 layout = LayoutInfo(requiredAnchorLayoutAttr);
1154 }
1155
1156 propagateIfChanged(operands[0], operands[0]->meet(layout));
1157}
1158
1159namespace {
1160//===----------------------------------------------------------------------===//
1161// RunLayoutInfoPropagation
1162//===----------------------------------------------------------------------===//
1163
1164/// Driver class for running the LayoutInfoPropagation analysis.
1165class RunLayoutInfoPropagation {
1166public:
1167 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RunLayoutInfoPropagation)
1168
1169 RunLayoutInfoPropagation(Operation *op, xegpu::LayoutKind layoutKind,
1170 unsigned indexBitWidth)
1171 : target(op) {
1172 SymbolTableCollection symbolTable;
1173 loadBaselineAnalyses(solver);
1174 solver.load<LayoutInfoPropagation>(symbolTable, layoutKind, indexBitWidth);
1175 (void)solver.initializeAndRun(op);
1176 }
1177
1178 LayoutInfo getLayoutInfo(Value val);
1179
1180 void printAnalysisResult(llvm::raw_ostream &os);
1181
1182private:
1183 DataFlowSolver solver;
1184 const Operation *target;
1185};
1186} // namespace
1187
1188LayoutInfo RunLayoutInfoPropagation::getLayoutInfo(Value val) {
1189 auto *state = solver.lookupState<LayoutInfoLattice>(val);
1190 if (!state)
1191 return {};
1192 return state->getValue();
1193}
1194
1195// Print the analysis result for debugging purposes.
1196void RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) {
1197 auto printFunctionResult = [&](FunctionOpInterface funcOp) {
1198 os << "function: " << funcOp.getName() << ":\n";
1199 // Function arguments
1200 for (BlockArgument arg : funcOp.getArguments()) {
1201 LayoutInfo layout = getLayoutInfo(arg);
1202 os << "argument: " << arg << "\n";
1203 os << "layout : ";
1204 layout.print(os);
1205 os << "\n";
1206 }
1207 // Function ops
1208 funcOp.walk([&](Operation *op) {
1209 // Skip ops that do not have results
1210 if (op->getResults().empty())
1211 return;
1212 os << "op : ";
1213 // For control-flow ops, print the op name only.
1214 if (isa<BranchOpInterface>(op) || isa<RegionBranchOpInterface>(op))
1215 os << op->getName();
1216 else
1217 op->print(os);
1218 os << "\n";
1219 // Print the layout for each result.
1220 for (auto [i, r] : llvm::enumerate(op->getResults())) {
1221 LayoutInfo layout = getLayoutInfo(r);
1222 os << "layout for result #" << i << ": ";
1223 layout.print(os);
1224 os << "\n";
1225 }
1226 });
1227 };
1228
1229 SmallVector<FunctionOpInterface> funcOps;
1230 if (auto modOp = dyn_cast<ModuleOp>(target)) {
1231 for (auto funcOp : modOp.getOps<FunctionOpInterface>())
1232 funcOps.push_back(funcOp);
1233
1234 // Collect all GpuFuncOps in the module.
1235 for (auto gpuModOp : modOp.getOps<gpu::GPUModuleOp>()) {
1236 for (auto gpuFuncOp : gpuModOp.getOps<FunctionOpInterface>())
1237 funcOps.push_back(gpuFuncOp);
1238 }
1239 }
1240 // Print the analysis result for each function.
1241 for (FunctionOpInterface funcOp : funcOps)
1242 printFunctionResult(funcOp);
1243}
1244
1245namespace {
1246
1247//===----------------------------------------------------------------------===//
1248// ResolveLayoutConflicts
1249//===----------------------------------------------------------------------===//
1250
1251/// Helper to get the defining CreateNdDescOp of a tensor descriptor value. This
1252/// function tries to find the defining CreateNdDescOp recursively accross
1253/// control-flow boundaries.
1254static xegpu::CreateNdDescOp getDefiningCreateNdDescOp(Value tdescValue) {
1255 // Try to get the defining CreateNdDescOp of the tensor descriptor.
1256 auto definingOp = tdescValue.getDefiningOp<xegpu::CreateNdDescOp>();
1257 if (definingOp)
1258 return definingOp;
1259 // If tdescValue is an argument, try to get the tied init value from the
1260 // parent loop-like op.
1261 if (auto arg = dyn_cast<BlockArgument>(tdescValue)) {
1262 auto *parentOp = arg.getOwner()->getParentOp();
1263 if (auto loop = dyn_cast<LoopLikeOpInterface>(parentOp)) {
1264 OpOperand *tiedInit = loop.getTiedLoopInit(arg);
1265 if (tiedInit)
1266 return getDefiningCreateNdDescOp(tiedInit->get());
1267 }
1268 }
1269 // If not found, return null.
1270 return nullptr;
1271}
1272
1273struct ResolveLayoutConflicts {
1274 ResolveLayoutConflicts(Operation *parentOp)
1275 : parentOp(parentOp), builder(parentOp->getContext()) {}
1276 LogicalResult run();
1277
1278private:
1279 Operation *parentOp;
1280 OpBuilder builder;
1281 LogicalResult resolveTensorDescConsumer(OpOperand &operand);
1282 LogicalResult resolveVectorConsumer(OpOperand &operand);
1283 LogicalResult assignResultLayout(OpResult &result);
1284};
1285
1286} // namespace
1287
1288LogicalResult ResolveLayoutConflicts::run() {
1289 // Scan all operations in the parent op and resolve layout conflicts at
1290 // tensor descriptor and vector use points.
1291 auto r = parentOp->walk([&](Operation *op) -> WalkResult {
1292 // if the operation inputs vector and output scalar, like multi-reduction we
1293 // need to check if the result has layout and add a convert_layout to serve
1294 // as anchor op for the reduction op's layout.
1295 if (isa<vector::MultiDimReductionOp>(op) || isa<vector::ReductionOp>(op)) {
1296 for (OpResult result : op->getResults()) {
1297 if (result.getType().isIntOrFloat()) {
1298 auto res = assignResultLayout(result);
1299 if (failed(res)) {
1300 DBGS() << "Failed to resolve vector consumer for multi-reduction "
1301 << *op << "\n";
1302 return WalkResult::interrupt();
1303 }
1304 }
1305 }
1306 }
1307 for (OpOperand &operand : op->getOpOperands()) {
1308 // Handle conflicts in tensor descriptor operands.
1309 Type operandType = operand.get().getType();
1310 if (isa<xegpu::AnchorLayoutInterface>(op) &&
1311 isa<xegpu::TensorDescType>(operandType)) {
1312 auto res = resolveTensorDescConsumer(operand);
1313 if (failed(res)) {
1314 DBGS() << "Failed to resolve tensor descriptor consumer: " << *op
1315 << "\n";
1316 return WalkResult::interrupt();
1317 }
1318 }
1319 // Handle conflicts in vector operands.
1320 if (isa<VectorType>(operandType)) {
1321 auto res = resolveVectorConsumer(operand);
1322 if (failed(res)) {
1323 DBGS() << "Failed to resolve vector consumer: " << *op << "\n";
1324 return WalkResult::interrupt();
1325 }
1326 }
1327 }
1328 return WalkResult::advance();
1329 });
1330
1331 return r.wasInterrupted() ? failure() : success();
1332}
1333
1334LogicalResult ResolveLayoutConflicts::assignResultLayout(OpResult &result) {
1335 Operation *producerOp = result.getDefiningOp();
1336 auto producerLayout = xegpu::getDistributeLayoutAttr(result);
1337 // Insert a convert_layout op to assign the layout.
1339 auto convertOp = xegpu::ConvertLayoutOp::create(
1340 builder, producerOp->getLoc(), result.getType(), result, producerLayout,
1341 producerLayout);
1342 result.replaceAllUsesExcept(convertOp.getResult(), convertOp);
1343 return success();
1344}
1345
1346LogicalResult
1347ResolveLayoutConflicts::resolveVectorConsumer(OpOperand &operand) {
1348 Value vectorValue = operand.get();
1349 Operation *consumerOp = operand.getOwner();
1350 // Get the current layout of the vector value.
1351 auto producerLayout = xegpu::getDistributeLayoutAttr(vectorValue);
1352 if (!producerLayout) {
1353 if (auto vectorTy = dyn_cast<VectorType>(vectorValue.getType());
1354 vectorTy && vectorTy.getRank() > 1)
1355 consumerOp->emitWarning("Expected layout for non-1D vectors.");
1356 return success(); // uniform non-tensor-data vector does not require layout
1357 }
1358 // Get the consumer expected layout at this operand.
1359 auto consumerLayout = xegpu::getConsumerLayoutAt(operand);
1360 if (!consumerLayout)
1361 return consumerOp->emitError(
1362 "No consumer layout found for vector operand.");
1363
1364 // If layouts are same, no conflict exists, return success.
1365 if (consumerLayout.isEqualTo(producerLayout))
1366 return success();
1367
1368 // Insert a convert_layout op to resolve the conflict.
1369 builder.setInsertionPointAfterValue(vectorValue);
1370 auto convertOp = xegpu::ConvertLayoutOp::create(
1371 builder, consumerOp->getLoc(), vectorValue.getType(), vectorValue,
1372 producerLayout, consumerLayout);
1373
1374 // Update the operand to use the converted value.
1375 operand.set(convertOp.getResult());
1376 return success();
1377}
1378
1379LogicalResult
1380ResolveLayoutConflicts::resolveTensorDescConsumer(OpOperand &operand) {
1381 Operation *consumerOp = operand.getOwner();
1382 Value tdescValue = operand.get();
1383 auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(consumerOp);
1384 auto currTDescType = dyn_cast<xegpu::TensorDescType>(tdescValue.getType());
1385 assert(anchorOp && currTDescType &&
1386 "Expected anchor layout op and tensor descriptor consumer.");
1387 Attribute currLayout = currTDescType.getLayout();
1388 Attribute expectedLayout = anchorOp.getAnchorLayout();
1389 // A conflict exists in tensor descriptor operand if tensor descriptor's
1390 // layout is different from the anchor layout expected by the consumer.
1391 if (expectedLayout && currLayout && expectedLayout != currLayout) {
1392 // Try to get the defining CreateNdDescOp of the tensor descriptor.
1393 auto conflictingCreateNdOp = getDefiningCreateNdDescOp(tdescValue);
1394 if (!conflictingCreateNdOp) {
1395 DBGS() << "Unable to find defining CreateNdDescOp for tensor descriptor: "
1396 << tdescValue << "\n";
1397 return failure();
1398 }
1399 // Duplicate the CreateNdDescOp with the expected layout.
1400 builder.setInsertionPointAfter(conflictingCreateNdOp);
1401 auto newTensorDescType = xegpu::TensorDescType::get(
1402 conflictingCreateNdOp.getContext(), currTDescType.getShape(),
1403 currTDescType.getElementType(), currTDescType.getEncoding(),
1404 expectedLayout);
1405 xegpu::CreateNdDescOp newOp = xegpu::CreateNdDescOp::create(
1406 builder, consumerOp->getLoc(), newTensorDescType,
1407 conflictingCreateNdOp->getOperands(),
1408 conflictingCreateNdOp->getAttrs());
1409 // Replace the tensor descriptor operand in the consumer op with the new
1410 // tensor descriptor.
1411 consumerOp->replaceUsesOfWith(tdescValue, newOp.getResult());
1412 }
1413 return success();
1414}
1415
1416using GetLayoutFnTy = function_ref<xegpu::DistributeLayoutAttr(Value)>;
1417/// Update an operation with the layout of its results. If the result type is
1418/// a vector type, a temporary layout attribute is added to the operation. If
1419/// the result type is a tensor descriptor type, the type is updated with the
1420/// layout attribute. The users of the result are also updated with the layout
1421/// attribute.
1422static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
1423 GetLayoutFnTy getLayoutOfValue) {
1424 // Region ops (like scf.for) are already handled by the
1425 // updateControlFlowOps.
1426 if (mlir::isa<mlir::RegionBranchOpInterface>(op))
1427 return success();
1428
1429 // Iterate over all the results.
1430 for (OpResult result : op->getResults()) {
1431 Type resultType = result.getType();
1432 // Layouts are needed only for vector and tensor descriptor types.
1433 if (!isa<VectorType, xegpu::TensorDescType>(resultType))
1434 continue;
1435 // If the result has no layout but has users, emit a warning and continue.
1436 xegpu::DistributeLayoutAttr layout = getLayoutOfValue(result);
1437 if (!layout && result.getNumUses() > 0) {
1438 op->emitWarning("op has users but no layout assigned for its result");
1439 continue;
1440 }
1441 // If the result is a tensor descriptor type, update the tensor desc type
1442 // with layout.
1443 if (auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType)) {
1444 auto typeWithLayout = xegpu::TensorDescType::get(
1445 tensorDescTy.getContext(), tensorDescTy.getShape(),
1446 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
1447 result.setType(typeWithLayout);
1448 continue;
1449 }
1450 // If the result is a vector type, add a temporary layout attribute to the
1451 // op.
1453 }
1454 return success();
1455}
1456
1457/// Region ops like scf.for need special handling because they have blocks
1458/// inside. If the blocks have tensor descriptor type as block arguments,
1459/// thier types must be updated. Also region op can have results that may not
1460/// have any users (e.g. A and B tiles). They are not assigned a layout by
1461/// layout analysis because they have no users. However inside the region op
1462/// corresponding block arguments for these results do have layouts.
1463/// Therefore, in this case we still need to update the result types with the
1464/// layout attribute. This function function updates the internal block
1465/// arguments and the result types of the region op with the assigned layouts.
1466/// clang-format off
1467/// Example: scf.for ... iter_args(...) -> (out types) {
1468/// ^bb0(block types):
1469/// ...
1470/// scf.yield ... : (yield types)
1471/// }
1472/// clang-format on
1473/// In this example, at scf.yield, control-flow can transfer to two successor
1474/// regions. One is the ^bb0 (for loop body) and the other is the scf.for op
1475/// itself (yield the results). So we update both the block arguments of the
1476/// successor region (i.e. block types) and the result types of the scf.for op
1477/// (i.e. out types). Note that yield types are updated by respective
1478/// producers inside bb0.
1479static LogicalResult
1481 mlir::RegionBranchTerminatorOpInterface terminator,
1482 GetLayoutFnTy getLayoutOfValue) {
1483 // Only process if the terminator is inside a region branch op.
1484 auto branchOp = dyn_cast<RegionBranchOpInterface>(terminator->getParentOp());
1485 if (!branchOp)
1486 return success();
1487
1489 branchOp.getSuccessorOperandInputMapping(mapping,
1490 RegionBranchPoint(terminator));
1491 for (const auto &[successorOperand, successorInputs] : mapping) {
1492 for (Value successorInput : successorInputs) {
1493 Type inputType = successorInput.getType();
1494 // We only need to operate on tensor descriptor or vector types.
1495 if (!isa<xegpu::TensorDescType, VectorType>(inputType))
1496 continue;
1497 xegpu::DistributeLayoutAttr successorInputLayout =
1498 getLayoutOfValue(successorInput);
1499 xegpu::DistributeLayoutAttr successorOperandLayout =
1500 getLayoutOfValue(successorOperand->get());
1501
1502 // If either of the layouts is not assigned, we cannot proceed.
1503 if (!successorOperandLayout) {
1504 LLVM_DEBUG(DBGS() << "No layout assigned for forwarded operand in "
1505 "branch terminator: "
1506 << successorOperand->get() << "\n");
1507 return failure();
1508 }
1509 // We expect the layouts to match.
1510 if (successorInputLayout &&
1511 successorInputLayout != successorOperandLayout) {
1512 LLVM_DEBUG(DBGS() << "Conflicting layouts for region argument and "
1513 "operand forwarded as the argument: "
1514 << successorInputLayout << " vs "
1515 << successorOperandLayout << "\n");
1516 return failure();
1517 }
1518 // Get tensor descriptor type with the layout.
1519 if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(inputType)) {
1520 auto newTdescTy = xegpu::TensorDescType::get(
1521 tdescTy.getContext(), tdescTy.getShape(), tdescTy.getElementType(),
1522 tdescTy.getEncoding(), successorOperandLayout);
1523 successorInput.setType(newTdescTy);
1524 continue;
1525 }
1526 // If the type is a vector type and this region argument is an OpResult,
1527 // set the layout attribute on the OpResult.
1528 if (auto result = dyn_cast<OpResult>(successorInput))
1529 xegpu::setDistributeLayoutAttr(result, successorOperandLayout);
1530 }
1531 }
1532 return success();
1533}
1534
1535/// Update the function arguments and results with the layouts.
1536static LogicalResult updateFunctionOpInterface(mlir::OpBuilder &builder,
1537 mlir::FunctionOpInterface funcOp,
1538 GetLayoutFnTy getLayoutOfValue) {
1539 // Only process functions whose type is a standard MLIR FunctionType.
1540 // Functions using a different type representation (e.g. llvm.func with
1541 // LLVMFunctionType) are not targets for XeGPU layout propagation, and
1542 // calling setType(FunctionType{}) on them would corrupt their type.
1543 if (!isa<FunctionType>(funcOp.getFunctionType()))
1544 return success();
1545 SmallVector<Type> newArgTypes;
1546 // Update the function arguments.
1547 for (BlockArgument arg : funcOp.getArguments()) {
1548 Type argType = arg.getType();
1549 newArgTypes.push_back(argType);
1550 if (!isa<VectorType, xegpu::TensorDescType>(argType))
1551 continue;
1552 xegpu::DistributeLayoutAttr layout = getLayoutOfValue(arg);
1553 if (!layout) {
1554 LLVM_DEBUG(DBGS() << "Expecting layout for function argument: " << arg
1555 << " but got none.\n");
1556 return failure();
1557 }
1558 if (auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(argType)) {
1559 auto newTdescTy = xegpu::TensorDescType::get(
1560 tensorDescTy.getContext(), tensorDescTy.getShape(),
1561 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
1562 arg.setType(newTdescTy);
1563 newArgTypes.back() = newTdescTy;
1564 }
1565 }
1566 // Update the function type with the new argument types.
1567 // NOTE: We assume that function results are not expected to have layouts.
1568 funcOp.setType(FunctionType::get(funcOp.getContext(), newArgTypes,
1569 funcOp.getResultTypes()));
1570 return success();
1571}
1572
1573namespace {
1574struct XeGPUPropagateLayoutPass final
1575 : public xegpu::impl::XeGPUPropagateLayoutBase<XeGPUPropagateLayoutPass> {
1576 XeGPUPropagateLayoutPass() = default;
1577 XeGPUPropagateLayoutPass(const XeGPUPropagateLayoutPass &other) = default;
1578 XeGPUPropagateLayoutPass(xegpu::XeGPUPropagateLayoutOptions options)
1579 : XeGPUPropagateLayoutBase(std::move(options)) {}
1580 void runOnOperation() override;
1581};
1582
1583} // namespace
1584
1586 LayoutKind layoutKind,
1587 unsigned indexBitWidth, bool printOnly) {
1588 RunLayoutInfoPropagation analysis(target, layoutKind, indexBitWidth);
1589 // Print the analysis result and exit. (for debugging purposes)
1590 if (printOnly) {
1591 auto &os = llvm::outs();
1592 analysis.printAnalysisResult(os);
1593 return success();
1594 }
1595 // Helper to convert LayoutInfo to xegpu::LayoutAttr.
1596 auto getXeGPULayoutForValue = [&](Value val) -> xegpu::DistributeLayoutAttr {
1597 LayoutInfo layout = analysis.getLayoutInfo(val);
1598 if (auto opResult = dyn_cast<OpResult>(val)) {
1599 Operation *defOp = opResult.getDefiningOp();
1600 if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(defOp)) {
1601 auto anchorLayout = anchorOp.getAnchorLayout();
1602 if (anchorLayout != nullptr)
1603 return anchorLayout;
1604 }
1605 xegpu::DistributeLayoutAttr requiredResLayoutAttr =
1606 xegpu::getTemporaryLayout(opResult);
1607 if (requiredResLayoutAttr != nullptr)
1608 return requiredResLayoutAttr;
1609 }
1610 if (!layout.isAssigned())
1611 return {};
1612 xegpu::DistributeLayoutAttr layoutAttr =
1613 cast<xegpu::DistributeLayoutAttr>(layout.get());
1614 if (layout.isSliceLayout())
1615 return cast<xegpu::SliceAttr>(layoutAttr);
1616
1617 return cast<xegpu::LayoutAttr>(layoutAttr);
1618 };
1619
1620 Operation *op = target;
1621 auto walkResult = op->walk([&](mlir::Block *block) -> WalkResult {
1622 for (mlir::Operation &op : llvm::reverse(block->getOperations())) {
1623 LogicalResult r = success();
1625 .Case([&](mlir::RegionBranchTerminatorOpInterface branchTermOp) {
1626 r = updateControlFlowOps(builder, branchTermOp,
1627 getXeGPULayoutForValue);
1628 })
1629 .Case([&](mlir::FunctionOpInterface funcOp) {
1630 r = updateFunctionOpInterface(builder, funcOp,
1631 getXeGPULayoutForValue);
1632 })
1633 .Default([&](Operation *op) {
1634 r = updateOp(builder, op, getXeGPULayoutForValue);
1635 });
1636 if (failed(r)) {
1637 op.emitError("Failed to update operation with the layout.");
1638 return WalkResult::interrupt();
1639 }
1640 }
1641 return WalkResult::advance();
1642 });
1643 if (walkResult.wasInterrupted())
1644 return failure();
1645
1646 return success();
1647}
1648
1650 ResolveLayoutConflicts resolver(target);
1651 return resolver.run();
1652}
1653
1654void XeGPUPropagateLayoutPass::runOnOperation() {
1655 // Clean up temporary layout attributes
1656 getOperation()->walk([](Operation *op) {
1657 SmallVector<StringAttr> attrsToRemove;
1658 for (auto namedAttr : op->getDiscardableAttrs()) {
1659 if (isa<xegpu::DistributeLayoutAttr>(namedAttr.getValue()))
1660 attrsToRemove.push_back(namedAttr.getName());
1661 }
1662 for (auto attrName : attrsToRemove)
1663 op->removeDiscardableAttr(attrName);
1664 });
1665 xegpu::LayoutKind layoutKind;
1666 if (this->layoutKind == "lane") {
1667 layoutKind = xegpu::LayoutKind::Lane;
1668 } else if (this->layoutKind == "inst") {
1669 layoutKind = xegpu::LayoutKind::InstData;
1670 } else if (this->layoutKind == "subgroup") {
1671 layoutKind = xegpu::LayoutKind::Subgroup;
1672 } else {
1673 getOperation()->emitError("Unsupported layout kind option: " +
1674 this->layoutKind);
1675 signalPassFailure();
1676 return;
1677 }
1678 OpBuilder builder(&getContext());
1679 if (failed(xegpu::propagateLayouts(builder, getOperation(), layoutKind,
1680 this->indexBitWidth, this->printOnly))) {
1681 signalPassFailure();
1682 return;
1683 }
1684 // Resolve layout conflicts if any.
1685 if (failed(xegpu::resolveLayoutConflicts(getOperation()))) {
1686 signalPassFailure();
1687 return;
1688 }
1689}
return success()
#define DBGS()
Definition Hoisting.cpp:32
std::string join(const Ts &...args)
Helper function to concatenate arguments into a std::string.
lhs
b getContext())
auto load
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition TypeID.h:331
static SmallVector< LayoutRepresentation > getValidLayouts(ArrayRef< int64_t > wgShape, ArrayRef< int64_t > instData, int64_t sgCount)
static LogicalResult updateControlFlowOps(mlir::OpBuilder &builder, mlir::RegionBranchTerminatorOpInterface terminator, GetLayoutFnTy getLayoutOfValue)
Region ops like scf.for need special handling because they have blocks inside.
function_ref< xegpu::DistributeLayoutAttr(Value)> GetLayoutFnTy
FailureOr< int64_t > getNumSg(Operation *op, const int sgSize)
static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op, GetLayoutFnTy getLayoutOfValue)
Update an operation with the layout of its results.
static LogicalResult updateFunctionOpInterface(mlir::OpBuilder &builder, mlir::FunctionOpInterface funcOp, GetLayoutFnTy getLayoutOfValue)
Update the function arguments and results with the layouts.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:306
Block represents an ordered list of Operations.
Definition Block.h:33
OpListType & getOperations()
Definition Block.h:147
The general data-flow analysis solver.
const StateT * lookupState(AnchorT anchor) const
Lookup an analysis state for the given lattice anchor.
AnalysisT * load(Args &&...args)
Load an analysis into the solver. Return the analysis instance.
LogicalResult initializeAndRun(Operation *top)
Initialize the children analyses starting from the provided top-level operation and run the analysis ...
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:209
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition Builders.h:423
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
This class represents an operand of an operation.
Definition Value.h:254
This is a value defined by a result of an operation.
Definition Value.h:454
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:409
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:256
auto getDiscardableAttrs()
Return a range of all of discardable attributes on this operation.
Definition Operation.h:512
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:116
void print(raw_ostream &os, const OpPrintingFlags &flags={})
Attribute removeDiscardableAttr(StringAttr name)
Remove the discardable attribute with the specified name if it exists.
Definition Operation.h:498
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:823
result_range getResults()
Definition Operation.h:441
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
This class represents a successor of a region.
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
This class represents a lattice holding a specific value of type ValueT.
A sparse (backward) data-flow analysis for propagating SSA value lattices backwards across the IR by ...
SparseBackwardDataFlowAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
void loadBaselineAnalyses(DataFlowSolver &solver)
Populates a DataFlowSolver with analyses that are required to ensure user-defined analyses are run pr...
Definition Utils.h:29
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
const uArch * getUArch(llvm::StringRef archName)
DistributeLayoutAttr inferShapeCastSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for a shape cast operation given the result layout attribute,...
DistributeLayoutAttr inferTransposeSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > permutation)
Infers the source layout attribute for a transpose operation given the result layout attribute and pe...
DistributeLayoutAttr inferInsertStridedSliceSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for an insert strided slice operation given the result layout attr...
void setTemporaryLayout(const T &operandOrResult, const DistributeLayoutAttr layout)
LayoutKind
Specifies the level of a layout hierarchy for comparison or propagation.
Definition XeGPU.h:32
void setDistributeLayoutAttr(const OpResult &Result, const DistributeLayoutAttr layout)
[to-be-deprecated] Sets the DistributeLayoutAttr for a given OpResult user should use setAnchorLayout...
DistributeLayoutAttr setupLoadMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the anchor layout for load matrix operation.
int getLargestDivisor(T dim, ArrayRef< T > candidates, ArrayRef< T > candidateMultiples={})
Helper Function to find a proper instruction multiple for the user-supplied sg-level data shape (dive...
DistributeLayoutAttr inferBroadcastSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for a broadcast operation given the result layout attribute,...
DistributeLayoutAttr setupStoreScatterAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, int chunkSize, const uArch::uArch *uArch)
Sets up the anchor layout for a store scatter operation.
SliceAttr setupMultiReductionResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, DistributeLayoutAttr consumerLayout, SmallVector< int64_t > reductionDims, int numSg, const uArch::uArch *uArch)
Sets up layout for Multi-Reduction operations by creating a SliceAttr for the result.
DistributeLayoutAttr setupBitCastResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Setup the result layout attribute for a bitcast operation based on element type bitwidths.
DistributeLayoutAttr inferMaskOffsetLayoutForScatterIO(DistributeLayoutAttr payloadLayout, int chunkSize)
Infers the layout attribute for mask and offset operand for Chunked load and store,...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
LogicalResult resolveLayoutConflicts(Operation *target)
DistributeLayoutAttr inferBitCastSourceLayout(DistributeLayoutAttr resLayout, int resElemTyBitWidth, int srcElemTyBitWidth)
Infers the source layout attribute for a bitcast operation given the result layout attribute,...
DistributeLayoutAttr setupInsertStridedSliceResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the result layout for an insert strided slice operation.
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
DistributeLayoutAttr inferReductionSourceLayout(DistributeLayoutAttr resLayout)
Infers the source layout attribute for a reduction operation given the result layout attribute and re...
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
xegpu::DistributeLayoutAttr getConsumerLayoutAt(OpOperand &operand)
Gets the expected layout for a given consumer operand.
DistributeLayoutAttr inferMultiReductionSourceLayout(DistributeLayoutAttr resLayout, SmallVector< int64_t > reduceDims)
Infers the source layout attribute for a reduction operation given the result layout attribute and re...
DistributeLayoutAttr setupLoadGatherAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, int chunkSize, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the anchor layout for a load gather operation.
LogicalResult propagateLayouts(OpBuilder &builder, Operation *target, LayoutKind layoutKind, unsigned indexBitWidth, bool printOnly=false)
std::optional< std::tuple< DistributeLayoutAttr, DistributeLayoutAttr, DistributeLayoutAttr > > setupDpasLayout(LayoutKind layoutKind, VectorType aTy, VectorType bTy, VectorType cdTy, DistributeLayoutAttr consumerLayout, int numSg, const uArch::uArch *uArch)
Sets up the anchor layouts for a dpas operands (A, B, and C/D).
SliceAttr setupReductionResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, const uArch::uArch *uArch)
Sets up layout for Reduction operations by creating a SliceAttr for the result.
DistributeLayoutAttr setupStoreMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, const uArch::uArch *uArch)
Sets up the anchor layout for a store matrix operation.
Include the generated interface declarations.
DenseMap< OpOperand *, SmallVector< Value > > RegionBranchSuccessorMapping
A mapping from successor operands to successor inputs.
bool operator==(StringAttr lhs, std::nullptr_t)
Define comparisons for StringAttr against nullptr and itself to avoid the StringRef overloads from be...
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:139
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
virtual int getSubgroupSize() const =0
const Instruction * getInstruction(InstructionKind instKind) const
Definition uArchBase.h:163