MLIR 22.0.0git
XeGPUPropagateLayout.cpp
Go to the documentation of this file.
1//===- XeGPUPropagateLayout.cpp - XeGPU Layout Propagation ------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
19#include "mlir/IR/Attributes.h"
20#include "mlir/IR/Builders.h"
23#include "mlir/IR/Operation.h"
24#include "mlir/IR/Value.h"
25#include "mlir/IR/Visitors.h"
28#include "mlir/Support/LLVM.h"
29#include "llvm/ADT/ArrayRef.h"
30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/SmallSet.h"
32#include "llvm/ADT/SmallVector.h"
33#include "llvm/ADT/TypeSwitch.h"
34#include "llvm/Support/Casting.h"
35#include "llvm/Support/Debug.h"
36#include "llvm/Support/LogicalResult.h"
37#include "llvm/Support/raw_ostream.h"
38
40
41namespace mlir {
42namespace xegpu {
43#define GEN_PASS_DEF_XEGPUPROPAGATELAYOUT
44#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
45} // namespace xegpu
46} // namespace mlir
47
48#define DEBUG_TYPE "xegpu-propagate-layout"
49#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
50
51using namespace mlir;
52using namespace mlir::dataflow;
53
54namespace {
55
56enum class LayoutKind { Lane, InstData };
57
58//===----------------------------------------------------------------------===//
59// LayoutInfo
60//===----------------------------------------------------------------------===//
61
62/// Helper class for tracking the analysis state of an mlir value. For layout
63/// propagation, the analysis state is simply the distribution layout of
64/// each value. The distribution layout information is encapsulated using
65/// xegpu::DistributeLayoutAttr class which can hold information about any type
66/// of distribution layout that XeGPU dialect supports. Purpose of this analysis
67/// to propagate some unique distribution layout for each value in the program
68/// starting from a set of anchor operations (like DPAS, StoreNd, etc.). Note
69/// that analysis will reach a fixed point when all values are reached some
70/// layout and, analysis does not try to modify any already assigned layouts.
71///
72/// Given this, LayoutInfo satisifies the following properties:
73/// 1) A LayoutInfo value can be in one of two states - `assigned` or `not
74/// assigned`.
75/// 2) Two LayoutInfo values are equal if they are both assigned or
76/// both not assigned. The concrete value of assigned state does not matter.
77/// 3) The meet operator works as follows:
78/// - If current state is assigned, return the current state. (already
79/// a unique layout is assigned. don't change it)
80/// - Otherwise, return the other state.
81
82struct LayoutInfo {
83private:
84 xegpu::DistributeLayoutAttr storage = nullptr;
85
86public:
87 LayoutInfo() = default;
88 LayoutInfo(const xegpu::DistributeLayoutAttr &layout) : storage(layout) {}
89
90 // Two lattice values are equal if they have `some` layout. The actual
91 // content of the layout does not matter.
92 bool operator==(const LayoutInfo &other) const {
93 return this->isAssigned() == other.isAssigned();
94 }
95
96 static LayoutInfo meet(const LayoutInfo &lhs, const LayoutInfo &rhs);
97
98 static LayoutInfo join(const LayoutInfo &lhs, const LayoutInfo &rhs);
99
100 void print(raw_ostream &os) const;
101
102 bool isAssigned() const { return storage != nullptr; }
103
104 LayoutInfo transpose(ArrayRef<int64_t> permutation) const;
105
106 SmallVector<int> getLaneLayout() const;
107
108 SmallVector<int> getLaneData() const;
109
110 SmallVector<int> getInstData() const;
111
112 bool isSliceLayout() const {
113 if (!isAssigned())
114 return false;
115 return isa<xegpu::SliceAttr>(storage);
116 }
117
118 int64_t getRank() const {
119 if (!isAssigned())
120 return -1;
121 return storage.getRank();
122 }
123
124 Attribute get() { return storage; }
125};
126
127SmallVector<int> LayoutInfo::getLaneLayout() const {
128 if (!isAssigned())
129 return {};
130 assert(storage.getEffectiveLaneLayoutAsInt().size() &&
131 "Expected lane layout to be assigned");
132 return llvm::map_to_vector(storage.getEffectiveLaneLayoutAsInt(),
133 [](int64_t val) { return static_cast<int>(val); });
134}
135
136SmallVector<int> LayoutInfo::getLaneData() const {
137 if (!isAssigned())
138 return {};
139 assert(storage.getEffectiveLaneDataAsInt().size() &&
140 "Expected lane data to be assigned");
141 return llvm::map_to_vector(storage.getEffectiveLaneDataAsInt(),
142 [](int64_t val) { return static_cast<int>(val); });
143}
144
145SmallVector<int> LayoutInfo::getInstData() const {
146 if (!isAssigned())
147 return {};
148 return llvm::map_to_vector(storage.getEffectiveInstDataAsInt(),
149 [](int64_t val) { return static_cast<int>(val); });
150}
151
152void LayoutInfo::print(raw_ostream &os) const {
153 if (isAssigned()) {
154 os << storage;
155 } else {
156 os << "Not assigned.";
157 }
158}
159
160LayoutInfo LayoutInfo::meet(const LayoutInfo &lhs, const LayoutInfo &rhs) {
161 if (!lhs.isAssigned())
162 return rhs;
163 return lhs;
164}
165
166/// Since this is a backward analysis, join method is not used.
167LayoutInfo LayoutInfo::join(const LayoutInfo &lhs, const LayoutInfo &rhs) {
168 llvm_unreachable("Join should not be triggered by layout propagation.");
169}
170
171/// Construct a new layout with the transposed inst_data or lane_layout,
172/// lane_data.
173LayoutInfo LayoutInfo::transpose(ArrayRef<int64_t> permutation) const {
174 if (!isAssigned())
175 return {};
176 // Check if the permutation is valid.
177 llvm::SmallSet<int64_t, 4> seen(permutation.begin(), permutation.end());
178 bool hasDuplicates = seen.size() != permutation.size();
179 bool withinRange = llvm::all_of(permutation, [&](int64_t idx) {
180 return idx >= 0 && idx < static_cast<int64_t>(permutation.size());
181 });
182
183 if (!withinRange || hasDuplicates) {
184 assert(false && "Invalid permutation for transpose.");
185 return {};
186 }
187
188 SmallVector<int32_t> laneLayout;
189 SmallVector<int32_t> laneData;
190 SmallVector<int32_t> instData;
191 for (int64_t idx : permutation) {
192 if (getLaneLayout().size()) {
193 laneLayout.push_back(static_cast<int32_t>(getLaneLayout()[idx]));
194 laneData.push_back(static_cast<int32_t>(getLaneData()[idx]));
195 }
196 if (getInstData().size())
197 instData.push_back(static_cast<int32_t>(getInstData()[idx]));
198 }
199 xegpu::LayoutAttr layoutAttr;
200 if (getLaneLayout().size())
201 layoutAttr =
202 xegpu::LayoutAttr::get(storage.getContext(), laneLayout, laneData);
203 if (getInstData().size())
204 layoutAttr = xegpu::LayoutAttr::get(storage.getContext(), instData);
205 return LayoutInfo(layoutAttr);
206}
207
208//===----------------------------------------------------------------------===//
209// LayoutInfoLattice
210//===----------------------------------------------------------------------===//
211
212/// Lattice holding the LayoutInfo for each value.
213struct LayoutInfoLattice : public Lattice<LayoutInfo> {
215 using Lattice::Lattice;
216};
217
218/// Helper Functions to get default layouts. A `default layout` is a layout that
219/// is assigned to a value when the layout is not fixed by some anchor operation
220/// (like DPAS).
221
222/// Helper Function to get the default layout for uniform values like constants.
223/// For 1D vector, lane_layout is [subgroupSize] and lane_data is [1].
224/// For 2D vector, lane_layout is [1, subgroupSize] and lane_data is [1, 1].
225static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
226 unsigned rank,
227 const xegpu::uArch::uArch *uArch) {
228 assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
229 if (rank == 1) {
230 return LayoutInfo(
231 xegpu::LayoutAttr::get(ctx, {uArch->getSubgroupSize()}, {1}));
232 }
233 return LayoutInfo(
234 xegpu::LayoutAttr::get(ctx, {1, uArch->getSubgroupSize()}, {1, 1}));
235}
236
237static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
238 unsigned rank, int subgroupSize) {
239 assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
240 if (rank == 1) {
241 return LayoutInfo(xegpu::LayoutAttr::get(ctx, {subgroupSize}, {1}));
242 }
243 return LayoutInfo(xegpu::LayoutAttr::get(ctx, {1, subgroupSize}, {1, 1}));
244}
245
246/// Helper to get the default layout for a vector type.
247static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy,
249 unsigned packingSize,
250 bool isScattered = false) {
251 // Expecting a 1D or 2D vector.
252 assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
253 "Expected 1D or 2D vector.");
254 // Expecting int or float element type.
255 assert(vectorTy.getElementType().isIntOrFloat() &&
256 "Expected int or float element type.");
257 // If the rank is 1, then return default layout for 1D vector.
258 if (vectorTy.getRank() == 1)
259 return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1, uArch);
260 // Packing factor is determined by the element type bitwidth.
261 unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
262 int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
263 if (isScattered) {
264 return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(),
265 {uArch->getSubgroupSize(), 1},
266 {1, packingFactor}));
267 }
268 return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(),
269 {1, uArch->getSubgroupSize()},
270 {1, packingFactor}));
272
273/// Helper to get the default layout for a vector type.
274static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy,
276 unsigned packingSize,
277 bool isScattered = false) {
278 // Expecting a 1D or 2D vector.
279 assert((tdescTy.getRank() == 1 || tdescTy.getRank() == 2) &&
280 "Expected 1D or 2D TensorDesc.");
281 // Expecting int or float element type.
282 assert(tdescTy.getElementType().isIntOrFloat() &&
283 "Expected int or float element type.");
284 // If the rank is 1, then return default layout for 1D vector.
285 if (tdescTy.getRank() == 1)
286 return getDefaultSIMTLayoutInfo(tdescTy.getContext(), 1, uArch);
287 // Packing factor is determined by the element type bitwidth.
288 unsigned bitwidth = tdescTy.getElementType().getIntOrFloatBitWidth();
289 int subgroupSize = uArch->getSubgroupSize();
290 int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
291 if (isScattered) {
292 return LayoutInfo(xegpu::LayoutAttr::get(
293 tdescTy.getContext(), {subgroupSize, 1}, {1, packingFactor}));
295
296 return LayoutInfo(xegpu::LayoutAttr::get(
297 tdescTy.getContext(), {1, subgroupSize}, {1, packingFactor}));
298}
299
300/// Helper Function to get the expected layouts for DPAS operands. `lane_data`
301/// is set according to the following criteria:
302/// * For A operand, the data must be packed in minimum
303/// `packedSizeInBitsForDefault`
304/// * For B operand, the data must be packed in minimum
305/// `packedSizeInBitsForDpasB`
306static LayoutInfo
307getSIMTLayoutInfoForDPASOperand(VectorType vectorTy, unsigned operandNum,
309 unsigned packingSize) {
310 Type elementTy = vectorTy.getElementType();
311 assert(elementTy.isIntOrFloat() &&
312 "Expected int or float type in DPAS operands");
313 SmallVector<int32_t, 2> layout({1, uArch->getSubgroupSize()});
314 // For B operand, data must be packed in minimum `packedDpasBSizeInBits` and
315 // must have the VNNI format.
316 if (operandNum == 1 && elementTy.getIntOrFloatBitWidth() < packingSize) {
318 {static_cast<int32_t>(packingSize / elementTy.getIntOrFloatBitWidth()),
319 1});
320 return LayoutInfo(
321 xegpu::LayoutAttr::get(vectorTy.getContext(), layout, data));
322 }
323 // Otherwise, return the default layout for the vector type.
324 return getDefaultSIMTLayoutInfo(vectorTy, uArch, packingSize);
325}
326
327//===----------------------------------------------------------------------===//
328// LayoutInfoPropagation
329//===----------------------------------------------------------------------===//
330
331/// Backward data flow analysis to propagate the lane_layout and lane_data of
332/// each value in the program. Currently, the layouts for operands DPAS,
333/// StoreNd, and StoreScatter are fixed (known before propagation). Purpose of
334/// this analysis is to propagate those known layouts to all their producers and
335/// (other) consumers.
336class LayoutInfoPropagation
337 : public SparseBackwardDataFlowAnalysis<LayoutInfoLattice> {
338private:
339 LayoutKind layoutKind;
340 void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
342
343 void visitStoreNdOp(xegpu::StoreNdOp store,
346
347 void visitStoreScatterOp(xegpu::StoreScatterOp storeScatter,
350
351 void visitLoadNdOp(xegpu::LoadNdOp load,
354
355 void visitLoadGatherOp(xegpu::LoadGatherOp load,
358
359 void visitTransposeOp(vector::TransposeOp transpose,
362
363 void visitVectorBitcastOp(vector::BitCastOp bitcast,
366
367 void visitCreateDescOp(xegpu::CreateDescOp createDesc,
370
371 void visitUpdateNdOffsetOp(xegpu::UpdateNdOffsetOp updateNdOffset,
374
375 void visitPrefetchNdOp(xegpu::PrefetchNdOp prefetch,
378
379 void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction,
382
383 void visitVectorBroadCastOp(vector::BroadcastOp broadcast,
386 void visitShapeCastOp(vector::ShapeCastOp shapeCast,
389
390 bool hasParamsOfLayoutKind(xegpu::DistributeLayoutAttr anchorLayout);
391
392public:
393 LayoutInfoPropagation(DataFlowSolver &solver,
394 SymbolTableCollection &symbolTable,
395 LayoutKind layoutKind)
396 : SparseBackwardDataFlowAnalysis(solver, symbolTable),
397 layoutKind(layoutKind) {}
399
400 LogicalResult
401 visitOperation(Operation *op, ArrayRef<LayoutInfoLattice *> operands,
402 ArrayRef<const LayoutInfoLattice *> results) override;
403
404 void visitBranchOperand(OpOperand &operand) override {};
405
406 void visitCallOperand(OpOperand &operand) override {};
407
408 void visitExternalCall(CallOpInterface call,
409 ArrayRef<LayoutInfoLattice *> operands,
410 ArrayRef<const LayoutInfoLattice *> results) override {
411 };
412
413 void setToExitState(LayoutInfoLattice *lattice) override {
414 (void)lattice->meet(LayoutInfo());
415 }
416};
417} // namespace
418
419LogicalResult LayoutInfoPropagation::visitOperation(
420 Operation *op, ArrayRef<LayoutInfoLattice *> operands,
421 ArrayRef<const LayoutInfoLattice *> results) {
423 .Case<xegpu::DpasOp>(
424 [&](auto dpasOp) { visitDpasOp(dpasOp, operands, results); })
425 .Case<xegpu::StoreNdOp>(
426 [&](auto storeNdOp) { visitStoreNdOp(storeNdOp, operands, results); })
427 .Case<xegpu::StoreScatterOp>([&](auto storeScatterOp) {
428 visitStoreScatterOp(storeScatterOp, operands, results);
429 })
430 .Case<xegpu::LoadNdOp>(
431 [&](auto loadNdOp) { visitLoadNdOp(loadNdOp, operands, results); })
432 .Case<xegpu::LoadGatherOp>([&](auto loadGatherOp) {
433 visitLoadGatherOp(loadGatherOp, operands, results);
434 })
435 .Case<xegpu::CreateDescOp>([&](auto createDescOp) {
436 visitCreateDescOp(createDescOp, operands, results);
437 })
438 .Case<xegpu::UpdateNdOffsetOp>([&](auto updateNdOffsetOp) {
439 visitUpdateNdOffsetOp(updateNdOffsetOp, operands, results);
440 })
441 .Case<xegpu::PrefetchNdOp>([&](auto prefetchNdOp) {
442 visitPrefetchNdOp(prefetchNdOp, operands, results);
443 })
444 .Case<vector::TransposeOp>([&](auto transposeOp) {
445 visitTransposeOp(transposeOp, operands, results);
446 })
447 .Case<vector::BitCastOp>([&](auto bitcastOp) {
448 visitVectorBitcastOp(bitcastOp, operands, results);
449 })
450 .Case<vector::MultiDimReductionOp>([&](auto reductionOp) {
451 visitVectorMultiReductionOp(reductionOp, operands, results);
452 })
453 .Case<vector::BroadcastOp>([&](auto broadcastOp) {
454 visitVectorBroadCastOp(broadcastOp, operands, results);
455 })
456 .Case<vector::ShapeCastOp>([&](auto shapeCastOp) {
457 visitShapeCastOp(shapeCastOp, operands, results);
458 })
459 // All other ops.
460 .Default([&](Operation *op) {
461 for (const LayoutInfoLattice *resultInfo : results) {
462 if (!resultInfo->getValue().isAssigned())
463 continue;
464 for (auto [operandInfo, operand] :
465 llvm::zip(operands, op->getOpOperands())) {
466 // If the operand type is not a vector or tensor descriptor, skip
467 // it.
468 if (!isa<xegpu::TensorDescType, VectorType>(
469 operand.get().getType()))
470 continue;
471 // Propagate the result layout to the operand.
472 meet(operandInfo, *resultInfo);
473 }
474 }
475 });
476
477 return success();
478}
479
480bool LayoutInfoPropagation::hasParamsOfLayoutKind(
481 xegpu::DistributeLayoutAttr anchorLayout) {
482 if (anchorLayout == nullptr) {
483 return false;
484 }
485 if (layoutKind == LayoutKind::InstData) {
486 return !(anchorLayout.getEffectiveInstDataAsInt().empty());
487 } else if (layoutKind == LayoutKind::Lane) {
488 return !(anchorLayout.getEffectiveLaneLayoutAsInt().empty() ||
489 anchorLayout.getEffectiveLaneDataAsInt().empty());
490 }
491 return false;
492}
493
494void LayoutInfoPropagation::visitPrefetchNdOp(
495 xegpu::PrefetchNdOp prefetch, ArrayRef<LayoutInfoLattice *> operands,
496 ArrayRef<const LayoutInfoLattice *> results) {
497
498 LayoutInfo prefetchLayout;
499 xegpu::DistributeLayoutAttr anchorLayout = prefetch.getLayoutAttr();
500 if (hasParamsOfLayoutKind(anchorLayout)) {
501 prefetchLayout = LayoutInfo(anchorLayout);
502 } else {
503 // Here we assign the default layout to the tensor descriptor operand of
504 // prefetch.
505 auto tdescTy = prefetch.getTensorDescType();
506
507 auto uArch = getUArch(getChipStr(prefetch).value_or(""));
508 const auto *uArchInstruction =
509 dyn_cast<xegpu::uArch::Subgroup2DBlockPrefetchInstruction>(
510 uArch->getInstruction(
511 xegpu::uArch::InstructionKind::Subgroup2DBlockPrefetch));
512
513 auto blockWHC =
514 uArchInstruction->getBlockWidthHeightCount(tdescTy.getElementType());
515 if (!blockWHC)
516 prefetch.emitWarning("No known block params found for the element type.");
517 auto [bWidth, bHeight, bCount] = blockWHC.value();
518 SmallVector<int> instData;
519 int instWidth = xegpu::getLargestDivisor(
520 static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 1)), bWidth);
521 if (instWidth == -1)
522 prefetch.emitWarning(
523 "No suitable instruction multiple found for the given shape.");
524 if (tdescTy.getRank() == 1)
525 instData = {instWidth};
526 else {
527 int instHeight = xegpu::getLargestDivisor(
528 static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 2)), bHeight);
529 if (instHeight == -1)
530 prefetch.emitWarning(
531 "No suitable instruction multiple found for the given shape.");
532 instData = {instHeight, instWidth};
533 }
534
535 if (layoutKind == LayoutKind::InstData)
536 prefetchLayout =
537 LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(), instData));
538 else
539 prefetchLayout = getDefaultSIMTLayoutInfo(
540 tdescTy, uArch, uArchInstruction->getPackedFormatBitSize());
541
542 prefetch.setLayoutAttr(
543 dyn_cast<xegpu::DistributeLayoutAttr>(prefetchLayout.get()));
544 }
545 // Propagate the layout to the source tensor descriptor.
546 propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout));
547}
548
549void LayoutInfoPropagation::visitVectorMultiReductionOp(
550 vector::MultiDimReductionOp reduction,
551 ArrayRef<LayoutInfoLattice *> operands,
552 ArrayRef<const LayoutInfoLattice *> results) {
553 // The layout of the result must be present.
554 LayoutInfo resultLayout = results[0]->getValue();
555 if (!resultLayout.isAssigned())
556 return;
557 // We only consider 2D -> 1D reductions at this point.
558 VectorType resultTy = llvm::dyn_cast<VectorType>(reduction.getDestType());
559 if (!resultTy || resultTy.getRank() != 1) {
560 reduction.emitWarning("Expecting output type to be 1D vector.");
561 return;
562 }
563 auto uArch = getUArch(xegpu::getChipStr(reduction).value_or(""));
564 // Given that the result is 1D, the layout of the operand should be 2D with
565 // default layout.
566 LayoutInfo operandLayout = getDefaultSIMTLayoutInfo(
567 reduction->getContext(), 2, uArch->getSubgroupSize());
568 propagateIfChanged(operands[0], operands[0]->meet(operandLayout));
569 // Accumulator should have the same layout as the result.
570 propagateIfChanged(operands[1], operands[1]->meet(resultLayout));
571}
572
573void LayoutInfoPropagation::visitVectorBroadCastOp(
574 vector::BroadcastOp broadcast, ArrayRef<LayoutInfoLattice *> operands,
575 ArrayRef<const LayoutInfoLattice *> results) {
576 // The layout of the result must be present.
577 LayoutInfo resultLayout = results[0]->getValue();
578 if (!resultLayout.isAssigned())
579 return;
580 // Only consider vector to vector broadcasts for now.
581 VectorType resultTy = broadcast.getResultVectorType();
582 VectorType sourceTy = dyn_cast<VectorType>(broadcast.getSourceType());
583 // skip layout propagation for non-vector source operand.
584 if (!sourceTy)
585 return;
586
587 // Hanlding broadcast from low-rank to high-rank (e.g., 1D to 2D) case.
588 if (sourceTy.getRank() != resultTy.getRank()) {
589 auto sourceDims = sourceTy.getShape();
590 auto resultDims = resultTy.getShape();
591 SmallVector<int64_t> bcastDims;
592 auto dimDiff = resultTy.getRank() - sourceTy.getRank();
593 // adding the missing leading dims
594 for (int i = 0; i < dimDiff; i++)
595 bcastDims.push_back(i);
596
597 // for the rest dims in the resultTy, if sourceTy dim is 1, then it's
598 // broadcasted dim
599 for (size_t i = 0; i < sourceDims.size(); i++)
600 if ((sourceDims[i] == 1) && (resultDims[i + dimDiff] != 1))
601 bcastDims.push_back(i + dimDiff);
602
603 // create a slice layout for the source
604 xegpu::SliceAttr sliceLayout = xegpu::SliceAttr::get(
606 cast<xegpu::DistributeLayoutAttr>(resultLayout.get()),
608
609 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(sliceLayout)));
610 return;
611 }
612
613 SetVector<int64_t> broadcastUnitDims = broadcast.computeBroadcastedUnitDims();
614 resultLayout = cast<xegpu::DistributeLayoutAttr>(resultLayout.get())
615 .setUnitDimData(broadcastUnitDims);
616 propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
617}
618
619void LayoutInfoPropagation::visitShapeCastOp(
620 vector::ShapeCastOp shapeCast, ArrayRef<LayoutInfoLattice *> operands,
621 ArrayRef<const LayoutInfoLattice *> results) {
622 // The layout of the result must be present.
623 LayoutInfo resultLayout = results[0]->getValue();
624 if (!resultLayout.isAssigned())
625 return;
626 VectorType sourceTy = shapeCast.getSourceVectorType();
627 VectorType resultTy = shapeCast.getResultVectorType();
628 // Shape cast layout propagation only supports 1D -> 2D shape casts.
629 // TODO: Support kD -> nD shape casts (k < n, n >= 2) where expanded dims are
630 // unit dimensions and non-unit dims match.
631 if (sourceTy.getRank() != 1 || resultTy.getRank() != 2) {
632 shapeCast.emitWarning("Expecting shape cast to be 1D -> 2D.");
633 return;
634 }
635 int64_t slicedDim = resultTy.getShape()[0] == 1 ? 0 : 1;
636 xegpu::SliceAttr sliceLayout = xegpu::SliceAttr::get(
637 shapeCast->getContext(), cast<xegpu::LayoutAttr>(resultLayout.get()),
638 DenseI64ArrayAttr::get(shapeCast->getContext(), {slicedDim}));
639 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(sliceLayout)));
640}
641
642/// Propagate the layout of the result tensor to the source tensor descriptor
643/// in UpdateNdOffsetOp.
644void LayoutInfoPropagation::visitUpdateNdOffsetOp(
645 xegpu::UpdateNdOffsetOp updateNdOffset,
646 ArrayRef<LayoutInfoLattice *> operands,
647 ArrayRef<const LayoutInfoLattice *> results) {
648 // The layout of the result must be present.
649 LayoutInfo resultLayout = results[0]->getValue();
650 if (!resultLayout.isAssigned())
651 return;
652 // Propagate the layout to the source operand.
653 propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
654}
655
656/// Set the layouts for DPAS A, B, and C operands.
657void LayoutInfoPropagation::visitDpasOp(
658 xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
659 ArrayRef<const LayoutInfoLattice *> results) {
660
661 LayoutInfo dpasALayout;
662 LayoutInfo dpasBLayout;
663 LayoutInfo dpasCDLayout;
664
665 xegpu::DistributeLayoutAttr anchorLayoutCD = dpas.getLayoutCdAttr();
666 if (hasParamsOfLayoutKind(anchorLayoutCD)) {
667 xegpu::DistributeLayoutAttr anchorLayoutA = dpas.getLayoutAAttr();
668 xegpu::DistributeLayoutAttr anchorLayoutB = dpas.getLayoutBAttr();
669 assert(hasParamsOfLayoutKind(anchorLayoutA) &&
670 "Expected anchor layout for DPAS A operand.");
671 assert(hasParamsOfLayoutKind(anchorLayoutB) &&
672 "Expected anchor layout for DPAS B operand.");
673 dpasALayout = LayoutInfo(anchorLayoutA);
674 dpasBLayout = LayoutInfo(anchorLayoutB);
675 dpasCDLayout = LayoutInfo(anchorLayoutCD);
676
677 } else {
678
679 VectorType aTy = dpas.getLhsType();
680 VectorType bTy = dpas.getRhsType();
681
682 auto uArch = getUArch(getChipStr(dpas).value_or(""));
683 const int subgroupSize = uArch->getSubgroupSize();
684 const auto *uArchInstruction =
685 dyn_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc>(uArch->getInstruction(
686 xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc));
687
688 const unsigned dataALen = aTy.getShape().front();
689 auto supportedALen = uArchInstruction->getSupportedM(aTy.getElementType());
690 const int maxALen =
691 xegpu::getLargestDivisor(dataALen, ArrayRef<unsigned>(supportedALen));
692 if (maxALen == -1)
693 dpas.emitWarning(
694 "No suitable instruction multiple found for the given shape.");
695
696 const unsigned dataBLen = bTy.getShape().back();
697 auto supportedBLen = uArchInstruction->getSupportedN(bTy.getElementType());
698
699 const int maxBLen =
700 xegpu::getLargestDivisor(dataBLen, ArrayRef<unsigned>(supportedBLen));
701
702 if (maxBLen == -1)
703 dpas.emitWarning(
704 "No suitable instruction multiple found for the given shape.");
705 SmallVector<int> instDataA = {maxALen, subgroupSize};
706 SmallVector<int> instDataB = {subgroupSize, maxBLen};
707
708 if (layoutKind == LayoutKind::InstData) {
709 dpasALayout =
710 LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataA));
711 dpasBLayout =
712 LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataB));
713 } else {
714 dpasALayout = getSIMTLayoutInfoForDPASOperand(
715 aTy, 0, uArch, uArchInstruction->getPackedFormatBitSizeA());
716 dpasBLayout = getSIMTLayoutInfoForDPASOperand(
717 bTy, 1, uArch, uArchInstruction->getPackedFormatBitSizeB());
718 }
719
720 if (operands.size() > 2) {
721 VectorType cTy = dpas.getAccType();
722 if (layoutKind == LayoutKind::InstData) {
723 const unsigned dataCLen = bTy.getShape().back();
724 auto supportedCLen =
725 uArchInstruction->getSupportedN(bTy.getElementType());
726 const int maxCLen = xegpu::getLargestDivisor(
727 dataCLen, ArrayRef<unsigned>(supportedCLen));
728 if (maxCLen == -1)
729 dpas.emitWarning(
730 "No suitable instruction multiple found for the given shape.");
731 SmallVector<int> instDataC = {maxALen, maxCLen};
732 dpasCDLayout =
733 LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataC));
734 } else
735 dpasCDLayout = getSIMTLayoutInfoForDPASOperand(
736 cTy, 2, uArch, uArchInstruction->getPackedFormatBitSizeB());
737
738 dpas.setLayoutCdAttr(
739 dyn_cast<xegpu::DistributeLayoutAttr>(dpasCDLayout.get()));
740 }
741 dpas.setLayoutAAttr(
742 dyn_cast<xegpu::DistributeLayoutAttr>(dpasALayout.get()));
743 dpas.setLayoutBAttr(
744 dyn_cast<xegpu::DistributeLayoutAttr>(dpasBLayout.get()));
745 }
746
747 propagateIfChanged(operands[0], operands[0]->meet(dpasALayout));
748 propagateIfChanged(operands[1], operands[1]->meet(dpasBLayout));
749 if (operands.size() > 2) {
750 propagateIfChanged(operands[2], operands[2]->meet(dpasCDLayout));
751 }
752}
753
754/// Set the layout for the value and tensor descriptor operands in StoreNdOp.
755void LayoutInfoPropagation::visitStoreNdOp(
756 xegpu::StoreNdOp store, ArrayRef<LayoutInfoLattice *> operands,
757 ArrayRef<const LayoutInfoLattice *> results) {
758
759 LayoutInfo storeLayout;
760 xegpu::DistributeLayoutAttr anchorLayout = store.getLayoutAttr();
761 if (hasParamsOfLayoutKind(anchorLayout)) {
762 storeLayout = LayoutInfo(anchorLayout);
763 } else {
764 auto uArch = getUArch(getChipStr(store).value_or(""));
765 const auto *uArchInstruction =
766 dyn_cast<xegpu::uArch::Subgroup2DBlockStoreInstruction>(
767 uArch->getInstruction(
768 xegpu::uArch::InstructionKind::Subgroup2DBlockStore));
769 VectorType dataTy = store.getValueType();
770 auto blockWHC = uArchInstruction->getBlockWidthHeightCount(
771 store.getValueType().getElementType());
772 if (!blockWHC)
773 store.emitWarning("No known block params found for the element type.");
774 auto [bWidth, bHeight, bCount] = blockWHC.value();
775 SmallVector<int> instData;
776 int instWidth = xegpu::getLargestDivisor(
777 static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 1)), bWidth);
778 if (instWidth == -1)
779 store.emitWarning(
780 "No suitable instruction multiple found for the given shape.");
781 if (dataTy.getRank() == 1)
782 instData = {instWidth};
783 else {
784 int instHeight = xegpu::getLargestDivisor(
785 static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 2)), bHeight);
786 if (instHeight == -1)
787 store.emitWarning(
788 "No suitable instruction multiple found for the given shape.");
789 instData = {instHeight, instWidth};
790 }
791
792 if (layoutKind == LayoutKind::InstData)
793 storeLayout =
794 LayoutInfo(xegpu::LayoutAttr::get(dataTy.getContext(), instData));
795 else
796 storeLayout =
797 getDefaultSIMTLayoutInfo(store.getValueType(), uArch,
798 uArchInstruction->getPackedFormatBitSize());
799 store.setLayoutAttr(
800 dyn_cast<xegpu::DistributeLayoutAttr>(storeLayout.get()));
801 }
802 // Propagate the layout to the value operand.
803 // Both operands should have the same layout
804 for (LayoutInfoLattice *operand : operands)
805 propagateIfChanged(operand, operand->meet(storeLayout));
806}
807
808/// Propagate the layout of the value to the tensor descriptor operand in
809/// LoadNdOp.
810void LayoutInfoPropagation::visitLoadNdOp(
811 xegpu::LoadNdOp load, ArrayRef<LayoutInfoLattice *> operands,
812 ArrayRef<const LayoutInfoLattice *> results) {
813
814 LayoutInfo loadLayout;
815 xegpu::DistributeLayoutAttr anchorLayout = load.getLayoutAttr();
816 if (hasParamsOfLayoutKind(anchorLayout)) {
817 loadLayout = LayoutInfo(anchorLayout);
818 } else {
819
820 LayoutInfo valueLayout = results[0]->getValue();
821 // Need the layout of the value to propagate to the tensor descriptor.
822 if (!valueLayout.isAssigned())
823 return;
824 loadLayout = valueLayout;
825 // LoadNdOp has the transpose effect. However, at the stage of this analysis
826 // this effect is not expected and should be abstracted away. Emit a
827 // warning.
828 if (auto transpose = load.getTranspose()) {
829 load.emitWarning("Transpose effect is not expected for LoadNdOp at "
830 "LayoutInfoPropagation stage.");
831 loadLayout = valueLayout.transpose(transpose.value());
832 }
833 load.setLayoutAttr(dyn_cast<xegpu::DistributeLayoutAttr>(loadLayout.get()));
834 }
835 // Propagate the new layout to the tensor descriptor operand.
836 propagateIfChanged(operands[0], operands[0]->meet(loadLayout));
837}
838
839/// For vector::TransposeOp, the layout of the result is transposed and
840/// propagated to the operand.
841void LayoutInfoPropagation::visitTransposeOp(
842 vector::TransposeOp transpose, ArrayRef<LayoutInfoLattice *> operands,
843 ArrayRef<const LayoutInfoLattice *> results) {
844 // Need the layout of transpose result to propagate to the operands.
845 LayoutInfo resultLayout = results[0]->getValue();
846 if (!resultLayout.isAssigned())
847 return;
848 LayoutInfo newLayout = resultLayout.transpose(transpose.getPermutation());
849 // Propagate the new layout to the vector operand.
850 propagateIfChanged(operands[0], operands[0]->meet(newLayout));
851}
852
853/// For vector::BitCastOp, the lane_data of the source layout is changed based
854/// on the bit width of the source and result types.
855void LayoutInfoPropagation::visitVectorBitcastOp(
856 vector::BitCastOp bitcast, ArrayRef<LayoutInfoLattice *> operands,
857 ArrayRef<const LayoutInfoLattice *> results) {
858 // Need the layout of bitcast result to propagate to the operands.
859 LayoutInfo resultLayout = results[0]->getValue();
860 if (!resultLayout.isAssigned())
861 return;
862 int inElemTyBitWidth =
863 bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
864 int outElemTyBitWidth =
865 bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
866 // If the element bit widths are the same, then the layout does not change.
867 if (inElemTyBitWidth == outElemTyBitWidth) {
868 propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
869 return;
870 }
871 // Check if the result layout is valid. i.e. result vector can be distributed.
872 auto resultLaneLayout = resultLayout.getLaneLayout();
873 auto resultLaneData = resultLayout.getLaneData();
875 bitcast.getResultVectorType(),
876 xegpu::LayoutAttr::get(bitcast->getContext(), resultLaneLayout,
877 resultLaneData)))) {
878 bitcast.emitWarning(
879 "Result vector type can not be evenly distributed across lanes.");
880 return;
881 }
882 int64_t rank = bitcast.getSourceVectorType().getRank();
883 // Bitcast is a `narrowing` if the input element type bit width larger than
884 // the output element type bit width. eg. f32 -> f16 is a narrowing bitcast.
885 bool isNarrowing = inElemTyBitWidth > outElemTyBitWidth;
886 int bitCastRatio = isNarrowing ? inElemTyBitWidth / outElemTyBitWidth
887 : outElemTyBitWidth / inElemTyBitWidth;
888 SmallVector<int> sourceLaneLayout =
889 resultLayout.getLaneLayout(); // Lane layout does not change for bitcast.
890 SmallVector<int> outData = resultLayout.getLaneData();
891
892 // TODO: Currently we assume that bitcasts does not require cross lane
893 // communication. So each lane must own the required number of elements to
894 // perform the bitcast locally without cross-lane communication.
895 int outInnerBitsPerLane = outData[rank - 1] * outElemTyBitWidth;
896 if (outInnerBitsPerLane < inElemTyBitWidth) {
897 bitcast.emitWarning(
898 "Narrowing bitcast with cross lane communication is not supported.");
899 return;
900 }
901 // Check if each lane owns a single element in all dimensions except the
902 // innermost dimension.
903 SmallVector<int> sourceLaneData(outData.begin(), outData.end() - 1);
904 if (llvm::any_of(sourceLaneData, [](int64_t d) { return d != 1; })) {
905 bitcast.emitWarning("Each lane must not own multiple elements in any "
906 "dimension other than "
907 "the innermost dimension.");
908 return;
909 }
910 // Decide lane data based on whether the bitcast is narrowing or widening.
911 int64_t innerMostLaneData = isNarrowing ? outData[rank - 1] / bitCastRatio
912 : outData[rank - 1] * bitCastRatio;
913 sourceLaneData.push_back(innerMostLaneData);
914
915 propagateIfChanged(
916 operands[0],
917 operands[0]->meet(LayoutInfo(xegpu::LayoutAttr::get(
918 bitcast->getContext(), sourceLaneLayout, sourceLaneData))));
919}
920
921/// Propagate the layout of the result to the tensor descriptor, mask and offset
922/// operands in LoadGatherOp.
923void LayoutInfoPropagation::visitLoadGatherOp(
924 xegpu::LoadGatherOp load, ArrayRef<LayoutInfoLattice *> operands,
925 ArrayRef<const LayoutInfoLattice *> results) {
926
927 LayoutInfo loadLayout;
928 LayoutInfo maskLayout;
929 xegpu::DistributeLayoutAttr anchorLayout = load.getLayoutAttr();
930 if (hasParamsOfLayoutKind(anchorLayout)) {
931 loadLayout = LayoutInfo(anchorLayout);
932 maskLayout = loadLayout;
933 } else {
934
935 // The layout is strictly determined by the payload type.
936 VectorType payloadTy = load.getValueType();
937 if (!payloadTy) {
938 load.emitWarning("Not propagating, non-vector payload supplied.");
939 return;
940 }
941 auto uArch = getUArch(getChipStr(load).value_or(""));
942 const int subgroupSize = uArch->getSubgroupSize();
943 SmallVector<int> instData{subgroupSize};
944 if (auto chunkSize = load.getChunkSize().value_or(0); chunkSize > 1)
945 instData.push_back(chunkSize);
946 else if (auto srcTdescTy =
947 dyn_cast<xegpu::TensorDescType>(load.getSourceType())) {
948 if (srcTdescTy.getChunkSizeAsInt() > 1)
949 instData.push_back(chunkSize);
950 }
951
952 if (layoutKind == LayoutKind::InstData)
953 loadLayout =
954 LayoutInfo(xegpu::LayoutAttr::get(load.getContext(), instData));
955 else
956 loadLayout = getDefaultSIMTLayoutInfo(
957 payloadTy, uArch, uArch->getGeneralPackedFormatBitSize(),
958 /*scattered*/ true);
959
960 // Mask operand should have 1D default layout.
961 maskLayout = getDefaultSIMTLayoutInfo(load->getContext(), 1, subgroupSize);
962
963 load.setLayoutAttr(dyn_cast<xegpu::DistributeLayoutAttr>(loadLayout.get()));
964 }
965 // Propagate the new layout to the tensor descriptor operand.
966 if (isa<xegpu::TensorDescType>(load.getSourceType()))
967 propagateIfChanged(operands[0], operands[0]->meet(loadLayout));
968 // Propagate the new layout to the mask and optional offset operand.
969 propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
970 if (load.getOffsets())
971 propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
972}
973
974/// Propagate the layout of the descriptor to the vector offset operand in
975/// CreateDescOp.
976void LayoutInfoPropagation::visitCreateDescOp(
977 xegpu::CreateDescOp createDesc, ArrayRef<LayoutInfoLattice *> operands,
978 ArrayRef<const LayoutInfoLattice *> results) {
979 LayoutInfo descLayout = results[0]->getValue();
980 // Need the layout of the descriptor to propagate to the operands.
981 if (!descLayout.isAssigned())
982 return;
983 auto uArch = getUArch(getChipStr(createDesc).value_or(""));
984 // For offset operand propagate 1D default layout.
985 LayoutInfo layout = getDefaultSIMTLayoutInfo(createDesc->getContext(), 1,
986 uArch->getSubgroupSize());
987 propagateIfChanged(operands[1], operands[1]->meet(layout));
988}
989
990/// Set the layout for the value, tensor descriptor, offset and mask operands in
991/// the StoreScatterOp.
992void LayoutInfoPropagation::visitStoreScatterOp(
993 xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands,
994 ArrayRef<const LayoutInfoLattice *> results) {
995
996 LayoutInfo payloadLayout;
997 LayoutInfo maskLayout;
998 xegpu::DistributeLayoutAttr anchorLayout = storeScatter.getLayoutAttr();
999 if (hasParamsOfLayoutKind(anchorLayout)) {
1000 payloadLayout = LayoutInfo(anchorLayout);
1001 maskLayout = payloadLayout;
1002 } else {
1003 // Currently, for 2D StoreScatterOp we expect that the height dimension of
1004 // the tensor descriptor is equal to the subgroup size. This is ensured by
1005 // the op verifier.
1006 VectorType payloadTy = storeScatter.getValueType();
1007 if (!payloadTy) {
1008 storeScatter.emitWarning("Not propagating, non-vector payload supplied.");
1009 return;
1010 }
1011
1012 auto uArch = getUArch(getChipStr(storeScatter).value_or(""));
1013 const int subgroupSize = uArch->getSubgroupSize();
1014
1015 if (layoutKind == LayoutKind::InstData) {
1016 SmallVector<int> instData{subgroupSize};
1017 if (auto chunkSize = storeScatter.getChunkSize().value_or(0);
1018 chunkSize > 1)
1019 instData.push_back(chunkSize);
1020 else if (auto dstTdescTy = dyn_cast<xegpu::TensorDescType>(
1021 storeScatter.getDestType())) {
1022 if (dstTdescTy.getChunkSizeAsInt() > 1)
1023 instData.push_back(chunkSize);
1024 }
1025 payloadLayout = LayoutInfo(
1026 xegpu::LayoutAttr::get(storeScatter.getContext(), instData));
1027 } else {
1028 auto payloadShape = payloadTy.getShape();
1029 if (payloadShape.size() > 1)
1030 assert(payloadShape[0] == subgroupSize &&
1031 "Expected the first dimension of 2D tensor descriptor to be "
1032 "equal to "
1033 "subgroup size.");
1034 payloadLayout = getDefaultSIMTLayoutInfo(
1035 payloadTy, uArch, uArch->getGeneralPackedFormatBitSize(),
1036 /*scattered=*/true);
1037 }
1038
1039 maskLayout =
1040 getDefaultSIMTLayoutInfo(storeScatter->getContext(), 1, subgroupSize);
1041
1042 storeScatter.setLayoutAttr(
1043 dyn_cast<xegpu::DistributeLayoutAttr>(payloadLayout.get()));
1044 }
1045 // Propagate the payload operand layout
1046 propagateIfChanged(operands[0], operands[0]->meet(payloadLayout));
1047 // Propagate the destination (if tdesc) operand layout
1048 if (isa<xegpu::TensorDescType>(storeScatter.getDestType()))
1049 propagateIfChanged(operands[1], operands[1]->meet(payloadLayout));
1050 // Propagate the new layout to the mask and optional offset operand.
1051 propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
1052 if (storeScatter.getOffsets())
1053 propagateIfChanged(operands[3], operands[3]->meet(maskLayout));
1054}
1055
1056namespace {
1057//===----------------------------------------------------------------------===//
1058// RunLayoutInfoPropagation
1059//===----------------------------------------------------------------------===//
1060
1061/// Driver class for running the LayoutInfoPropagation analysis.
1062class RunLayoutInfoPropagation {
1063public:
1064 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RunLayoutInfoPropagation)
1065
1066 RunLayoutInfoPropagation(Operation *op, LayoutKind layoutKind) : target(op) {
1067 SymbolTableCollection symbolTable;
1068 loadBaselineAnalyses(solver);
1069 solver.load<LayoutInfoPropagation>(symbolTable, layoutKind);
1070 (void)solver.initializeAndRun(op);
1071 }
1072
1073 LayoutInfo getLayoutInfo(Value val);
1074
1075 void printAnalysisResult(llvm::raw_ostream &os);
1076
1077private:
1078 DataFlowSolver solver;
1079 const Operation *target;
1080};
1081} // namespace
1082
1083LayoutInfo RunLayoutInfoPropagation::getLayoutInfo(Value val) {
1084 auto *state = solver.lookupState<LayoutInfoLattice>(val);
1085 if (!state)
1086 return {};
1087 return state->getValue();
1088}
1089
1090// Print the analysis result for debugging purposes.
1091void RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) {
1092 auto printFunctionResult = [&](FunctionOpInterface funcOp) {
1093 os << "function: " << funcOp.getName() << ":\n";
1094 // Function arguments
1095 for (BlockArgument arg : funcOp.getArguments()) {
1096 LayoutInfo layout = getLayoutInfo(arg);
1097 os << "argument: " << arg << "\n";
1098 os << "layout : ";
1099 layout.print(os);
1100 os << "\n";
1101 }
1102 // Function ops
1103 funcOp.walk([&](Operation *op) {
1104 // Skip ops that do not have results
1105 if (op->getResults().empty())
1106 return;
1107 os << "op : ";
1108 // For control-flow ops, print the op name only.
1109 if (isa<BranchOpInterface>(op) || isa<RegionBranchOpInterface>(op))
1110 os << op->getName();
1111 else
1112 op->print(os);
1113 os << "\n";
1114 // Print the layout for each result.
1115 for (auto [i, r] : llvm::enumerate(op->getResults())) {
1116 LayoutInfo layout = getLayoutInfo(r);
1117 os << "layout for result #" << i << ": ";
1118 layout.print(os);
1119 os << "\n";
1120 }
1121 });
1122 };
1123
1124 SmallVector<FunctionOpInterface> funcOps;
1125 if (auto modOp = dyn_cast<ModuleOp>(target)) {
1126 for (auto funcOp : modOp.getOps<FunctionOpInterface>())
1127 funcOps.push_back(funcOp);
1128
1129 // Collect all GpuFuncOps in the module.
1130 for (auto gpuModOp : modOp.getOps<gpu::GPUModuleOp>()) {
1131 for (auto gpuFuncOp : gpuModOp.getOps<FunctionOpInterface>())
1132 funcOps.push_back(gpuFuncOp);
1133 }
1134 }
1135 // Print the analysis result for each function.
1136 for (FunctionOpInterface funcOp : funcOps)
1137 printFunctionResult(funcOp);
1138}
1139
1140using GetLayoutFnTy = function_ref<xegpu::DistributeLayoutAttr(Value)>;
1141/// Update an operation with the layout of its results. If the result type is
1142/// a vector type, a temporary layout attribute is added to the operation. If
1143/// the result type is a tensor descriptor type, the type is updated with the
1144/// layout attribute. The users of the result are also updated with the layout
1145/// attribute.
1146static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
1147 GetLayoutFnTy getLayoutOfValue) {
1148 // Region ops (like scf.for) are already handled by the
1149 // updateControlFlowOps.
1150 if (mlir::isa<mlir::RegionBranchOpInterface>(op))
1151 return success();
1152
1153 // Iterate over all the results.
1154 for (OpResult result : op->getResults()) {
1155 Type resultType = result.getType();
1156 // Layouts are needed only for vector and tensor descriptor types.
1157 if (!isa<VectorType, xegpu::TensorDescType>(resultType))
1158 continue;
1159 // If the result has no layout but has users, emit a warning and continue.
1160 xegpu::DistributeLayoutAttr layout = getLayoutOfValue(result);
1161 if (!layout && result.getNumUses() > 0) {
1162 op->emitWarning("op has users but no layout assigned for its result");
1163 continue;
1164 }
1165 // If the result is a tensor descriptor type, update the tensor desc type
1166 // with layout.
1167 if (auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType)) {
1168 auto typeWithLayout = xegpu::TensorDescType::get(
1169 tensorDescTy.getContext(), tensorDescTy.getShape(),
1170 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
1171 result.setType(typeWithLayout);
1172 continue;
1173 }
1174 // If the result is a vector type, add a temporary layout attribute to the
1175 // op.
1176 xegpu::setDistributeLayoutAttr(result, layout, /*respectPermLayout*/ true);
1177 }
1178 return success();
1179}
1180
1181/// Region ops like scf.for need special handling because they have blocks
1182/// inside. If the blocks have tensor descriptor type as block arguments,
1183/// thier types must be updated. Also region op can have results that may not
1184/// have any users (e.g. A and B tiles). They are not assigned a layout by
1185/// layout analysis because they have no users. However inside the region op
1186/// corresponding block arguments for these results do have layouts.
1187/// Therefore, in this case we still need to update the result types with the
1188/// layout attribute. This function function updates the internal block
1189/// arguments and the result types of the region op with the assigned layouts.
1190/// clang-format off
1191/// Example: scf.for ... iter_args(...) -> (out types) {
1192/// ^bb0(block types):
1193/// ...
1194/// scf.yield ... : (yield types)
1195/// }
1196/// clang-format on
1197/// In this example, at scf.yield, control-flow can transfer to two successor
1198/// regions. One is the ^bb0 (for loop body) and the other is the scf.for op
1199/// itself (yield the results). So we update both the block arguments of the
1200/// successor region (i.e. block types) and the result types of the scf.for op
1201/// (i.e. out types). Note that yield types are updated by respective
1202/// producers inside bb0.
1203static LogicalResult
1205 mlir::RegionBranchTerminatorOpInterface terminator,
1206 GetLayoutFnTy getLayoutOfValue) {
1207 // Only process if the terminator is inside a region branch op.
1208 if (!mlir::isa<mlir::RegionBranchOpInterface>(terminator->getParentOp()))
1209 return success();
1210
1212 llvm::SmallVector<mlir::Attribute> operands(terminator->getNumOperands(),
1213 nullptr);
1214 terminator.getSuccessorRegions(operands, successors);
1215
1216 for (mlir::RegionSuccessor &successor : successors) {
1217 mlir::OperandRange successorOperands =
1218 terminator.getSuccessorOperands(successor);
1219 mlir::ValueRange successorInputs = successor.getSuccessorInputs();
1220 for (auto [successorOperand, successorInput] :
1221 llvm::zip(successorOperands, successorInputs)) {
1222 Type inputType = successorInput.getType();
1223 // We only need to operate on tensor descriptor or vector types.
1224 if (!isa<xegpu::TensorDescType, VectorType>(inputType))
1225 continue;
1226 xegpu::DistributeLayoutAttr successorInputLayout =
1227 getLayoutOfValue(successorInput);
1228 xegpu::DistributeLayoutAttr successorOperandLayout =
1229 getLayoutOfValue(successorOperand);
1230
1231 // If either of the layouts is not assigned, we cannot proceed.
1232 if (!successorOperandLayout) {
1233 LLVM_DEBUG(DBGS() << "No layout assigned for forwarded operand in "
1234 "branch terminator: "
1235 << successorOperand << "\n");
1236 return failure();
1237 }
1238 // We expect the layouts to match.
1239 if (successorInputLayout &&
1240 successorInputLayout != successorOperandLayout) {
1241 LLVM_DEBUG(DBGS() << "Conflicting layouts for region argument and "
1242 "operand forwarded as the argument: "
1243 << successorInputLayout << " vs "
1244 << successorOperandLayout << "\n");
1245 return failure();
1246 }
1247 // Get tensor descriptor type with the layout.
1248 if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(inputType)) {
1249 auto newTdescTy = xegpu::TensorDescType::get(
1250 tdescTy.getContext(), tdescTy.getShape(), tdescTy.getElementType(),
1251 tdescTy.getEncoding(), successorOperandLayout);
1252 successorInput.setType(newTdescTy);
1253 continue;
1254 }
1255 // If the type is a vector type and this region argument is an OpResult,
1256 // set the layout attribute on the OpResult.
1257 if (auto result = dyn_cast<OpResult>(successorInput))
1258 xegpu::setDistributeLayoutAttr(result, successorOperandLayout);
1259 }
1260 }
1261 return success();
1262}
1263
1264/// Update the function arguments and results with the layouts.
1265static LogicalResult updateFunctionOpInterface(mlir::OpBuilder &builder,
1266 mlir::FunctionOpInterface funcOp,
1267 GetLayoutFnTy getLayoutOfValue) {
1268 SmallVector<Type> newArgTypes;
1269 // Update the function arguments.
1270 for (BlockArgument arg : funcOp.getArguments()) {
1271 Type argType = arg.getType();
1272 newArgTypes.push_back(argType);
1273 if (!isa<VectorType, xegpu::TensorDescType>(argType))
1274 continue;
1275 xegpu::DistributeLayoutAttr layout = getLayoutOfValue(arg);
1276 if (!layout) {
1277 LLVM_DEBUG(DBGS() << "Expecting layout for function argument: " << arg
1278 << " but got none.\n");
1279 return failure();
1280 }
1281 if (auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(argType)) {
1282 auto newTdescTy = xegpu::TensorDescType::get(
1283 tensorDescTy.getContext(), tensorDescTy.getShape(),
1284 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
1285 arg.setType(newTdescTy);
1286 newArgTypes.back() = newTdescTy;
1287 }
1288 }
1289 // Update the function type with the new argument types.
1290 // NOTE: We assume that function results are not expected to have layouts.
1291 funcOp.setType(FunctionType::get(funcOp.getContext(), newArgTypes,
1292 funcOp.getResultTypes()));
1293 return success();
1294}
1295
1296namespace {
1297struct XeGPUPropagateLayoutPass final
1298 : public xegpu::impl::XeGPUPropagateLayoutBase<XeGPUPropagateLayoutPass> {
1299 XeGPUPropagateLayoutPass() = default;
1300 XeGPUPropagateLayoutPass(const XeGPUPropagateLayoutPass &other) = default;
1301 XeGPUPropagateLayoutPass(xegpu::XeGPUPropagateLayoutOptions options)
1302 : XeGPUPropagateLayoutBase(options) {}
1303 void runOnOperation() override;
1304};
1305
1306} // namespace
1307
1308void XeGPUPropagateLayoutPass::runOnOperation() {
1309 LayoutKind layoutKind;
1310 if (this->layoutKind == "lane") {
1311 layoutKind = LayoutKind::Lane;
1312 } else if (this->layoutKind == "inst") {
1313 layoutKind = LayoutKind::InstData;
1314 } else {
1315 getOperation()->emitError("Unsupported layout kind option: " +
1316 this->layoutKind);
1317 signalPassFailure();
1318 return;
1319 }
1320 RunLayoutInfoPropagation analysis(getOperation(), layoutKind);
1321 // Print the analysis result and exit. (for debugging purposes)
1322 if (printOnly) {
1323 auto &os = llvm::outs();
1324 analysis.printAnalysisResult(os);
1325 return;
1326 }
1327 // Helper to convert LayoutInfo to xegpu::LayoutAttr.
1328 auto getXeGPULayoutForValue = [&](Value val) -> xegpu::DistributeLayoutAttr {
1329 LayoutInfo layout = analysis.getLayoutInfo(val);
1330 if (!layout.isAssigned())
1331 return {};
1332 xegpu::DistributeLayoutAttr layoutAttr =
1333 cast<xegpu::DistributeLayoutAttr>(layout.get());
1334 if (layout.isSliceLayout())
1335 return cast<xegpu::SliceAttr>(layoutAttr);
1336 return cast<xegpu::LayoutAttr>(layoutAttr);
1337 };
1338
1339 mlir::OpBuilder builder(&getContext());
1340 Operation *op = getOperation();
1341 auto walkResult = op->walk([&](mlir::Block *block) -> WalkResult {
1342 for (mlir::Operation &op : llvm::reverse(block->getOperations())) {
1343 LogicalResult r = success();
1345 .Case<mlir::RegionBranchTerminatorOpInterface>(
1346 [&](mlir::RegionBranchTerminatorOpInterface branchTermOp) {
1347 r = updateControlFlowOps(builder, branchTermOp,
1348 getXeGPULayoutForValue);
1349 })
1350 .Case<mlir::FunctionOpInterface>(
1351 [&](mlir::FunctionOpInterface funcOp) {
1352 r = updateFunctionOpInterface(builder, funcOp,
1353 getXeGPULayoutForValue);
1354 })
1355 .Default([&](Operation *op) {
1356 r = updateOp(builder, op, getXeGPULayoutForValue);
1357 });
1358 if (failed(r)) {
1359 op.emitError("Failed to update operation with the layout.");
1360 return WalkResult::interrupt();
1361 }
1362 }
1363 return WalkResult::advance();
1364 });
1365 if (walkResult.wasInterrupted()) {
1366 signalPassFailure();
1367 return;
1368 }
1369}
return success()
#define DBGS()
Definition Hoisting.cpp:32
lhs
b getContext())
auto load
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition TypeID.h:331
static LogicalResult updateControlFlowOps(mlir::OpBuilder &builder, mlir::RegionBranchTerminatorOpInterface terminator, GetLayoutFnTy getLayoutOfValue)
Region ops like scf.for need special handling because they have blocks inside.
function_ref< xegpu::DistributeLayoutAttr(Value)> GetLayoutFnTy
static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op, GetLayoutFnTy getLayoutOfValue)
Update an operation with the layout of its results.
static LogicalResult updateFunctionOpInterface(mlir::OpBuilder &builder, mlir::FunctionOpInterface funcOp, GetLayoutFnTy getLayoutOfValue)
Update the function arguments and results with the layouts.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:309
OpListType & getOperations()
Definition Block.h:137
The general data-flow analysis solver.
const StateT * lookupState(AnchorT anchor) const
Lookup an analysis state for the given lattice anchor.
AnalysisT * load(Args &&...args)
Load an analysis into the solver. Return the analysis instance.
LogicalResult initializeAndRun(Operation *top)
Initialize the children analyses starting from the provided top-level operation and run the analysis ...
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:207
This class represents an operand of an operation.
Definition Value.h:257
This is a value defined by a result of an operation.
Definition Value.h:457
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:383
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
void print(raw_ostream &os, const OpPrintingFlags &flags={})
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:797
result_range getResults()
Definition Operation.h:415
This class represents a successor of a region.
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition Value.h:108
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
This class represents a lattice holding a specific value of type ValueT.
A sparse (backward) data-flow analysis for propagating SSA value lattices backwards across the IR by ...
SparseBackwardDataFlowAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int64_t > content)
void loadBaselineAnalyses(DataFlowSolver &solver)
Populates a DataFlowSolver with analyses that are required to ensure user-defined analyses are run pr...
Definition Utils.h:29
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
detail::InFlightRemark analysis(Location loc, RemarkOpts opts)
Report an optimization analysis remark.
Definition Remarks.h:579
const uArch * getUArch(llvm::StringRef archName)
int getLargestDivisor(T dim, ArrayRef< T > candidates, ArrayRef< T > candidateMultiples={})
Helper Function to find a proper instruction multiple for the user-supplied sg-level data shape (dive...
void setDistributeLayoutAttr(const T &operandOrResult, const DistributeLayoutAttr layout, bool respectPermLayout=false)
Sets the DistributeLayoutAttr for a given OpOperand or OpResult by attaching it to the owner's dictio...
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)
If tensor descriptor has a layout attribute it is used in SIMT mode.
Include the generated interface declarations.
bool operator==(StringAttr lhs, std::nullptr_t)
Define comparisons for StringAttr against nullptr and itself to avoid the StringRef overloads from be...
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:144
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
virtual unsigned getGeneralPackedFormatBitSize() const =0
virtual int getSubgroupSize() const =0
const Instruction * getInstruction(InstructionKind instKind) const
Definition uArchBase.h:157