MLIR 23.0.0git
XeGPUPropagateLayout.cpp
Go to the documentation of this file.
1//===- XeGPUPropagateLayout.cpp - XeGPU Layout Propagation ------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
21#include "mlir/IR/Attributes.h"
22#include "mlir/IR/Builders.h"
25#include "mlir/IR/Operation.h"
26#include "mlir/IR/Value.h"
27#include "mlir/IR/Visitors.h"
31#include "mlir/Support/LLVM.h"
32#include "llvm/ADT/ArrayRef.h"
33#include "llvm/ADT/STLExtras.h"
34#include "llvm/ADT/SmallSet.h"
35#include "llvm/ADT/SmallVector.h"
36#include "llvm/ADT/TypeSwitch.h"
37#include "llvm/Support/Casting.h"
38#include "llvm/Support/Debug.h"
39#include "llvm/Support/LogicalResult.h"
40#include "llvm/Support/raw_ostream.h"
41
42namespace mlir {
43namespace xegpu {
44#define GEN_PASS_DEF_XEGPUPROPAGATELAYOUT
45#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
46} // namespace xegpu
47} // namespace mlir
48
49#define DEBUG_TYPE "xegpu-propagate-layout"
50#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
51
52using namespace mlir;
53using namespace mlir::dataflow;
54
55namespace {
56
57//===----------------------------------------------------------------------===//
58// LayoutInfo
59//===----------------------------------------------------------------------===//
60
61/// Helper class for tracking the analysis state of an mlir value. For layout
62/// propagation, the analysis state is simply the distribution layout of
63/// each value. The distribution layout information is encapsulated using
64/// xegpu::DistributeLayoutAttr class which can hold information about any type
65/// of distribution layout that XeGPU dialect supports. Purpose of this analysis
66/// to propagate some unique distribution layout for each value in the program
67/// starting from a set of anchor operations (like DPAS, StoreNd, etc.). Note
68/// that analysis will reach a fixed point when all values are reached some
69/// layout and, analysis does not try to modify any already assigned layouts.
70///
71/// Given this, LayoutInfo satisifies the following properties:
72/// 1) A LayoutInfo value can be in one of two states - `assigned` or `not
73/// assigned`.
74/// 2) Two LayoutInfo values are equal if they are both assigned or
75/// both not assigned. The concrete value of assigned state does not matter.
76/// 3) The meet operator works as follows:
77/// - If current state is assigned, return the current state. (already
78/// a unique layout is assigned. don't change it)
79/// - Otherwise, return the other state.
80
81struct LayoutInfo {
82private:
83 xegpu::DistributeLayoutAttr storage = nullptr;
84
85public:
86 LayoutInfo() = default;
87 LayoutInfo(const xegpu::DistributeLayoutAttr &layout) : storage(layout) {}
88
89 // Two lattice values are equal if they have `some` layout. The actual
90 // content of the layout does not matter.
91 bool operator==(const LayoutInfo &other) const {
92 return this->isAssigned() == other.isAssigned();
93 }
94
95 static LayoutInfo meet(const LayoutInfo &lhs, const LayoutInfo &rhs);
96
97 static LayoutInfo join(const LayoutInfo &lhs, const LayoutInfo &rhs);
98
99 void print(raw_ostream &os) const;
100
101 bool isAssigned() const { return storage != nullptr; }
102
103 LayoutInfo transpose(ArrayRef<int64_t> permutation) const;
104
105 SmallVector<int> getLaneLayout() const;
106
107 SmallVector<int> getLaneData() const;
108
109 SmallVector<int> getInstData() const;
110
111 SmallVector<int> getSgLayout() const;
112
113 SmallVector<int> getSgData() const;
114
115 SmallVector<int> getOrder() const;
116
117 bool isSliceLayout() const {
118 if (!isAssigned())
119 return false;
120 return isa<xegpu::SliceAttr>(storage);
121 }
122
123 int64_t getRank() const {
124 if (!isAssigned())
125 return -1;
126 return storage.getRank();
127 }
128
129 Attribute get() { return storage; }
130 void set(const xegpu::DistributeLayoutAttr &layout) { storage = layout; }
131};
132
133SmallVector<int> LayoutInfo::getLaneLayout() const {
134 if (!isAssigned())
135 return {};
136 return llvm::map_to_vector(storage.getEffectiveLaneLayoutAsInt(),
137 [](int64_t val) { return static_cast<int>(val); });
138}
139
140SmallVector<int> LayoutInfo::getLaneData() const {
141 if (!isAssigned())
142 return {};
143 return llvm::map_to_vector(storage.getEffectiveLaneDataAsInt(),
144 [](int64_t val) { return static_cast<int>(val); });
145}
146
147SmallVector<int> LayoutInfo::getInstData() const {
148 if (!isAssigned())
149 return {};
150 return llvm::map_to_vector(storage.getEffectiveInstDataAsInt(),
151 [](int64_t val) { return static_cast<int>(val); });
152}
153
154SmallVector<int> LayoutInfo::getSgLayout() const {
155 if (!isAssigned())
156 return {};
157 return llvm::map_to_vector(storage.getEffectiveSgLayoutAsInt(),
158 [](int64_t val) { return static_cast<int>(val); });
159}
160
161SmallVector<int> LayoutInfo::getSgData() const {
162 if (!isAssigned())
163 return {};
164 return llvm::map_to_vector(storage.getEffectiveSgDataAsInt(),
165 [](int64_t val) { return static_cast<int>(val); });
166}
167
168SmallVector<int> LayoutInfo::getOrder() const {
169 if (!isAssigned() || !storage.getOrder())
170 return {};
171 return llvm::map_to_vector(storage.getOrder().asArrayRef(),
172 [](int64_t val) { return static_cast<int>(val); });
173}
174
175void LayoutInfo::print(raw_ostream &os) const {
176 if (isAssigned()) {
177 os << storage;
178 } else {
179 os << "Not assigned.";
180 }
181}
182
183LayoutInfo LayoutInfo::meet(const LayoutInfo &lhs, const LayoutInfo &rhs) {
184 if (!lhs.isAssigned())
185 return rhs;
186 return lhs;
187}
188
189/// Since this is a backward analysis, join method is not used.
190LayoutInfo LayoutInfo::join(const LayoutInfo &lhs, const LayoutInfo &rhs) {
191 llvm_unreachable("Join should not be triggered by layout propagation.");
192}
193
194/// Construct a new layout with the transposed inst_data or lane_layout,
195/// lane_data.
196LayoutInfo LayoutInfo::transpose(ArrayRef<int64_t> permutation) const {
197 if (!isAssigned())
198 return {};
199 // Check if the permutation is valid.
200 llvm::SmallSet<int64_t, 4> seen(permutation.begin(), permutation.end());
201 bool hasDuplicates = seen.size() != permutation.size();
202 bool withinRange = llvm::all_of(permutation, [&](int64_t idx) {
203 return idx >= 0 && idx < static_cast<int64_t>(permutation.size());
204 });
205
206 if (!withinRange || hasDuplicates) {
207 assert(false && "Invalid permutation for transpose.");
208 return {};
209 }
210
211 SmallVector<int32_t> laneLayout;
212 SmallVector<int32_t> laneData;
213 SmallVector<int32_t> instData;
214 SmallVector<int32_t> sgLayout;
217
218 for (int64_t idx : permutation) {
219 if (getLaneLayout().size()) {
220 laneLayout.push_back(static_cast<int32_t>(getLaneLayout()[idx]));
221 laneData.push_back(static_cast<int32_t>(getLaneData()[idx]));
222 }
223 if (getInstData().size())
224 instData.push_back(static_cast<int32_t>(getInstData()[idx]));
225 if (getSgData().size()) {
226 sgLayout.push_back(static_cast<int32_t>(getSgLayout()[idx]));
227 sgData.push_back(static_cast<int32_t>(getSgData()[idx]));
228 }
229 if (getOrder().size()) {
230 order.push_back(static_cast<int32_t>(getOrder()[idx]));
231 }
232 }
233 auto orderAttr = order.size()
234 ? DenseI32ArrayAttr::get(storage.getContext(), order)
235 : nullptr;
236 xegpu::LayoutAttr layoutAttr;
237 if (getLaneLayout().size())
238 layoutAttr =
239 xegpu::LayoutAttr::get(storage.getContext(), laneLayout, laneData);
240 if (getInstData().size())
241 layoutAttr = xegpu::LayoutAttr::get(storage.getContext(), instData);
242 if (getSgData().size())
243 layoutAttr = xegpu::LayoutAttr::get(
244 storage.getContext(),
245 DenseI32ArrayAttr::get(storage.getContext(), sgLayout),
246 DenseI32ArrayAttr::get(storage.getContext(), sgData),
247 /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
248 /*lane_data =*/nullptr, orderAttr);
249 return LayoutInfo(layoutAttr);
250}
251
252//===----------------------------------------------------------------------===//
253// LayoutInfoLattice
254//===----------------------------------------------------------------------===//
255
256/// Lattice holding the LayoutInfo for each value.
257struct LayoutInfoLattice : public Lattice<LayoutInfo> {
259 using Lattice::Lattice;
260};
261
262/// Helper Functions to get default layouts. A `default layout` is a layout that
263/// is assigned to a value when the layout is not fixed by some anchor operation
264/// (like DPAS).
265
266/// Helper Function to get the default layout for uniform values like constants.
267/// For 1D vector, lane_layout is [subgroupSize] and lane_data is [1].
268/// For 2D vector, lane_layout is [1, subgroupSize] and lane_data is [1, 1].
269static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
270 unsigned rank,
271 const xegpu::uArch::uArch *uArch) {
272 assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
273 if (rank == 1) {
274 return LayoutInfo(
275 xegpu::LayoutAttr::get(ctx, {uArch->getSubgroupSize()}, {1}));
276 }
277 return LayoutInfo(
278 xegpu::LayoutAttr::get(ctx, {1, uArch->getSubgroupSize()}, {1, 1}));
279}
280
281static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
282 unsigned rank, int subgroupSize) {
283 assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
284 if (rank == 1) {
285 return LayoutInfo(xegpu::LayoutAttr::get(ctx, {subgroupSize}, {1}));
286 }
287 return LayoutInfo(xegpu::LayoutAttr::get(ctx, {1, subgroupSize}, {1, 1}));
288}
289
290/// Helper to get the default layout for 2D block operations.
291template <typename Ty>
292static LayoutInfo getSIMTLayoutInfoBlockIO(Ty ty,
294 unsigned packingSize) {
295 // Expecting a 1D or 2D vector.
296 assert((ty.getRank() == 1 || ty.getRank() == 2) &&
297 "Expected 1D or 2D vector.");
298 // Expecting int or float element type.
299 assert(ty.getElementType().isIntOrFloat() &&
300 "Expected int or float element type.");
301 // If the rank is 1, then return default layout for 1D vector.
302 if (ty.getRank() == 1)
303 return getDefaultSIMTLayoutInfo(ty.getContext(), 1, uArch);
304 // Packing factor is determined by the element type bitwidth.
305 unsigned bitwidth = ty.getElementType().getIntOrFloatBitWidth();
306 int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
307 return LayoutInfo(xegpu::LayoutAttr::get(
308 ty.getContext(), {1, uArch->getSubgroupSize()}, {1, packingFactor}));
309}
310
311//===----------------------------------------------------------------------===//
312// LayoutInfoPropagation
313//===----------------------------------------------------------------------===//
314
315/// Backward data flow analysis to propagate the lane_layout and lane_data of
316/// each value in the program. Currently, the layouts for operands DPAS,
317/// StoreNd, and StoreScatter are fixed (known before propagation). Purpose of
318/// this analysis is to propagate those known layouts to all their producers and
319/// (other) consumers.
320class LayoutInfoPropagation
321 : public SparseBackwardDataFlowAnalysis<LayoutInfoLattice> {
322private:
323 xegpu::LayoutKind layoutKind;
324 void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
326
327 void visitStoreNdOp(xegpu::StoreNdOp store,
330
331 void visitStoreScatterOp(xegpu::StoreScatterOp storeScatter,
334
335 void visitLoadNdOp(xegpu::LoadNdOp load,
338
339 void visitLoadGatherOp(xegpu::LoadGatherOp load,
342
343 void visitTransposeOp(vector::TransposeOp transpose,
346
347 void visitVectorBitcastOp(vector::BitCastOp bitcast,
350
351 void visitCreateDescOp(xegpu::CreateDescOp createDesc,
354
355 void visitUpdateNdOffsetOp(xegpu::UpdateNdOffsetOp updateNdOffset,
358
359 void visitPrefetchNdOp(xegpu::PrefetchNdOp prefetch,
362
363 void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction,
366
367 void visitVectorBroadCastOp(vector::BroadcastOp broadcast,
370 void visitShapeCastOp(vector::ShapeCastOp shapeCast,
373 void
374 visitInsertStridedSliceOp(vector::InsertStridedSliceOp insertStridedSlice,
377
378 void visitLoadMatrixOp(xegpu::LoadMatrixOp load,
381
382 void visitStoreMatrixOp(xegpu::StoreMatrixOp store,
385
386 void visitLoadGatherOp(xegpu::LoadMatrixOp load,
389
390 void visitStoreScatterOp(xegpu::StoreMatrixOp store,
393
394 bool hasParamsOfLayoutKind(xegpu::DistributeLayoutAttr anchorLayout);
395
396public:
397 LayoutInfoPropagation(DataFlowSolver &solver,
398 SymbolTableCollection &symbolTable,
399 xegpu::LayoutKind layoutKind)
400 : SparseBackwardDataFlowAnalysis(solver, symbolTable),
401 layoutKind(layoutKind) {}
403
404 LogicalResult
405 visitOperation(Operation *op, ArrayRef<LayoutInfoLattice *> operands,
406 ArrayRef<const LayoutInfoLattice *> results) override;
407
408 void visitBranchOperand(OpOperand &operand) override {};
409
410 void visitCallOperand(OpOperand &operand) override {};
411
412 void
413 visitNonControlFlowArguments(RegionSuccessor &successor,
414 ArrayRef<BlockArgument> arguments) override {};
415
416 void visitExternalCall(CallOpInterface call,
418 ArrayRef<const LayoutInfoLattice *> results) override {
419 };
420
421 void setToExitState(LayoutInfoLattice *lattice) override {
422 (void)lattice->meet(LayoutInfo());
423 }
424};
425} // namespace
426
427LogicalResult LayoutInfoPropagation::visitOperation(
428 Operation *op, ArrayRef<LayoutInfoLattice *> operands,
429 ArrayRef<const LayoutInfoLattice *> results) {
431 .Case(
432 [&](xegpu::DpasOp dpasOp) { visitDpasOp(dpasOp, operands, results); })
433 .Case([&](xegpu::StoreNdOp storeNdOp) {
434 visitStoreNdOp(storeNdOp, operands, results);
435 })
436 .Case([&](xegpu::StoreScatterOp storeScatterOp) {
437 visitStoreScatterOp(storeScatterOp, operands, results);
438 })
439 .Case([&](xegpu::LoadNdOp loadNdOp) {
440 visitLoadNdOp(loadNdOp, operands, results);
441 })
442 .Case([&](xegpu::LoadGatherOp loadGatherOp) {
443 visitLoadGatherOp(loadGatherOp, operands, results);
444 })
445 .Case([&](xegpu::CreateDescOp createDescOp) {
446 visitCreateDescOp(createDescOp, operands, results);
447 })
448 .Case([&](xegpu::UpdateNdOffsetOp updateNdOffsetOp) {
449 visitUpdateNdOffsetOp(updateNdOffsetOp, operands, results);
450 })
451 .Case([&](xegpu::PrefetchNdOp prefetchNdOp) {
452 visitPrefetchNdOp(prefetchNdOp, operands, results);
453 })
454 .Case([&](vector::TransposeOp transposeOp) {
455 visitTransposeOp(transposeOp, operands, results);
456 })
457 .Case([&](vector::BitCastOp bitcastOp) {
458 visitVectorBitcastOp(bitcastOp, operands, results);
459 })
460 .Case([&](vector::MultiDimReductionOp reductionOp) {
461 visitVectorMultiReductionOp(reductionOp, operands, results);
462 })
463 .Case([&](vector::BroadcastOp broadcastOp) {
464 visitVectorBroadCastOp(broadcastOp, operands, results);
465 })
466 .Case([&](vector::ShapeCastOp shapeCastOp) {
467 visitShapeCastOp(shapeCastOp, operands, results);
468 })
469 .Case([&](vector::InsertStridedSliceOp insertStridedSliceOp) {
470 visitInsertStridedSliceOp(insertStridedSliceOp, operands, results);
471 })
472 .Case([&](xegpu::LoadMatrixOp loadMatrixOp) {
473 visitLoadMatrixOp(loadMatrixOp, operands, results);
474 })
475 .Case([&](xegpu::StoreMatrixOp storeMatrixOp) {
476 visitStoreMatrixOp(storeMatrixOp, operands, results);
477 })
478 // All other ops.
479 .Default([&](Operation *op) {
480 for (const LayoutInfoLattice *resultInfo : results) {
481 if (!resultInfo->getValue().isAssigned())
482 continue;
483 for (auto [operandInfo, operand] :
484 llvm::zip(operands, op->getOpOperands())) {
485 // If the operand type is not a vector or tensor descriptor, skip
486 // it.
487 if (!isa<xegpu::TensorDescType, VectorType>(
488 operand.get().getType()))
489 continue;
490 // Propagate the result layout to the operand.
491 meet(operandInfo, *resultInfo);
492 }
493 }
494 });
495
496 return success();
497}
498
499bool LayoutInfoPropagation::hasParamsOfLayoutKind(
500 xegpu::DistributeLayoutAttr anchorLayout) {
501 if (anchorLayout == nullptr) {
502 return false;
503 }
504 if (layoutKind == xegpu::LayoutKind::InstData) {
505 return !(anchorLayout.getEffectiveInstDataAsInt().empty());
506 }
507 if (layoutKind == xegpu::LayoutKind::Lane) {
508 return !(anchorLayout.getEffectiveLaneLayoutAsInt().empty() ||
509 anchorLayout.getEffectiveLaneDataAsInt().empty());
510 }
511 if (layoutKind == xegpu::LayoutKind::Subgroup) {
512 return !(anchorLayout.getEffectiveSgLayoutAsInt().empty() ||
513 anchorLayout.getEffectiveSgDataAsInt().empty());
514 }
515 return false;
516}
517
518// This function returns all layouts for the given sgCount, whose sgData:
519// 1. Evenly divides the wgShape.
520// 2. Is a multiple of instData.
521// Example:
522// wgShape = [128, 64], instData = [8, 16], sgCount = 32
523// Returns layouts:
524// [(8,4), (16,2)], which correspond to sgData [16,16] and [8,32].
526 ArrayRef<int> instData,
527 int64_t sgCount) {
529 for (int sgLayout0 = 1; sgLayout0 <= sgCount; ++sgLayout0) {
530 if (sgCount % sgLayout0)
531 continue;
532 int sgLayout1 = sgCount / sgLayout0;
533 int sgData0 = wgShape[0] / sgLayout0;
534 int sgData1 = wgShape[1] / sgLayout1;
535 if ((wgShape[0] % sgLayout0 || wgShape[1] % sgLayout1) ||
536 (sgData0 % instData[0] || sgData1 % instData[1]))
537 continue;
538 candidates.emplace_back(sgLayout0, sgLayout1);
539 }
540 // Sort primarily by how balanced they are
541 // (i.e., minimize the absolute difference between the two dimensions), and
542 // secondarily by the first dimension in ascending order.
543 llvm::sort(candidates, [](const std::pair<int, int> &lhs,
544 const std::pair<int, int> &rhs) {
545 int diffLhs = std::abs(lhs.first - lhs.second);
546 int diffRhs = std::abs(rhs.first - rhs.second);
547 if (diffLhs != diffRhs)
548 return diffLhs < diffRhs;
549 return lhs.first < rhs.first;
550 });
551 return candidates;
552}
553
554FailureOr<int64_t> getNumSg(Operation *op, const int sgSize) {
555 // Oblivious to workitem layout, the total count matters.
556 auto gpuFunc = op->getParentOfType<gpu::GPUFuncOp>();
557 if (!gpuFunc)
558 return failure();
559 auto knownBlockSize = gpuFunc.getKnownBlockSize();
560 if (!knownBlockSize.has_value())
561 return failure();
562 const int flatBlockSize = llvm::product_of(knownBlockSize.value());
563 return flatBlockSize / sgSize;
564}
565
566void LayoutInfoPropagation::visitPrefetchNdOp(
567 xegpu::PrefetchNdOp prefetch, ArrayRef<LayoutInfoLattice *> operands,
568 ArrayRef<const LayoutInfoLattice *> results) {
569
570 LayoutInfo prefetchLayout;
571 xegpu::DistributeLayoutAttr anchorLayout = prefetch.getLayoutAttr();
572 if (hasParamsOfLayoutKind(anchorLayout)) {
573 prefetchLayout = LayoutInfo(anchorLayout);
574 } else {
575 // Here we assign the default layout to the tensor descriptor operand of
576 // prefetch.
577 auto tdescTy = prefetch.getTensorDescType();
578
579 const uArch *uArch = getUArch(getChipStr(prefetch).value_or(""));
580 if (!uArch)
581 return;
582 const auto *uArchInstruction =
583 dyn_cast<xegpu::uArch::Subgroup2DBlockPrefetchInstruction>(
584 uArch->getInstruction(
585 xegpu::uArch::InstructionKind::Subgroup2DBlockPrefetch));
586
587 auto blockWHC =
588 uArchInstruction->getBlockWidthHeightCount(tdescTy.getElementType());
589 if (!blockWHC)
590 prefetch.emitWarning("No known block params found for the element type.");
591 auto [bWidth, bHeight, bCount] = blockWHC.value();
592 SmallVector<int> instData;
593 int instWidth = xegpu::getLargestDivisor(
594 static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 1)), bWidth);
595 if (instWidth == -1)
596 prefetch.emitWarning(
597 "No suitable instruction multiple found for the given shape.");
598 if (tdescTy.getRank() == 1)
599 instData = {instWidth};
600 else {
601 int instHeight = xegpu::getLargestDivisor(
602 static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 2)), bHeight);
603 if (instHeight == -1)
604 prefetch.emitWarning(
605 "No suitable instruction multiple found for the given shape.");
606 instData = {instHeight, instWidth};
607 }
608
609 if (layoutKind == xegpu::LayoutKind::InstData)
610 prefetchLayout =
611 LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(), instData));
612 else
613 prefetchLayout = getSIMTLayoutInfoBlockIO(
614 tdescTy, uArch, uArchInstruction->getPackedFormatBitSize());
615
616 prefetch.setLayoutAttr(
617 dyn_cast<xegpu::DistributeLayoutAttr>(prefetchLayout.get()));
618 }
619 // Propagate the layout to the source tensor descriptor.
620 propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout));
621}
622
623void LayoutInfoPropagation::visitVectorMultiReductionOp(
624 vector::MultiDimReductionOp reduction,
625 ArrayRef<LayoutInfoLattice *> operands,
626 ArrayRef<const LayoutInfoLattice *> results) {
627 // The layout of the result must be present.
628 LayoutInfo resLayoutInfo = results[0]->getValue();
629 if (!resLayoutInfo.isAssigned())
630 return;
631
632 VectorType sourceTy = reduction.getSourceVectorType();
633 SmallVector<int64_t> reductionDims(reduction.getReductionDims());
634
635 const uArch *uArch = getUArch(xegpu::getChipStr(reduction).value_or(""));
636 if (!uArch)
637 return;
638 auto consumerLayoutAttr =
639 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
640
641 // The result layout represents the layout requirements of the operation.
642 // it is recorded to anchor layout or temporary layout.
643 // it must be honored for current op and may conflict with the layout
644 // propagated from consumer op, the conflict is resolved in later phase by
645 // converting the required result layout to the consumer layout
646 auto requiredResLayoutAttr = xegpu::setupMultiReductionResultLayout(
647 layoutKind, sourceTy, consumerLayoutAttr, reductionDims, uArch);
648
649 xegpu::setTemporaryLayout(reduction->getResult(0), requiredResLayoutAttr);
650
651 // derive the source layout from the dominant layout and reduction dims
652 auto srcLayoutAttr = xegpu::inferMultiReductionSourceLayout(
653 requiredResLayoutAttr, reductionDims);
654
655 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
656 // Accumulator should have the same layout as the result.
657 propagateIfChanged(operands[1],
658 operands[1]->meet(LayoutInfo(requiredResLayoutAttr)));
659}
660
661void LayoutInfoPropagation::visitVectorBroadCastOp(
662 vector::BroadcastOp broadcast, ArrayRef<LayoutInfoLattice *> operands,
663 ArrayRef<const LayoutInfoLattice *> results) {
664 // The layout of the result must be present.
665 LayoutInfo resLayoutInfo = results[0]->getValue();
666 if (!resLayoutInfo.isAssigned())
667 return;
668
669 // Only consider vector to vector broadcasts for now.
670 VectorType resultTy = broadcast.getResultVectorType();
671 VectorType sourceTy = dyn_cast<VectorType>(broadcast.getSourceType());
672 // skip layout propagation for non-vector source operand.
673 if (!sourceTy)
674 return;
675
676 auto srcShape = sourceTy.getShape();
677 auto resShape = resultTy.getShape();
678
679 size_t dimDiff = resultTy.getRank() - sourceTy.getRank();
680 for (size_t i = 0; i < srcShape.size(); i++)
681 if ((srcShape[i] == 1) && (resShape[i + dimDiff] != 1))
682 broadcast.emitWarning("broadcast must either from low-rank or same-rank "
683 "with unit-dim, mixed scenario is not supported!");
684
685 auto resultLayoutAttr =
686 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
687
688 xegpu::DistributeLayoutAttr srcLayoutAttr =
689 xegpu::inferBroadcastSourceLayout(resultLayoutAttr, resShape, srcShape);
690
691 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
692}
693
694void LayoutInfoPropagation::visitShapeCastOp(
695 vector::ShapeCastOp shapeCast, ArrayRef<LayoutInfoLattice *> operands,
696 ArrayRef<const LayoutInfoLattice *> results) {
697 // The layout of the result must be present.
698 LayoutInfo resLayoutInfo = results[0]->getValue();
699 if (!resLayoutInfo.isAssigned())
700 return;
701 ArrayRef<int64_t> resShape = shapeCast.getResultVectorType().getShape();
702 ArrayRef<int64_t> srcShape = shapeCast.getSourceVectorType().getShape();
703 auto resultLayoutAttr =
704 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
705
706 xegpu::DistributeLayoutAttr srcLayoutAttr =
707 xegpu::inferShapeCastSourceLayout(resultLayoutAttr, resShape, srcShape);
708
709 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
710}
711
712/// Propagate the layout of the result tensor to the source tensor descriptor
713/// in UpdateNdOffsetOp.
714void LayoutInfoPropagation::visitUpdateNdOffsetOp(
715 xegpu::UpdateNdOffsetOp updateNdOffset,
716 ArrayRef<LayoutInfoLattice *> operands,
717 ArrayRef<const LayoutInfoLattice *> results) {
718 // The layout of the result must be present.
719 LayoutInfo resultLayout = results[0]->getValue();
720 if (!resultLayout.isAssigned())
721 return;
722 // Propagate the layout to the source operand.
723 propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
724}
725
726/// Set the layouts for DPAS A, B, and C operands.
727void LayoutInfoPropagation::visitDpasOp(
728 xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
729 ArrayRef<const LayoutInfoLattice *> results) {
730 LayoutInfo dpasALayout;
731 LayoutInfo dpasBLayout;
732 LayoutInfo dpasCDLayout;
733
734 xegpu::DistributeLayoutAttr anchorLayoutCD = dpas.getLayoutCdAttr();
735 if (hasParamsOfLayoutKind(anchorLayoutCD)) {
736 xegpu::DistributeLayoutAttr anchorLayoutA = dpas.getLayoutAAttr();
737 xegpu::DistributeLayoutAttr anchorLayoutB = dpas.getLayoutBAttr();
738 assert(hasParamsOfLayoutKind(anchorLayoutA) &&
739 "Expected anchor layout for DPAS A operand.");
740 assert(hasParamsOfLayoutKind(anchorLayoutB) &&
741 "Expected anchor layout for DPAS B operand.");
742 dpasALayout = LayoutInfo(anchorLayoutA);
743 dpasBLayout = LayoutInfo(anchorLayoutB);
744 dpasCDLayout = LayoutInfo(anchorLayoutCD);
745 } else {
746 const uArch *uArch = getUArch(getChipStr(dpas).value_or(""));
747 if (!uArch)
748 return;
749 VectorType aTy = dpas.getLhsType();
750 VectorType bTy = dpas.getRhsType();
751 VectorType cdTy = dpas.getResultType();
752
753 xegpu::DistributeLayoutAttr consumerLayoutAttr = nullptr;
754 xegpu::DistributeLayoutAttr requiredCDLayoutAttr, requiredALayout,
755 requiredBLayout;
756
757 int numSg = 0;
758 if (layoutKind == xegpu::LayoutKind::Subgroup) {
759 LayoutInfo consumerLayout = results[0]->getValue();
760 if (!consumerLayout.isAssigned())
761 return;
762 consumerLayoutAttr =
763 dyn_cast<xegpu::DistributeLayoutAttr>(consumerLayout.get());
764 auto numSgOrErr = getNumSg(dpas, uArch->getSubgroupSize());
765 if (failed(numSgOrErr)) {
766 dpas.emitWarning(
767 "Unable to determine the number of subgroups for the operation.");
768 return;
769 }
770 numSg = numSgOrErr.value();
771 }
772 auto layouts = xegpu::setupDpasLayout(layoutKind, aTy, bTy, cdTy,
773 consumerLayoutAttr, uArch, numSg);
774 if (!layouts.has_value()) {
775 dpas.emitWarning(
776 "Failed to determine required layouts for DPAS operands.");
777 return;
778 }
779
780 std::tie(requiredALayout, requiredBLayout, requiredCDLayoutAttr) = *layouts;
781
782 dpas.setLayoutAAttr(requiredALayout);
783 dpas.setLayoutBAttr(requiredBLayout);
784 dpas.setLayoutCdAttr(requiredCDLayoutAttr);
785 dpasALayout = LayoutInfo(requiredALayout);
786 dpasBLayout = LayoutInfo(requiredBLayout);
787 dpasCDLayout = LayoutInfo(requiredCDLayoutAttr);
788 }
789 propagateIfChanged(operands[0], operands[0]->meet(dpasALayout));
790 propagateIfChanged(operands[1], operands[1]->meet(dpasBLayout));
791 if (operands.size() > 2)
792 propagateIfChanged(operands[2], operands[2]->meet(dpasCDLayout));
793}
794
795/// Set the layout for the value and tensor descriptor operands in StoreNdOp.
796void LayoutInfoPropagation::visitStoreNdOp(
797 xegpu::StoreNdOp store, ArrayRef<LayoutInfoLattice *> operands,
798 ArrayRef<const LayoutInfoLattice *> results) {
799 LayoutInfo storeLayout;
800 xegpu::DistributeLayoutAttr anchorLayout = store.getLayoutAttr();
801 if (hasParamsOfLayoutKind(anchorLayout)) {
802 storeLayout = LayoutInfo(anchorLayout);
803 } else {
804 const uArch *uArch = getUArch(getChipStr(store).value_or(""));
805 if (!uArch)
806 return;
807 const auto *uArchInstruction =
808 dyn_cast<xegpu::uArch::Subgroup2DBlockStoreInstruction>(
809 uArch->getInstruction(
810 xegpu::uArch::InstructionKind::Subgroup2DBlockStore));
811 VectorType dataTy = store.getValueType();
812 auto blockWHC = uArchInstruction->getBlockWidthHeightCount(
813 store.getValueType().getElementType());
814 if (!blockWHC)
815 store.emitWarning("No known block params found for the element type.");
816 auto [bWidth, bHeight, bCount] = blockWHC.value();
817 SmallVector<int> instData;
818 int instWidth = xegpu::getLargestDivisor(
819 static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 1)), bWidth);
820 if (instWidth == -1)
821 store.emitWarning(
822 "No suitable instruction multiple found for the given shape.");
823 if (dataTy.getRank() == 1)
824 instData = {instWidth};
825 else {
826 int instHeight = xegpu::getLargestDivisor(
827 static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 2)), bHeight);
828 if (instHeight == -1)
829 store.emitWarning(
830 "No suitable instruction multiple found for the given shape.");
831 instData = {instHeight, instWidth};
832 }
833
834 if (layoutKind == xegpu::LayoutKind::InstData)
835 storeLayout =
836 LayoutInfo(xegpu::LayoutAttr::get(dataTy.getContext(), instData));
837 else if (layoutKind == xegpu::LayoutKind::Lane)
838 storeLayout =
839 getSIMTLayoutInfoBlockIO(store.getValueType(), uArch,
840 uArchInstruction->getPackedFormatBitSize());
841 else { // xegpu::LayoutKind::Subgroup
842 auto sgSize = uArch->getSubgroupSize();
843 auto numSgOrErr = getNumSg(store, sgSize);
844 if (failed(numSgOrErr)) {
845 store.emitWarning(
846 "Unable to determine the number of subgroups for the operation.");
847 return;
848 }
849 auto sgLayouts = getValidLayouts(store.getValueType().getShape(),
850 instData, numSgOrErr.value());
851 if (sgLayouts.empty()) {
852 store.emitWarning(
853 "Unable to determine suitable subgroup layout for store value.");
854 return;
855 }
856 SmallVector<int> sgLayout = {sgLayouts[0].first, sgLayouts[0].second};
857 SmallVector<int> sgData = {
858 static_cast<int>(dataTy.getShape()[0]) / sgLayout[0],
859 static_cast<int>(dataTy.getShape()[1]) / sgLayout[1]};
860 storeLayout = LayoutInfo(xegpu::LayoutAttr::get(
861 dataTy.getContext(),
862 DenseI32ArrayAttr::get(dataTy.getContext(), sgLayout),
863 DenseI32ArrayAttr::get(dataTy.getContext(), sgData),
864 /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
865 /*lane_data =*/nullptr, /*order =*/nullptr));
866 }
867 store.setLayoutAttr(
868 dyn_cast<xegpu::DistributeLayoutAttr>(storeLayout.get()));
869 }
870 // Propagate the layout to the value operand.
871 // Both operands should have the same layout
872 for (LayoutInfoLattice *operand : operands)
873 propagateIfChanged(operand, operand->meet(storeLayout));
874}
875
876/// Propagate the layout of the value to the tensor descriptor operand in
877/// LoadNdOp.
878void LayoutInfoPropagation::visitLoadNdOp(
879 xegpu::LoadNdOp load, ArrayRef<LayoutInfoLattice *> operands,
880 ArrayRef<const LayoutInfoLattice *> results) {
881 LayoutInfo loadLayout;
882 xegpu::DistributeLayoutAttr anchorLayout = load.getLayoutAttr();
883 if (hasParamsOfLayoutKind(anchorLayout)) {
884 loadLayout = LayoutInfo(anchorLayout);
885 } else {
886
887 LayoutInfo valueLayout = results[0]->getValue();
888 // Need the layout of the value to propagate to the tensor descriptor.
889 if (!valueLayout.isAssigned())
890 return;
891 loadLayout = valueLayout;
892 // LoadNdOp has the transpose effect. However, at the stage of this analysis
893 // this effect is not expected and should be abstracted away. Emit a
894 // warning.
895 if (auto transpose = load.getTranspose()) {
896 load.emitWarning("Transpose effect is not expected for LoadNdOp at "
897 "LayoutInfoPropagation stage.");
898 loadLayout = valueLayout.transpose(transpose.value());
899 }
900 load.setLayoutAttr(dyn_cast<xegpu::DistributeLayoutAttr>(loadLayout.get()));
901 }
902 // Propagate the new layout to the tensor descriptor operand.
903 propagateIfChanged(operands[0], operands[0]->meet(loadLayout));
904}
905
906/// For vector::TransposeOp, the layout of the result is transposed and
907/// propagated to the operand.
908void LayoutInfoPropagation::visitTransposeOp(
909 vector::TransposeOp transpose, ArrayRef<LayoutInfoLattice *> operands,
910 ArrayRef<const LayoutInfoLattice *> results) {
911 // Need the layout of transpose result to propagate to the operands.
912 LayoutInfo resultLayout = results[0]->getValue();
913 if (!resultLayout.isAssigned())
914 return;
915 LayoutInfo newLayout = resultLayout.transpose(transpose.getPermutation());
916 // Propagate the new layout to the vector operand.
917 propagateIfChanged(operands[0], operands[0]->meet(newLayout));
918}
919
920/// For vector::BitCastOp, the lane_data of the source layout is changed based
921/// on the bit width of the source and result types.
922void LayoutInfoPropagation::visitVectorBitcastOp(
923 vector::BitCastOp bitcast, ArrayRef<LayoutInfoLattice *> operands,
924 ArrayRef<const LayoutInfoLattice *> results) {
925 // Need the layout of bitcast result to propagate to the operands.
926 LayoutInfo resLayoutInfo = results[0]->getValue();
927 if (!resLayoutInfo.isAssigned())
928 return;
929
930 auto srcVecType = bitcast.getSourceVectorType();
931 auto resVecType = bitcast.getResultVectorType();
932
933 auto consumerLayoutAttr =
934 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
935 const uArch *uArch = getUArch(xegpu::getChipStr(bitcast).value_or(""));
936 if (!uArch)
937 return;
938 auto requiredResLayoutAttr = setupBitCastResultLayout(
939 layoutKind, srcVecType, resVecType, consumerLayoutAttr, uArch);
940
941 xegpu::setTemporaryLayout(bitcast->getResult(0), requiredResLayoutAttr);
942
943 int inElemTyBitWidth = srcVecType.getElementType().getIntOrFloatBitWidth();
944 int outElemTyBitWidth = resVecType.getElementType().getIntOrFloatBitWidth();
945
946 // derive the source layout from the dominant layout and reduction dims
947 auto srcLayoutAttr = xegpu::inferBitCastSourceLayout(
948 requiredResLayoutAttr, outElemTyBitWidth, inElemTyBitWidth);
949
950 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
951}
952
953void LayoutInfoPropagation::visitInsertStridedSliceOp(
954 vector::InsertStridedSliceOp insertStridedSlice,
955 ArrayRef<LayoutInfoLattice *> operands,
956 ArrayRef<const LayoutInfoLattice *> results) {
957 // The layout of the result must be present.
958 LayoutInfo resLayoutInfo = results[0]->getValue();
959 if (!resLayoutInfo.isAssigned())
960 return;
961
962 auto srcVecType = insertStridedSlice.getSourceVectorType();
963 auto resVecType = insertStridedSlice.getDestVectorType();
964
965 auto consumerLayoutAttr =
966 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
967 const uArch *uArch =
968 getUArch(xegpu::getChipStr(insertStridedSlice).value_or(""));
969 if (!uArch)
970 return;
971
972 auto requiredResLayoutAttr = xegpu::setupInsertStridedSliceResultLayout(
973 layoutKind, srcVecType, resVecType, consumerLayoutAttr, uArch);
974
975 xegpu::setTemporaryLayout(insertStridedSlice->getResult(0),
976 requiredResLayoutAttr);
977
979 requiredResLayoutAttr, resVecType.getShape(), srcVecType.getShape());
980
981 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
982 propagateIfChanged(operands[1],
983 operands[1]->meet(LayoutInfo(requiredResLayoutAttr)));
984}
985
986/// Propagate the layout of the result to the tensor descriptor, mask and offset
987/// operands in LoadGatherOp.
988void LayoutInfoPropagation::visitLoadGatherOp(
989 xegpu::LoadGatherOp load, ArrayRef<LayoutInfoLattice *> operands,
990 ArrayRef<const LayoutInfoLattice *> results) {
991 xegpu::DistributeLayoutAttr requiredAnchorLayoutAttr;
992 xegpu::DistributeLayoutAttr anchorLayoutAttr = load.getLayoutAttr();
993 const uArch *uArch = getUArch(getChipStr(load).value_or(""));
994 if (!uArch)
995 return;
996 auto subgroupSize = uArch->getSubgroupSize();
997 VectorType resVecTy = load.getValueType();
998 int chunkSize = load.getChunkSize().value_or(1);
999
1000 LayoutInfo resLayoutInfo = results[0]->getValue();
1001 if (!resLayoutInfo.isAssigned())
1002 return;
1003 auto consumerLayoutAttr =
1004 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
1005
1006 if (hasParamsOfLayoutKind(anchorLayoutAttr)) {
1007 requiredAnchorLayoutAttr = anchorLayoutAttr;
1008 } else {
1009 if (!resVecTy) {
1010 load.emitWarning("Not propagating, non-vector payload supplied.");
1011 return;
1012 }
1013 requiredAnchorLayoutAttr = xegpu::setupLoadGatherAnchorLayout(
1014 layoutKind, resVecTy, chunkSize, consumerLayoutAttr, uArch);
1015 load.setLayoutAttr(requiredAnchorLayoutAttr);
1016 }
1017
1018 auto maskLayoutAttr = requiredAnchorLayoutAttr;
1019 // Special handling mask layout for chunked ops: Enforce the default xegpu 1D
1020 // layout for mask.
1021 if (chunkSize > 1) {
1022 if (layoutKind == xegpu::LayoutKind::InstData)
1023 maskLayoutAttr =
1024 xegpu::LayoutAttr::get(load->getContext(), {subgroupSize});
1025 else if (layoutKind == xegpu::LayoutKind::Lane)
1026 maskLayoutAttr =
1027 xegpu::LayoutAttr::get(load->getContext(), {subgroupSize}, {1});
1028 else
1029 assert(false &&
1030 "chunked StoreScatterOp should not be used at workgroup level");
1031 }
1032
1033 LayoutInfo maskLayoutInfo = LayoutInfo(maskLayoutAttr);
1034 auto loadLayoutInfo = LayoutInfo(requiredAnchorLayoutAttr);
1035
1036 // Propagate the new layout to the tensor descriptor operand.
1037 if (isa<xegpu::TensorDescType>(load.getSourceType()))
1038 propagateIfChanged(operands[0], operands[0]->meet(loadLayoutInfo));
1039 // Propagate the new layout to the mask and optional offset operand.
1040 propagateIfChanged(operands[1], operands[1]->meet(maskLayoutInfo));
1041 if (load.getOffsets())
1042 propagateIfChanged(operands[2], operands[2]->meet(maskLayoutInfo));
1043}
1044
1045/// Propagate the layout of the descriptor to the vector offset operand in
1046/// CreateDescOp.
1047void LayoutInfoPropagation::visitCreateDescOp(
1048 xegpu::CreateDescOp createDesc, ArrayRef<LayoutInfoLattice *> operands,
1049 ArrayRef<const LayoutInfoLattice *> results) {
1050 LayoutInfo descLayout = results[0]->getValue();
1051 // Need the layout of the descriptor to propagate to the operands.
1052 if (!descLayout.isAssigned())
1053 return;
1054 const uArch *uArch = getUArch(getChipStr(createDesc).value_or(""));
1055 if (!uArch)
1056 return;
1057 // For offset operand propagate 1D default layout.
1058 LayoutInfo layout = getDefaultSIMTLayoutInfo(createDesc->getContext(), 1,
1059 uArch->getSubgroupSize());
1060 propagateIfChanged(operands[1], operands[1]->meet(layout));
1061}
1062
1063/// Set the layout for the value, tensor descriptor, offset and mask operands in
1064/// the StoreScatterOp.
1065void LayoutInfoPropagation::visitStoreScatterOp(
1066 xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands,
1067 ArrayRef<const LayoutInfoLattice *> results) {
1068
1069 xegpu::DistributeLayoutAttr requiredAnchorLayoutAttr;
1070 xegpu::DistributeLayoutAttr anchorLayoutAttr = storeScatter.getLayoutAttr();
1071 const uArch *uArch = getUArch(getChipStr(storeScatter).value_or(""));
1072 if (!uArch)
1073 return;
1074 auto subgroupSize = uArch->getSubgroupSize();
1075 VectorType srcVecTy = storeScatter.getValueType();
1076 int chunkSize = storeScatter.getChunkSize().value_or(1);
1077
1078 if (hasParamsOfLayoutKind(anchorLayoutAttr)) {
1079 requiredAnchorLayoutAttr = anchorLayoutAttr;
1080 } else {
1081 if (!srcVecTy) {
1082 storeScatter.emitWarning("Not propagating, non-vector payload supplied.");
1083 return;
1084 }
1085 requiredAnchorLayoutAttr = xegpu::setupStoreScatterAnchorLayout(
1086 layoutKind, srcVecTy, chunkSize, uArch);
1087 storeScatter.setLayoutAttr(requiredAnchorLayoutAttr);
1088 }
1089
1090 LayoutInfo srcLayoutInfo = LayoutInfo(requiredAnchorLayoutAttr);
1091 auto maskLayoutAttr = requiredAnchorLayoutAttr;
1092 // Special handling mask layout for chunked ops: Enforce the default xegpu 1D
1093 // layout for mask.
1094 if (chunkSize > 1) {
1095 if (layoutKind == xegpu::LayoutKind::InstData)
1096 maskLayoutAttr =
1097 xegpu::LayoutAttr::get(storeScatter->getContext(), {subgroupSize});
1098 else if (layoutKind == xegpu::LayoutKind::Lane)
1099 maskLayoutAttr = xegpu::LayoutAttr::get(storeScatter->getContext(),
1100 {subgroupSize}, {1});
1101 else
1102 assert(false &&
1103 "chunked StoreScatterOp should not be used at workgroup level");
1104 }
1105
1106 LayoutInfo maskLayoutInfo = LayoutInfo(maskLayoutAttr);
1107
1108 // Propagate the payload operand layout
1109 propagateIfChanged(operands[0], operands[0]->meet(srcLayoutInfo));
1110 // Propagate the destination (if tdesc) operand layout
1111 if (isa<xegpu::TensorDescType>(storeScatter.getDestType()))
1112 propagateIfChanged(operands[1], operands[1]->meet(srcLayoutInfo));
1113 // Propagate the new layout to the mask and optional offset operand.
1114 propagateIfChanged(operands[2], operands[2]->meet(maskLayoutInfo));
1115 if (storeScatter.getOffsets())
1116 propagateIfChanged(operands[3], operands[3]->meet(maskLayoutInfo));
1117}
1118
1119void LayoutInfoPropagation::visitLoadMatrixOp(
1120 xegpu::LoadMatrixOp loadMatrixOp, ArrayRef<LayoutInfoLattice *> operands,
1121 ArrayRef<const LayoutInfoLattice *> results) {
1122
1123 LayoutInfo resLayoutInfo = results[0]->getValue();
1124 if (!resLayoutInfo.isAssigned())
1125 return;
1126
1127 auto consumerLayoutAttr =
1128 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
1129
1130 xegpu::DistributeLayoutAttr anchorLayout = loadMatrixOp.getLayoutAttr();
1131
1132 // only need to set anchor layout, no need to porpagate to memdesc and
1133 // offset
1134 if (!hasParamsOfLayoutKind(anchorLayout)) {
1135 VectorType resVecTy =
1136 llvm::cast<VectorType>(loadMatrixOp.getRes().getType());
1137 assert(resVecTy.getRank() == 2 && "Expecting 2D vector for store matrix.");
1138 const uArch *uArch = getUArch(getChipStr(loadMatrixOp).value_or(""));
1139 if (!uArch)
1140 return;
1141 auto requiredAnchorLayoutAttr = xegpu::setupLoadMatrixAnchorLayout(
1142 layoutKind, resVecTy, consumerLayoutAttr, uArch);
1143 loadMatrixOp.setLayoutAttr(requiredAnchorLayoutAttr);
1144 }
1145}
1146
1147// Store matrix is a flavor of scattered store for 2D shapes.
1148void LayoutInfoPropagation::visitStoreMatrixOp(
1149 xegpu::StoreMatrixOp storeMatrix, ArrayRef<LayoutInfoLattice *> operands,
1150 ArrayRef<const LayoutInfoLattice *> results) {
1151 xegpu::DistributeLayoutAttr anchorLayout = storeMatrix.getLayoutAttr();
1152 LayoutInfo layout;
1153 if (hasParamsOfLayoutKind(anchorLayout)) {
1154 layout = LayoutInfo(anchorLayout);
1155 } else {
1156 VectorType srcVecTy =
1157 llvm::cast<VectorType>(storeMatrix.getData().getType());
1158 assert(srcVecTy.getRank() == 2 && "Expecting 2D vector for store matrix.");
1159 const uArch *uArch = getUArch(getChipStr(storeMatrix).value_or(""));
1160 if (!uArch)
1161 return;
1162 auto requiredAnchorLayoutAttr =
1163 xegpu::setupStoreMatrixAnchorLayout(layoutKind, srcVecTy, uArch);
1164 storeMatrix.setLayoutAttr(requiredAnchorLayoutAttr);
1165 layout = LayoutInfo(requiredAnchorLayoutAttr);
1166 }
1167
1168 propagateIfChanged(operands[0], operands[0]->meet(layout));
1169}
1170
1171namespace {
1172//===----------------------------------------------------------------------===//
1173// RunLayoutInfoPropagation
1174//===----------------------------------------------------------------------===//
1175
1176/// Driver class for running the LayoutInfoPropagation analysis.
1177class RunLayoutInfoPropagation {
1178public:
1179 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RunLayoutInfoPropagation)
1180
1181 RunLayoutInfoPropagation(Operation *op, xegpu::LayoutKind layoutKind)
1182 : target(op) {
1183 SymbolTableCollection symbolTable;
1184 loadBaselineAnalyses(solver);
1185 solver.load<LayoutInfoPropagation>(symbolTable, layoutKind);
1186 (void)solver.initializeAndRun(op);
1187 }
1188
1189 LayoutInfo getLayoutInfo(Value val);
1190
1191 void printAnalysisResult(llvm::raw_ostream &os);
1192
1193private:
1194 DataFlowSolver solver;
1195 const Operation *target;
1196};
1197} // namespace
1198
1199LayoutInfo RunLayoutInfoPropagation::getLayoutInfo(Value val) {
1200 auto *state = solver.lookupState<LayoutInfoLattice>(val);
1201 if (!state)
1202 return {};
1203 return state->getValue();
1204}
1205
1206// Print the analysis result for debugging purposes.
1207void RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) {
1208 auto printFunctionResult = [&](FunctionOpInterface funcOp) {
1209 os << "function: " << funcOp.getName() << ":\n";
1210 // Function arguments
1211 for (BlockArgument arg : funcOp.getArguments()) {
1212 LayoutInfo layout = getLayoutInfo(arg);
1213 os << "argument: " << arg << "\n";
1214 os << "layout : ";
1215 layout.print(os);
1216 os << "\n";
1217 }
1218 // Function ops
1219 funcOp.walk([&](Operation *op) {
1220 // Skip ops that do not have results
1221 if (op->getResults().empty())
1222 return;
1223 os << "op : ";
1224 // For control-flow ops, print the op name only.
1225 if (isa<BranchOpInterface>(op) || isa<RegionBranchOpInterface>(op))
1226 os << op->getName();
1227 else
1228 op->print(os);
1229 os << "\n";
1230 // Print the layout for each result.
1231 for (auto [i, r] : llvm::enumerate(op->getResults())) {
1232 LayoutInfo layout = getLayoutInfo(r);
1233 os << "layout for result #" << i << ": ";
1234 layout.print(os);
1235 os << "\n";
1236 }
1237 });
1238 };
1239
1240 SmallVector<FunctionOpInterface> funcOps;
1241 if (auto modOp = dyn_cast<ModuleOp>(target)) {
1242 for (auto funcOp : modOp.getOps<FunctionOpInterface>())
1243 funcOps.push_back(funcOp);
1244
1245 // Collect all GpuFuncOps in the module.
1246 for (auto gpuModOp : modOp.getOps<gpu::GPUModuleOp>()) {
1247 for (auto gpuFuncOp : gpuModOp.getOps<FunctionOpInterface>())
1248 funcOps.push_back(gpuFuncOp);
1249 }
1250 }
1251 // Print the analysis result for each function.
1252 for (FunctionOpInterface funcOp : funcOps)
1253 printFunctionResult(funcOp);
1254}
1255
1256namespace {
1257
1258//===----------------------------------------------------------------------===//
1259// ResolveLayoutConflicts
1260//===----------------------------------------------------------------------===//
1261
1262/// Helper to get the defining CreateNdDescOp of a tensor descriptor value. This
1263/// function tries to find the defining CreateNdDescOp recursively accross
1264/// control-flow boundaries.
1265static xegpu::CreateNdDescOp getDefiningCreateNdDescOp(Value tdescValue) {
1266 // Try to get the defining CreateNdDescOp of the tensor descriptor.
1267 auto definingOp = tdescValue.getDefiningOp<xegpu::CreateNdDescOp>();
1268 if (definingOp)
1269 return definingOp;
1270 // If tdescValue is an argument, try to get the tied init value from the
1271 // parent loop-like op.
1272 if (auto arg = dyn_cast<BlockArgument>(tdescValue)) {
1273 auto *parentOp = arg.getOwner()->getParentOp();
1274 if (auto loop = dyn_cast<LoopLikeOpInterface>(parentOp)) {
1275 OpOperand *tiedInit = loop.getTiedLoopInit(arg);
1276 if (tiedInit)
1277 return getDefiningCreateNdDescOp(tiedInit->get());
1278 }
1279 }
1280 // If not found, return null.
1281 return nullptr;
1282}
1283
1284struct ResolveLayoutConflicts {
1285 ResolveLayoutConflicts(Operation *parentOp)
1286 : parentOp(parentOp), builder(parentOp->getContext()) {}
1287 LogicalResult run();
1288
1289private:
1290 Operation *parentOp;
1291 OpBuilder builder;
1292 LogicalResult resolveTensorDescConsumer(OpOperand &operand);
1293 LogicalResult resolveVectorConsumer(OpOperand &operand);
1294};
1295
1296} // namespace
1297
1298LogicalResult ResolveLayoutConflicts::run() {
1299 // Scan all operations in the parent op and resolve layout conflicts at
1300 // tensor descriptor and vector use points.
1301 auto r = parentOp->walk([&](Operation *op) -> WalkResult {
1302 for (OpOperand &operand : op->getOpOperands()) {
1303 // Handle conflicts in tensor descriptor operands.
1304 Type operandType = operand.get().getType();
1305 if (isa<xegpu::AnchorLayoutInterface>(op) &&
1306 isa<xegpu::TensorDescType>(operandType)) {
1307 auto res = resolveTensorDescConsumer(operand);
1308 if (failed(res)) {
1309 DBGS() << "Failed to resolve tensor descriptor consumer: " << *op
1310 << "\n";
1311 return WalkResult::interrupt();
1312 }
1313 }
1314 // Handle conflicts in vector operands.
1315 if (isa<VectorType>(operandType)) {
1316 auto res = resolveVectorConsumer(operand);
1317 if (failed(res)) {
1318 DBGS() << "Failed to resolve vector consumer: " << *op << "\n";
1319 return WalkResult::interrupt();
1320 }
1321 }
1322 }
1323 return WalkResult::advance();
1324 });
1325
1326 return r.wasInterrupted() ? failure() : success();
1327}
1328
1329LogicalResult
1330ResolveLayoutConflicts::resolveVectorConsumer(OpOperand &operand) {
1331 Value vectorValue = operand.get();
1332 Operation *consumerOp = operand.getOwner();
1333 // Get the current layout of the vector value.
1334 auto producerLayout = xegpu::getDistributeLayoutAttr(vectorValue);
1335 if (!producerLayout) {
1336 if (auto vectorTy = dyn_cast<VectorType>(vectorValue.getType());
1337 vectorTy && vectorTy.getRank() > 1)
1338 consumerOp->emitWarning("Expected layout for non-1D vectors.");
1339 return success(); // uniform non-tensor-data vector does not require layout
1340 }
1341 // Get the consumer expected layout at this operand.
1342 auto consumerLayout = xegpu::getConsumerLayoutAt(operand);
1343 if (!consumerLayout)
1344 return consumerOp->emitError(
1345 "No consumer layout found for vector operand.");
1346
1347 // If layouts are same, no conflict exists, return success.
1348 if (consumerLayout.isEqualTo(producerLayout))
1349 return success();
1350
1351 // Insert a convert_layout op to resolve the conflict.
1352 builder.setInsertionPointAfterValue(vectorValue);
1353 auto convertOp = xegpu::ConvertLayoutOp::create(
1354 builder, consumerOp->getLoc(), vectorValue.getType(), vectorValue,
1355 producerLayout, consumerLayout);
1356
1357 // Update the operand to use the converted value.
1358 operand.set(convertOp.getResult());
1359 return success();
1360}
1361
1362LogicalResult
1363ResolveLayoutConflicts::resolveTensorDescConsumer(OpOperand &operand) {
1364 Operation *consumerOp = operand.getOwner();
1365 Value tdescValue = operand.get();
1366 auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(consumerOp);
1367 auto currTDescType = dyn_cast<xegpu::TensorDescType>(tdescValue.getType());
1368 assert(anchorOp && currTDescType &&
1369 "Expected anchor layout op and tensor descriptor consumer.");
1370 // TODO: Scattered tensor desc is not supported for now.
1371 if (currTDescType.isScattered()) {
1372 DBGS() << "Scattered tensor descriptor not supported: " << tdescValue
1373 << "\n";
1374 return failure();
1375 }
1376 Attribute currLayout = currTDescType.getLayout();
1377 Attribute expectedLayout = anchorOp.getAnchorLayout();
1378 // A conflict exists in tensor descriptor operand if tensor descriptor's
1379 // layout is different from the anchor layout expected by the consumer.
1380 if (expectedLayout && currLayout && expectedLayout != currLayout) {
1381 // Try to get the defining CreateNdDescOp of the tensor descriptor.
1382 auto conflictingCreateNdOp = getDefiningCreateNdDescOp(tdescValue);
1383 if (!conflictingCreateNdOp) {
1384 DBGS() << "Unable to find defining CreateNdDescOp for tensor descriptor: "
1385 << tdescValue << "\n";
1386 return failure();
1387 }
1388 // Duplicate the CreateNdDescOp with the expected layout.
1389 builder.setInsertionPointAfter(conflictingCreateNdOp);
1390 auto newTensorDescType = xegpu::TensorDescType::get(
1391 conflictingCreateNdOp.getContext(), currTDescType.getShape(),
1392 currTDescType.getElementType(), currTDescType.getEncoding(),
1393 expectedLayout);
1394 xegpu::CreateNdDescOp newOp = xegpu::CreateNdDescOp::create(
1395 builder, consumerOp->getLoc(), newTensorDescType,
1396 conflictingCreateNdOp->getOperands(),
1397 conflictingCreateNdOp->getAttrs());
1398 // Replace the tensor descriptor operand in the consumer op with the new
1399 // tensor descriptor.
1400 consumerOp->replaceUsesOfWith(tdescValue, newOp.getResult());
1401 }
1402 return success();
1403}
1404
1405using GetLayoutFnTy = function_ref<xegpu::DistributeLayoutAttr(Value)>;
1406/// Update an operation with the layout of its results. If the result type is
1407/// a vector type, a temporary layout attribute is added to the operation. If
1408/// the result type is a tensor descriptor type, the type is updated with the
1409/// layout attribute. The users of the result are also updated with the layout
1410/// attribute.
1411static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
1412 GetLayoutFnTy getLayoutOfValue) {
1413 // Region ops (like scf.for) are already handled by the
1414 // updateControlFlowOps.
1415 if (mlir::isa<mlir::RegionBranchOpInterface>(op))
1416 return success();
1417
1418 // Iterate over all the results.
1419 for (OpResult result : op->getResults()) {
1420 Type resultType = result.getType();
1421 // Layouts are needed only for vector and tensor descriptor types.
1422 if (!isa<VectorType, xegpu::TensorDescType>(resultType))
1423 continue;
1424 // If the result has no layout but has users, emit a warning and continue.
1425 xegpu::DistributeLayoutAttr layout = getLayoutOfValue(result);
1426 if (!layout && result.getNumUses() > 0) {
1427 op->emitWarning("op has users but no layout assigned for its result");
1428 continue;
1429 }
1430 // If the result is a tensor descriptor type, update the tensor desc type
1431 // with layout.
1432 if (auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType)) {
1433 auto typeWithLayout = xegpu::TensorDescType::get(
1434 tensorDescTy.getContext(), tensorDescTy.getShape(),
1435 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
1436 result.setType(typeWithLayout);
1437 continue;
1438 }
1439 // If the result is a vector type, add a temporary layout attribute to the
1440 // op.
1442 }
1443 return success();
1444}
1445
1446/// Region ops like scf.for need special handling because they have blocks
1447/// inside. If the blocks have tensor descriptor type as block arguments,
1448/// thier types must be updated. Also region op can have results that may not
1449/// have any users (e.g. A and B tiles). They are not assigned a layout by
1450/// layout analysis because they have no users. However inside the region op
1451/// corresponding block arguments for these results do have layouts.
1452/// Therefore, in this case we still need to update the result types with the
1453/// layout attribute. This function function updates the internal block
1454/// arguments and the result types of the region op with the assigned layouts.
1455/// clang-format off
1456/// Example: scf.for ... iter_args(...) -> (out types) {
1457/// ^bb0(block types):
1458/// ...
1459/// scf.yield ... : (yield types)
1460/// }
1461/// clang-format on
1462/// In this example, at scf.yield, control-flow can transfer to two successor
1463/// regions. One is the ^bb0 (for loop body) and the other is the scf.for op
1464/// itself (yield the results). So we update both the block arguments of the
1465/// successor region (i.e. block types) and the result types of the scf.for op
1466/// (i.e. out types). Note that yield types are updated by respective
1467/// producers inside bb0.
1468static LogicalResult
1470 mlir::RegionBranchTerminatorOpInterface terminator,
1471 GetLayoutFnTy getLayoutOfValue) {
1472 // Only process if the terminator is inside a region branch op.
1473 auto branchOp = dyn_cast<RegionBranchOpInterface>(terminator->getParentOp());
1474 if (!branchOp)
1475 return success();
1476
1478 branchOp.getSuccessorOperandInputMapping(mapping,
1479 RegionBranchPoint(terminator));
1480 for (const auto &[successorOperand, successorInputs] : mapping) {
1481 for (Value successorInput : successorInputs) {
1482 Type inputType = successorInput.getType();
1483 // We only need to operate on tensor descriptor or vector types.
1484 if (!isa<xegpu::TensorDescType, VectorType>(inputType))
1485 continue;
1486 xegpu::DistributeLayoutAttr successorInputLayout =
1487 getLayoutOfValue(successorInput);
1488 xegpu::DistributeLayoutAttr successorOperandLayout =
1489 getLayoutOfValue(successorOperand->get());
1490
1491 // If either of the layouts is not assigned, we cannot proceed.
1492 if (!successorOperandLayout) {
1493 LLVM_DEBUG(DBGS() << "No layout assigned for forwarded operand in "
1494 "branch terminator: "
1495 << successorOperand->get() << "\n");
1496 return failure();
1497 }
1498 // We expect the layouts to match.
1499 if (successorInputLayout &&
1500 successorInputLayout != successorOperandLayout) {
1501 LLVM_DEBUG(DBGS() << "Conflicting layouts for region argument and "
1502 "operand forwarded as the argument: "
1503 << successorInputLayout << " vs "
1504 << successorOperandLayout << "\n");
1505 return failure();
1506 }
1507 // Get tensor descriptor type with the layout.
1508 if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(inputType)) {
1509 auto newTdescTy = xegpu::TensorDescType::get(
1510 tdescTy.getContext(), tdescTy.getShape(), tdescTy.getElementType(),
1511 tdescTy.getEncoding(), successorOperandLayout);
1512 successorInput.setType(newTdescTy);
1513 continue;
1514 }
1515 // If the type is a vector type and this region argument is an OpResult,
1516 // set the layout attribute on the OpResult.
1517 if (auto result = dyn_cast<OpResult>(successorInput))
1518 xegpu::setDistributeLayoutAttr(result, successorOperandLayout);
1519 }
1520 }
1521 return success();
1522}
1523
1524/// Update the function arguments and results with the layouts.
1525static LogicalResult updateFunctionOpInterface(mlir::OpBuilder &builder,
1526 mlir::FunctionOpInterface funcOp,
1527 GetLayoutFnTy getLayoutOfValue) {
1528 // Only process functions whose type is a standard MLIR FunctionType.
1529 // Functions using a different type representation (e.g. llvm.func with
1530 // LLVMFunctionType) are not targets for XeGPU layout propagation, and
1531 // calling setType(FunctionType{}) on them would corrupt their type.
1532 if (!isa<FunctionType>(funcOp.getFunctionType()))
1533 return success();
1534 SmallVector<Type> newArgTypes;
1535 // Update the function arguments.
1536 for (BlockArgument arg : funcOp.getArguments()) {
1537 Type argType = arg.getType();
1538 newArgTypes.push_back(argType);
1539 if (!isa<VectorType, xegpu::TensorDescType>(argType))
1540 continue;
1541 xegpu::DistributeLayoutAttr layout = getLayoutOfValue(arg);
1542 if (!layout) {
1543 LLVM_DEBUG(DBGS() << "Expecting layout for function argument: " << arg
1544 << " but got none.\n");
1545 return failure();
1546 }
1547 if (auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(argType)) {
1548 auto newTdescTy = xegpu::TensorDescType::get(
1549 tensorDescTy.getContext(), tensorDescTy.getShape(),
1550 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
1551 arg.setType(newTdescTy);
1552 newArgTypes.back() = newTdescTy;
1553 }
1554 }
1555 // Update the function type with the new argument types.
1556 // NOTE: We assume that function results are not expected to have layouts.
1557 funcOp.setType(FunctionType::get(funcOp.getContext(), newArgTypes,
1558 funcOp.getResultTypes()));
1559 return success();
1560}
1561
1562namespace {
1563struct XeGPUPropagateLayoutPass final
1564 : public xegpu::impl::XeGPUPropagateLayoutBase<XeGPUPropagateLayoutPass> {
1565 XeGPUPropagateLayoutPass() = default;
1566 XeGPUPropagateLayoutPass(const XeGPUPropagateLayoutPass &other) = default;
1567 XeGPUPropagateLayoutPass(xegpu::XeGPUPropagateLayoutOptions options)
1568 : XeGPUPropagateLayoutBase(std::move(options)) {}
1569 void runOnOperation() override;
1570};
1571
1572} // namespace
1573
1575 LayoutKind layoutKind, bool printOnly) {
1576 RunLayoutInfoPropagation analysis(target, layoutKind);
1577 // Print the analysis result and exit. (for debugging purposes)
1578 if (printOnly) {
1579 auto &os = llvm::outs();
1580 analysis.printAnalysisResult(os);
1581 return success();
1582 }
1583 // Helper to convert LayoutInfo to xegpu::LayoutAttr.
1584 auto getXeGPULayoutForValue = [&](Value val) -> xegpu::DistributeLayoutAttr {
1585 LayoutInfo layout = analysis.getLayoutInfo(val);
1586 if (!layout.isAssigned())
1587 return {};
1588 if (auto opResult = dyn_cast<OpResult>(val)) {
1589
1590 Operation *defOp = opResult.getDefiningOp();
1591 if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(defOp)) {
1592 auto anchorLayout = anchorOp.getAnchorLayout();
1593 if (anchorLayout != nullptr)
1594 return anchorLayout;
1595 }
1596 xegpu::DistributeLayoutAttr requiredResLayoutAttr =
1597 xegpu::getTemporaryLayout(opResult);
1598 if (requiredResLayoutAttr != nullptr)
1599 return requiredResLayoutAttr;
1600 }
1601 xegpu::DistributeLayoutAttr layoutAttr =
1602 cast<xegpu::DistributeLayoutAttr>(layout.get());
1603 if (layout.isSliceLayout())
1604 return cast<xegpu::SliceAttr>(layoutAttr);
1605
1606 return cast<xegpu::LayoutAttr>(layoutAttr);
1607 };
1608
1609 Operation *op = target;
1610 auto walkResult = op->walk([&](mlir::Block *block) -> WalkResult {
1611 for (mlir::Operation &op : llvm::reverse(block->getOperations())) {
1612 LogicalResult r = success();
1614 .Case([&](mlir::RegionBranchTerminatorOpInterface branchTermOp) {
1615 r = updateControlFlowOps(builder, branchTermOp,
1616 getXeGPULayoutForValue);
1617 })
1618 .Case([&](mlir::FunctionOpInterface funcOp) {
1619 r = updateFunctionOpInterface(builder, funcOp,
1620 getXeGPULayoutForValue);
1621 })
1622 .Default([&](Operation *op) {
1623 r = updateOp(builder, op, getXeGPULayoutForValue);
1624 });
1625 if (failed(r)) {
1626 op.emitError("Failed to update operation with the layout.");
1627 return WalkResult::interrupt();
1628 }
1629 }
1630 return WalkResult::advance();
1631 });
1632 if (walkResult.wasInterrupted())
1633 return failure();
1634
1635 return success();
1636}
1637
1639 ResolveLayoutConflicts resolver(target);
1640 return resolver.run();
1641}
1642
1643void XeGPUPropagateLayoutPass::runOnOperation() {
1644 xegpu::LayoutKind layoutKind;
1645 if (this->layoutKind == "lane") {
1646 layoutKind = xegpu::LayoutKind::Lane;
1647 } else if (this->layoutKind == "inst") {
1648 layoutKind = xegpu::LayoutKind::InstData;
1649 } else if (this->layoutKind == "subgroup") {
1650 layoutKind = xegpu::LayoutKind::Subgroup;
1651 } else {
1652 getOperation()->emitError("Unsupported layout kind option: " +
1653 this->layoutKind);
1654 signalPassFailure();
1655 return;
1656 }
1657 OpBuilder builder(&getContext());
1658 if (failed(xegpu::propagateLayouts(builder, getOperation(), layoutKind,
1659 this->printOnly))) {
1660 signalPassFailure();
1661 return;
1662 }
1663 // Resolve layout conflicts if any.
1664 if (failed(xegpu::resolveLayoutConflicts(getOperation()))) {
1665 signalPassFailure();
1666 return;
1667 }
1668}
return success()
#define DBGS()
Definition Hoisting.cpp:32
std::string join(const Ts &...args)
Helper function to concatenate arguments into a std::string.
lhs
b getContext())
auto load
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition TypeID.h:331
static SmallVector< LayoutRepresentation > getValidLayouts(ArrayRef< int64_t > wgShape, ArrayRef< int64_t > instData, int64_t sgCount)
static LogicalResult updateControlFlowOps(mlir::OpBuilder &builder, mlir::RegionBranchTerminatorOpInterface terminator, GetLayoutFnTy getLayoutOfValue)
Region ops like scf.for need special handling because they have blocks inside.
function_ref< xegpu::DistributeLayoutAttr(Value)> GetLayoutFnTy
FailureOr< int64_t > getNumSg(Operation *op, const int sgSize)
static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op, GetLayoutFnTy getLayoutOfValue)
Update an operation with the layout of its results.
static LogicalResult updateFunctionOpInterface(mlir::OpBuilder &builder, mlir::FunctionOpInterface funcOp, GetLayoutFnTy getLayoutOfValue)
Update the function arguments and results with the layouts.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:309
Block represents an ordered list of Operations.
Definition Block.h:33
OpListType & getOperations()
Definition Block.h:147
The general data-flow analysis solver.
const StateT * lookupState(AnchorT anchor) const
Lookup an analysis state for the given lattice anchor.
AnalysisT * load(Args &&...args)
Load an analysis into the solver. Return the analysis instance.
LogicalResult initializeAndRun(Operation *top)
Initialize the children analyses starting from the provided top-level operation and run the analysis ...
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:209
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition Builders.h:423
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
This class represents an operand of an operation.
Definition Value.h:257
This is a value defined by a result of an operation.
Definition Value.h:457
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:391
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:238
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
void print(raw_ostream &os, const OpPrintingFlags &flags={})
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:805
result_range getResults()
Definition Operation.h:423
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
This class represents a successor of a region.
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
This class represents a lattice holding a specific value of type ValueT.
A sparse (backward) data-flow analysis for propagating SSA value lattices backwards across the IR by ...
SparseBackwardDataFlowAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
void loadBaselineAnalyses(DataFlowSolver &solver)
Populates a DataFlowSolver with analyses that are required to ensure user-defined analyses are run pr...
Definition Utils.h:29
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
const uArch * getUArch(llvm::StringRef archName)
DistributeLayoutAttr inferShapeCastSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for a shape cast operation given the result layout attribute,...
LogicalResult propagateLayouts(OpBuilder &builder, Operation *target, LayoutKind layoutKind, bool printOnly=false)
SliceAttr setupMultiReductionResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, DistributeLayoutAttr consumerLayout, SmallVector< int64_t > reductionDims, const uArch::uArch *uArch)
Sets up layout for reduction operations by creating a SliceAttr for the result.
DistributeLayoutAttr inferInsertStridedSliceSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for an insert strided slice operation given the result layout attr...
void setTemporaryLayout(const T &operandOrResult, const DistributeLayoutAttr layout)
std::optional< std::tuple< DistributeLayoutAttr, DistributeLayoutAttr, DistributeLayoutAttr > > setupDpasLayout(LayoutKind layoutKind, VectorType aTy, VectorType bTy, VectorType cdTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch, int numSg)
Sets up the anchor layouts for a dpas operands (A, B, and C/D).
LayoutKind
Specifies the level of a layout hierarchy for comparison or propagation.
Definition XeGPU.h:32
void setDistributeLayoutAttr(const OpResult &Result, const DistributeLayoutAttr layout)
[to-be-deprecated] Sets the DistributeLayoutAttr for a given OpResult user should use setAnchorLayout...
DistributeLayoutAttr setupLoadMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the anchor layout for load matrix operation.
int getLargestDivisor(T dim, ArrayRef< T > candidates, ArrayRef< T > candidateMultiples={})
Helper Function to find a proper instruction multiple for the user-supplied sg-level data shape (dive...
DistributeLayoutAttr inferBroadcastSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for a broadcast operation given the result layout attribute,...
DistributeLayoutAttr setupStoreScatterAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, int chunkSize, const uArch::uArch *uArch)
Sets up the anchor layout for a store scatter operation.
DistributeLayoutAttr setupBitCastResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Setup the result layout attribute for a bitcast operation based on element type bitwidths.
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
LogicalResult resolveLayoutConflicts(Operation *target)
DistributeLayoutAttr inferBitCastSourceLayout(DistributeLayoutAttr resLayout, int resElemTyBitWidth, int srcElemTyBitWidth)
Infers the source layout attribute for a bitcast operation given the result layout attribute,...
DistributeLayoutAttr setupInsertStridedSliceResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the result layout for an insert strided slice operation.
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
xegpu::DistributeLayoutAttr getConsumerLayoutAt(OpOperand &operand)
Gets the expected layout for a given consumer operand.
DistributeLayoutAttr inferMultiReductionSourceLayout(DistributeLayoutAttr resLayout, SmallVector< int64_t > reduceDims)
Infers the source layout attribute for a reduction operation given the result layout attribute and re...
DistributeLayoutAttr setupLoadGatherAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, int chunkSize, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the anchor layout for a load gather operation.
DistributeLayoutAttr setupStoreMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, const uArch::uArch *uArch)
Sets up the anchor layout for a store matrix operation.
Include the generated interface declarations.
DenseMap< OpOperand *, SmallVector< Value > > RegionBranchSuccessorMapping
A mapping from successor operands to successor inputs.
bool operator==(StringAttr lhs, std::nullptr_t)
Define comparisons for StringAttr against nullptr and itself to avoid the StringRef overloads from be...
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:136
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144
virtual int getSubgroupSize() const =0
const Instruction * getInstruction(InstructionKind instKind) const
Definition uArchBase.h:163