MLIR 23.0.0git
XeGPUPropagateLayout.cpp
Go to the documentation of this file.
1//===- XeGPUPropagateLayout.cpp - XeGPU Layout Propagation ------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
21#include "mlir/IR/Attributes.h"
22#include "mlir/IR/Builders.h"
25#include "mlir/IR/Operation.h"
26#include "mlir/IR/Value.h"
27#include "mlir/IR/Visitors.h"
31#include "mlir/Support/LLVM.h"
32#include "llvm/ADT/ArrayRef.h"
33#include "llvm/ADT/STLExtras.h"
34#include "llvm/ADT/SmallSet.h"
35#include "llvm/ADT/SmallVector.h"
36#include "llvm/ADT/TypeSwitch.h"
37#include "llvm/Support/Casting.h"
38#include "llvm/Support/Debug.h"
39#include "llvm/Support/LogicalResult.h"
40#include "llvm/Support/raw_ostream.h"
41
42namespace mlir {
43namespace xegpu {
44#define GEN_PASS_DEF_XEGPUPROPAGATELAYOUT
45#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
46} // namespace xegpu
47} // namespace mlir
48
49#define DEBUG_TYPE "xegpu-propagate-layout"
50#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
51
52using namespace mlir;
53using namespace mlir::dataflow;
54
55namespace {
56
57//===----------------------------------------------------------------------===//
58// LayoutInfo
59//===----------------------------------------------------------------------===//
60
61/// Helper class for tracking the analysis state of an mlir value. For layout
62/// propagation, the analysis state is simply the distribution layout of
63/// each value. The distribution layout information is encapsulated using
64/// xegpu::DistributeLayoutAttr class which can hold information about any type
65/// of distribution layout that XeGPU dialect supports. Purpose of this analysis
66/// to propagate some unique distribution layout for each value in the program
67/// starting from a set of anchor operations (like DPAS, StoreNd, etc.). Note
68/// that analysis will reach a fixed point when all values are reached some
69/// layout and, analysis does not try to modify any already assigned layouts.
70///
71/// Given this, LayoutInfo satisifies the following properties:
72/// 1) A LayoutInfo value can be in one of two states - `assigned` or `not
73/// assigned`.
74/// 2) Two LayoutInfo values are equal if they are both assigned or
75/// both not assigned. The concrete value of assigned state does not matter.
76/// 3) The meet operator works as follows:
77/// - If current state is assigned, return the current state. (already
78/// a unique layout is assigned. don't change it)
79/// - Otherwise, return the other state.
80
81struct LayoutInfo {
82private:
83 xegpu::DistributeLayoutAttr storage = nullptr;
84
85public:
86 LayoutInfo() = default;
87 LayoutInfo(const xegpu::DistributeLayoutAttr &layout) : storage(layout) {}
88
89 // Two lattice values are equal if they have `some` layout. The actual
90 // content of the layout does not matter.
91 bool operator==(const LayoutInfo &other) const {
92 return this->isAssigned() == other.isAssigned();
93 }
94
95 static LayoutInfo meet(const LayoutInfo &lhs, const LayoutInfo &rhs);
96
97 static LayoutInfo join(const LayoutInfo &lhs, const LayoutInfo &rhs);
98
99 void print(raw_ostream &os) const;
100
101 bool isAssigned() const { return storage != nullptr; }
102
103 LayoutInfo transpose(ArrayRef<int64_t> permutation) const;
104
105 SmallVector<int> getLaneLayout() const;
106
107 SmallVector<int> getLaneData() const;
108
109 SmallVector<int> getInstData() const;
110
111 SmallVector<int> getSgLayout() const;
112
113 SmallVector<int> getSgData() const;
114
115 SmallVector<int> getOrder() const;
116
117 bool isSliceLayout() const {
118 if (!isAssigned())
119 return false;
120 return isa<xegpu::SliceAttr>(storage);
121 }
122
123 int64_t getRank() const {
124 if (!isAssigned())
125 return -1;
126 return storage.getRank();
127 }
128
129 Attribute get() { return storage; }
130 void set(const xegpu::DistributeLayoutAttr &layout) { storage = layout; }
131};
132
133SmallVector<int> LayoutInfo::getLaneLayout() const {
134 if (!isAssigned())
135 return {};
136 return llvm::map_to_vector(storage.getEffectiveLaneLayoutAsInt(),
137 [](int64_t val) { return static_cast<int>(val); });
138}
139
140SmallVector<int> LayoutInfo::getLaneData() const {
141 if (!isAssigned())
142 return {};
143 return llvm::map_to_vector(storage.getEffectiveLaneDataAsInt(),
144 [](int64_t val) { return static_cast<int>(val); });
145}
146
147SmallVector<int> LayoutInfo::getInstData() const {
148 if (!isAssigned())
149 return {};
150 return llvm::map_to_vector(storage.getEffectiveInstDataAsInt(),
151 [](int64_t val) { return static_cast<int>(val); });
152}
153
154SmallVector<int> LayoutInfo::getSgLayout() const {
155 if (!isAssigned())
156 return {};
157 return llvm::map_to_vector(storage.getEffectiveSgLayoutAsInt(),
158 [](int64_t val) { return static_cast<int>(val); });
159}
160
161SmallVector<int> LayoutInfo::getSgData() const {
162 if (!isAssigned())
163 return {};
164 return llvm::map_to_vector(storage.getEffectiveSgDataAsInt(),
165 [](int64_t val) { return static_cast<int>(val); });
166}
167
168SmallVector<int> LayoutInfo::getOrder() const {
169 if (!isAssigned() || !storage.getOrder())
170 return {};
171 return llvm::map_to_vector(storage.getOrder().asArrayRef(),
172 [](int64_t val) { return static_cast<int>(val); });
173}
174
175void LayoutInfo::print(raw_ostream &os) const {
176 if (isAssigned()) {
177 os << storage;
178 } else {
179 os << "Not assigned.";
180 }
181}
182
183LayoutInfo LayoutInfo::meet(const LayoutInfo &lhs, const LayoutInfo &rhs) {
184 if (!lhs.isAssigned())
185 return rhs;
186 return lhs;
187}
188
189/// Since this is a backward analysis, join method is not used.
190LayoutInfo LayoutInfo::join(const LayoutInfo &lhs, const LayoutInfo &rhs) {
191 llvm_unreachable("Join should not be triggered by layout propagation.");
192}
193
194/// Construct a new layout with the transposed inst_data or lane_layout,
195/// lane_data.
196LayoutInfo LayoutInfo::transpose(ArrayRef<int64_t> permutation) const {
197 if (!isAssigned())
198 return {};
199 // Check if the permutation is valid.
200 llvm::SmallSet<int64_t, 4> seen(permutation.begin(), permutation.end());
201 bool hasDuplicates = seen.size() != permutation.size();
202 bool withinRange = llvm::all_of(permutation, [&](int64_t idx) {
203 return idx >= 0 && idx < static_cast<int64_t>(permutation.size());
204 });
205
206 if (!withinRange || hasDuplicates) {
207 assert(false && "Invalid permutation for transpose.");
208 return {};
209 }
210
211 SmallVector<int32_t> laneLayout;
212 SmallVector<int32_t> laneData;
213 SmallVector<int32_t> instData;
214 SmallVector<int32_t> sgLayout;
217
218 for (int64_t idx : permutation) {
219 if (getLaneLayout().size()) {
220 laneLayout.push_back(static_cast<int32_t>(getLaneLayout()[idx]));
221 laneData.push_back(static_cast<int32_t>(getLaneData()[idx]));
222 }
223 if (getInstData().size())
224 instData.push_back(static_cast<int32_t>(getInstData()[idx]));
225 if (getSgData().size()) {
226 sgLayout.push_back(static_cast<int32_t>(getSgLayout()[idx]));
227 sgData.push_back(static_cast<int32_t>(getSgData()[idx]));
228 }
229 if (getOrder().size()) {
230 order.push_back(static_cast<int32_t>(getOrder()[idx]));
231 }
232 }
233 auto orderAttr = order.size()
234 ? DenseI32ArrayAttr::get(storage.getContext(), order)
235 : nullptr;
236 xegpu::LayoutAttr layoutAttr;
237 if (getLaneLayout().size())
238 layoutAttr =
239 xegpu::LayoutAttr::get(storage.getContext(), laneLayout, laneData);
240 if (getInstData().size())
241 layoutAttr = xegpu::LayoutAttr::get(storage.getContext(), instData);
242 if (getSgData().size())
243 layoutAttr = xegpu::LayoutAttr::get(
244 storage.getContext(),
245 DenseI32ArrayAttr::get(storage.getContext(), sgLayout),
246 DenseI32ArrayAttr::get(storage.getContext(), sgData),
247 /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
248 /*lane_data =*/nullptr, orderAttr);
249 return LayoutInfo(layoutAttr);
250}
251
252//===----------------------------------------------------------------------===//
253// LayoutInfoLattice
254//===----------------------------------------------------------------------===//
255
256/// Lattice holding the LayoutInfo for each value.
257struct LayoutInfoLattice : public Lattice<LayoutInfo> {
259 using Lattice::Lattice;
260};
261
262/// Helper Functions to get default layouts. A `default layout` is a layout that
263/// is assigned to a value when the layout is not fixed by some anchor operation
264/// (like DPAS).
265
266/// Helper Function to get the default layout for uniform values like constants.
267/// For 1D vector, lane_layout is [subgroupSize] and lane_data is [1].
268/// For 2D vector, lane_layout is [1, subgroupSize] and lane_data is [1, 1].
269static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
270 unsigned rank,
271 const xegpu::uArch::uArch *uArch) {
272 assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
273 if (rank == 1) {
274 return LayoutInfo(
275 xegpu::LayoutAttr::get(ctx, {uArch->getSubgroupSize()}, {1}));
276 }
277 return LayoutInfo(
278 xegpu::LayoutAttr::get(ctx, {1, uArch->getSubgroupSize()}, {1, 1}));
279}
280
281static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
282 unsigned rank, int subgroupSize) {
283 assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
284 if (rank == 1) {
285 return LayoutInfo(xegpu::LayoutAttr::get(ctx, {subgroupSize}, {1}));
286 }
287 return LayoutInfo(xegpu::LayoutAttr::get(ctx, {1, subgroupSize}, {1, 1}));
288}
289
290/// Helper to get the default layout for 2D block operations.
291template <typename Ty>
292static LayoutInfo getSIMTLayoutInfoBlockIO(Ty ty,
294 unsigned packingSize) {
295 // Expecting a 1D or 2D vector.
296 assert((ty.getRank() == 1 || ty.getRank() == 2) &&
297 "Expected 1D or 2D vector.");
298 // Expecting int or float element type.
299 assert(ty.getElementType().isIntOrFloat() &&
300 "Expected int or float element type.");
301 // If the rank is 1, then return default layout for 1D vector.
302 if (ty.getRank() == 1)
303 return getDefaultSIMTLayoutInfo(ty.getContext(), 1, uArch);
304 // Packing factor is determined by the element type bitwidth.
305 unsigned bitwidth = ty.getElementType().getIntOrFloatBitWidth();
306 int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
307 return LayoutInfo(xegpu::LayoutAttr::get(
308 ty.getContext(), {1, uArch->getSubgroupSize()}, {1, packingFactor}));
309}
310
311//===----------------------------------------------------------------------===//
312// LayoutInfoPropagation
313//===----------------------------------------------------------------------===//
314
315/// Backward data flow analysis to propagate the lane_layout and lane_data of
316/// each value in the program. Currently, the layouts for operands DPAS,
317/// StoreNd, and StoreScatter are fixed (known before propagation). Purpose of
318/// this analysis is to propagate those known layouts to all their producers and
319/// (other) consumers.
320class LayoutInfoPropagation
321 : public SparseBackwardDataFlowAnalysis<LayoutInfoLattice> {
322private:
323 xegpu::LayoutKind layoutKind;
324 unsigned indexBitWidth;
325 void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
327
328 void visitStoreNdOp(xegpu::StoreNdOp store,
331
332 void visitStoreScatterOp(xegpu::StoreScatterOp storeScatter,
335
336 void visitLoadNdOp(xegpu::LoadNdOp load,
339
340 void visitLoadGatherOp(xegpu::LoadGatherOp load,
343
344 void visitTransposeOp(vector::TransposeOp transpose,
347
348 void visitVectorBitcastOp(vector::BitCastOp bitcast,
351
352 void visitCreateDescOp(xegpu::CreateDescOp createDesc,
355
356 void visitUpdateNdOffsetOp(xegpu::UpdateNdOffsetOp updateNdOffset,
359
360 void visitPrefetchNdOp(xegpu::PrefetchNdOp prefetch,
363
364 void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction,
367
368 void visitVectorBroadCastOp(vector::BroadcastOp broadcast,
371 void visitShapeCastOp(vector::ShapeCastOp shapeCast,
374 void
375 visitInsertStridedSliceOp(vector::InsertStridedSliceOp insertStridedSlice,
378
379 void visitLoadMatrixOp(xegpu::LoadMatrixOp load,
382
383 void visitStoreMatrixOp(xegpu::StoreMatrixOp store,
386
387 void visitLoadGatherOp(xegpu::LoadMatrixOp load,
390
391 void visitStoreScatterOp(xegpu::StoreMatrixOp store,
394
395 bool hasParamsOfLayoutKind(xegpu::DistributeLayoutAttr anchorLayout);
396
397public:
398 LayoutInfoPropagation(DataFlowSolver &solver,
399 SymbolTableCollection &symbolTable,
400 xegpu::LayoutKind layoutKind, unsigned indexBitWidth)
401 : SparseBackwardDataFlowAnalysis(solver, symbolTable),
402 layoutKind(layoutKind), indexBitWidth(indexBitWidth) {}
404
405 LogicalResult
406 visitOperation(Operation *op, ArrayRef<LayoutInfoLattice *> operands,
407 ArrayRef<const LayoutInfoLattice *> results) override;
408
409 void visitBranchOperand(OpOperand &operand) override {};
410
411 void visitCallOperand(OpOperand &operand) override {};
412
413 void
414 visitNonControlFlowArguments(RegionSuccessor &successor,
415 ArrayRef<BlockArgument> arguments) override {};
416
417 void visitExternalCall(CallOpInterface call,
419 ArrayRef<const LayoutInfoLattice *> results) override {
420 };
421
422 void setToExitState(LayoutInfoLattice *lattice) override {
423 (void)lattice->meet(LayoutInfo());
424 }
425};
426} // namespace
427
428LogicalResult LayoutInfoPropagation::visitOperation(
429 Operation *op, ArrayRef<LayoutInfoLattice *> operands,
430 ArrayRef<const LayoutInfoLattice *> results) {
432 .Case(
433 [&](xegpu::DpasOp dpasOp) { visitDpasOp(dpasOp, operands, results); })
434 .Case([&](xegpu::StoreNdOp storeNdOp) {
435 visitStoreNdOp(storeNdOp, operands, results);
436 })
437 .Case([&](xegpu::StoreScatterOp storeScatterOp) {
438 visitStoreScatterOp(storeScatterOp, operands, results);
439 })
440 .Case([&](xegpu::LoadNdOp loadNdOp) {
441 visitLoadNdOp(loadNdOp, operands, results);
442 })
443 .Case([&](xegpu::LoadGatherOp loadGatherOp) {
444 visitLoadGatherOp(loadGatherOp, operands, results);
445 })
446 .Case([&](xegpu::CreateDescOp createDescOp) {
447 visitCreateDescOp(createDescOp, operands, results);
448 })
449 .Case([&](xegpu::UpdateNdOffsetOp updateNdOffsetOp) {
450 visitUpdateNdOffsetOp(updateNdOffsetOp, operands, results);
451 })
452 .Case([&](xegpu::PrefetchNdOp prefetchNdOp) {
453 visitPrefetchNdOp(prefetchNdOp, operands, results);
454 })
455 .Case([&](vector::TransposeOp transposeOp) {
456 visitTransposeOp(transposeOp, operands, results);
457 })
458 .Case([&](vector::BitCastOp bitcastOp) {
459 visitVectorBitcastOp(bitcastOp, operands, results);
460 })
461 .Case([&](vector::MultiDimReductionOp reductionOp) {
462 visitVectorMultiReductionOp(reductionOp, operands, results);
463 })
464 .Case([&](vector::BroadcastOp broadcastOp) {
465 visitVectorBroadCastOp(broadcastOp, operands, results);
466 })
467 .Case([&](vector::ShapeCastOp shapeCastOp) {
468 visitShapeCastOp(shapeCastOp, operands, results);
469 })
470 .Case([&](vector::InsertStridedSliceOp insertStridedSliceOp) {
471 visitInsertStridedSliceOp(insertStridedSliceOp, operands, results);
472 })
473 .Case([&](xegpu::LoadMatrixOp loadMatrixOp) {
474 visitLoadMatrixOp(loadMatrixOp, operands, results);
475 })
476 .Case([&](xegpu::StoreMatrixOp storeMatrixOp) {
477 visitStoreMatrixOp(storeMatrixOp, operands, results);
478 })
479 // All other ops.
480 .Default([&](Operation *op) {
481 for (const LayoutInfoLattice *resultInfo : results) {
482 if (!resultInfo->getValue().isAssigned())
483 continue;
484 for (auto [operandInfo, operand] :
485 llvm::zip(operands, op->getOpOperands())) {
486 // If the operand type is not a vector or tensor descriptor, skip
487 // it.
488 if (!isa<xegpu::TensorDescType, VectorType>(
489 operand.get().getType()))
490 continue;
491 // Propagate the result layout to the operand.
492 meet(operandInfo, *resultInfo);
493 }
494 }
495 });
496
497 return success();
498}
499
500bool LayoutInfoPropagation::hasParamsOfLayoutKind(
501 xegpu::DistributeLayoutAttr anchorLayout) {
502 if (anchorLayout == nullptr) {
503 return false;
504 }
505 if (layoutKind == xegpu::LayoutKind::InstData) {
506 return !(anchorLayout.getEffectiveInstDataAsInt().empty());
507 }
508 if (layoutKind == xegpu::LayoutKind::Lane) {
509 return !(anchorLayout.getEffectiveLaneLayoutAsInt().empty() ||
510 anchorLayout.getEffectiveLaneDataAsInt().empty());
511 }
512 if (layoutKind == xegpu::LayoutKind::Subgroup) {
513 return !(anchorLayout.getEffectiveSgLayoutAsInt().empty() ||
514 anchorLayout.getEffectiveSgDataAsInt().empty());
515 }
516 return false;
517}
518
519// This function returns all layouts for the given sgCount, whose sgData:
520// 1. Evenly divides the wgShape.
521// 2. Is a multiple of instData.
522// Example:
523// wgShape = [128, 64], instData = [8, 16], sgCount = 32
524// Returns layouts:
525// [(8,4), (16,2)], which correspond to sgData [16,16] and [8,32].
527 ArrayRef<int> instData,
528 int64_t sgCount) {
530 for (int sgLayout0 = 1; sgLayout0 <= sgCount; ++sgLayout0) {
531 if (sgCount % sgLayout0)
532 continue;
533 int sgLayout1 = sgCount / sgLayout0;
534 int sgData0 = wgShape[0] / sgLayout0;
535 int sgData1 = wgShape[1] / sgLayout1;
536 if ((wgShape[0] % sgLayout0 || wgShape[1] % sgLayout1) ||
537 (sgData0 % instData[0] || sgData1 % instData[1]))
538 continue;
539 candidates.emplace_back(sgLayout0, sgLayout1);
540 }
541 // Sort primarily by how balanced they are
542 // (i.e., minimize the absolute difference between the two dimensions), and
543 // secondarily by the first dimension in ascending order.
544 llvm::sort(candidates, [](const std::pair<int, int> &lhs,
545 const std::pair<int, int> &rhs) {
546 int diffLhs = std::abs(lhs.first - lhs.second);
547 int diffRhs = std::abs(rhs.first - rhs.second);
548 if (diffLhs != diffRhs)
549 return diffLhs < diffRhs;
550 return lhs.first < rhs.first;
551 });
552 return candidates;
553}
554
555FailureOr<int64_t> getNumSg(Operation *op, const int sgSize) {
556 // Oblivious to workitem layout, the total count matters.
557 auto gpuFunc = op->getParentOfType<gpu::GPUFuncOp>();
558 if (!gpuFunc)
559 return failure();
560 auto knownBlockSize = gpuFunc.getKnownBlockSize();
561 if (!knownBlockSize.has_value())
562 return failure();
563 const int flatBlockSize = llvm::product_of(knownBlockSize.value());
564 return flatBlockSize / sgSize;
565}
566
567void LayoutInfoPropagation::visitPrefetchNdOp(
568 xegpu::PrefetchNdOp prefetch, ArrayRef<LayoutInfoLattice *> operands,
569 ArrayRef<const LayoutInfoLattice *> results) {
570
571 LayoutInfo prefetchLayout;
572 xegpu::DistributeLayoutAttr anchorLayout = prefetch.getLayoutAttr();
573 if (hasParamsOfLayoutKind(anchorLayout)) {
574 prefetchLayout = LayoutInfo(anchorLayout);
575 } else {
576 // Here we assign the default layout to the tensor descriptor operand of
577 // prefetch.
578 auto tdescTy = prefetch.getTensorDescType();
579
580 const uArch *uArch = getUArch(getChipStr(prefetch).value_or(""));
581 if (!uArch)
582 return;
583 const auto *uArchInstruction =
584 dyn_cast<xegpu::uArch::Subgroup2DBlockPrefetchInstruction>(
585 uArch->getInstruction(
586 xegpu::uArch::InstructionKind::Subgroup2DBlockPrefetch));
587
588 auto blockWHC =
589 uArchInstruction->getBlockWidthHeightCount(tdescTy.getElementType());
590 if (!blockWHC)
591 prefetch.emitWarning("No known block params found for the element type.");
592 auto [bWidth, bHeight, bCount] = blockWHC.value();
593 SmallVector<int> instData;
594 int instWidth = xegpu::getLargestDivisor(
595 static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 1)), bWidth);
596 if (instWidth == -1)
597 prefetch.emitWarning(
598 "No suitable instruction multiple found for the given shape.");
599 if (tdescTy.getRank() == 1)
600 instData = {instWidth};
601 else {
602 int instHeight = xegpu::getLargestDivisor(
603 static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 2)), bHeight);
604 if (instHeight == -1)
605 prefetch.emitWarning(
606 "No suitable instruction multiple found for the given shape.");
607 instData = {instHeight, instWidth};
608 }
609
610 if (layoutKind == xegpu::LayoutKind::InstData)
611 prefetchLayout =
612 LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(), instData));
613 else
614 prefetchLayout = getSIMTLayoutInfoBlockIO(
615 tdescTy, uArch, uArchInstruction->getPackedFormatBitSize());
616
617 prefetch.setLayoutAttr(
618 dyn_cast<xegpu::DistributeLayoutAttr>(prefetchLayout.get()));
619 }
620 // Propagate the layout to the source tensor descriptor.
621 propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout));
622}
623
624void LayoutInfoPropagation::visitVectorMultiReductionOp(
625 vector::MultiDimReductionOp reduction,
626 ArrayRef<LayoutInfoLattice *> operands,
627 ArrayRef<const LayoutInfoLattice *> results) {
628 // The layout of the result must be present.
629 LayoutInfo resLayoutInfo = results[0]->getValue();
630 if (!resLayoutInfo.isAssigned())
631 return;
632
633 VectorType sourceTy = reduction.getSourceVectorType();
634 SmallVector<int64_t> reductionDims(reduction.getReductionDims());
635
636 const uArch *uArch = getUArch(xegpu::getChipStr(reduction).value_or(""));
637 if (!uArch)
638 return;
639 auto consumerLayoutAttr =
640 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
641
642 // The result layout represents the layout requirements of the operation.
643 // it is recorded to anchor layout or temporary layout.
644 // it must be honored for current op and may conflict with the layout
645 // propagated from consumer op, the conflict is resolved in later phase by
646 // converting the required result layout to the consumer layout
647 auto requiredResLayoutAttr = xegpu::setupMultiReductionResultLayout(
648 layoutKind, sourceTy, consumerLayoutAttr, reductionDims, uArch);
649
650 xegpu::setTemporaryLayout(reduction->getResult(0), requiredResLayoutAttr);
651
652 // derive the source layout from the dominant layout and reduction dims
653 auto srcLayoutAttr = xegpu::inferMultiReductionSourceLayout(
654 requiredResLayoutAttr, reductionDims);
655
656 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
657 // Accumulator should have the same layout as the result.
658 propagateIfChanged(operands[1],
659 operands[1]->meet(LayoutInfo(requiredResLayoutAttr)));
660}
661
662void LayoutInfoPropagation::visitVectorBroadCastOp(
663 vector::BroadcastOp broadcast, ArrayRef<LayoutInfoLattice *> operands,
664 ArrayRef<const LayoutInfoLattice *> results) {
665 // The layout of the result must be present.
666 LayoutInfo resLayoutInfo = results[0]->getValue();
667 if (!resLayoutInfo.isAssigned())
668 return;
669
670 // Only consider vector to vector broadcasts for now.
671 VectorType resultTy = broadcast.getResultVectorType();
672 VectorType sourceTy = dyn_cast<VectorType>(broadcast.getSourceType());
673 // skip layout propagation for non-vector source operand.
674 if (!sourceTy)
675 return;
676
677 auto srcShape = sourceTy.getShape();
678 auto resShape = resultTy.getShape();
679
680 auto resultLayoutAttr =
681 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
682
683 xegpu::DistributeLayoutAttr srcLayoutAttr =
684 xegpu::inferBroadcastSourceLayout(resultLayoutAttr, resShape, srcShape);
685
686 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
687}
688
689void LayoutInfoPropagation::visitShapeCastOp(
690 vector::ShapeCastOp shapeCast, ArrayRef<LayoutInfoLattice *> operands,
691 ArrayRef<const LayoutInfoLattice *> results) {
692 // The layout of the result must be present.
693 LayoutInfo resLayoutInfo = results[0]->getValue();
694 if (!resLayoutInfo.isAssigned())
695 return;
696 ArrayRef<int64_t> resShape = shapeCast.getResultVectorType().getShape();
697 ArrayRef<int64_t> srcShape = shapeCast.getSourceVectorType().getShape();
698 auto resultLayoutAttr =
699 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
700
701 xegpu::DistributeLayoutAttr srcLayoutAttr =
702 xegpu::inferShapeCastSourceLayout(resultLayoutAttr, resShape, srcShape);
703
704 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
705}
706
707/// Propagate the layout of the result tensor to the source tensor descriptor
708/// in UpdateNdOffsetOp.
709void LayoutInfoPropagation::visitUpdateNdOffsetOp(
710 xegpu::UpdateNdOffsetOp updateNdOffset,
711 ArrayRef<LayoutInfoLattice *> operands,
712 ArrayRef<const LayoutInfoLattice *> results) {
713 // The layout of the result must be present.
714 LayoutInfo resultLayout = results[0]->getValue();
715 if (!resultLayout.isAssigned())
716 return;
717 // Propagate the layout to the source operand.
718 propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
719}
720
721/// Set the layouts for DPAS A, B, and C operands.
722void LayoutInfoPropagation::visitDpasOp(
723 xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
724 ArrayRef<const LayoutInfoLattice *> results) {
725 LayoutInfo dpasALayout;
726 LayoutInfo dpasBLayout;
727 LayoutInfo dpasCDLayout;
728
729 xegpu::DistributeLayoutAttr anchorLayoutCD = dpas.getLayoutCdAttr();
730 if (hasParamsOfLayoutKind(anchorLayoutCD)) {
731 xegpu::DistributeLayoutAttr anchorLayoutA = dpas.getLayoutAAttr();
732 xegpu::DistributeLayoutAttr anchorLayoutB = dpas.getLayoutBAttr();
733 assert(hasParamsOfLayoutKind(anchorLayoutA) &&
734 "Expected anchor layout for DPAS A operand.");
735 assert(hasParamsOfLayoutKind(anchorLayoutB) &&
736 "Expected anchor layout for DPAS B operand.");
737 dpasALayout = LayoutInfo(anchorLayoutA);
738 dpasBLayout = LayoutInfo(anchorLayoutB);
739 dpasCDLayout = LayoutInfo(anchorLayoutCD);
740 } else {
741 const uArch *uArch = getUArch(getChipStr(dpas).value_or(""));
742 if (!uArch)
743 return;
744 VectorType aTy = dpas.getLhsType();
745 VectorType bTy = dpas.getRhsType();
746 VectorType cdTy = dpas.getResultType();
747
748 xegpu::DistributeLayoutAttr consumerLayoutAttr = nullptr;
749 xegpu::DistributeLayoutAttr requiredCDLayoutAttr, requiredALayout,
750 requiredBLayout;
751
752 int numSg = 0;
753 if (layoutKind == xegpu::LayoutKind::Subgroup) {
754 LayoutInfo consumerLayout = results[0]->getValue();
755 if (!consumerLayout.isAssigned())
756 return;
757 consumerLayoutAttr =
758 dyn_cast<xegpu::DistributeLayoutAttr>(consumerLayout.get());
759 auto numSgOrErr = getNumSg(dpas, uArch->getSubgroupSize());
760 if (failed(numSgOrErr)) {
761 dpas.emitWarning(
762 "Unable to determine the number of subgroups for the operation.");
763 return;
764 }
765 numSg = numSgOrErr.value();
766 }
767 auto layouts = xegpu::setupDpasLayout(layoutKind, aTy, bTy, cdTy,
768 consumerLayoutAttr, uArch, numSg);
769 if (!layouts.has_value()) {
770 dpas.emitWarning(
771 "Failed to determine required layouts for DPAS operands.");
772 return;
773 }
774
775 std::tie(requiredALayout, requiredBLayout, requiredCDLayoutAttr) = *layouts;
776
777 dpas.setLayoutAAttr(requiredALayout);
778 dpas.setLayoutBAttr(requiredBLayout);
779 dpas.setLayoutCdAttr(requiredCDLayoutAttr);
780 dpasALayout = LayoutInfo(requiredALayout);
781 dpasBLayout = LayoutInfo(requiredBLayout);
782 dpasCDLayout = LayoutInfo(requiredCDLayoutAttr);
783 }
784 propagateIfChanged(operands[0], operands[0]->meet(dpasALayout));
785 propagateIfChanged(operands[1], operands[1]->meet(dpasBLayout));
786 if (operands.size() > 2)
787 propagateIfChanged(operands[2], operands[2]->meet(dpasCDLayout));
788}
789
790/// Set the layout for the value and tensor descriptor operands in StoreNdOp.
791void LayoutInfoPropagation::visitStoreNdOp(
792 xegpu::StoreNdOp store, ArrayRef<LayoutInfoLattice *> operands,
793 ArrayRef<const LayoutInfoLattice *> results) {
794 LayoutInfo storeLayout;
795 xegpu::DistributeLayoutAttr anchorLayout = store.getLayoutAttr();
796 if (hasParamsOfLayoutKind(anchorLayout)) {
797 storeLayout = LayoutInfo(anchorLayout);
798 } else {
799 const uArch *uArch = getUArch(getChipStr(store).value_or(""));
800 if (!uArch)
801 return;
802 const auto *uArchInstruction =
803 dyn_cast<xegpu::uArch::Subgroup2DBlockStoreInstruction>(
804 uArch->getInstruction(
805 xegpu::uArch::InstructionKind::Subgroup2DBlockStore));
806 VectorType dataTy = store.getValueType();
807 auto blockWHC = uArchInstruction->getBlockWidthHeightCount(
808 store.getValueType().getElementType());
809 if (!blockWHC)
810 store.emitWarning("No known block params found for the element type.");
811 auto [bWidth, bHeight, bCount] = blockWHC.value();
812 SmallVector<int> instData;
813 int instWidth = xegpu::getLargestDivisor(
814 static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 1)), bWidth);
815 if (instWidth == -1)
816 store.emitWarning(
817 "No suitable instruction multiple found for the given shape.");
818 if (dataTy.getRank() == 1)
819 instData = {instWidth};
820 else {
821 int instHeight = xegpu::getLargestDivisor(
822 static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 2)), bHeight);
823 if (instHeight == -1)
824 store.emitWarning(
825 "No suitable instruction multiple found for the given shape.");
826 instData = {instHeight, instWidth};
827 }
828
829 if (layoutKind == xegpu::LayoutKind::InstData)
830 storeLayout =
831 LayoutInfo(xegpu::LayoutAttr::get(dataTy.getContext(), instData));
832 else if (layoutKind == xegpu::LayoutKind::Lane)
833 storeLayout =
834 getSIMTLayoutInfoBlockIO(store.getValueType(), uArch,
835 uArchInstruction->getPackedFormatBitSize());
836 else { // xegpu::LayoutKind::Subgroup
837 auto sgSize = uArch->getSubgroupSize();
838 auto numSgOrErr = getNumSg(store, sgSize);
839 if (failed(numSgOrErr)) {
840 store.emitWarning(
841 "Unable to determine the number of subgroups for the operation.");
842 return;
843 }
844 auto sgLayouts = getValidLayouts(store.getValueType().getShape(),
845 instData, numSgOrErr.value());
846 if (sgLayouts.empty()) {
847 store.emitWarning(
848 "Unable to determine suitable subgroup layout for store value.");
849 return;
850 }
851 SmallVector<int> sgLayout = {sgLayouts[0].first, sgLayouts[0].second};
852 SmallVector<int> sgData = {
853 static_cast<int>(dataTy.getShape()[0]) / sgLayout[0],
854 static_cast<int>(dataTy.getShape()[1]) / sgLayout[1]};
855 storeLayout = LayoutInfo(xegpu::LayoutAttr::get(
856 dataTy.getContext(),
857 DenseI32ArrayAttr::get(dataTy.getContext(), sgLayout),
858 DenseI32ArrayAttr::get(dataTy.getContext(), sgData),
859 /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
860 /*lane_data =*/nullptr, /*order =*/nullptr));
861 }
862 store.setLayoutAttr(
863 dyn_cast<xegpu::DistributeLayoutAttr>(storeLayout.get()));
864 }
865 // Propagate the layout to the value operand.
866 // Both operands should have the same layout
867 for (LayoutInfoLattice *operand : operands)
868 propagateIfChanged(operand, operand->meet(storeLayout));
869}
870
871/// Propagate the layout of the value to the tensor descriptor operand in
872/// LoadNdOp.
873void LayoutInfoPropagation::visitLoadNdOp(
874 xegpu::LoadNdOp load, ArrayRef<LayoutInfoLattice *> operands,
875 ArrayRef<const LayoutInfoLattice *> results) {
876 LayoutInfo loadLayout;
877 xegpu::DistributeLayoutAttr anchorLayout = load.getLayoutAttr();
878 if (hasParamsOfLayoutKind(anchorLayout)) {
879 loadLayout = LayoutInfo(anchorLayout);
880 } else {
881
882 LayoutInfo valueLayout = results[0]->getValue();
883 // Need the layout of the value to propagate to the tensor descriptor.
884 if (!valueLayout.isAssigned())
885 return;
886 loadLayout = valueLayout;
887 // LoadNdOp has the transpose effect. However, at the stage of this analysis
888 // this effect is not expected and should be abstracted away. Emit a
889 // warning.
890 if (auto transpose = load.getTranspose()) {
891 load.emitWarning("Transpose effect is not expected for LoadNdOp at "
892 "LayoutInfoPropagation stage.");
893 loadLayout = valueLayout.transpose(transpose.value());
894 }
895 load.setLayoutAttr(dyn_cast<xegpu::DistributeLayoutAttr>(loadLayout.get()));
896 }
897 // Propagate the new layout to the tensor descriptor operand.
898 propagateIfChanged(operands[0], operands[0]->meet(loadLayout));
899}
900
901/// For vector::TransposeOp, the layout of the result is transposed and
902/// propagated to the operand.
903void LayoutInfoPropagation::visitTransposeOp(
904 vector::TransposeOp transpose, ArrayRef<LayoutInfoLattice *> operands,
905 ArrayRef<const LayoutInfoLattice *> results) {
906 // Need the layout of transpose result to propagate to the operands.
907 LayoutInfo resultLayout = results[0]->getValue();
908 if (!resultLayout.isAssigned())
909 return;
910 auto consumerLayoutAttr =
911 dyn_cast<xegpu::DistributeLayoutAttr>(resultLayout.get());
912 auto srcLayoutAttr = xegpu::inferTransposeSourceLayout(
913 consumerLayoutAttr, transpose.getPermutation());
914 // Propagate the new layout to the vector operand.
915 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
916}
917
918/// For vector::BitCastOp, the lane_data of the source layout is changed based
919/// on the bit width of the source and result types.
920void LayoutInfoPropagation::visitVectorBitcastOp(
921 vector::BitCastOp bitcast, ArrayRef<LayoutInfoLattice *> operands,
922 ArrayRef<const LayoutInfoLattice *> results) {
923 // Need the layout of bitcast result to propagate to the operands.
924 LayoutInfo resLayoutInfo = results[0]->getValue();
925 if (!resLayoutInfo.isAssigned())
926 return;
927
928 auto srcVecType = bitcast.getSourceVectorType();
929 auto resVecType = bitcast.getResultVectorType();
930
931 auto consumerLayoutAttr =
932 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
933 const uArch *uArch = getUArch(xegpu::getChipStr(bitcast).value_or(""));
934 if (!uArch)
935 return;
936 auto requiredResLayoutAttr = setupBitCastResultLayout(
937 layoutKind, srcVecType, resVecType, consumerLayoutAttr, uArch);
938
939 xegpu::setTemporaryLayout(bitcast->getResult(0), requiredResLayoutAttr);
940
941 int inElemTyBitWidth = srcVecType.getElementType().getIntOrFloatBitWidth();
942 int outElemTyBitWidth = resVecType.getElementType().getIntOrFloatBitWidth();
943
944 // derive the source layout from the dominant layout and reduction dims
945 auto srcLayoutAttr = xegpu::inferBitCastSourceLayout(
946 requiredResLayoutAttr, outElemTyBitWidth, inElemTyBitWidth);
947
948 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
949}
950
951void LayoutInfoPropagation::visitInsertStridedSliceOp(
952 vector::InsertStridedSliceOp insertStridedSlice,
953 ArrayRef<LayoutInfoLattice *> operands,
954 ArrayRef<const LayoutInfoLattice *> results) {
955 // The layout of the result must be present.
956 LayoutInfo resLayoutInfo = results[0]->getValue();
957 if (!resLayoutInfo.isAssigned())
958 return;
959
960 auto srcVecType = insertStridedSlice.getSourceVectorType();
961 auto resVecType = insertStridedSlice.getDestVectorType();
962
963 auto consumerLayoutAttr =
964 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
965 const uArch *uArch =
966 getUArch(xegpu::getChipStr(insertStridedSlice).value_or(""));
967 if (!uArch)
968 return;
969
970 auto requiredResLayoutAttr = xegpu::setupInsertStridedSliceResultLayout(
971 layoutKind, srcVecType, resVecType, consumerLayoutAttr, uArch);
972 xegpu::setTemporaryLayout(insertStridedSlice->getResult(0),
973 requiredResLayoutAttr);
974
976 requiredResLayoutAttr, resVecType.getShape(), srcVecType.getShape());
977 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
978 propagateIfChanged(operands[1],
979 operands[1]->meet(LayoutInfo(requiredResLayoutAttr)));
980}
981
982/// Propagate the layout of the result to the tensor descriptor, mask and offset
983/// operands in LoadGatherOp.
984void LayoutInfoPropagation::visitLoadGatherOp(
985 xegpu::LoadGatherOp load, ArrayRef<LayoutInfoLattice *> operands,
986 ArrayRef<const LayoutInfoLattice *> results) {
987 xegpu::DistributeLayoutAttr requiredAnchorLayoutAttr;
988 xegpu::DistributeLayoutAttr anchorLayoutAttr = load.getLayoutAttr();
989 const uArch *uArch = getUArch(getChipStr(load).value_or(""));
990 if (!uArch)
991 return;
992 auto subgroupSize = uArch->getSubgroupSize();
993 VectorType resVecTy = load.getValueType();
994 int chunkSize = load.getChunkSize().value_or(1);
995
996 LayoutInfo resLayoutInfo = results[0]->getValue();
997 if (!resLayoutInfo.isAssigned())
998 return;
999 auto consumerLayoutAttr =
1000 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
1001
1002 if (hasParamsOfLayoutKind(anchorLayoutAttr)) {
1003 requiredAnchorLayoutAttr = anchorLayoutAttr;
1004 } else {
1005 if (!resVecTy) {
1006 load.emitWarning("Not propagating, non-vector payload supplied.");
1007 return;
1008 }
1009 requiredAnchorLayoutAttr = xegpu::setupLoadGatherAnchorLayout(
1010 layoutKind, resVecTy, chunkSize, consumerLayoutAttr, uArch);
1011 load.setLayoutAttr(requiredAnchorLayoutAttr);
1012 }
1013
1014 auto maskLayoutAttr = requiredAnchorLayoutAttr;
1015 // Special handling mask layout for chunked ops: Enforce the default xegpu 1D
1016 // layout for mask.
1017 if (chunkSize > 1) {
1018 if (layoutKind == xegpu::LayoutKind::InstData)
1019 maskLayoutAttr =
1020 xegpu::LayoutAttr::get(load->getContext(), {subgroupSize});
1021 else if (layoutKind == xegpu::LayoutKind::Lane)
1022 maskLayoutAttr =
1023 xegpu::LayoutAttr::get(load->getContext(), {subgroupSize}, {1});
1024 else
1025 assert(false &&
1026 "chunked StoreScatterOp should not be used at workgroup level");
1027 }
1028
1029 LayoutInfo maskLayoutInfo = LayoutInfo(maskLayoutAttr);
1030 auto loadLayoutInfo = LayoutInfo(requiredAnchorLayoutAttr);
1031
1032 // Propagate the new layout to the tensor descriptor operand.
1033 if (isa<xegpu::TensorDescType>(load.getSourceType()))
1034 propagateIfChanged(operands[0], operands[0]->meet(loadLayoutInfo));
1035 // Propagate the new layout to the mask and optional offset operand.
1036 propagateIfChanged(operands[1], operands[1]->meet(maskLayoutInfo));
1037 if (load.getOffsets())
1038 propagateIfChanged(operands[2], operands[2]->meet(maskLayoutInfo));
1039}
1040
1041/// Propagate the layout of the descriptor to the vector offset operand in
1042/// CreateDescOp.
1043void LayoutInfoPropagation::visitCreateDescOp(
1044 xegpu::CreateDescOp createDesc, ArrayRef<LayoutInfoLattice *> operands,
1045 ArrayRef<const LayoutInfoLattice *> results) {
1046 LayoutInfo descLayout = results[0]->getValue();
1047 // Need the layout of the descriptor to propagate to the operands.
1048 if (!descLayout.isAssigned())
1049 return;
1050 const uArch *uArch = getUArch(getChipStr(createDesc).value_or(""));
1051 if (!uArch)
1052 return;
1053 // For offset operand propagate 1D default layout.
1054 LayoutInfo layout = getDefaultSIMTLayoutInfo(createDesc->getContext(), 1,
1055 uArch->getSubgroupSize());
1056 propagateIfChanged(operands[1], operands[1]->meet(layout));
1057}
1058
1059/// Set the layout for the value, tensor descriptor, offset and mask operands in
1060/// the StoreScatterOp.
1061void LayoutInfoPropagation::visitStoreScatterOp(
1062 xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands,
1063 ArrayRef<const LayoutInfoLattice *> results) {
1064
1065 xegpu::DistributeLayoutAttr requiredAnchorLayoutAttr;
1066 xegpu::DistributeLayoutAttr anchorLayoutAttr = storeScatter.getLayoutAttr();
1067 const uArch *uArch = getUArch(getChipStr(storeScatter).value_or(""));
1068 if (!uArch)
1069 return;
1070 auto subgroupSize = uArch->getSubgroupSize();
1071 VectorType srcVecTy = storeScatter.getValueType();
1072 int chunkSize = storeScatter.getChunkSize().value_or(1);
1073
1074 if (hasParamsOfLayoutKind(anchorLayoutAttr)) {
1075 requiredAnchorLayoutAttr = anchorLayoutAttr;
1076 } else {
1077 if (!srcVecTy) {
1078 storeScatter.emitWarning("Not propagating, non-vector payload supplied.");
1079 return;
1080 }
1081 requiredAnchorLayoutAttr = xegpu::setupStoreScatterAnchorLayout(
1082 layoutKind, srcVecTy, chunkSize, uArch);
1083 storeScatter.setLayoutAttr(requiredAnchorLayoutAttr);
1084 }
1085
1086 LayoutInfo srcLayoutInfo = LayoutInfo(requiredAnchorLayoutAttr);
1087 auto maskLayoutAttr = requiredAnchorLayoutAttr;
1088 // Special handling mask layout for chunked ops: Enforce the default xegpu 1D
1089 // layout for mask.
1090 if (chunkSize > 1) {
1091 if (layoutKind == xegpu::LayoutKind::InstData)
1092 maskLayoutAttr =
1093 xegpu::LayoutAttr::get(storeScatter->getContext(), {subgroupSize});
1094 else if (layoutKind == xegpu::LayoutKind::Lane)
1095 maskLayoutAttr = xegpu::LayoutAttr::get(storeScatter->getContext(),
1096 {subgroupSize}, {1});
1097 else
1098 assert(false &&
1099 "chunked StoreScatterOp should not be used at workgroup level");
1100 }
1101
1102 LayoutInfo maskLayoutInfo = LayoutInfo(maskLayoutAttr);
1103
1104 // Propagate the payload operand layout
1105 propagateIfChanged(operands[0], operands[0]->meet(srcLayoutInfo));
1106 // Propagate the destination (if tdesc) operand layout
1107 if (isa<xegpu::TensorDescType>(storeScatter.getDestType()))
1108 propagateIfChanged(operands[1], operands[1]->meet(srcLayoutInfo));
1109 // Propagate the new layout to the mask and optional offset operand.
1110 propagateIfChanged(operands[2], operands[2]->meet(maskLayoutInfo));
1111 if (storeScatter.getOffsets())
1112 propagateIfChanged(operands[3], operands[3]->meet(maskLayoutInfo));
1113}
1114
1115void LayoutInfoPropagation::visitLoadMatrixOp(
1116 xegpu::LoadMatrixOp loadMatrixOp, ArrayRef<LayoutInfoLattice *> operands,
1117 ArrayRef<const LayoutInfoLattice *> results) {
1118
1119 LayoutInfo resLayoutInfo = results[0]->getValue();
1120 if (!resLayoutInfo.isAssigned())
1121 return;
1122
1123 auto consumerLayoutAttr =
1124 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
1125
1126 xegpu::DistributeLayoutAttr anchorLayout = loadMatrixOp.getLayoutAttr();
1127
1128 // only need to set anchor layout, no need to porpagate to memdesc and
1129 // offset
1130 if (!hasParamsOfLayoutKind(anchorLayout)) {
1131 VectorType resVecTy =
1132 llvm::cast<VectorType>(loadMatrixOp.getRes().getType());
1133 const uArch *uArch = getUArch(getChipStr(loadMatrixOp).value_or(""));
1134 if (!uArch)
1135 return;
1136 auto requiredAnchorLayoutAttr = xegpu::setupLoadMatrixAnchorLayout(
1137 layoutKind, resVecTy, consumerLayoutAttr, uArch);
1138 loadMatrixOp.setLayoutAttr(requiredAnchorLayoutAttr);
1139 }
1140}
1141
1142// Store matrix is a flavor of scattered store for 2D shapes.
1143void LayoutInfoPropagation::visitStoreMatrixOp(
1144 xegpu::StoreMatrixOp storeMatrix, ArrayRef<LayoutInfoLattice *> operands,
1145 ArrayRef<const LayoutInfoLattice *> results) {
1146 xegpu::DistributeLayoutAttr anchorLayout = storeMatrix.getLayoutAttr();
1147 LayoutInfo layout;
1148 if (hasParamsOfLayoutKind(anchorLayout)) {
1149 layout = LayoutInfo(anchorLayout);
1150 } else {
1151 VectorType srcVecTy =
1152 llvm::cast<VectorType>(storeMatrix.getData().getType());
1153 const uArch *uArch = getUArch(getChipStr(storeMatrix).value_or(""));
1154 if (!uArch)
1155 return;
1156 auto requiredAnchorLayoutAttr =
1157 xegpu::setupStoreMatrixAnchorLayout(layoutKind, srcVecTy, uArch);
1158 storeMatrix.setLayoutAttr(requiredAnchorLayoutAttr);
1159 layout = LayoutInfo(requiredAnchorLayoutAttr);
1160 }
1161
1162 propagateIfChanged(operands[0], operands[0]->meet(layout));
1163}
1164
1165namespace {
1166//===----------------------------------------------------------------------===//
1167// RunLayoutInfoPropagation
1168//===----------------------------------------------------------------------===//
1169
1170/// Driver class for running the LayoutInfoPropagation analysis.
1171class RunLayoutInfoPropagation {
1172public:
1173 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RunLayoutInfoPropagation)
1174
1175 RunLayoutInfoPropagation(Operation *op, xegpu::LayoutKind layoutKind,
1176 unsigned indexBitWidth)
1177 : target(op) {
1178 SymbolTableCollection symbolTable;
1179 loadBaselineAnalyses(solver);
1180 solver.load<LayoutInfoPropagation>(symbolTable, layoutKind, indexBitWidth);
1181 (void)solver.initializeAndRun(op);
1182 }
1183
1184 LayoutInfo getLayoutInfo(Value val);
1185
1186 void printAnalysisResult(llvm::raw_ostream &os);
1187
1188private:
1189 DataFlowSolver solver;
1190 const Operation *target;
1191};
1192} // namespace
1193
1194LayoutInfo RunLayoutInfoPropagation::getLayoutInfo(Value val) {
1195 auto *state = solver.lookupState<LayoutInfoLattice>(val);
1196 if (!state)
1197 return {};
1198 return state->getValue();
1199}
1200
1201// Print the analysis result for debugging purposes.
1202void RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) {
1203 auto printFunctionResult = [&](FunctionOpInterface funcOp) {
1204 os << "function: " << funcOp.getName() << ":\n";
1205 // Function arguments
1206 for (BlockArgument arg : funcOp.getArguments()) {
1207 LayoutInfo layout = getLayoutInfo(arg);
1208 os << "argument: " << arg << "\n";
1209 os << "layout : ";
1210 layout.print(os);
1211 os << "\n";
1212 }
1213 // Function ops
1214 funcOp.walk([&](Operation *op) {
1215 // Skip ops that do not have results
1216 if (op->getResults().empty())
1217 return;
1218 os << "op : ";
1219 // For control-flow ops, print the op name only.
1220 if (isa<BranchOpInterface>(op) || isa<RegionBranchOpInterface>(op))
1221 os << op->getName();
1222 else
1223 op->print(os);
1224 os << "\n";
1225 // Print the layout for each result.
1226 for (auto [i, r] : llvm::enumerate(op->getResults())) {
1227 LayoutInfo layout = getLayoutInfo(r);
1228 os << "layout for result #" << i << ": ";
1229 layout.print(os);
1230 os << "\n";
1231 }
1232 });
1233 };
1234
1235 SmallVector<FunctionOpInterface> funcOps;
1236 if (auto modOp = dyn_cast<ModuleOp>(target)) {
1237 for (auto funcOp : modOp.getOps<FunctionOpInterface>())
1238 funcOps.push_back(funcOp);
1239
1240 // Collect all GpuFuncOps in the module.
1241 for (auto gpuModOp : modOp.getOps<gpu::GPUModuleOp>()) {
1242 for (auto gpuFuncOp : gpuModOp.getOps<FunctionOpInterface>())
1243 funcOps.push_back(gpuFuncOp);
1244 }
1245 }
1246 // Print the analysis result for each function.
1247 for (FunctionOpInterface funcOp : funcOps)
1248 printFunctionResult(funcOp);
1249}
1250
1251namespace {
1252
1253//===----------------------------------------------------------------------===//
1254// ResolveLayoutConflicts
1255//===----------------------------------------------------------------------===//
1256
1257/// Helper to get the defining CreateNdDescOp of a tensor descriptor value. This
1258/// function tries to find the defining CreateNdDescOp recursively accross
1259/// control-flow boundaries.
1260static xegpu::CreateNdDescOp getDefiningCreateNdDescOp(Value tdescValue) {
1261 // Try to get the defining CreateNdDescOp of the tensor descriptor.
1262 auto definingOp = tdescValue.getDefiningOp<xegpu::CreateNdDescOp>();
1263 if (definingOp)
1264 return definingOp;
1265 // If tdescValue is an argument, try to get the tied init value from the
1266 // parent loop-like op.
1267 if (auto arg = dyn_cast<BlockArgument>(tdescValue)) {
1268 auto *parentOp = arg.getOwner()->getParentOp();
1269 if (auto loop = dyn_cast<LoopLikeOpInterface>(parentOp)) {
1270 OpOperand *tiedInit = loop.getTiedLoopInit(arg);
1271 if (tiedInit)
1272 return getDefiningCreateNdDescOp(tiedInit->get());
1273 }
1274 }
1275 // If not found, return null.
1276 return nullptr;
1277}
1278
1279struct ResolveLayoutConflicts {
1280 ResolveLayoutConflicts(Operation *parentOp)
1281 : parentOp(parentOp), builder(parentOp->getContext()) {}
1282 LogicalResult run();
1283
1284private:
1285 Operation *parentOp;
1286 OpBuilder builder;
1287 LogicalResult resolveTensorDescConsumer(OpOperand &operand);
1288 LogicalResult resolveVectorConsumer(OpOperand &operand);
1289};
1290
1291} // namespace
1292
1293LogicalResult ResolveLayoutConflicts::run() {
1294 // Scan all operations in the parent op and resolve layout conflicts at
1295 // tensor descriptor and vector use points.
1296 auto r = parentOp->walk([&](Operation *op) -> WalkResult {
1297 for (OpOperand &operand : op->getOpOperands()) {
1298 // Handle conflicts in tensor descriptor operands.
1299 Type operandType = operand.get().getType();
1300 if (isa<xegpu::AnchorLayoutInterface>(op) &&
1301 isa<xegpu::TensorDescType>(operandType)) {
1302 auto res = resolveTensorDescConsumer(operand);
1303 if (failed(res)) {
1304 DBGS() << "Failed to resolve tensor descriptor consumer: " << *op
1305 << "\n";
1306 return WalkResult::interrupt();
1307 }
1308 }
1309 // Handle conflicts in vector operands.
1310 if (isa<VectorType>(operandType)) {
1311 auto res = resolveVectorConsumer(operand);
1312 if (failed(res)) {
1313 DBGS() << "Failed to resolve vector consumer: " << *op << "\n";
1314 return WalkResult::interrupt();
1315 }
1316 }
1317 }
1318 return WalkResult::advance();
1319 });
1320
1321 return r.wasInterrupted() ? failure() : success();
1322}
1323
1324LogicalResult
1325ResolveLayoutConflicts::resolveVectorConsumer(OpOperand &operand) {
1326 Value vectorValue = operand.get();
1327 Operation *consumerOp = operand.getOwner();
1328 // Get the current layout of the vector value.
1329 auto producerLayout = xegpu::getDistributeLayoutAttr(vectorValue);
1330 if (!producerLayout) {
1331 if (auto vectorTy = dyn_cast<VectorType>(vectorValue.getType());
1332 vectorTy && vectorTy.getRank() > 1)
1333 consumerOp->emitWarning("Expected layout for non-1D vectors.");
1334 return success(); // uniform non-tensor-data vector does not require layout
1335 }
1336 // Get the consumer expected layout at this operand.
1337 auto consumerLayout = xegpu::getConsumerLayoutAt(operand);
1338 if (!consumerLayout)
1339 return consumerOp->emitError(
1340 "No consumer layout found for vector operand.");
1341
1342 // If layouts are same, no conflict exists, return success.
1343 if (consumerLayout.isEqualTo(producerLayout))
1344 return success();
1345
1346 // Insert a convert_layout op to resolve the conflict.
1347 builder.setInsertionPointAfterValue(vectorValue);
1348 auto convertOp = xegpu::ConvertLayoutOp::create(
1349 builder, consumerOp->getLoc(), vectorValue.getType(), vectorValue,
1350 producerLayout, consumerLayout);
1351
1352 // Update the operand to use the converted value.
1353 operand.set(convertOp.getResult());
1354 return success();
1355}
1356
1357LogicalResult
1358ResolveLayoutConflicts::resolveTensorDescConsumer(OpOperand &operand) {
1359 Operation *consumerOp = operand.getOwner();
1360 Value tdescValue = operand.get();
1361 auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(consumerOp);
1362 auto currTDescType = dyn_cast<xegpu::TensorDescType>(tdescValue.getType());
1363 assert(anchorOp && currTDescType &&
1364 "Expected anchor layout op and tensor descriptor consumer.");
1365 // TODO: Scattered tensor desc is not supported for now.
1366 if (currTDescType.isScattered()) {
1367 DBGS() << "Scattered tensor descriptor not supported: " << tdescValue
1368 << "\n";
1369 return failure();
1370 }
1371 Attribute currLayout = currTDescType.getLayout();
1372 Attribute expectedLayout = anchorOp.getAnchorLayout();
1373 // A conflict exists in tensor descriptor operand if tensor descriptor's
1374 // layout is different from the anchor layout expected by the consumer.
1375 if (expectedLayout && currLayout && expectedLayout != currLayout) {
1376 // Try to get the defining CreateNdDescOp of the tensor descriptor.
1377 auto conflictingCreateNdOp = getDefiningCreateNdDescOp(tdescValue);
1378 if (!conflictingCreateNdOp) {
1379 DBGS() << "Unable to find defining CreateNdDescOp for tensor descriptor: "
1380 << tdescValue << "\n";
1381 return failure();
1382 }
1383 // Duplicate the CreateNdDescOp with the expected layout.
1384 builder.setInsertionPointAfter(conflictingCreateNdOp);
1385 auto newTensorDescType = xegpu::TensorDescType::get(
1386 conflictingCreateNdOp.getContext(), currTDescType.getShape(),
1387 currTDescType.getElementType(), currTDescType.getEncoding(),
1388 expectedLayout);
1389 xegpu::CreateNdDescOp newOp = xegpu::CreateNdDescOp::create(
1390 builder, consumerOp->getLoc(), newTensorDescType,
1391 conflictingCreateNdOp->getOperands(),
1392 conflictingCreateNdOp->getAttrs());
1393 // Replace the tensor descriptor operand in the consumer op with the new
1394 // tensor descriptor.
1395 consumerOp->replaceUsesOfWith(tdescValue, newOp.getResult());
1396 }
1397 return success();
1398}
1399
1400using GetLayoutFnTy = function_ref<xegpu::DistributeLayoutAttr(Value)>;
1401/// Update an operation with the layout of its results. If the result type is
1402/// a vector type, a temporary layout attribute is added to the operation. If
1403/// the result type is a tensor descriptor type, the type is updated with the
1404/// layout attribute. The users of the result are also updated with the layout
1405/// attribute.
1406static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
1407 GetLayoutFnTy getLayoutOfValue) {
1408 // Region ops (like scf.for) are already handled by the
1409 // updateControlFlowOps.
1410 if (mlir::isa<mlir::RegionBranchOpInterface>(op))
1411 return success();
1412
1413 // Iterate over all the results.
1414 for (OpResult result : op->getResults()) {
1415 Type resultType = result.getType();
1416 // Layouts are needed only for vector and tensor descriptor types.
1417 if (!isa<VectorType, xegpu::TensorDescType>(resultType))
1418 continue;
1419 // If the result has no layout but has users, emit a warning and continue.
1420 xegpu::DistributeLayoutAttr layout = getLayoutOfValue(result);
1421 if (!layout && result.getNumUses() > 0) {
1422 op->emitWarning("op has users but no layout assigned for its result");
1423 continue;
1424 }
1425 // If the result is a tensor descriptor type, update the tensor desc type
1426 // with layout.
1427 if (auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType)) {
1428 auto typeWithLayout = xegpu::TensorDescType::get(
1429 tensorDescTy.getContext(), tensorDescTy.getShape(),
1430 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
1431 result.setType(typeWithLayout);
1432 continue;
1433 }
1434 // If the result is a vector type, add a temporary layout attribute to the
1435 // op.
1437 }
1438 return success();
1439}
1440
1441/// Region ops like scf.for need special handling because they have blocks
1442/// inside. If the blocks have tensor descriptor type as block arguments,
1443/// thier types must be updated. Also region op can have results that may not
1444/// have any users (e.g. A and B tiles). They are not assigned a layout by
1445/// layout analysis because they have no users. However inside the region op
1446/// corresponding block arguments for these results do have layouts.
1447/// Therefore, in this case we still need to update the result types with the
1448/// layout attribute. This function function updates the internal block
1449/// arguments and the result types of the region op with the assigned layouts.
1450/// clang-format off
1451/// Example: scf.for ... iter_args(...) -> (out types) {
1452/// ^bb0(block types):
1453/// ...
1454/// scf.yield ... : (yield types)
1455/// }
1456/// clang-format on
1457/// In this example, at scf.yield, control-flow can transfer to two successor
1458/// regions. One is the ^bb0 (for loop body) and the other is the scf.for op
1459/// itself (yield the results). So we update both the block arguments of the
1460/// successor region (i.e. block types) and the result types of the scf.for op
1461/// (i.e. out types). Note that yield types are updated by respective
1462/// producers inside bb0.
1463static LogicalResult
1465 mlir::RegionBranchTerminatorOpInterface terminator,
1466 GetLayoutFnTy getLayoutOfValue) {
1467 // Only process if the terminator is inside a region branch op.
1468 auto branchOp = dyn_cast<RegionBranchOpInterface>(terminator->getParentOp());
1469 if (!branchOp)
1470 return success();
1471
1473 branchOp.getSuccessorOperandInputMapping(mapping,
1474 RegionBranchPoint(terminator));
1475 for (const auto &[successorOperand, successorInputs] : mapping) {
1476 for (Value successorInput : successorInputs) {
1477 Type inputType = successorInput.getType();
1478 // We only need to operate on tensor descriptor or vector types.
1479 if (!isa<xegpu::TensorDescType, VectorType>(inputType))
1480 continue;
1481 xegpu::DistributeLayoutAttr successorInputLayout =
1482 getLayoutOfValue(successorInput);
1483 xegpu::DistributeLayoutAttr successorOperandLayout =
1484 getLayoutOfValue(successorOperand->get());
1485
1486 // If either of the layouts is not assigned, we cannot proceed.
1487 if (!successorOperandLayout) {
1488 LLVM_DEBUG(DBGS() << "No layout assigned for forwarded operand in "
1489 "branch terminator: "
1490 << successorOperand->get() << "\n");
1491 return failure();
1492 }
1493 // We expect the layouts to match.
1494 if (successorInputLayout &&
1495 successorInputLayout != successorOperandLayout) {
1496 LLVM_DEBUG(DBGS() << "Conflicting layouts for region argument and "
1497 "operand forwarded as the argument: "
1498 << successorInputLayout << " vs "
1499 << successorOperandLayout << "\n");
1500 return failure();
1501 }
1502 // Get tensor descriptor type with the layout.
1503 if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(inputType)) {
1504 auto newTdescTy = xegpu::TensorDescType::get(
1505 tdescTy.getContext(), tdescTy.getShape(), tdescTy.getElementType(),
1506 tdescTy.getEncoding(), successorOperandLayout);
1507 successorInput.setType(newTdescTy);
1508 continue;
1509 }
1510 // If the type is a vector type and this region argument is an OpResult,
1511 // set the layout attribute on the OpResult.
1512 if (auto result = dyn_cast<OpResult>(successorInput))
1513 xegpu::setDistributeLayoutAttr(result, successorOperandLayout);
1514 }
1515 }
1516 return success();
1517}
1518
1519/// Update the function arguments and results with the layouts.
1520static LogicalResult updateFunctionOpInterface(mlir::OpBuilder &builder,
1521 mlir::FunctionOpInterface funcOp,
1522 GetLayoutFnTy getLayoutOfValue) {
1523 // Only process functions whose type is a standard MLIR FunctionType.
1524 // Functions using a different type representation (e.g. llvm.func with
1525 // LLVMFunctionType) are not targets for XeGPU layout propagation, and
1526 // calling setType(FunctionType{}) on them would corrupt their type.
1527 if (!isa<FunctionType>(funcOp.getFunctionType()))
1528 return success();
1529 SmallVector<Type> newArgTypes;
1530 // Update the function arguments.
1531 for (BlockArgument arg : funcOp.getArguments()) {
1532 Type argType = arg.getType();
1533 newArgTypes.push_back(argType);
1534 if (!isa<VectorType, xegpu::TensorDescType>(argType))
1535 continue;
1536 xegpu::DistributeLayoutAttr layout = getLayoutOfValue(arg);
1537 if (!layout) {
1538 LLVM_DEBUG(DBGS() << "Expecting layout for function argument: " << arg
1539 << " but got none.\n");
1540 return failure();
1541 }
1542 if (auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(argType)) {
1543 auto newTdescTy = xegpu::TensorDescType::get(
1544 tensorDescTy.getContext(), tensorDescTy.getShape(),
1545 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
1546 arg.setType(newTdescTy);
1547 newArgTypes.back() = newTdescTy;
1548 }
1549 }
1550 // Update the function type with the new argument types.
1551 // NOTE: We assume that function results are not expected to have layouts.
1552 funcOp.setType(FunctionType::get(funcOp.getContext(), newArgTypes,
1553 funcOp.getResultTypes()));
1554 return success();
1555}
1556
1557namespace {
1558struct XeGPUPropagateLayoutPass final
1559 : public xegpu::impl::XeGPUPropagateLayoutBase<XeGPUPropagateLayoutPass> {
1560 XeGPUPropagateLayoutPass() = default;
1561 XeGPUPropagateLayoutPass(const XeGPUPropagateLayoutPass &other) = default;
1562 XeGPUPropagateLayoutPass(xegpu::XeGPUPropagateLayoutOptions options)
1563 : XeGPUPropagateLayoutBase(std::move(options)) {}
1564 void runOnOperation() override;
1565};
1566
1567} // namespace
1568
1570 LayoutKind layoutKind,
1571 unsigned indexBitWidth, bool printOnly) {
1572 RunLayoutInfoPropagation analysis(target, layoutKind, indexBitWidth);
1573 // Print the analysis result and exit. (for debugging purposes)
1574 if (printOnly) {
1575 auto &os = llvm::outs();
1576 analysis.printAnalysisResult(os);
1577 return success();
1578 }
1579 // Helper to convert LayoutInfo to xegpu::LayoutAttr.
1580 auto getXeGPULayoutForValue = [&](Value val) -> xegpu::DistributeLayoutAttr {
1581 LayoutInfo layout = analysis.getLayoutInfo(val);
1582 if (!layout.isAssigned())
1583 return {};
1584 if (auto opResult = dyn_cast<OpResult>(val)) {
1585
1586 Operation *defOp = opResult.getDefiningOp();
1587 if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(defOp)) {
1588 auto anchorLayout = anchorOp.getAnchorLayout();
1589 if (anchorLayout != nullptr)
1590 return anchorLayout;
1591 }
1592 xegpu::DistributeLayoutAttr requiredResLayoutAttr =
1593 xegpu::getTemporaryLayout(opResult);
1594 if (requiredResLayoutAttr != nullptr)
1595 return requiredResLayoutAttr;
1596 }
1597 xegpu::DistributeLayoutAttr layoutAttr =
1598 cast<xegpu::DistributeLayoutAttr>(layout.get());
1599 if (layout.isSliceLayout())
1600 return cast<xegpu::SliceAttr>(layoutAttr);
1601
1602 return cast<xegpu::LayoutAttr>(layoutAttr);
1603 };
1604
1605 Operation *op = target;
1606 auto walkResult = op->walk([&](mlir::Block *block) -> WalkResult {
1607 for (mlir::Operation &op : llvm::reverse(block->getOperations())) {
1608 LogicalResult r = success();
1610 .Case([&](mlir::RegionBranchTerminatorOpInterface branchTermOp) {
1611 r = updateControlFlowOps(builder, branchTermOp,
1612 getXeGPULayoutForValue);
1613 })
1614 .Case([&](mlir::FunctionOpInterface funcOp) {
1615 r = updateFunctionOpInterface(builder, funcOp,
1616 getXeGPULayoutForValue);
1617 })
1618 .Default([&](Operation *op) {
1619 r = updateOp(builder, op, getXeGPULayoutForValue);
1620 });
1621 if (failed(r)) {
1622 op.emitError("Failed to update operation with the layout.");
1623 return WalkResult::interrupt();
1624 }
1625 }
1626 return WalkResult::advance();
1627 });
1628 if (walkResult.wasInterrupted())
1629 return failure();
1630
1631 return success();
1632}
1633
1635 ResolveLayoutConflicts resolver(target);
1636 return resolver.run();
1637}
1638
1639void XeGPUPropagateLayoutPass::runOnOperation() {
1640 xegpu::LayoutKind layoutKind;
1641 if (this->layoutKind == "lane") {
1642 layoutKind = xegpu::LayoutKind::Lane;
1643 } else if (this->layoutKind == "inst") {
1644 layoutKind = xegpu::LayoutKind::InstData;
1645 } else if (this->layoutKind == "subgroup") {
1646 layoutKind = xegpu::LayoutKind::Subgroup;
1647 } else {
1648 getOperation()->emitError("Unsupported layout kind option: " +
1649 this->layoutKind);
1650 signalPassFailure();
1651 return;
1652 }
1653 OpBuilder builder(&getContext());
1654 if (failed(xegpu::propagateLayouts(builder, getOperation(), layoutKind,
1655 this->indexBitWidth, this->printOnly))) {
1656 signalPassFailure();
1657 return;
1658 }
1659 // Resolve layout conflicts if any.
1660 if (failed(xegpu::resolveLayoutConflicts(getOperation()))) {
1661 signalPassFailure();
1662 return;
1663 }
1664}
return success()
#define DBGS()
Definition Hoisting.cpp:32
std::string join(const Ts &...args)
Helper function to concatenate arguments into a std::string.
lhs
b getContext())
auto load
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition TypeID.h:331
static SmallVector< LayoutRepresentation > getValidLayouts(ArrayRef< int64_t > wgShape, ArrayRef< int64_t > instData, int64_t sgCount)
static LogicalResult updateControlFlowOps(mlir::OpBuilder &builder, mlir::RegionBranchTerminatorOpInterface terminator, GetLayoutFnTy getLayoutOfValue)
Region ops like scf.for need special handling because they have blocks inside.
function_ref< xegpu::DistributeLayoutAttr(Value)> GetLayoutFnTy
FailureOr< int64_t > getNumSg(Operation *op, const int sgSize)
static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op, GetLayoutFnTy getLayoutOfValue)
Update an operation with the layout of its results.
static LogicalResult updateFunctionOpInterface(mlir::OpBuilder &builder, mlir::FunctionOpInterface funcOp, GetLayoutFnTy getLayoutOfValue)
Update the function arguments and results with the layouts.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:306
Block represents an ordered list of Operations.
Definition Block.h:33
OpListType & getOperations()
Definition Block.h:147
The general data-flow analysis solver.
const StateT * lookupState(AnchorT anchor) const
Lookup an analysis state for the given lattice anchor.
AnalysisT * load(Args &&...args)
Load an analysis into the solver. Return the analysis instance.
LogicalResult initializeAndRun(Operation *top)
Initialize the children analyses starting from the provided top-level operation and run the analysis ...
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:209
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition Builders.h:423
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
This class represents an operand of an operation.
Definition Value.h:254
This is a value defined by a result of an operation.
Definition Value.h:454
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:244
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:412
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:259
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
void print(raw_ostream &os, const OpPrintingFlags &flags={})
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:826
result_range getResults()
Definition Operation.h:444
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
This class represents a successor of a region.
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
This class represents a lattice holding a specific value of type ValueT.
A sparse (backward) data-flow analysis for propagating SSA value lattices backwards across the IR by ...
SparseBackwardDataFlowAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
void loadBaselineAnalyses(DataFlowSolver &solver)
Populates a DataFlowSolver with analyses that are required to ensure user-defined analyses are run pr...
Definition Utils.h:29
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
const uArch * getUArch(llvm::StringRef archName)
DistributeLayoutAttr inferShapeCastSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for a shape cast operation given the result layout attribute,...
DistributeLayoutAttr inferTransposeSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > permutation)
Infers the source layout attribute for a transpose operation given the result layout attribute and pe...
SliceAttr setupMultiReductionResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, DistributeLayoutAttr consumerLayout, SmallVector< int64_t > reductionDims, const uArch::uArch *uArch)
Sets up layout for reduction operations by creating a SliceAttr for the result.
DistributeLayoutAttr inferInsertStridedSliceSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for an insert strided slice operation given the result layout attr...
void setTemporaryLayout(const T &operandOrResult, const DistributeLayoutAttr layout)
std::optional< std::tuple< DistributeLayoutAttr, DistributeLayoutAttr, DistributeLayoutAttr > > setupDpasLayout(LayoutKind layoutKind, VectorType aTy, VectorType bTy, VectorType cdTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch, int numSg)
Sets up the anchor layouts for a dpas operands (A, B, and C/D).
LayoutKind
Specifies the level of a layout hierarchy for comparison or propagation.
Definition XeGPU.h:32
void setDistributeLayoutAttr(const OpResult &Result, const DistributeLayoutAttr layout)
[to-be-deprecated] Sets the DistributeLayoutAttr for a given OpResult user should use setAnchorLayout...
DistributeLayoutAttr setupLoadMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the anchor layout for load matrix operation.
int getLargestDivisor(T dim, ArrayRef< T > candidates, ArrayRef< T > candidateMultiples={})
Helper Function to find a proper instruction multiple for the user-supplied sg-level data shape (dive...
DistributeLayoutAttr inferBroadcastSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for a broadcast operation given the result layout attribute,...
DistributeLayoutAttr setupStoreScatterAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, int chunkSize, const uArch::uArch *uArch)
Sets up the anchor layout for a store scatter operation.
DistributeLayoutAttr setupBitCastResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Setup the result layout attribute for a bitcast operation based on element type bitwidths.
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
LogicalResult resolveLayoutConflicts(Operation *target)
DistributeLayoutAttr inferBitCastSourceLayout(DistributeLayoutAttr resLayout, int resElemTyBitWidth, int srcElemTyBitWidth)
Infers the source layout attribute for a bitcast operation given the result layout attribute,...
DistributeLayoutAttr setupInsertStridedSliceResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the result layout for an insert strided slice operation.
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
xegpu::DistributeLayoutAttr getConsumerLayoutAt(OpOperand &operand)
Gets the expected layout for a given consumer operand.
DistributeLayoutAttr inferMultiReductionSourceLayout(DistributeLayoutAttr resLayout, SmallVector< int64_t > reduceDims)
Infers the source layout attribute for a reduction operation given the result layout attribute and re...
DistributeLayoutAttr setupLoadGatherAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, int chunkSize, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the anchor layout for a load gather operation.
LogicalResult propagateLayouts(OpBuilder &builder, Operation *target, LayoutKind layoutKind, unsigned indexBitWidth, bool printOnly=false)
DistributeLayoutAttr setupStoreMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, const uArch::uArch *uArch)
Sets up the anchor layout for a store matrix operation.
Include the generated interface declarations.
DenseMap< OpOperand *, SmallVector< Value > > RegionBranchSuccessorMapping
A mapping from successor operands to successor inputs.
bool operator==(StringAttr lhs, std::nullptr_t)
Define comparisons for StringAttr against nullptr and itself to avoid the StringRef overloads from be...
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:139
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
virtual int getSubgroupSize() const =0
const Instruction * getInstruction(InstructionKind instKind) const
Definition uArchBase.h:163