MLIR 23.0.0git
XeGPUPropagateLayout.cpp
Go to the documentation of this file.
1//===- XeGPUPropagateLayout.cpp - XeGPU Layout Propagation ------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
21#include "mlir/IR/Attributes.h"
22#include "mlir/IR/Builders.h"
25#include "mlir/IR/Operation.h"
26#include "mlir/IR/Value.h"
27#include "mlir/IR/Visitors.h"
31#include "mlir/Support/LLVM.h"
32#include "llvm/ADT/ArrayRef.h"
33#include "llvm/ADT/STLExtras.h"
34#include "llvm/ADT/SmallSet.h"
35#include "llvm/ADT/SmallVector.h"
36#include "llvm/ADT/TypeSwitch.h"
37#include "llvm/Support/Casting.h"
38#include "llvm/Support/Debug.h"
39#include "llvm/Support/LogicalResult.h"
40#include "llvm/Support/raw_ostream.h"
41
42namespace mlir {
43namespace xegpu {
44#define GEN_PASS_DEF_XEGPUPROPAGATELAYOUT
45#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
46} // namespace xegpu
47} // namespace mlir
48
49#define DEBUG_TYPE "xegpu-propagate-layout"
50#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
51
52using namespace mlir;
53using namespace mlir::dataflow;
54
55namespace {
56
57//===----------------------------------------------------------------------===//
58// LayoutInfo
59//===----------------------------------------------------------------------===//
60
61/// Helper class for tracking the analysis state of an mlir value. For layout
62/// propagation, the analysis state is simply the distribution layout of
63/// each value. The distribution layout information is encapsulated using
64/// xegpu::DistributeLayoutAttr class which can hold information about any type
65/// of distribution layout that XeGPU dialect supports. Purpose of this analysis
66/// to propagate some unique distribution layout for each value in the program
67/// starting from a set of anchor operations (like DPAS, StoreNd, etc.). Note
68/// that analysis will reach a fixed point when all values are reached some
69/// layout and, analysis does not try to modify any already assigned layouts.
70///
71/// Given this, LayoutInfo satisifies the following properties:
72/// 1) A LayoutInfo value can be in one of two states - `assigned` or `not
73/// assigned`.
74/// 2) Two LayoutInfo values are equal if they are both assigned or
75/// both not assigned. The concrete value of assigned state does not matter.
76/// 3) The meet operator works as follows:
77/// - If current state is assigned, return the current state. (already
78/// a unique layout is assigned. don't change it)
79/// - Otherwise, return the other state.
80
81struct LayoutInfo {
82private:
83 xegpu::DistributeLayoutAttr storage = nullptr;
84
85public:
86 LayoutInfo() = default;
87 LayoutInfo(const xegpu::DistributeLayoutAttr &layout) : storage(layout) {}
88
89 // Two lattice values are equal if they have `some` layout. The actual
90 // content of the layout does not matter.
91 bool operator==(const LayoutInfo &other) const {
92 return this->isAssigned() == other.isAssigned();
93 }
94
95 static LayoutInfo meet(const LayoutInfo &lhs, const LayoutInfo &rhs);
96
97 static LayoutInfo join(const LayoutInfo &lhs, const LayoutInfo &rhs);
98
99 void print(raw_ostream &os) const;
100
101 bool isAssigned() const { return storage != nullptr; }
102
103 LayoutInfo transpose(ArrayRef<int64_t> permutation) const;
104
105 SmallVector<int> getLaneLayout() const;
106
107 SmallVector<int> getLaneData() const;
108
109 SmallVector<int> getInstData() const;
110
111 SmallVector<int> getSgLayout() const;
112
113 SmallVector<int> getSgData() const;
114
115 SmallVector<int> getOrder() const;
116
117 bool isSliceLayout() const {
118 if (!isAssigned())
119 return false;
120 return isa<xegpu::SliceAttr>(storage);
121 }
122
123 int64_t getRank() const {
124 if (!isAssigned())
125 return -1;
126 return storage.getRank();
127 }
128
129 Attribute get() { return storage; }
130 void set(const xegpu::DistributeLayoutAttr &layout) { storage = layout; }
131};
132
133SmallVector<int> LayoutInfo::getLaneLayout() const {
134 if (!isAssigned())
135 return {};
136 return llvm::map_to_vector(storage.getEffectiveLaneLayoutAsInt(),
137 [](int64_t val) { return static_cast<int>(val); });
138}
139
140SmallVector<int> LayoutInfo::getLaneData() const {
141 if (!isAssigned())
142 return {};
143 return llvm::map_to_vector(storage.getEffectiveLaneDataAsInt(),
144 [](int64_t val) { return static_cast<int>(val); });
145}
146
147SmallVector<int> LayoutInfo::getInstData() const {
148 if (!isAssigned())
149 return {};
150 return llvm::map_to_vector(storage.getEffectiveInstDataAsInt(),
151 [](int64_t val) { return static_cast<int>(val); });
152}
153
154SmallVector<int> LayoutInfo::getSgLayout() const {
155 if (!isAssigned())
156 return {};
157 return llvm::map_to_vector(storage.getEffectiveSgLayoutAsInt(),
158 [](int64_t val) { return static_cast<int>(val); });
159}
160
161SmallVector<int> LayoutInfo::getSgData() const {
162 if (!isAssigned())
163 return {};
164 return llvm::map_to_vector(storage.getEffectiveSgDataAsInt(),
165 [](int64_t val) { return static_cast<int>(val); });
166}
167
168SmallVector<int> LayoutInfo::getOrder() const {
169 if (!isAssigned() || !storage.getOrder())
170 return {};
171 return llvm::map_to_vector(storage.getOrder().asArrayRef(),
172 [](int64_t val) { return static_cast<int>(val); });
173}
174
175void LayoutInfo::print(raw_ostream &os) const {
176 if (isAssigned()) {
177 os << storage;
178 } else {
179 os << "Not assigned.";
180 }
181}
182
183LayoutInfo LayoutInfo::meet(const LayoutInfo &lhs, const LayoutInfo &rhs) {
184 if (!lhs.isAssigned())
185 return rhs;
186 return lhs;
187}
188
189/// Since this is a backward analysis, join method is not used.
190LayoutInfo LayoutInfo::join(const LayoutInfo &lhs, const LayoutInfo &rhs) {
191 llvm_unreachable("Join should not be triggered by layout propagation.");
192}
193
194/// Construct a new layout with the transposed inst_data or lane_layout,
195/// lane_data.
196LayoutInfo LayoutInfo::transpose(ArrayRef<int64_t> permutation) const {
197 if (!isAssigned())
198 return {};
199 // Check if the permutation is valid.
200 llvm::SmallSet<int64_t, 4> seen(permutation.begin(), permutation.end());
201 bool hasDuplicates = seen.size() != permutation.size();
202 bool withinRange = llvm::all_of(permutation, [&](int64_t idx) {
203 return idx >= 0 && idx < static_cast<int64_t>(permutation.size());
204 });
206 if (!withinRange || hasDuplicates) {
207 assert(false && "Invalid permutation for transpose.");
208 return {};
209 }
210
212 SmallVector<int32_t> laneData;
214 SmallVector<int32_t> sgLayout;
217
218 for (int64_t idx : permutation) {
219 if (getLaneLayout().size()) {
220 laneLayout.push_back(static_cast<int32_t>(getLaneLayout()[idx]));
221 laneData.push_back(static_cast<int32_t>(getLaneData()[idx]));
223 if (getInstData().size())
224 instData.push_back(static_cast<int32_t>(getInstData()[idx]));
225 if (getSgData().size()) {
226 sgLayout.push_back(static_cast<int32_t>(getSgLayout()[idx]));
227 sgData.push_back(static_cast<int32_t>(getSgData()[idx]));
228 }
229 if (getOrder().size()) {
230 order.push_back(static_cast<int32_t>(getOrder()[idx]));
231 }
233 auto orderAttr = order.size()
234 ? DenseI32ArrayAttr::get(storage.getContext(), order)
235 : nullptr;
236 xegpu::LayoutAttr layoutAttr;
237 if (getLaneLayout().size())
238 layoutAttr =
239 xegpu::LayoutAttr::get(storage.getContext(), laneLayout, laneData);
240 if (getInstData().size())
241 layoutAttr = xegpu::LayoutAttr::get(storage.getContext(), instData);
242 if (getSgData().size())
243 layoutAttr = xegpu::LayoutAttr::get(
244 storage.getContext(),
245 DenseI32ArrayAttr::get(storage.getContext(), sgLayout),
246 DenseI32ArrayAttr::get(storage.getContext(), sgData),
247 /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
248 /*lane_data =*/nullptr, orderAttr);
249 return LayoutInfo(layoutAttr);
252//===----------------------------------------------------------------------===//
253// LayoutInfoLattice
254//===----------------------------------------------------------------------===//
255
256/// Lattice holding the LayoutInfo for each value.
257struct LayoutInfoLattice : public Lattice<LayoutInfo> {
259 using Lattice::Lattice;
260};
261
262/// Helper Functions to get default layouts. A `default layout` is a layout that
263/// is assigned to a value when the layout is not fixed by some anchor operation
264/// (like DPAS).
265
266/// Helper Function to get the default layout for uniform values like constants.
267/// For 1D vector, lane_layout is [subgroupSize] and lane_data is [1].
268/// For 2D vector, lane_layout is [1, subgroupSize] and lane_data is [1, 1].
269static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
270 unsigned rank,
271 const xegpu::uArch::uArch *uArch) {
272 assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
273 if (rank == 1) {
274 return LayoutInfo(
275 xegpu::LayoutAttr::get(ctx, {uArch->getSubgroupSize()}, {1}));
276 }
277 return LayoutInfo(
278 xegpu::LayoutAttr::get(ctx, {1, uArch->getSubgroupSize()}, {1, 1}));
279}
280
281static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
282 unsigned rank, int subgroupSize) {
283 assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
284 if (rank == 1) {
285 return LayoutInfo(xegpu::LayoutAttr::get(ctx, {subgroupSize}, {1}));
286 }
287 return LayoutInfo(xegpu::LayoutAttr::get(ctx, {1, subgroupSize}, {1, 1}));
288}
289
290/// Helper to get the default layout for 2D block operations.
291template <typename Ty>
292static LayoutInfo getSIMTLayoutInfoBlockIO(Ty ty,
293 const xegpu::uArch::uArch *uArch,
294 unsigned packingSize) {
295 // Expecting a 1D or 2D vector.
296 assert((ty.getRank() == 1 || ty.getRank() == 2) &&
297 "Expected 1D or 2D vector.");
298 // Expecting int or float element type.
299 assert(ty.getElementType().isIntOrFloat() &&
300 "Expected int or float element type.");
301 // If the rank is 1, then return default layout for 1D vector.
302 if (ty.getRank() == 1)
303 return getDefaultSIMTLayoutInfo(ty.getContext(), 1, uArch);
304 // Packing factor is determined by the element type bitwidth.
305 unsigned bitwidth = ty.getElementType().getIntOrFloatBitWidth();
306 int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
307 return LayoutInfo(xegpu::LayoutAttr::get(
308 ty.getContext(), {1, uArch->getSubgroupSize()}, {1, packingFactor}));
309}
310
311//===----------------------------------------------------------------------===//
312// LayoutInfoPropagation
313//===----------------------------------------------------------------------===//
314
315/// Backward data flow analysis to propagate the lane_layout and lane_data of
316/// each value in the program. Currently, the layouts for operands DPAS,
317/// StoreNd, and StoreScatter are fixed (known before propagation). Purpose of
318/// this analysis is to propagate those known layouts to all their producers and
319/// (other) consumers.
320class LayoutInfoPropagation
321 : public SparseBackwardDataFlowAnalysis<LayoutInfoLattice> {
322private:
323 xegpu::LayoutKind layoutKind;
324 unsigned indexBitWidth;
325 void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
326 ArrayRef<const LayoutInfoLattice *> results);
327
328 void visitStoreNdOp(xegpu::StoreNdOp store,
329 ArrayRef<LayoutInfoLattice *> operands,
330 ArrayRef<const LayoutInfoLattice *> results);
331
332 void visitStoreScatterOp(xegpu::StoreScatterOp storeScatter,
333 ArrayRef<LayoutInfoLattice *> operands,
334 ArrayRef<const LayoutInfoLattice *> results);
335
336 void visitLoadNdOp(xegpu::LoadNdOp load,
337 ArrayRef<LayoutInfoLattice *> operands,
338 ArrayRef<const LayoutInfoLattice *> results);
339
340 void visitLoadGatherOp(xegpu::LoadGatherOp load,
341 ArrayRef<LayoutInfoLattice *> operands,
342 ArrayRef<const LayoutInfoLattice *> results);
343
344 void visitTransposeOp(vector::TransposeOp transpose,
345 ArrayRef<LayoutInfoLattice *> operands,
346 ArrayRef<const LayoutInfoLattice *> results);
347
348 void visitVectorBitcastOp(vector::BitCastOp bitcast,
349 ArrayRef<LayoutInfoLattice *> operands,
350 ArrayRef<const LayoutInfoLattice *> results);
351
352 void visitCreateDescOp(xegpu::CreateDescOp createDesc,
353 ArrayRef<LayoutInfoLattice *> operands,
354 ArrayRef<const LayoutInfoLattice *> results);
355
356 void visitUpdateNdOffsetOp(xegpu::UpdateNdOffsetOp updateNdOffset,
357 ArrayRef<LayoutInfoLattice *> operands,
358 ArrayRef<const LayoutInfoLattice *> results);
359
360 void visitPrefetchNdOp(xegpu::PrefetchNdOp prefetch,
361 ArrayRef<LayoutInfoLattice *> operands,
362 ArrayRef<const LayoutInfoLattice *> results);
363
364 void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction,
365 ArrayRef<LayoutInfoLattice *> operands,
366 ArrayRef<const LayoutInfoLattice *> results);
367
368 void visitVectorBroadCastOp(vector::BroadcastOp broadcast,
369 ArrayRef<LayoutInfoLattice *> operands,
370 ArrayRef<const LayoutInfoLattice *> results);
371 void visitShapeCastOp(vector::ShapeCastOp shapeCast,
372 ArrayRef<LayoutInfoLattice *> operands,
373 ArrayRef<const LayoutInfoLattice *> results);
374 void
375 visitInsertStridedSliceOp(vector::InsertStridedSliceOp insertStridedSlice,
376 ArrayRef<LayoutInfoLattice *> operands,
377 ArrayRef<const LayoutInfoLattice *> results);
378
379 void visitLoadMatrixOp(xegpu::LoadMatrixOp load,
380 ArrayRef<LayoutInfoLattice *> operands,
381 ArrayRef<const LayoutInfoLattice *> results);
382
383 void visitStoreMatrixOp(xegpu::StoreMatrixOp store,
384 ArrayRef<LayoutInfoLattice *> operands,
385 ArrayRef<const LayoutInfoLattice *> results);
386
387 void visitLoadGatherOp(xegpu::LoadMatrixOp load,
388 ArrayRef<LayoutInfoLattice *> operands,
389 ArrayRef<const LayoutInfoLattice *> results);
390
391 void visitStoreScatterOp(xegpu::StoreMatrixOp store,
392 ArrayRef<LayoutInfoLattice *> operands,
393 ArrayRef<const LayoutInfoLattice *> results);
394
395 bool hasParamsOfLayoutKind(xegpu::DistributeLayoutAttr anchorLayout);
396
397public:
398 LayoutInfoPropagation(DataFlowSolver &solver,
399 SymbolTableCollection &symbolTable,
400 xegpu::LayoutKind layoutKind, unsigned indexBitWidth)
401 : SparseBackwardDataFlowAnalysis(solver, symbolTable),
402 layoutKind(layoutKind), indexBitWidth(indexBitWidth) {}
404
405 LogicalResult
406 visitOperation(Operation *op, ArrayRef<LayoutInfoLattice *> operands,
407 ArrayRef<const LayoutInfoLattice *> results) override;
408
409 void visitBranchOperand(OpOperand &operand) override {};
410
411 void visitCallOperand(OpOperand &operand) override {};
412
413 void
414 visitNonControlFlowArguments(RegionSuccessor &successor,
415 ArrayRef<BlockArgument> arguments) override {};
416
417 void visitExternalCall(CallOpInterface call,
418 ArrayRef<LayoutInfoLattice *> operands,
419 ArrayRef<const LayoutInfoLattice *> results) override {
420 };
421
422 void setToExitState(LayoutInfoLattice *lattice) override {
423 (void)lattice->meet(LayoutInfo());
424 }
425};
426} // namespace
427
428LogicalResult LayoutInfoPropagation::visitOperation(
429 Operation *op, ArrayRef<LayoutInfoLattice *> operands,
430 ArrayRef<const LayoutInfoLattice *> results) {
432 .Case(
433 [&](xegpu::DpasOp dpasOp) { visitDpasOp(dpasOp, operands, results); })
434 .Case([&](xegpu::StoreNdOp storeNdOp) {
435 visitStoreNdOp(storeNdOp, operands, results);
436 })
437 .Case([&](xegpu::StoreScatterOp storeScatterOp) {
438 visitStoreScatterOp(storeScatterOp, operands, results);
439 })
440 .Case([&](xegpu::LoadNdOp loadNdOp) {
441 visitLoadNdOp(loadNdOp, operands, results);
442 })
443 .Case([&](xegpu::LoadGatherOp loadGatherOp) {
444 visitLoadGatherOp(loadGatherOp, operands, results);
445 })
446 .Case([&](xegpu::CreateDescOp createDescOp) {
447 visitCreateDescOp(createDescOp, operands, results);
448 })
449 .Case([&](xegpu::UpdateNdOffsetOp updateNdOffsetOp) {
450 visitUpdateNdOffsetOp(updateNdOffsetOp, operands, results);
451 })
452 .Case([&](xegpu::PrefetchNdOp prefetchNdOp) {
453 visitPrefetchNdOp(prefetchNdOp, operands, results);
454 })
455 .Case([&](vector::TransposeOp transposeOp) {
456 visitTransposeOp(transposeOp, operands, results);
457 })
458 .Case([&](vector::BitCastOp bitcastOp) {
459 visitVectorBitcastOp(bitcastOp, operands, results);
460 })
461 .Case([&](vector::MultiDimReductionOp reductionOp) {
462 visitVectorMultiReductionOp(reductionOp, operands, results);
463 })
464 .Case([&](vector::BroadcastOp broadcastOp) {
465 visitVectorBroadCastOp(broadcastOp, operands, results);
466 })
467 .Case([&](vector::ShapeCastOp shapeCastOp) {
468 visitShapeCastOp(shapeCastOp, operands, results);
469 })
470 .Case([&](vector::InsertStridedSliceOp insertStridedSliceOp) {
471 visitInsertStridedSliceOp(insertStridedSliceOp, operands, results);
472 })
473 .Case([&](xegpu::LoadMatrixOp loadMatrixOp) {
474 visitLoadMatrixOp(loadMatrixOp, operands, results);
475 })
476 .Case([&](xegpu::StoreMatrixOp storeMatrixOp) {
477 visitStoreMatrixOp(storeMatrixOp, operands, results);
478 })
479 // All other ops.
480 .Default([&](Operation *op) {
481 for (const LayoutInfoLattice *resultInfo : results) {
482 if (!resultInfo->getValue().isAssigned())
483 continue;
484 for (auto [operandInfo, operand] :
485 llvm::zip(operands, op->getOpOperands())) {
486 // If the operand type is not a vector or tensor descriptor, skip
487 // it.
488 if (!isa<xegpu::TensorDescType, VectorType>(
489 operand.get().getType()))
490 continue;
491 // Propagate the result layout to the operand.
492 meet(operandInfo, *resultInfo);
493 }
494 }
495 });
496
497 return success();
498}
499
500bool LayoutInfoPropagation::hasParamsOfLayoutKind(
501 xegpu::DistributeLayoutAttr anchorLayout) {
502 if (anchorLayout == nullptr) {
503 return false;
504 }
505 if (layoutKind == xegpu::LayoutKind::InstData) {
506 return !(anchorLayout.getEffectiveInstDataAsInt().empty());
507 }
508 if (layoutKind == xegpu::LayoutKind::Lane) {
509 return !(anchorLayout.getEffectiveLaneLayoutAsInt().empty() ||
510 anchorLayout.getEffectiveLaneDataAsInt().empty());
511 }
512 if (layoutKind == xegpu::LayoutKind::Subgroup) {
513 return !(anchorLayout.getEffectiveSgLayoutAsInt().empty() ||
514 anchorLayout.getEffectiveSgDataAsInt().empty());
515 }
516 return false;
517}
518
519// This function returns all layouts for the given sgCount, whose sgData:
520// 1. Evenly divides the wgShape.
521// 2. Is a multiple of instData.
522// Example:
523// wgShape = [128, 64], instData = [8, 16], sgCount = 32
524// Returns layouts:
525// [(8,4), (16,2)], which correspond to sgData [16,16] and [8,32].
527 ArrayRef<int> instData,
528 int64_t sgCount) {
530 for (int sgLayout0 = 1; sgLayout0 <= sgCount; ++sgLayout0) {
531 if (sgCount % sgLayout0)
532 continue;
533 int sgLayout1 = sgCount / sgLayout0;
534 int sgData0 = wgShape[0] / sgLayout0;
535 int sgData1 = wgShape[1] / sgLayout1;
536 if ((wgShape[0] % sgLayout0 || wgShape[1] % sgLayout1) ||
537 (sgData0 % instData[0] || sgData1 % instData[1]))
538 continue;
539 candidates.emplace_back(sgLayout0, sgLayout1);
540 }
541 // Sort primarily by how balanced they are
542 // (i.e., minimize the absolute difference between the two dimensions), and
543 // secondarily by the first dimension in ascending order.
544 llvm::sort(candidates, [](const std::pair<int, int> &lhs,
545 const std::pair<int, int> &rhs) {
546 int diffLhs = std::abs(lhs.first - lhs.second);
547 int diffRhs = std::abs(rhs.first - rhs.second);
548 if (diffLhs != diffRhs)
549 return diffLhs < diffRhs;
550 return lhs.first < rhs.first;
551 });
552 return candidates;
553}
554
555FailureOr<int64_t> getNumSg(Operation *op, const int sgSize) {
556 // Oblivious to workitem layout, the total count matters.
557 auto gpuFunc = op->getParentOfType<gpu::GPUFuncOp>();
558 if (!gpuFunc)
559 return failure();
560 auto knownBlockSize = gpuFunc.getKnownBlockSize();
561 if (!knownBlockSize.has_value())
562 return failure();
563 const int flatBlockSize = llvm::product_of(knownBlockSize.value());
564 return flatBlockSize / sgSize;
565}
566
567void LayoutInfoPropagation::visitPrefetchNdOp(
568 xegpu::PrefetchNdOp prefetch, ArrayRef<LayoutInfoLattice *> operands,
569 ArrayRef<const LayoutInfoLattice *> results) {
570
571 LayoutInfo prefetchLayout;
572 xegpu::DistributeLayoutAttr anchorLayout = prefetch.getLayoutAttr();
573 if (hasParamsOfLayoutKind(anchorLayout)) {
574 prefetchLayout = LayoutInfo(anchorLayout);
575 } else {
576 // Here we assign the default layout to the tensor descriptor operand of
577 // prefetch.
578 auto tdescTy = prefetch.getTensorDescType();
579
580 const uArch *uArch = getUArch(getChipStr(prefetch).value_or(""));
581 if (!uArch)
582 return;
583 const auto *uArchInstruction =
584 dyn_cast<xegpu::uArch::Subgroup2DBlockPrefetchInstruction>(
585 uArch->getInstruction(
586 xegpu::uArch::InstructionKind::Subgroup2DBlockPrefetch));
587
588 auto blockWHC =
589 uArchInstruction->getBlockWidthHeightCount(tdescTy.getElementType());
590 if (!blockWHC)
591 prefetch.emitWarning("No known block params found for the element type.");
592 auto [bWidth, bHeight, bCount] = blockWHC.value();
593 SmallVector<int> instData;
594 int instWidth = xegpu::getLargestDivisor(
595 static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 1)), bWidth);
596 if (instWidth == -1)
597 prefetch.emitWarning(
598 "No suitable instruction multiple found for the given shape.");
599 if (tdescTy.getRank() == 1)
600 instData = {instWidth};
601 else {
602 int instHeight = xegpu::getLargestDivisor(
603 static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 2)), bHeight);
604 if (instHeight == -1)
605 prefetch.emitWarning(
606 "No suitable instruction multiple found for the given shape.");
607 instData = {instHeight, instWidth};
608 }
609
610 if (layoutKind == xegpu::LayoutKind::InstData)
611 prefetchLayout =
612 LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(), instData));
613 else
614 prefetchLayout = getSIMTLayoutInfoBlockIO(
615 tdescTy, uArch, uArchInstruction->getPackedFormatBitSize());
616
617 prefetch.setLayoutAttr(
618 dyn_cast<xegpu::DistributeLayoutAttr>(prefetchLayout.get()));
619 }
620 // Propagate the layout to the source tensor descriptor.
621 propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout));
622}
623
624void LayoutInfoPropagation::visitVectorMultiReductionOp(
625 vector::MultiDimReductionOp reduction,
626 ArrayRef<LayoutInfoLattice *> operands,
627 ArrayRef<const LayoutInfoLattice *> results) {
628 // The layout of the result must be present.
629 LayoutInfo resLayoutInfo = results[0]->getValue();
630 if (!resLayoutInfo.isAssigned())
631 return;
632
633 VectorType sourceTy = reduction.getSourceVectorType();
634 SmallVector<int64_t> reductionDims(reduction.getReductionDims());
635
636 const uArch *uArch = getUArch(xegpu::getChipStr(reduction).value_or(""));
637 if (!uArch)
638 return;
639 auto consumerLayoutAttr =
640 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
641
642 // The result layout represents the layout requirements of the operation.
643 // it is recorded to anchor layout or temporary layout.
644 // it must be honored for current op and may conflict with the layout
645 // propagated from consumer op, the conflict is resolved in later phase by
646 // converting the required result layout to the consumer layout
647 auto requiredResLayoutAttr = xegpu::setupMultiReductionResultLayout(
648 layoutKind, sourceTy, consumerLayoutAttr, reductionDims, uArch);
649
650 xegpu::setTemporaryLayout(reduction->getResult(0), requiredResLayoutAttr);
651
652 // derive the source layout from the dominant layout and reduction dims
653 auto srcLayoutAttr = xegpu::inferMultiReductionSourceLayout(
654 requiredResLayoutAttr, reductionDims);
655
656 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
657 // Accumulator should have the same layout as the result.
658 propagateIfChanged(operands[1],
659 operands[1]->meet(LayoutInfo(requiredResLayoutAttr)));
660}
661
662void LayoutInfoPropagation::visitVectorBroadCastOp(
663 vector::BroadcastOp broadcast, ArrayRef<LayoutInfoLattice *> operands,
664 ArrayRef<const LayoutInfoLattice *> results) {
665 // The layout of the result must be present.
666 LayoutInfo resLayoutInfo = results[0]->getValue();
667 if (!resLayoutInfo.isAssigned())
668 return;
669
670 // Only consider vector to vector broadcasts for now.
671 VectorType resultTy = broadcast.getResultVectorType();
672 VectorType sourceTy = dyn_cast<VectorType>(broadcast.getSourceType());
673 // skip layout propagation for non-vector source operand.
674 if (!sourceTy)
675 return;
676
677 auto srcShape = sourceTy.getShape();
678 auto resShape = resultTy.getShape();
679
680 size_t dimDiff = resultTy.getRank() - sourceTy.getRank();
681 if (dimDiff == 0) {
682 Operation *srcOp = broadcast.getSource().getDefiningOp();
683 if (!srcOp)
684 return;
685 [[maybe_unused]] bool hasUnitDim =
686 llvm::any_of(srcShape, [](int64_t dim) { return dim == 1; });
687 assert(
688 hasUnitDim && isa<vector::ShapeCastOp>(srcOp) &&
689 "When broadcasting from unit-dim, the producer op must be shape_cast!");
690 }
691
692 auto resultLayoutAttr =
693 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
694
695 xegpu::DistributeLayoutAttr srcLayoutAttr =
696 xegpu::inferBroadcastSourceLayout(resultLayoutAttr, resShape, srcShape);
697
698 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
699}
700
701void LayoutInfoPropagation::visitShapeCastOp(
702 vector::ShapeCastOp shapeCast, ArrayRef<LayoutInfoLattice *> operands,
703 ArrayRef<const LayoutInfoLattice *> results) {
704 // The layout of the result must be present.
705 LayoutInfo resLayoutInfo = results[0]->getValue();
706 if (!resLayoutInfo.isAssigned())
707 return;
708 ArrayRef<int64_t> resShape = shapeCast.getResultVectorType().getShape();
709 ArrayRef<int64_t> srcShape = shapeCast.getSourceVectorType().getShape();
710 auto resultLayoutAttr =
711 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
712
713 xegpu::DistributeLayoutAttr srcLayoutAttr =
714 xegpu::inferShapeCastSourceLayout(resultLayoutAttr, resShape, srcShape);
715
716 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
717}
718
719/// Propagate the layout of the result tensor to the source tensor descriptor
720/// in UpdateNdOffsetOp.
721void LayoutInfoPropagation::visitUpdateNdOffsetOp(
722 xegpu::UpdateNdOffsetOp updateNdOffset,
723 ArrayRef<LayoutInfoLattice *> operands,
724 ArrayRef<const LayoutInfoLattice *> results) {
725 // The layout of the result must be present.
726 LayoutInfo resultLayout = results[0]->getValue();
727 if (!resultLayout.isAssigned())
728 return;
729 // Propagate the layout to the source operand.
730 propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
731}
732
733/// Set the layouts for DPAS A, B, and C operands.
734void LayoutInfoPropagation::visitDpasOp(
735 xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
736 ArrayRef<const LayoutInfoLattice *> results) {
737 LayoutInfo dpasALayout;
738 LayoutInfo dpasBLayout;
739 LayoutInfo dpasCDLayout;
740
741 xegpu::DistributeLayoutAttr anchorLayoutCD = dpas.getLayoutCdAttr();
742 if (hasParamsOfLayoutKind(anchorLayoutCD)) {
743 xegpu::DistributeLayoutAttr anchorLayoutA = dpas.getLayoutAAttr();
744 xegpu::DistributeLayoutAttr anchorLayoutB = dpas.getLayoutBAttr();
745 assert(hasParamsOfLayoutKind(anchorLayoutA) &&
746 "Expected anchor layout for DPAS A operand.");
747 assert(hasParamsOfLayoutKind(anchorLayoutB) &&
748 "Expected anchor layout for DPAS B operand.");
749 dpasALayout = LayoutInfo(anchorLayoutA);
750 dpasBLayout = LayoutInfo(anchorLayoutB);
751 dpasCDLayout = LayoutInfo(anchorLayoutCD);
752 } else {
753 const uArch *uArch = getUArch(getChipStr(dpas).value_or(""));
754 if (!uArch)
755 return;
756 VectorType aTy = dpas.getLhsType();
757 VectorType bTy = dpas.getRhsType();
758 VectorType cdTy = dpas.getResultType();
759
760 xegpu::DistributeLayoutAttr consumerLayoutAttr = nullptr;
761 xegpu::DistributeLayoutAttr requiredCDLayoutAttr, requiredALayout,
762 requiredBLayout;
763
764 int numSg = 0;
765 if (layoutKind == xegpu::LayoutKind::Subgroup) {
766 LayoutInfo consumerLayout = results[0]->getValue();
767 if (!consumerLayout.isAssigned())
768 return;
769 consumerLayoutAttr =
770 dyn_cast<xegpu::DistributeLayoutAttr>(consumerLayout.get());
771 auto numSgOrErr = getNumSg(dpas, uArch->getSubgroupSize());
772 if (failed(numSgOrErr)) {
773 dpas.emitWarning(
774 "Unable to determine the number of subgroups for the operation.");
775 return;
776 }
777 numSg = numSgOrErr.value();
778 }
779 auto layouts = xegpu::setupDpasLayout(layoutKind, aTy, bTy, cdTy,
780 consumerLayoutAttr, uArch, numSg);
781 if (!layouts.has_value()) {
782 dpas.emitWarning(
783 "Failed to determine required layouts for DPAS operands.");
784 return;
785 }
786
787 std::tie(requiredALayout, requiredBLayout, requiredCDLayoutAttr) = *layouts;
788
789 dpas.setLayoutAAttr(requiredALayout);
790 dpas.setLayoutBAttr(requiredBLayout);
791 dpas.setLayoutCdAttr(requiredCDLayoutAttr);
792 dpasALayout = LayoutInfo(requiredALayout);
793 dpasBLayout = LayoutInfo(requiredBLayout);
794 dpasCDLayout = LayoutInfo(requiredCDLayoutAttr);
795 }
796 propagateIfChanged(operands[0], operands[0]->meet(dpasALayout));
797 propagateIfChanged(operands[1], operands[1]->meet(dpasBLayout));
798 if (operands.size() > 2)
799 propagateIfChanged(operands[2], operands[2]->meet(dpasCDLayout));
800}
801
802/// Set the layout for the value and tensor descriptor operands in StoreNdOp.
803void LayoutInfoPropagation::visitStoreNdOp(
804 xegpu::StoreNdOp store, ArrayRef<LayoutInfoLattice *> operands,
805 ArrayRef<const LayoutInfoLattice *> results) {
806 LayoutInfo storeLayout;
807 xegpu::DistributeLayoutAttr anchorLayout = store.getLayoutAttr();
808 if (hasParamsOfLayoutKind(anchorLayout)) {
809 storeLayout = LayoutInfo(anchorLayout);
810 } else {
811 const uArch *uArch = getUArch(getChipStr(store).value_or(""));
812 if (!uArch)
813 return;
814 const auto *uArchInstruction =
815 dyn_cast<xegpu::uArch::Subgroup2DBlockStoreInstruction>(
816 uArch->getInstruction(
817 xegpu::uArch::InstructionKind::Subgroup2DBlockStore));
818 VectorType dataTy = store.getValueType();
819 auto blockWHC = uArchInstruction->getBlockWidthHeightCount(
820 store.getValueType().getElementType());
821 if (!blockWHC)
822 store.emitWarning("No known block params found for the element type.");
823 auto [bWidth, bHeight, bCount] = blockWHC.value();
824 SmallVector<int> instData;
825 int instWidth = xegpu::getLargestDivisor(
826 static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 1)), bWidth);
827 if (instWidth == -1)
828 store.emitWarning(
829 "No suitable instruction multiple found for the given shape.");
830 if (dataTy.getRank() == 1)
831 instData = {instWidth};
832 else {
833 int instHeight = xegpu::getLargestDivisor(
834 static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 2)), bHeight);
835 if (instHeight == -1)
836 store.emitWarning(
837 "No suitable instruction multiple found for the given shape.");
838 instData = {instHeight, instWidth};
839 }
840
841 if (layoutKind == xegpu::LayoutKind::InstData)
842 storeLayout =
843 LayoutInfo(xegpu::LayoutAttr::get(dataTy.getContext(), instData));
844 else if (layoutKind == xegpu::LayoutKind::Lane)
845 storeLayout =
846 getSIMTLayoutInfoBlockIO(store.getValueType(), uArch,
847 uArchInstruction->getPackedFormatBitSize());
848 else { // xegpu::LayoutKind::Subgroup
849 auto sgSize = uArch->getSubgroupSize();
850 auto numSgOrErr = getNumSg(store, sgSize);
851 if (failed(numSgOrErr)) {
852 store.emitWarning(
853 "Unable to determine the number of subgroups for the operation.");
854 return;
855 }
856 auto sgLayouts = getValidLayouts(store.getValueType().getShape(),
857 instData, numSgOrErr.value());
858 if (sgLayouts.empty()) {
859 store.emitWarning(
860 "Unable to determine suitable subgroup layout for store value.");
861 return;
862 }
863 SmallVector<int> sgLayout = {sgLayouts[0].first, sgLayouts[0].second};
864 SmallVector<int> sgData = {
865 static_cast<int>(dataTy.getShape()[0]) / sgLayout[0],
866 static_cast<int>(dataTy.getShape()[1]) / sgLayout[1]};
867 storeLayout = LayoutInfo(xegpu::LayoutAttr::get(
868 dataTy.getContext(),
869 DenseI32ArrayAttr::get(dataTy.getContext(), sgLayout),
870 DenseI32ArrayAttr::get(dataTy.getContext(), sgData),
871 /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
872 /*lane_data =*/nullptr, /*order =*/nullptr));
873 }
874 store.setLayoutAttr(
875 dyn_cast<xegpu::DistributeLayoutAttr>(storeLayout.get()));
876 }
877 // Propagate the layout to the value operand.
878 // Both operands should have the same layout
879 for (LayoutInfoLattice *operand : operands)
880 propagateIfChanged(operand, operand->meet(storeLayout));
881}
882
883/// Propagate the layout of the value to the tensor descriptor operand in
884/// LoadNdOp.
885void LayoutInfoPropagation::visitLoadNdOp(
886 xegpu::LoadNdOp load, ArrayRef<LayoutInfoLattice *> operands,
887 ArrayRef<const LayoutInfoLattice *> results) {
888 LayoutInfo loadLayout;
889 xegpu::DistributeLayoutAttr anchorLayout = load.getLayoutAttr();
890 if (hasParamsOfLayoutKind(anchorLayout)) {
891 loadLayout = LayoutInfo(anchorLayout);
892 } else {
893
894 LayoutInfo valueLayout = results[0]->getValue();
895 // Need the layout of the value to propagate to the tensor descriptor.
896 if (!valueLayout.isAssigned())
897 return;
898 loadLayout = valueLayout;
899 // LoadNdOp has the transpose effect. However, at the stage of this analysis
900 // this effect is not expected and should be abstracted away. Emit a
901 // warning.
902 if (auto transpose = load.getTranspose()) {
903 load.emitWarning("Transpose effect is not expected for LoadNdOp at "
904 "LayoutInfoPropagation stage.");
905 loadLayout = valueLayout.transpose(transpose.value());
906 }
907 load.setLayoutAttr(dyn_cast<xegpu::DistributeLayoutAttr>(loadLayout.get()));
908 }
909 // Propagate the new layout to the tensor descriptor operand.
910 propagateIfChanged(operands[0], operands[0]->meet(loadLayout));
911}
912
913/// For vector::TransposeOp, the layout of the result is transposed and
914/// propagated to the operand.
915void LayoutInfoPropagation::visitTransposeOp(
916 vector::TransposeOp transpose, ArrayRef<LayoutInfoLattice *> operands,
917 ArrayRef<const LayoutInfoLattice *> results) {
918 // Need the layout of transpose result to propagate to the operands.
919 LayoutInfo resultLayout = results[0]->getValue();
920 if (!resultLayout.isAssigned())
921 return;
922 auto consumerLayoutAttr =
923 dyn_cast<xegpu::DistributeLayoutAttr>(resultLayout.get());
924 auto srcLayoutAttr = xegpu::inferTransposeSourceLayout(
925 consumerLayoutAttr, transpose.getPermutation());
926 // Propagate the new layout to the vector operand.
927 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
928}
929
930/// For vector::BitCastOp, the lane_data of the source layout is changed based
931/// on the bit width of the source and result types.
932void LayoutInfoPropagation::visitVectorBitcastOp(
933 vector::BitCastOp bitcast, ArrayRef<LayoutInfoLattice *> operands,
934 ArrayRef<const LayoutInfoLattice *> results) {
935 // Need the layout of bitcast result to propagate to the operands.
936 LayoutInfo resLayoutInfo = results[0]->getValue();
937 if (!resLayoutInfo.isAssigned())
938 return;
939
940 auto srcVecType = bitcast.getSourceVectorType();
941 auto resVecType = bitcast.getResultVectorType();
942
943 auto consumerLayoutAttr =
944 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
945 const uArch *uArch = getUArch(xegpu::getChipStr(bitcast).value_or(""));
946 if (!uArch)
947 return;
948 auto requiredResLayoutAttr = setupBitCastResultLayout(
949 layoutKind, srcVecType, resVecType, consumerLayoutAttr, uArch);
950
951 xegpu::setTemporaryLayout(bitcast->getResult(0), requiredResLayoutAttr);
952
953 int inElemTyBitWidth = srcVecType.getElementType().getIntOrFloatBitWidth();
954 int outElemTyBitWidth = resVecType.getElementType().getIntOrFloatBitWidth();
955
956 // derive the source layout from the dominant layout and reduction dims
957 auto srcLayoutAttr = xegpu::inferBitCastSourceLayout(
958 requiredResLayoutAttr, outElemTyBitWidth, inElemTyBitWidth);
959
960 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
961}
962
963void LayoutInfoPropagation::visitInsertStridedSliceOp(
964 vector::InsertStridedSliceOp insertStridedSlice,
965 ArrayRef<LayoutInfoLattice *> operands,
966 ArrayRef<const LayoutInfoLattice *> results) {
967 // The layout of the result must be present.
968 LayoutInfo resLayoutInfo = results[0]->getValue();
969 if (!resLayoutInfo.isAssigned())
970 return;
971
972 auto srcVecType = insertStridedSlice.getSourceVectorType();
973 auto resVecType = insertStridedSlice.getDestVectorType();
974
975 auto consumerLayoutAttr =
976 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
977 const uArch *uArch =
978 getUArch(xegpu::getChipStr(insertStridedSlice).value_or(""));
979 if (!uArch)
980 return;
981
982 auto requiredResLayoutAttr = xegpu::setupInsertStridedSliceResultLayout(
983 layoutKind, srcVecType, resVecType, consumerLayoutAttr, uArch);
984 xegpu::setTemporaryLayout(insertStridedSlice->getResult(0),
985 requiredResLayoutAttr);
986
988 requiredResLayoutAttr, resVecType.getShape(), srcVecType.getShape());
989 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
990 propagateIfChanged(operands[1],
991 operands[1]->meet(LayoutInfo(requiredResLayoutAttr)));
992}
993
994/// Propagate the layout of the result to the tensor descriptor, mask and offset
995/// operands in LoadGatherOp.
996void LayoutInfoPropagation::visitLoadGatherOp(
997 xegpu::LoadGatherOp load, ArrayRef<LayoutInfoLattice *> operands,
998 ArrayRef<const LayoutInfoLattice *> results) {
999 xegpu::DistributeLayoutAttr requiredAnchorLayoutAttr;
1000 xegpu::DistributeLayoutAttr anchorLayoutAttr = load.getLayoutAttr();
1001 const uArch *uArch = getUArch(getChipStr(load).value_or(""));
1002 if (!uArch)
1003 return;
1004 auto subgroupSize = uArch->getSubgroupSize();
1005 VectorType resVecTy = load.getValueType();
1006 int chunkSize = load.getChunkSize().value_or(1);
1007
1008 LayoutInfo resLayoutInfo = results[0]->getValue();
1009 if (!resLayoutInfo.isAssigned())
1010 return;
1011 auto consumerLayoutAttr =
1012 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
1013
1014 if (hasParamsOfLayoutKind(anchorLayoutAttr)) {
1015 requiredAnchorLayoutAttr = anchorLayoutAttr;
1016 } else {
1017 if (!resVecTy) {
1018 load.emitWarning("Not propagating, non-vector payload supplied.");
1019 return;
1020 }
1021 requiredAnchorLayoutAttr = xegpu::setupLoadGatherAnchorLayout(
1022 layoutKind, resVecTy, chunkSize, consumerLayoutAttr, uArch);
1023 load.setLayoutAttr(requiredAnchorLayoutAttr);
1024 }
1025
1026 auto maskLayoutAttr = requiredAnchorLayoutAttr;
1027 // Special handling mask layout for chunked ops: Enforce the default xegpu 1D
1028 // layout for mask.
1029 if (chunkSize > 1) {
1030 if (layoutKind == xegpu::LayoutKind::InstData)
1031 maskLayoutAttr =
1032 xegpu::LayoutAttr::get(load->getContext(), {subgroupSize});
1033 else if (layoutKind == xegpu::LayoutKind::Lane)
1034 maskLayoutAttr =
1035 xegpu::LayoutAttr::get(load->getContext(), {subgroupSize}, {1});
1036 else
1037 assert(false &&
1038 "chunked StoreScatterOp should not be used at workgroup level");
1039 }
1040
1041 LayoutInfo maskLayoutInfo = LayoutInfo(maskLayoutAttr);
1042 auto loadLayoutInfo = LayoutInfo(requiredAnchorLayoutAttr);
1043
1044 // Propagate the new layout to the tensor descriptor operand.
1045 if (isa<xegpu::TensorDescType>(load.getSourceType()))
1046 propagateIfChanged(operands[0], operands[0]->meet(loadLayoutInfo));
1047 // Propagate the new layout to the mask and optional offset operand.
1048 propagateIfChanged(operands[1], operands[1]->meet(maskLayoutInfo));
1049 if (load.getOffsets())
1050 propagateIfChanged(operands[2], operands[2]->meet(maskLayoutInfo));
1051}
1052
1053/// Propagate the layout of the descriptor to the vector offset operand in
1054/// CreateDescOp.
1055void LayoutInfoPropagation::visitCreateDescOp(
1056 xegpu::CreateDescOp createDesc, ArrayRef<LayoutInfoLattice *> operands,
1057 ArrayRef<const LayoutInfoLattice *> results) {
1058 LayoutInfo descLayout = results[0]->getValue();
1059 // Need the layout of the descriptor to propagate to the operands.
1060 if (!descLayout.isAssigned())
1061 return;
1062 const uArch *uArch = getUArch(getChipStr(createDesc).value_or(""));
1063 if (!uArch)
1064 return;
1065 // For offset operand propagate 1D default layout.
1066 LayoutInfo layout = getDefaultSIMTLayoutInfo(createDesc->getContext(), 1,
1067 uArch->getSubgroupSize());
1068 propagateIfChanged(operands[1], operands[1]->meet(layout));
1069}
1070
1071/// Set the layout for the value, tensor descriptor, offset and mask operands in
1072/// the StoreScatterOp.
1073void LayoutInfoPropagation::visitStoreScatterOp(
1074 xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands,
1075 ArrayRef<const LayoutInfoLattice *> results) {
1076
1077 xegpu::DistributeLayoutAttr requiredAnchorLayoutAttr;
1078 xegpu::DistributeLayoutAttr anchorLayoutAttr = storeScatter.getLayoutAttr();
1079 const uArch *uArch = getUArch(getChipStr(storeScatter).value_or(""));
1080 if (!uArch)
1081 return;
1082 auto subgroupSize = uArch->getSubgroupSize();
1083 VectorType srcVecTy = storeScatter.getValueType();
1084 int chunkSize = storeScatter.getChunkSize().value_or(1);
1085
1086 if (hasParamsOfLayoutKind(anchorLayoutAttr)) {
1087 requiredAnchorLayoutAttr = anchorLayoutAttr;
1088 } else {
1089 if (!srcVecTy) {
1090 storeScatter.emitWarning("Not propagating, non-vector payload supplied.");
1091 return;
1092 }
1093 requiredAnchorLayoutAttr = xegpu::setupStoreScatterAnchorLayout(
1094 layoutKind, srcVecTy, chunkSize, uArch);
1095 storeScatter.setLayoutAttr(requiredAnchorLayoutAttr);
1096 }
1097
1098 LayoutInfo srcLayoutInfo = LayoutInfo(requiredAnchorLayoutAttr);
1099 auto maskLayoutAttr = requiredAnchorLayoutAttr;
1100 // Special handling mask layout for chunked ops: Enforce the default xegpu 1D
1101 // layout for mask.
1102 if (chunkSize > 1) {
1103 if (layoutKind == xegpu::LayoutKind::InstData)
1104 maskLayoutAttr =
1105 xegpu::LayoutAttr::get(storeScatter->getContext(), {subgroupSize});
1106 else if (layoutKind == xegpu::LayoutKind::Lane)
1107 maskLayoutAttr = xegpu::LayoutAttr::get(storeScatter->getContext(),
1108 {subgroupSize}, {1});
1109 else
1110 assert(false &&
1111 "chunked StoreScatterOp should not be used at workgroup level");
1112 }
1113
1114 LayoutInfo maskLayoutInfo = LayoutInfo(maskLayoutAttr);
1115
1116 // Propagate the payload operand layout
1117 propagateIfChanged(operands[0], operands[0]->meet(srcLayoutInfo));
1118 // Propagate the destination (if tdesc) operand layout
1119 if (isa<xegpu::TensorDescType>(storeScatter.getDestType()))
1120 propagateIfChanged(operands[1], operands[1]->meet(srcLayoutInfo));
1121 // Propagate the new layout to the mask and optional offset operand.
1122 propagateIfChanged(operands[2], operands[2]->meet(maskLayoutInfo));
1123 if (storeScatter.getOffsets())
1124 propagateIfChanged(operands[3], operands[3]->meet(maskLayoutInfo));
1125}
1126
1127void LayoutInfoPropagation::visitLoadMatrixOp(
1128 xegpu::LoadMatrixOp loadMatrixOp, ArrayRef<LayoutInfoLattice *> operands,
1129 ArrayRef<const LayoutInfoLattice *> results) {
1130
1131 LayoutInfo resLayoutInfo = results[0]->getValue();
1132 if (!resLayoutInfo.isAssigned())
1133 return;
1134
1135 auto consumerLayoutAttr =
1136 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
1137
1138 xegpu::DistributeLayoutAttr anchorLayout = loadMatrixOp.getLayoutAttr();
1139
1140 // only need to set anchor layout, no need to porpagate to memdesc and
1141 // offset
1142 if (!hasParamsOfLayoutKind(anchorLayout)) {
1143 VectorType resVecTy =
1144 llvm::cast<VectorType>(loadMatrixOp.getRes().getType());
1145 assert(resVecTy.getRank() == 2 && "Expecting 2D vector for store matrix.");
1146 const uArch *uArch = getUArch(getChipStr(loadMatrixOp).value_or(""));
1147 if (!uArch)
1148 return;
1149 auto requiredAnchorLayoutAttr = xegpu::setupLoadMatrixAnchorLayout(
1150 layoutKind, resVecTy, consumerLayoutAttr, uArch);
1151 loadMatrixOp.setLayoutAttr(requiredAnchorLayoutAttr);
1152 }
1153}
1154
1155// Store matrix is a flavor of scattered store for 2D shapes.
1156void LayoutInfoPropagation::visitStoreMatrixOp(
1157 xegpu::StoreMatrixOp storeMatrix, ArrayRef<LayoutInfoLattice *> operands,
1158 ArrayRef<const LayoutInfoLattice *> results) {
1159 xegpu::DistributeLayoutAttr anchorLayout = storeMatrix.getLayoutAttr();
1160 LayoutInfo layout;
1161 if (hasParamsOfLayoutKind(anchorLayout)) {
1162 layout = LayoutInfo(anchorLayout);
1163 } else {
1164 VectorType srcVecTy =
1165 llvm::cast<VectorType>(storeMatrix.getData().getType());
1166 assert(srcVecTy.getRank() == 2 && "Expecting 2D vector for store matrix.");
1167 const uArch *uArch = getUArch(getChipStr(storeMatrix).value_or(""));
1168 if (!uArch)
1169 return;
1170 auto requiredAnchorLayoutAttr =
1171 xegpu::setupStoreMatrixAnchorLayout(layoutKind, srcVecTy, uArch);
1172 storeMatrix.setLayoutAttr(requiredAnchorLayoutAttr);
1173 layout = LayoutInfo(requiredAnchorLayoutAttr);
1174 }
1175
1176 propagateIfChanged(operands[0], operands[0]->meet(layout));
1177}
1178
1179namespace {
1180//===----------------------------------------------------------------------===//
1181// RunLayoutInfoPropagation
1182//===----------------------------------------------------------------------===//
1183
1184/// Driver class for running the LayoutInfoPropagation analysis.
1185class RunLayoutInfoPropagation {
1186public:
1187 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RunLayoutInfoPropagation)
1188
1189 RunLayoutInfoPropagation(Operation *op, xegpu::LayoutKind layoutKind,
1190 unsigned indexBitWidth)
1191 : target(op) {
1192 SymbolTableCollection symbolTable;
1193 loadBaselineAnalyses(solver);
1194 solver.load<LayoutInfoPropagation>(symbolTable, layoutKind, indexBitWidth);
1195 (void)solver.initializeAndRun(op);
1196 }
1197
1198 LayoutInfo getLayoutInfo(Value val);
1199
1200 void printAnalysisResult(llvm::raw_ostream &os);
1201
1202private:
1203 DataFlowSolver solver;
1204 const Operation *target;
1205};
1206} // namespace
1207
1208LayoutInfo RunLayoutInfoPropagation::getLayoutInfo(Value val) {
1209 auto *state = solver.lookupState<LayoutInfoLattice>(val);
1210 if (!state)
1211 return {};
1212 return state->getValue();
1213}
1214
1215// Print the analysis result for debugging purposes.
1216void RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) {
1217 auto printFunctionResult = [&](FunctionOpInterface funcOp) {
1218 os << "function: " << funcOp.getName() << ":\n";
1219 // Function arguments
1220 for (BlockArgument arg : funcOp.getArguments()) {
1221 LayoutInfo layout = getLayoutInfo(arg);
1222 os << "argument: " << arg << "\n";
1223 os << "layout : ";
1224 layout.print(os);
1225 os << "\n";
1226 }
1227 // Function ops
1228 funcOp.walk([&](Operation *op) {
1229 // Skip ops that do not have results
1230 if (op->getResults().empty())
1231 return;
1232 os << "op : ";
1233 // For control-flow ops, print the op name only.
1234 if (isa<BranchOpInterface>(op) || isa<RegionBranchOpInterface>(op))
1235 os << op->getName();
1236 else
1237 op->print(os);
1238 os << "\n";
1239 // Print the layout for each result.
1240 for (auto [i, r] : llvm::enumerate(op->getResults())) {
1241 LayoutInfo layout = getLayoutInfo(r);
1242 os << "layout for result #" << i << ": ";
1243 layout.print(os);
1244 os << "\n";
1245 }
1246 });
1247 };
1248
1249 SmallVector<FunctionOpInterface> funcOps;
1250 if (auto modOp = dyn_cast<ModuleOp>(target)) {
1251 for (auto funcOp : modOp.getOps<FunctionOpInterface>())
1252 funcOps.push_back(funcOp);
1253
1254 // Collect all GpuFuncOps in the module.
1255 for (auto gpuModOp : modOp.getOps<gpu::GPUModuleOp>()) {
1256 for (auto gpuFuncOp : gpuModOp.getOps<FunctionOpInterface>())
1257 funcOps.push_back(gpuFuncOp);
1258 }
1259 }
1260 // Print the analysis result for each function.
1261 for (FunctionOpInterface funcOp : funcOps)
1262 printFunctionResult(funcOp);
1263}
1264
1265namespace {
1266
1267//===----------------------------------------------------------------------===//
1268// ResolveLayoutConflicts
1269//===----------------------------------------------------------------------===//
1270
1271/// Helper to get the defining CreateNdDescOp of a tensor descriptor value. This
1272/// function tries to find the defining CreateNdDescOp recursively accross
1273/// control-flow boundaries.
1274static xegpu::CreateNdDescOp getDefiningCreateNdDescOp(Value tdescValue) {
1275 // Try to get the defining CreateNdDescOp of the tensor descriptor.
1276 auto definingOp = tdescValue.getDefiningOp<xegpu::CreateNdDescOp>();
1277 if (definingOp)
1278 return definingOp;
1279 // If tdescValue is an argument, try to get the tied init value from the
1280 // parent loop-like op.
1281 if (auto arg = dyn_cast<BlockArgument>(tdescValue)) {
1282 auto *parentOp = arg.getOwner()->getParentOp();
1283 if (auto loop = dyn_cast<LoopLikeOpInterface>(parentOp)) {
1284 OpOperand *tiedInit = loop.getTiedLoopInit(arg);
1285 if (tiedInit)
1286 return getDefiningCreateNdDescOp(tiedInit->get());
1287 }
1288 }
1289 // If not found, return null.
1290 return nullptr;
1291}
1292
1293struct ResolveLayoutConflicts {
1294 ResolveLayoutConflicts(Operation *parentOp)
1295 : parentOp(parentOp), builder(parentOp->getContext()) {}
1296 LogicalResult run();
1297
1298private:
1299 Operation *parentOp;
1300 OpBuilder builder;
1301 LogicalResult resolveTensorDescConsumer(OpOperand &operand);
1302 LogicalResult resolveVectorConsumer(OpOperand &operand);
1303};
1304
1305} // namespace
1306
1307LogicalResult ResolveLayoutConflicts::run() {
1308 // Scan all operations in the parent op and resolve layout conflicts at
1309 // tensor descriptor and vector use points.
1310 auto r = parentOp->walk([&](Operation *op) -> WalkResult {
1311 for (OpOperand &operand : op->getOpOperands()) {
1312 // Handle conflicts in tensor descriptor operands.
1313 Type operandType = operand.get().getType();
1314 if (isa<xegpu::AnchorLayoutInterface>(op) &&
1315 isa<xegpu::TensorDescType>(operandType)) {
1316 auto res = resolveTensorDescConsumer(operand);
1317 if (failed(res)) {
1318 DBGS() << "Failed to resolve tensor descriptor consumer: " << *op
1319 << "\n";
1320 return WalkResult::interrupt();
1321 }
1322 }
1323 // Handle conflicts in vector operands.
1324 if (isa<VectorType>(operandType)) {
1325 auto res = resolveVectorConsumer(operand);
1326 if (failed(res)) {
1327 DBGS() << "Failed to resolve vector consumer: " << *op << "\n";
1328 return WalkResult::interrupt();
1329 }
1330 }
1331 }
1332 return WalkResult::advance();
1333 });
1334
1335 return r.wasInterrupted() ? failure() : success();
1336}
1337
1338LogicalResult
1339ResolveLayoutConflicts::resolveVectorConsumer(OpOperand &operand) {
1340 Value vectorValue = operand.get();
1341 Operation *consumerOp = operand.getOwner();
1342 // Get the current layout of the vector value.
1343 auto producerLayout = xegpu::getDistributeLayoutAttr(vectorValue);
1344 if (!producerLayout) {
1345 if (auto vectorTy = dyn_cast<VectorType>(vectorValue.getType());
1346 vectorTy && vectorTy.getRank() > 1)
1347 consumerOp->emitWarning("Expected layout for non-1D vectors.");
1348 return success(); // uniform non-tensor-data vector does not require layout
1349 }
1350 // Get the consumer expected layout at this operand.
1351 auto consumerLayout = xegpu::getConsumerLayoutAt(operand);
1352 if (!consumerLayout)
1353 return consumerOp->emitError(
1354 "No consumer layout found for vector operand.");
1355
1356 // If layouts are same, no conflict exists, return success.
1357 if (consumerLayout.isEqualTo(producerLayout))
1358 return success();
1359
1360 // Insert a convert_layout op to resolve the conflict.
1361 builder.setInsertionPointAfterValue(vectorValue);
1362 auto convertOp = xegpu::ConvertLayoutOp::create(
1363 builder, consumerOp->getLoc(), vectorValue.getType(), vectorValue,
1364 producerLayout, consumerLayout);
1365
1366 // Update the operand to use the converted value.
1367 operand.set(convertOp.getResult());
1368 return success();
1369}
1370
1371LogicalResult
1372ResolveLayoutConflicts::resolveTensorDescConsumer(OpOperand &operand) {
1373 Operation *consumerOp = operand.getOwner();
1374 Value tdescValue = operand.get();
1375 auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(consumerOp);
1376 auto currTDescType = dyn_cast<xegpu::TensorDescType>(tdescValue.getType());
1377 assert(anchorOp && currTDescType &&
1378 "Expected anchor layout op and tensor descriptor consumer.");
1379 // TODO: Scattered tensor desc is not supported for now.
1380 if (currTDescType.isScattered()) {
1381 DBGS() << "Scattered tensor descriptor not supported: " << tdescValue
1382 << "\n";
1383 return failure();
1384 }
1385 Attribute currLayout = currTDescType.getLayout();
1386 Attribute expectedLayout = anchorOp.getAnchorLayout();
1387 // A conflict exists in tensor descriptor operand if tensor descriptor's
1388 // layout is different from the anchor layout expected by the consumer.
1389 if (expectedLayout && currLayout && expectedLayout != currLayout) {
1390 // Try to get the defining CreateNdDescOp of the tensor descriptor.
1391 auto conflictingCreateNdOp = getDefiningCreateNdDescOp(tdescValue);
1392 if (!conflictingCreateNdOp) {
1393 DBGS() << "Unable to find defining CreateNdDescOp for tensor descriptor: "
1394 << tdescValue << "\n";
1395 return failure();
1396 }
1397 // Duplicate the CreateNdDescOp with the expected layout.
1398 builder.setInsertionPointAfter(conflictingCreateNdOp);
1399 auto newTensorDescType = xegpu::TensorDescType::get(
1400 conflictingCreateNdOp.getContext(), currTDescType.getShape(),
1401 currTDescType.getElementType(), currTDescType.getEncoding(),
1402 expectedLayout);
1403 xegpu::CreateNdDescOp newOp = xegpu::CreateNdDescOp::create(
1404 builder, consumerOp->getLoc(), newTensorDescType,
1405 conflictingCreateNdOp->getOperands(),
1406 conflictingCreateNdOp->getAttrs());
1407 // Replace the tensor descriptor operand in the consumer op with the new
1408 // tensor descriptor.
1409 consumerOp->replaceUsesOfWith(tdescValue, newOp.getResult());
1410 }
1411 return success();
1412}
1413
1414using GetLayoutFnTy = function_ref<xegpu::DistributeLayoutAttr(Value)>;
1415/// Update an operation with the layout of its results. If the result type is
1416/// a vector type, a temporary layout attribute is added to the operation. If
1417/// the result type is a tensor descriptor type, the type is updated with the
1418/// layout attribute. The users of the result are also updated with the layout
1419/// attribute.
1420static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
1421 GetLayoutFnTy getLayoutOfValue) {
1422 // Region ops (like scf.for) are already handled by the
1423 // updateControlFlowOps.
1424 if (mlir::isa<mlir::RegionBranchOpInterface>(op))
1425 return success();
1426
1427 // Iterate over all the results.
1428 for (OpResult result : op->getResults()) {
1429 Type resultType = result.getType();
1430 // Layouts are needed only for vector and tensor descriptor types.
1431 if (!isa<VectorType, xegpu::TensorDescType>(resultType))
1432 continue;
1433 // If the result has no layout but has users, emit a warning and continue.
1434 xegpu::DistributeLayoutAttr layout = getLayoutOfValue(result);
1435 if (!layout && result.getNumUses() > 0) {
1436 op->emitWarning("op has users but no layout assigned for its result");
1437 continue;
1438 }
1439 // If the result is a tensor descriptor type, update the tensor desc type
1440 // with layout.
1441 if (auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType)) {
1442 auto typeWithLayout = xegpu::TensorDescType::get(
1443 tensorDescTy.getContext(), tensorDescTy.getShape(),
1444 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
1445 result.setType(typeWithLayout);
1446 continue;
1447 }
1448 // If the result is a vector type, add a temporary layout attribute to the
1449 // op.
1451 }
1452 return success();
1453}
1454
1455/// Region ops like scf.for need special handling because they have blocks
1456/// inside. If the blocks have tensor descriptor type as block arguments,
1457/// thier types must be updated. Also region op can have results that may not
1458/// have any users (e.g. A and B tiles). They are not assigned a layout by
1459/// layout analysis because they have no users. However inside the region op
1460/// corresponding block arguments for these results do have layouts.
1461/// Therefore, in this case we still need to update the result types with the
1462/// layout attribute. This function function updates the internal block
1463/// arguments and the result types of the region op with the assigned layouts.
1464/// clang-format off
1465/// Example: scf.for ... iter_args(...) -> (out types) {
1466/// ^bb0(block types):
1467/// ...
1468/// scf.yield ... : (yield types)
1469/// }
1470/// clang-format on
1471/// In this example, at scf.yield, control-flow can transfer to two successor
1472/// regions. One is the ^bb0 (for loop body) and the other is the scf.for op
1473/// itself (yield the results). So we update both the block arguments of the
1474/// successor region (i.e. block types) and the result types of the scf.for op
1475/// (i.e. out types). Note that yield types are updated by respective
1476/// producers inside bb0.
1477static LogicalResult
1479 mlir::RegionBranchTerminatorOpInterface terminator,
1480 GetLayoutFnTy getLayoutOfValue) {
1481 // Only process if the terminator is inside a region branch op.
1482 auto branchOp = dyn_cast<RegionBranchOpInterface>(terminator->getParentOp());
1483 if (!branchOp)
1484 return success();
1485
1487 branchOp.getSuccessorOperandInputMapping(mapping,
1488 RegionBranchPoint(terminator));
1489 for (const auto &[successorOperand, successorInputs] : mapping) {
1490 for (Value successorInput : successorInputs) {
1491 Type inputType = successorInput.getType();
1492 // We only need to operate on tensor descriptor or vector types.
1493 if (!isa<xegpu::TensorDescType, VectorType>(inputType))
1494 continue;
1495 xegpu::DistributeLayoutAttr successorInputLayout =
1496 getLayoutOfValue(successorInput);
1497 xegpu::DistributeLayoutAttr successorOperandLayout =
1498 getLayoutOfValue(successorOperand->get());
1499
1500 // If either of the layouts is not assigned, we cannot proceed.
1501 if (!successorOperandLayout) {
1502 LLVM_DEBUG(DBGS() << "No layout assigned for forwarded operand in "
1503 "branch terminator: "
1504 << successorOperand->get() << "\n");
1505 return failure();
1506 }
1507 // We expect the layouts to match.
1508 if (successorInputLayout &&
1509 successorInputLayout != successorOperandLayout) {
1510 LLVM_DEBUG(DBGS() << "Conflicting layouts for region argument and "
1511 "operand forwarded as the argument: "
1512 << successorInputLayout << " vs "
1513 << successorOperandLayout << "\n");
1514 return failure();
1515 }
1516 // Get tensor descriptor type with the layout.
1517 if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(inputType)) {
1518 auto newTdescTy = xegpu::TensorDescType::get(
1519 tdescTy.getContext(), tdescTy.getShape(), tdescTy.getElementType(),
1520 tdescTy.getEncoding(), successorOperandLayout);
1521 successorInput.setType(newTdescTy);
1522 continue;
1523 }
1524 // If the type is a vector type and this region argument is an OpResult,
1525 // set the layout attribute on the OpResult.
1526 if (auto result = dyn_cast<OpResult>(successorInput))
1527 xegpu::setDistributeLayoutAttr(result, successorOperandLayout);
1528 }
1529 }
1530 return success();
1531}
1532
1533/// Update the function arguments and results with the layouts.
1534static LogicalResult updateFunctionOpInterface(mlir::OpBuilder &builder,
1535 mlir::FunctionOpInterface funcOp,
1536 GetLayoutFnTy getLayoutOfValue) {
1537 // Only process functions whose type is a standard MLIR FunctionType.
1538 // Functions using a different type representation (e.g. llvm.func with
1539 // LLVMFunctionType) are not targets for XeGPU layout propagation, and
1540 // calling setType(FunctionType{}) on them would corrupt their type.
1541 if (!isa<FunctionType>(funcOp.getFunctionType()))
1542 return success();
1543 SmallVector<Type> newArgTypes;
1544 // Update the function arguments.
1545 for (BlockArgument arg : funcOp.getArguments()) {
1546 Type argType = arg.getType();
1547 newArgTypes.push_back(argType);
1548 if (!isa<VectorType, xegpu::TensorDescType>(argType))
1549 continue;
1550 xegpu::DistributeLayoutAttr layout = getLayoutOfValue(arg);
1551 if (!layout) {
1552 LLVM_DEBUG(DBGS() << "Expecting layout for function argument: " << arg
1553 << " but got none.\n");
1554 return failure();
1555 }
1556 if (auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(argType)) {
1557 auto newTdescTy = xegpu::TensorDescType::get(
1558 tensorDescTy.getContext(), tensorDescTy.getShape(),
1559 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
1560 arg.setType(newTdescTy);
1561 newArgTypes.back() = newTdescTy;
1562 }
1563 }
1564 // Update the function type with the new argument types.
1565 // NOTE: We assume that function results are not expected to have layouts.
1566 funcOp.setType(FunctionType::get(funcOp.getContext(), newArgTypes,
1567 funcOp.getResultTypes()));
1568 return success();
1569}
1570
1571namespace {
1572struct XeGPUPropagateLayoutPass final
1573 : public xegpu::impl::XeGPUPropagateLayoutBase<XeGPUPropagateLayoutPass> {
1574 XeGPUPropagateLayoutPass() = default;
1575 XeGPUPropagateLayoutPass(const XeGPUPropagateLayoutPass &other) = default;
1576 XeGPUPropagateLayoutPass(xegpu::XeGPUPropagateLayoutOptions options)
1577 : XeGPUPropagateLayoutBase(std::move(options)) {}
1578 void runOnOperation() override;
1579};
1580
1581} // namespace
1582
1584 LayoutKind layoutKind,
1585 unsigned indexBitWidth, bool printOnly) {
1586 RunLayoutInfoPropagation analysis(target, layoutKind, indexBitWidth);
1587 // Print the analysis result and exit. (for debugging purposes)
1588 if (printOnly) {
1589 auto &os = llvm::outs();
1590 analysis.printAnalysisResult(os);
1591 return success();
1592 }
1593 // Helper to convert LayoutInfo to xegpu::LayoutAttr.
1594 auto getXeGPULayoutForValue = [&](Value val) -> xegpu::DistributeLayoutAttr {
1595 LayoutInfo layout = analysis.getLayoutInfo(val);
1596 if (!layout.isAssigned())
1597 return {};
1598 if (auto opResult = dyn_cast<OpResult>(val)) {
1599
1600 Operation *defOp = opResult.getDefiningOp();
1601 if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(defOp)) {
1602 auto anchorLayout = anchorOp.getAnchorLayout();
1603 if (anchorLayout != nullptr)
1604 return anchorLayout;
1605 }
1606 xegpu::DistributeLayoutAttr requiredResLayoutAttr =
1607 xegpu::getTemporaryLayout(opResult);
1608 if (requiredResLayoutAttr != nullptr)
1609 return requiredResLayoutAttr;
1610 }
1611 xegpu::DistributeLayoutAttr layoutAttr =
1612 cast<xegpu::DistributeLayoutAttr>(layout.get());
1613 if (layout.isSliceLayout())
1614 return cast<xegpu::SliceAttr>(layoutAttr);
1615
1616 return cast<xegpu::LayoutAttr>(layoutAttr);
1617 };
1618
1619 Operation *op = target;
1620 auto walkResult = op->walk([&](mlir::Block *block) -> WalkResult {
1621 for (mlir::Operation &op : llvm::reverse(block->getOperations())) {
1622 LogicalResult r = success();
1624 .Case([&](mlir::RegionBranchTerminatorOpInterface branchTermOp) {
1625 r = updateControlFlowOps(builder, branchTermOp,
1626 getXeGPULayoutForValue);
1627 })
1628 .Case([&](mlir::FunctionOpInterface funcOp) {
1629 r = updateFunctionOpInterface(builder, funcOp,
1630 getXeGPULayoutForValue);
1631 })
1632 .Default([&](Operation *op) {
1633 r = updateOp(builder, op, getXeGPULayoutForValue);
1634 });
1635 if (failed(r)) {
1636 op.emitError("Failed to update operation with the layout.");
1637 return WalkResult::interrupt();
1638 }
1639 }
1640 return WalkResult::advance();
1641 });
1642 if (walkResult.wasInterrupted())
1643 return failure();
1644
1645 return success();
1646}
1647
1649 ResolveLayoutConflicts resolver(target);
1650 return resolver.run();
1651}
1652
1653void XeGPUPropagateLayoutPass::runOnOperation() {
1654 xegpu::LayoutKind layoutKind;
1655 if (this->layoutKind == "lane") {
1656 layoutKind = xegpu::LayoutKind::Lane;
1657 } else if (this->layoutKind == "inst") {
1658 layoutKind = xegpu::LayoutKind::InstData;
1659 } else if (this->layoutKind == "subgroup") {
1660 layoutKind = xegpu::LayoutKind::Subgroup;
1661 } else {
1662 getOperation()->emitError("Unsupported layout kind option: " +
1663 this->layoutKind);
1664 signalPassFailure();
1665 return;
1666 }
1667 OpBuilder builder(&getContext());
1668 if (failed(xegpu::propagateLayouts(builder, getOperation(), layoutKind,
1669 this->indexBitWidth, this->printOnly))) {
1670 signalPassFailure();
1671 return;
1672 }
1673 // Resolve layout conflicts if any.
1674 if (failed(xegpu::resolveLayoutConflicts(getOperation()))) {
1675 signalPassFailure();
1676 return;
1677 }
1678}
return success()
#define DBGS()
Definition Hoisting.cpp:32
std::string join(const Ts &...args)
Helper function to concatenate arguments into a std::string.
lhs
b getContext())
auto load
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition TypeID.h:331
static SmallVector< LayoutRepresentation > getValidLayouts(ArrayRef< int64_t > wgShape, ArrayRef< int64_t > instData, int64_t sgCount)
static LogicalResult updateControlFlowOps(mlir::OpBuilder &builder, mlir::RegionBranchTerminatorOpInterface terminator, GetLayoutFnTy getLayoutOfValue)
Region ops like scf.for need special handling because they have blocks inside.
function_ref< xegpu::DistributeLayoutAttr(Value)> GetLayoutFnTy
FailureOr< int64_t > getNumSg(Operation *op, const int sgSize)
static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op, GetLayoutFnTy getLayoutOfValue)
Update an operation with the layout of its results.
static LogicalResult updateFunctionOpInterface(mlir::OpBuilder &builder, mlir::FunctionOpInterface funcOp, GetLayoutFnTy getLayoutOfValue)
Update the function arguments and results with the layouts.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:306
Block represents an ordered list of Operations.
Definition Block.h:33
OpListType & getOperations()
Definition Block.h:147
const StateT * lookupState(AnchorT anchor) const
Lookup an analysis state for the given lattice anchor.
AnalysisT * load(Args &&...args)
Load an analysis into the solver. Return the analysis instance.
LogicalResult initializeAndRun(Operation *top)
Initialize the children analyses starting from the provided top-level operation and run the analysis ...
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:209
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition Builders.h:423
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
This is a value defined by a result of an operation.
Definition Value.h:454
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:244
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:412
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:259
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
void print(raw_ostream &os, const OpPrintingFlags &flags={})
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:826
result_range getResults()
Definition Operation.h:444
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
This class represents a lattice holding a specific value of type ValueT.
SparseBackwardDataFlowAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
void loadBaselineAnalyses(DataFlowSolver &solver)
Populates a DataFlowSolver with analyses that are required to ensure user-defined analyses are run pr...
Definition Utils.h:29
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
const uArch * getUArch(llvm::StringRef archName)
DistributeLayoutAttr inferShapeCastSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for a shape cast operation given the result layout attribute,...
DistributeLayoutAttr inferTransposeSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > permutation)
Infers the source layout attribute for a transpose operation given the result layout attribute and pe...
SliceAttr setupMultiReductionResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, DistributeLayoutAttr consumerLayout, SmallVector< int64_t > reductionDims, const uArch::uArch *uArch)
Sets up layout for reduction operations by creating a SliceAttr for the result.
DistributeLayoutAttr inferInsertStridedSliceSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for an insert strided slice operation given the result layout attr...
void setTemporaryLayout(const T &operandOrResult, const DistributeLayoutAttr layout)
std::optional< std::tuple< DistributeLayoutAttr, DistributeLayoutAttr, DistributeLayoutAttr > > setupDpasLayout(LayoutKind layoutKind, VectorType aTy, VectorType bTy, VectorType cdTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch, int numSg)
Sets up the anchor layouts for a dpas operands (A, B, and C/D).
LayoutKind
Specifies the level of a layout hierarchy for comparison or propagation.
Definition XeGPU.h:32
void setDistributeLayoutAttr(const OpResult &Result, const DistributeLayoutAttr layout)
[to-be-deprecated] Sets the DistributeLayoutAttr for a given OpResult user should use setAnchorLayout...
DistributeLayoutAttr setupLoadMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the anchor layout for load matrix operation.
int getLargestDivisor(T dim, ArrayRef< T > candidates, ArrayRef< T > candidateMultiples={})
Helper Function to find a proper instruction multiple for the user-supplied sg-level data shape (dive...
DistributeLayoutAttr inferBroadcastSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for a broadcast operation given the result layout attribute,...
DistributeLayoutAttr setupStoreScatterAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, int chunkSize, const uArch::uArch *uArch)
Sets up the anchor layout for a store scatter operation.
DistributeLayoutAttr setupBitCastResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Setup the result layout attribute for a bitcast operation based on element type bitwidths.
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
LogicalResult resolveLayoutConflicts(Operation *target)
DistributeLayoutAttr inferBitCastSourceLayout(DistributeLayoutAttr resLayout, int resElemTyBitWidth, int srcElemTyBitWidth)
Infers the source layout attribute for a bitcast operation given the result layout attribute,...
DistributeLayoutAttr setupInsertStridedSliceResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the result layout for an insert strided slice operation.
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
xegpu::DistributeLayoutAttr getConsumerLayoutAt(OpOperand &operand)
Gets the expected layout for a given consumer operand.
DistributeLayoutAttr inferMultiReductionSourceLayout(DistributeLayoutAttr resLayout, SmallVector< int64_t > reduceDims)
Infers the source layout attribute for a reduction operation given the result layout attribute and re...
DistributeLayoutAttr setupLoadGatherAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, int chunkSize, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the anchor layout for a load gather operation.
LogicalResult propagateLayouts(OpBuilder &builder, Operation *target, LayoutKind layoutKind, unsigned indexBitWidth, bool printOnly=false)
DistributeLayoutAttr setupStoreMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, const uArch::uArch *uArch)
Sets up the anchor layout for a store matrix operation.
Include the generated interface declarations.
DenseMap< OpOperand *, SmallVector< Value > > RegionBranchSuccessorMapping
A mapping from successor operands to successor inputs.
bool operator==(StringAttr lhs, std::nullptr_t)
Define comparisons for StringAttr against nullptr and itself to avoid the StringRef overloads from be...
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:136
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144
virtual int getSubgroupSize() const =0
const Instruction * getInstruction(InstructionKind instKind) const
Definition uArchBase.h:163