MLIR 22.0.0git
XeGPUPropagateLayout.cpp
Go to the documentation of this file.
1//===- XeGPUPropagateLayout.cpp - XeGPU Layout Propagation ------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
19#include "mlir/IR/Attributes.h"
20#include "mlir/IR/Builders.h"
23#include "mlir/IR/Operation.h"
24#include "mlir/IR/Value.h"
25#include "mlir/IR/Visitors.h"
28#include "mlir/Support/LLVM.h"
29#include "llvm/ADT/ArrayRef.h"
30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/SmallSet.h"
32#include "llvm/ADT/SmallVector.h"
33#include "llvm/ADT/TypeSwitch.h"
34#include "llvm/Support/Casting.h"
35#include "llvm/Support/Debug.h"
36#include "llvm/Support/LogicalResult.h"
37#include "llvm/Support/raw_ostream.h"
38
40
41namespace mlir {
42namespace xegpu {
43#define GEN_PASS_DEF_XEGPUPROPAGATELAYOUT
44#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
45} // namespace xegpu
46} // namespace mlir
47
48#define DEBUG_TYPE "xegpu-propagate-layout"
49#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
50
51using namespace mlir;
52using namespace mlir::dataflow;
53
54namespace {
55
56enum class LayoutKind { Lane, InstData, Subgroup };
57
58//===----------------------------------------------------------------------===//
59// LayoutInfo
60//===----------------------------------------------------------------------===//
61
62/// Helper class for tracking the analysis state of an mlir value. For layout
63/// propagation, the analysis state is simply the distribution layout of
64/// each value. The distribution layout information is encapsulated using
65/// xegpu::DistributeLayoutAttr class which can hold information about any type
66/// of distribution layout that XeGPU dialect supports. Purpose of this analysis
67/// to propagate some unique distribution layout for each value in the program
68/// starting from a set of anchor operations (like DPAS, StoreNd, etc.). Note
69/// that analysis will reach a fixed point when all values are reached some
70/// layout and, analysis does not try to modify any already assigned layouts.
71///
72/// Given this, LayoutInfo satisifies the following properties:
73/// 1) A LayoutInfo value can be in one of two states - `assigned` or `not
74/// assigned`.
75/// 2) Two LayoutInfo values are equal if they are both assigned or
76/// both not assigned. The concrete value of assigned state does not matter.
77/// 3) The meet operator works as follows:
78/// - If current state is assigned, return the current state. (already
79/// a unique layout is assigned. don't change it)
80/// - Otherwise, return the other state.
81
82struct LayoutInfo {
83private:
84 xegpu::DistributeLayoutAttr storage = nullptr;
85
86public:
87 LayoutInfo() = default;
88 LayoutInfo(const xegpu::DistributeLayoutAttr &layout) : storage(layout) {}
89
90 // Two lattice values are equal if they have `some` layout. The actual
91 // content of the layout does not matter.
92 bool operator==(const LayoutInfo &other) const {
93 return this->isAssigned() == other.isAssigned();
94 }
95
96 static LayoutInfo meet(const LayoutInfo &lhs, const LayoutInfo &rhs);
97
98 static LayoutInfo join(const LayoutInfo &lhs, const LayoutInfo &rhs);
99
100 void print(raw_ostream &os) const;
101
102 bool isAssigned() const { return storage != nullptr; }
103
104 LayoutInfo transpose(ArrayRef<int64_t> permutation) const;
105
106 SmallVector<int> getLaneLayout() const;
107
108 SmallVector<int> getLaneData() const;
109
110 SmallVector<int> getInstData() const;
111
112 SmallVector<int> getSgLayout() const;
113
114 SmallVector<int> getSgData() const;
115
116 SmallVector<int> getOrder() const;
117
118 bool isSliceLayout() const {
119 if (!isAssigned())
120 return false;
121 return isa<xegpu::SliceAttr>(storage);
122 }
123
124 int64_t getRank() const {
125 if (!isAssigned())
126 return -1;
127 return storage.getRank();
128 }
129
130 Attribute get() { return storage; }
131};
132
133SmallVector<int> LayoutInfo::getLaneLayout() const {
134 if (!isAssigned())
135 return {};
136 return llvm::map_to_vector(storage.getEffectiveLaneLayoutAsInt(),
137 [](int64_t val) { return static_cast<int>(val); });
138}
139
140SmallVector<int> LayoutInfo::getLaneData() const {
141 if (!isAssigned())
142 return {};
143 return llvm::map_to_vector(storage.getEffectiveLaneDataAsInt(),
144 [](int64_t val) { return static_cast<int>(val); });
145}
146
147SmallVector<int> LayoutInfo::getInstData() const {
148 if (!isAssigned())
149 return {};
150 return llvm::map_to_vector(storage.getEffectiveInstDataAsInt(),
151 [](int64_t val) { return static_cast<int>(val); });
152}
153
154SmallVector<int> LayoutInfo::getSgLayout() const {
155 if (!isAssigned())
156 return {};
157 return llvm::map_to_vector(storage.getEffectiveSgLayoutAsInt(),
158 [](int64_t val) { return static_cast<int>(val); });
159}
160
161SmallVector<int> LayoutInfo::getSgData() const {
162 if (!isAssigned())
163 return {};
164 return llvm::map_to_vector(storage.getEffectiveSgDataAsInt(),
165 [](int64_t val) { return static_cast<int>(val); });
166}
167
168SmallVector<int> LayoutInfo::getOrder() const {
169 if (!isAssigned() || !storage.getOrder())
170 return {};
171 return llvm::map_to_vector(storage.getOrder().asArrayRef(),
172 [](int64_t val) { return static_cast<int>(val); });
173}
174
175void LayoutInfo::print(raw_ostream &os) const {
176 if (isAssigned()) {
177 os << storage;
178 } else {
179 os << "Not assigned.";
180 }
181}
182
183LayoutInfo LayoutInfo::meet(const LayoutInfo &lhs, const LayoutInfo &rhs) {
184 if (!lhs.isAssigned())
185 return rhs;
186 return lhs;
187}
188
189/// Since this is a backward analysis, join method is not used.
190LayoutInfo LayoutInfo::join(const LayoutInfo &lhs, const LayoutInfo &rhs) {
191 llvm_unreachable("Join should not be triggered by layout propagation.");
192}
193
194/// Construct a new layout with the transposed inst_data or lane_layout,
195/// lane_data.
196LayoutInfo LayoutInfo::transpose(ArrayRef<int64_t> permutation) const {
197 if (!isAssigned())
198 return {};
199 // Check if the permutation is valid.
200 llvm::SmallSet<int64_t, 4> seen(permutation.begin(), permutation.end());
201 bool hasDuplicates = seen.size() != permutation.size();
202 bool withinRange = llvm::all_of(permutation, [&](int64_t idx) {
203 return idx >= 0 && idx < static_cast<int64_t>(permutation.size());
204 });
205
206 if (!withinRange || hasDuplicates) {
207 assert(false && "Invalid permutation for transpose.");
208 return {};
209 }
210
211 SmallVector<int32_t> laneLayout;
212 SmallVector<int32_t> laneData;
213 SmallVector<int32_t> instData;
214 SmallVector<int32_t> sgLayout;
217
218 for (int64_t idx : permutation) {
219 if (getLaneLayout().size()) {
220 laneLayout.push_back(static_cast<int32_t>(getLaneLayout()[idx]));
221 laneData.push_back(static_cast<int32_t>(getLaneData()[idx]));
222 }
223 if (getInstData().size())
224 instData.push_back(static_cast<int32_t>(getInstData()[idx]));
225 if (getSgData().size()) {
226 sgLayout.push_back(static_cast<int32_t>(getSgLayout()[idx]));
227 sgData.push_back(static_cast<int32_t>(getSgData()[idx]));
228 }
229 if (getOrder().size()) {
230 order.push_back(static_cast<int32_t>(getOrder()[idx]));
231 }
232 }
233 auto orderAttr = order.size()
234 ? DenseI32ArrayAttr::get(storage.getContext(), order)
235 : nullptr;
236 xegpu::LayoutAttr layoutAttr;
237 if (getLaneLayout().size())
238 layoutAttr =
239 xegpu::LayoutAttr::get(storage.getContext(), laneLayout, laneData);
240 if (getInstData().size())
241 layoutAttr = xegpu::LayoutAttr::get(storage.getContext(), instData);
242 if (getSgData().size())
243 layoutAttr = xegpu::LayoutAttr::get(
244 storage.getContext(),
245 DenseI32ArrayAttr::get(storage.getContext(), sgLayout),
246 DenseI32ArrayAttr::get(storage.getContext(), sgData),
247 /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
248 /*lane_data =*/nullptr, orderAttr);
249 return LayoutInfo(layoutAttr);
250}
251
252//===----------------------------------------------------------------------===//
253// LayoutInfoLattice
254//===----------------------------------------------------------------------===//
255
256/// Lattice holding the LayoutInfo for each value.
257struct LayoutInfoLattice : public Lattice<LayoutInfo> {
259 using Lattice::Lattice;
260};
261
262/// Helper Functions to get default layouts. A `default layout` is a layout that
263/// is assigned to a value when the layout is not fixed by some anchor operation
264/// (like DPAS).
265
266/// Helper Function to get the default layout for uniform values like constants.
267/// For 1D vector, lane_layout is [subgroupSize] and lane_data is [1].
268/// For 2D vector, lane_layout is [1, subgroupSize] and lane_data is [1, 1].
269static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
270 unsigned rank,
271 const xegpu::uArch::uArch *uArch) {
272 assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
273 if (rank == 1) {
274 return LayoutInfo(
275 xegpu::LayoutAttr::get(ctx, {uArch->getSubgroupSize()}, {1}));
276 }
277 return LayoutInfo(
278 xegpu::LayoutAttr::get(ctx, {1, uArch->getSubgroupSize()}, {1, 1}));
279}
280
281static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
282 unsigned rank, int subgroupSize) {
283 assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
284 if (rank == 1) {
285 return LayoutInfo(xegpu::LayoutAttr::get(ctx, {subgroupSize}, {1}));
286 }
287 return LayoutInfo(xegpu::LayoutAttr::get(ctx, {1, subgroupSize}, {1, 1}));
288}
289
290/// Helper to get the default layout for a vector type.
291static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy,
293 unsigned packingSize,
294 bool isScattered = false) {
295 // Expecting a 1D or 2D vector.
296 assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
297 "Expected 1D or 2D vector.");
298 // Expecting int or float element type.
299 assert(vectorTy.getElementType().isIntOrFloat() &&
300 "Expected int or float element type.");
301 // If the rank is 1, then return default layout for 1D vector.
302 if (vectorTy.getRank() == 1)
303 return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1, uArch);
304 // Packing factor is determined by the element type bitwidth.
305 unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
306 int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
307 if (isScattered) {
308 return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(),
309 {uArch->getSubgroupSize(), 1},
310 {1, packingFactor}));
311 }
312 return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(),
313 {1, uArch->getSubgroupSize()},
314 {1, packingFactor}));
315}
316
317/// Helper to get the default layout for a vector type.
318static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy,
320 unsigned packingSize,
321 bool isScattered = false) {
322 // Expecting a 1D or 2D vector.
323 assert((tdescTy.getRank() == 1 || tdescTy.getRank() == 2) &&
324 "Expected 1D or 2D TensorDesc.");
325 // Expecting int or float element type.
326 assert(tdescTy.getElementType().isIntOrFloat() &&
327 "Expected int or float element type.");
328 // If the rank is 1, then return default layout for 1D vector.
329 if (tdescTy.getRank() == 1)
330 return getDefaultSIMTLayoutInfo(tdescTy.getContext(), 1, uArch);
331 // Packing factor is determined by the element type bitwidth.
332 unsigned bitwidth = tdescTy.getElementType().getIntOrFloatBitWidth();
333 int subgroupSize = uArch->getSubgroupSize();
334 int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
335 if (isScattered) {
336 return LayoutInfo(xegpu::LayoutAttr::get(
337 tdescTy.getContext(), {subgroupSize, 1}, {1, packingFactor}));
338 }
339
340 return LayoutInfo(xegpu::LayoutAttr::get(
341 tdescTy.getContext(), {1, subgroupSize}, {1, packingFactor}));
342}
343
344/// Helper Function to get the expected layouts for DPAS operands. `lane_data`
345/// is set according to the following criteria:
346/// * For A operand, the data must be packed in minimum
347/// `packedSizeInBitsForDefault`
348/// * For B operand, the data must be packed in minimum
349/// `packedSizeInBitsForDpasB`
350static LayoutInfo
351getSIMTLayoutInfoForDPASOperand(VectorType vectorTy, unsigned operandNum,
353 unsigned packingSize) {
354 Type elementTy = vectorTy.getElementType();
355 assert(elementTy.isIntOrFloat() &&
356 "Expected int or float type in DPAS operands");
358 // For B operand, data must be packed in minimum `packedDpasBSizeInBits` and
359 // must have the VNNI format.
360 if (operandNum == 1 && elementTy.getIntOrFloatBitWidth() < packingSize) {
362 {static_cast<int32_t>(packingSize / elementTy.getIntOrFloatBitWidth()),
363 1});
364 return LayoutInfo(
365 xegpu::LayoutAttr::get(vectorTy.getContext(), layout, data));
366 }
367 // Otherwise, return the default layout for the vector type.
368 return getDefaultSIMTLayoutInfo(vectorTy, uArch, packingSize);
369}
370
371//===----------------------------------------------------------------------===//
372// LayoutInfoPropagation
373//===----------------------------------------------------------------------===//
374
375/// Backward data flow analysis to propagate the lane_layout and lane_data of
376/// each value in the program. Currently, the layouts for operands DPAS,
377/// StoreNd, and StoreScatter are fixed (known before propagation). Purpose of
378/// this analysis is to propagate those known layouts to all their producers and
379/// (other) consumers.
380class LayoutInfoPropagation
381 : public SparseBackwardDataFlowAnalysis<LayoutInfoLattice> {
382private:
383 LayoutKind layoutKind;
384 void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
386
387 void visitStoreNdOp(xegpu::StoreNdOp store,
390
391 void visitStoreScatterOp(xegpu::StoreScatterOp storeScatter,
394
395 void visitLoadNdOp(xegpu::LoadNdOp load,
398
399 void visitLoadGatherOp(xegpu::LoadGatherOp load,
402
403 void visitTransposeOp(vector::TransposeOp transpose,
406
407 void visitVectorBitcastOp(vector::BitCastOp bitcast,
410
411 void visitCreateDescOp(xegpu::CreateDescOp createDesc,
414
415 void visitUpdateNdOffsetOp(xegpu::UpdateNdOffsetOp updateNdOffset,
418
419 void visitPrefetchNdOp(xegpu::PrefetchNdOp prefetch,
422
423 void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction,
426
427 void visitVectorBroadCastOp(vector::BroadcastOp broadcast,
430 void visitShapeCastOp(vector::ShapeCastOp shapeCast,
433
434 bool hasParamsOfLayoutKind(xegpu::DistributeLayoutAttr anchorLayout);
435
436public:
437 LayoutInfoPropagation(DataFlowSolver &solver,
438 SymbolTableCollection &symbolTable,
439 LayoutKind layoutKind)
440 : SparseBackwardDataFlowAnalysis(solver, symbolTable),
441 layoutKind(layoutKind) {}
443
444 LogicalResult
445 visitOperation(Operation *op, ArrayRef<LayoutInfoLattice *> operands,
446 ArrayRef<const LayoutInfoLattice *> results) override;
447
448 void visitBranchOperand(OpOperand &operand) override {};
449
450 void visitCallOperand(OpOperand &operand) override {};
451
452 void
453 visitNonControlFlowArguments(RegionSuccessor &successor,
454 ArrayRef<BlockArgument> arguments) override {};
455
456 void visitExternalCall(CallOpInterface call,
458 ArrayRef<const LayoutInfoLattice *> results) override {
459 };
460
461 void setToExitState(LayoutInfoLattice *lattice) override {
462 (void)lattice->meet(LayoutInfo());
463 }
464};
465} // namespace
466
467LogicalResult LayoutInfoPropagation::visitOperation(
468 Operation *op, ArrayRef<LayoutInfoLattice *> operands,
469 ArrayRef<const LayoutInfoLattice *> results) {
471 .Case<xegpu::DpasOp>(
472 [&](auto dpasOp) { visitDpasOp(dpasOp, operands, results); })
473 .Case<xegpu::StoreNdOp>(
474 [&](auto storeNdOp) { visitStoreNdOp(storeNdOp, operands, results); })
475 .Case<xegpu::StoreScatterOp>([&](auto storeScatterOp) {
476 visitStoreScatterOp(storeScatterOp, operands, results);
477 })
478 .Case<xegpu::LoadNdOp>(
479 [&](auto loadNdOp) { visitLoadNdOp(loadNdOp, operands, results); })
480 .Case<xegpu::LoadGatherOp>([&](auto loadGatherOp) {
481 visitLoadGatherOp(loadGatherOp, operands, results);
482 })
483 .Case<xegpu::CreateDescOp>([&](auto createDescOp) {
484 visitCreateDescOp(createDescOp, operands, results);
485 })
486 .Case<xegpu::UpdateNdOffsetOp>([&](auto updateNdOffsetOp) {
487 visitUpdateNdOffsetOp(updateNdOffsetOp, operands, results);
488 })
489 .Case<xegpu::PrefetchNdOp>([&](auto prefetchNdOp) {
490 visitPrefetchNdOp(prefetchNdOp, operands, results);
491 })
492 .Case<vector::TransposeOp>([&](auto transposeOp) {
493 visitTransposeOp(transposeOp, operands, results);
494 })
495 .Case<vector::BitCastOp>([&](auto bitcastOp) {
496 visitVectorBitcastOp(bitcastOp, operands, results);
497 })
498 .Case<vector::MultiDimReductionOp>([&](auto reductionOp) {
499 visitVectorMultiReductionOp(reductionOp, operands, results);
500 })
501 .Case<vector::BroadcastOp>([&](auto broadcastOp) {
502 visitVectorBroadCastOp(broadcastOp, operands, results);
503 })
504 .Case<vector::ShapeCastOp>([&](auto shapeCastOp) {
505 visitShapeCastOp(shapeCastOp, operands, results);
506 })
507 // All other ops.
508 .Default([&](Operation *op) {
509 for (const LayoutInfoLattice *resultInfo : results) {
510 if (!resultInfo->getValue().isAssigned())
511 continue;
512 for (auto [operandInfo, operand] :
513 llvm::zip(operands, op->getOpOperands())) {
514 // If the operand type is not a vector or tensor descriptor, skip
515 // it.
516 if (!isa<xegpu::TensorDescType, VectorType>(
517 operand.get().getType()))
518 continue;
519 // Propagate the result layout to the operand.
520 meet(operandInfo, *resultInfo);
521 }
522 }
523 });
524
525 return success();
526}
527
528bool LayoutInfoPropagation::hasParamsOfLayoutKind(
529 xegpu::DistributeLayoutAttr anchorLayout) {
530 if (anchorLayout == nullptr) {
531 return false;
532 }
533 if (layoutKind == LayoutKind::InstData) {
534 return !(anchorLayout.getEffectiveInstDataAsInt().empty());
535 } else if (layoutKind == LayoutKind::Lane) {
536 return !(anchorLayout.getEffectiveLaneLayoutAsInt().empty() ||
537 anchorLayout.getEffectiveLaneDataAsInt().empty());
538 } else if (layoutKind == LayoutKind::Subgroup) {
539 return !(anchorLayout.getEffectiveSgLayoutAsInt().empty() ||
540 anchorLayout.getEffectiveSgDataAsInt().empty());
541 }
542 return false;
543}
544
545void LayoutInfoPropagation::visitPrefetchNdOp(
546 xegpu::PrefetchNdOp prefetch, ArrayRef<LayoutInfoLattice *> operands,
547 ArrayRef<const LayoutInfoLattice *> results) {
548
549 LayoutInfo prefetchLayout;
550 xegpu::DistributeLayoutAttr anchorLayout = prefetch.getLayoutAttr();
551 if (hasParamsOfLayoutKind(anchorLayout)) {
552 prefetchLayout = LayoutInfo(anchorLayout);
553 } else {
554 // Here we assign the default layout to the tensor descriptor operand of
555 // prefetch.
556 auto tdescTy = prefetch.getTensorDescType();
557
558 auto uArch = getUArch(getChipStr(prefetch).value_or(""));
559 const auto *uArchInstruction =
560 dyn_cast<xegpu::uArch::Subgroup2DBlockPrefetchInstruction>(
561 uArch->getInstruction(
562 xegpu::uArch::InstructionKind::Subgroup2DBlockPrefetch));
563
564 auto blockWHC =
565 uArchInstruction->getBlockWidthHeightCount(tdescTy.getElementType());
566 if (!blockWHC)
567 prefetch.emitWarning("No known block params found for the element type.");
568 auto [bWidth, bHeight, bCount] = blockWHC.value();
569 SmallVector<int> instData;
570 int instWidth = xegpu::getLargestDivisor(
571 static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 1)), bWidth);
572 if (instWidth == -1)
573 prefetch.emitWarning(
574 "No suitable instruction multiple found for the given shape.");
575 if (tdescTy.getRank() == 1)
576 instData = {instWidth};
577 else {
578 int instHeight = xegpu::getLargestDivisor(
579 static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 2)), bHeight);
580 if (instHeight == -1)
581 prefetch.emitWarning(
582 "No suitable instruction multiple found for the given shape.");
583 instData = {instHeight, instWidth};
584 }
585
586 if (layoutKind == LayoutKind::InstData)
587 prefetchLayout =
588 LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(), instData));
589 else
590 prefetchLayout = getDefaultSIMTLayoutInfo(
591 tdescTy, uArch, uArchInstruction->getPackedFormatBitSize());
592
593 prefetch.setLayoutAttr(
594 dyn_cast<xegpu::DistributeLayoutAttr>(prefetchLayout.get()));
595 }
596 // Propagate the layout to the source tensor descriptor.
597 propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout));
598}
599
600void LayoutInfoPropagation::visitVectorMultiReductionOp(
601 vector::MultiDimReductionOp reduction,
602 ArrayRef<LayoutInfoLattice *> operands,
603 ArrayRef<const LayoutInfoLattice *> results) {
604 // The layout of the result must be present.
605 LayoutInfo resultLayout = results[0]->getValue();
606 if (!resultLayout.isAssigned())
607 return;
608 // We only consider 2D -> 1D reductions at this point.
609 VectorType resultTy = llvm::dyn_cast<VectorType>(reduction.getDestType());
610 if (!resultTy || resultTy.getRank() != 1) {
611 reduction.emitWarning("Expecting output type to be 1D vector.");
612 return;
613 }
614 auto uArch = getUArch(xegpu::getChipStr(reduction).value_or(""));
615 // Given that the result is 1D, the layout of the operand should be 2D with
616 // default layout.
617 LayoutInfo operandLayout = getDefaultSIMTLayoutInfo(
618 reduction->getContext(), 2, uArch->getSubgroupSize());
619 propagateIfChanged(operands[0], operands[0]->meet(operandLayout));
620 // Accumulator should have the same layout as the result.
621 propagateIfChanged(operands[1], operands[1]->meet(resultLayout));
622}
623
624void LayoutInfoPropagation::visitVectorBroadCastOp(
625 vector::BroadcastOp broadcast, ArrayRef<LayoutInfoLattice *> operands,
626 ArrayRef<const LayoutInfoLattice *> results) {
627 // The layout of the result must be present.
628 LayoutInfo resultLayout = results[0]->getValue();
629 if (!resultLayout.isAssigned())
630 return;
631 // Only consider vector to vector broadcasts for now.
632 VectorType resultTy = broadcast.getResultVectorType();
633 VectorType sourceTy = dyn_cast<VectorType>(broadcast.getSourceType());
634 // skip layout propagation for non-vector source operand.
635 if (!sourceTy)
636 return;
637
638 // Hanlding broadcast from low-rank to high-rank (e.g., 1D to 2D) case.
639 if (sourceTy.getRank() != resultTy.getRank()) {
640 auto sourceDims = sourceTy.getShape();
641 auto resultDims = resultTy.getShape();
642 SmallVector<int64_t> bcastDims;
643 auto dimDiff = resultTy.getRank() - sourceTy.getRank();
644 // adding the missing leading dims
645 for (int i = 0; i < dimDiff; i++)
646 bcastDims.push_back(i);
647
648 // for the rest dims in the resultTy, if sourceTy dim is 1, then it's
649 // broadcasted dim
650 for (size_t i = 0; i < sourceDims.size(); i++)
651 if ((sourceDims[i] == 1) && (resultDims[i + dimDiff] != 1))
652 bcastDims.push_back(i + dimDiff);
653
654 // create a slice layout for the source
655 xegpu::SliceAttr sliceLayout = xegpu::SliceAttr::get(
657 cast<xegpu::DistributeLayoutAttr>(resultLayout.get()),
659
660 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(sliceLayout)));
661 return;
662 }
663 propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
664}
665
666void LayoutInfoPropagation::visitShapeCastOp(
667 vector::ShapeCastOp shapeCast, ArrayRef<LayoutInfoLattice *> operands,
668 ArrayRef<const LayoutInfoLattice *> results) {
669 // The layout of the result must be present.
670 LayoutInfo resultLayout = results[0]->getValue();
671 if (!resultLayout.isAssigned())
672 return;
673 VectorType sourceTy = shapeCast.getSourceVectorType();
674 VectorType resultTy = shapeCast.getResultVectorType();
675 // Shape cast layout propagation only supports 1D -> 2D shape casts.
676 // TODO: Support kD -> nD shape casts (k < n, n >= 2) where expanded dims are
677 // unit dimensions and non-unit dims match.
678 if (sourceTy.getRank() != 1 || resultTy.getRank() != 2) {
679 shapeCast.emitWarning("Expecting shape cast to be 1D -> 2D.");
680 return;
681 }
682 int64_t slicedDim = resultTy.getShape()[0] == 1 ? 0 : 1;
683 xegpu::SliceAttr sliceLayout = xegpu::SliceAttr::get(
684 shapeCast->getContext(), cast<xegpu::LayoutAttr>(resultLayout.get()),
685 DenseI64ArrayAttr::get(shapeCast->getContext(), {slicedDim}));
686 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(sliceLayout)));
687}
688
689/// Propagate the layout of the result tensor to the source tensor descriptor
690/// in UpdateNdOffsetOp.
691void LayoutInfoPropagation::visitUpdateNdOffsetOp(
692 xegpu::UpdateNdOffsetOp updateNdOffset,
693 ArrayRef<LayoutInfoLattice *> operands,
694 ArrayRef<const LayoutInfoLattice *> results) {
695 // The layout of the result must be present.
696 LayoutInfo resultLayout = results[0]->getValue();
697 if (!resultLayout.isAssigned())
698 return;
699 // Propagate the layout to the source operand.
700 propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
701}
702
703/// Set the layouts for DPAS A, B, and C operands.
704void LayoutInfoPropagation::visitDpasOp(
705 xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
706 ArrayRef<const LayoutInfoLattice *> results) {
707
708 LayoutInfo dpasALayout;
709 LayoutInfo dpasBLayout;
710 LayoutInfo dpasCDLayout;
711
712 xegpu::DistributeLayoutAttr anchorLayoutCD = dpas.getLayoutCdAttr();
713 if (hasParamsOfLayoutKind(anchorLayoutCD)) {
714 xegpu::DistributeLayoutAttr anchorLayoutA = dpas.getLayoutAAttr();
715 xegpu::DistributeLayoutAttr anchorLayoutB = dpas.getLayoutBAttr();
716 assert(hasParamsOfLayoutKind(anchorLayoutA) &&
717 "Expected anchor layout for DPAS A operand.");
718 assert(hasParamsOfLayoutKind(anchorLayoutB) &&
719 "Expected anchor layout for DPAS B operand.");
720 dpasALayout = LayoutInfo(anchorLayoutA);
721 dpasBLayout = LayoutInfo(anchorLayoutB);
722 dpasCDLayout = LayoutInfo(anchorLayoutCD);
723
724 } else {
725
726 VectorType aTy = dpas.getLhsType();
727 VectorType bTy = dpas.getRhsType();
728
729 auto uArch = getUArch(getChipStr(dpas).value_or(""));
730 const int subgroupSize = uArch->getSubgroupSize();
731 const auto *uArchInstruction =
732 dyn_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc>(uArch->getInstruction(
733 xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc));
734
735 const unsigned dataALen = aTy.getShape().front();
736 auto supportedALen = uArchInstruction->getSupportedM(aTy.getElementType());
737 const int maxALen =
738 xegpu::getLargestDivisor(dataALen, ArrayRef<unsigned>(supportedALen));
739 if (maxALen == -1)
740 dpas.emitWarning(
741 "No suitable instruction multiple found for the given shape.");
742
743 const unsigned dataBLen = bTy.getShape().back();
744 auto supportedBLen = uArchInstruction->getSupportedN(bTy.getElementType());
745
746 const int maxBLen =
747 xegpu::getLargestDivisor(dataBLen, ArrayRef<unsigned>(supportedBLen));
748
749 if (maxBLen == -1)
750 dpas.emitWarning(
751 "No suitable instruction multiple found for the given shape.");
752 SmallVector<int> instDataA = {maxALen, subgroupSize};
753 SmallVector<int> instDataB = {subgroupSize, maxBLen};
754
755 if (layoutKind == LayoutKind::InstData) {
756 dpasALayout =
757 LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataA));
758 dpasBLayout =
759 LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataB));
760 } else {
761 dpasALayout = getSIMTLayoutInfoForDPASOperand(
762 aTy, 0, uArch, uArchInstruction->getPackedFormatBitSizeA());
763 dpasBLayout = getSIMTLayoutInfoForDPASOperand(
764 bTy, 1, uArch, uArchInstruction->getPackedFormatBitSizeB());
765 }
766
767 if (operands.size() > 2) {
768 VectorType cTy = dpas.getAccType();
769 if (layoutKind == LayoutKind::InstData) {
770 const unsigned dataCLen = bTy.getShape().back();
771 auto supportedCLen =
772 uArchInstruction->getSupportedN(bTy.getElementType());
773 const int maxCLen = xegpu::getLargestDivisor(
774 dataCLen, ArrayRef<unsigned>(supportedCLen));
775 if (maxCLen == -1)
776 dpas.emitWarning(
777 "No suitable instruction multiple found for the given shape.");
778 SmallVector<int> instDataC = {maxALen, maxCLen};
779 dpasCDLayout =
780 LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataC));
781 } else
782 dpasCDLayout = getSIMTLayoutInfoForDPASOperand(
783 cTy, 2, uArch, uArchInstruction->getPackedFormatBitSizeB());
784
785 dpas.setLayoutCdAttr(
786 dyn_cast<xegpu::DistributeLayoutAttr>(dpasCDLayout.get()));
787 }
788 dpas.setLayoutAAttr(
789 dyn_cast<xegpu::DistributeLayoutAttr>(dpasALayout.get()));
790 dpas.setLayoutBAttr(
791 dyn_cast<xegpu::DistributeLayoutAttr>(dpasBLayout.get()));
792 }
793
794 propagateIfChanged(operands[0], operands[0]->meet(dpasALayout));
795 propagateIfChanged(operands[1], operands[1]->meet(dpasBLayout));
796 if (operands.size() > 2) {
797 propagateIfChanged(operands[2], operands[2]->meet(dpasCDLayout));
798 }
799}
800
801/// Set the layout for the value and tensor descriptor operands in StoreNdOp.
802void LayoutInfoPropagation::visitStoreNdOp(
803 xegpu::StoreNdOp store, ArrayRef<LayoutInfoLattice *> operands,
804 ArrayRef<const LayoutInfoLattice *> results) {
805
806 LayoutInfo storeLayout;
807 xegpu::DistributeLayoutAttr anchorLayout = store.getLayoutAttr();
808 if (hasParamsOfLayoutKind(anchorLayout)) {
809 storeLayout = LayoutInfo(anchorLayout);
810 } else {
811 auto uArch = getUArch(getChipStr(store).value_or(""));
812 const auto *uArchInstruction =
813 dyn_cast<xegpu::uArch::Subgroup2DBlockStoreInstruction>(
814 uArch->getInstruction(
815 xegpu::uArch::InstructionKind::Subgroup2DBlockStore));
816 VectorType dataTy = store.getValueType();
817 auto blockWHC = uArchInstruction->getBlockWidthHeightCount(
818 store.getValueType().getElementType());
819 if (!blockWHC)
820 store.emitWarning("No known block params found for the element type.");
821 auto [bWidth, bHeight, bCount] = blockWHC.value();
822 SmallVector<int> instData;
823 int instWidth = xegpu::getLargestDivisor(
824 static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 1)), bWidth);
825 if (instWidth == -1)
826 store.emitWarning(
827 "No suitable instruction multiple found for the given shape.");
828 if (dataTy.getRank() == 1)
829 instData = {instWidth};
830 else {
831 int instHeight = xegpu::getLargestDivisor(
832 static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 2)), bHeight);
833 if (instHeight == -1)
834 store.emitWarning(
835 "No suitable instruction multiple found for the given shape.");
836 instData = {instHeight, instWidth};
837 }
838
839 if (layoutKind == LayoutKind::InstData)
840 storeLayout =
841 LayoutInfo(xegpu::LayoutAttr::get(dataTy.getContext(), instData));
842 else
843 storeLayout =
844 getDefaultSIMTLayoutInfo(store.getValueType(), uArch,
845 uArchInstruction->getPackedFormatBitSize());
846 store.setLayoutAttr(
847 dyn_cast<xegpu::DistributeLayoutAttr>(storeLayout.get()));
848 }
849 // Propagate the layout to the value operand.
850 // Both operands should have the same layout
851 for (LayoutInfoLattice *operand : operands)
852 propagateIfChanged(operand, operand->meet(storeLayout));
853}
854
855/// Propagate the layout of the value to the tensor descriptor operand in
856/// LoadNdOp.
857void LayoutInfoPropagation::visitLoadNdOp(
858 xegpu::LoadNdOp load, ArrayRef<LayoutInfoLattice *> operands,
859 ArrayRef<const LayoutInfoLattice *> results) {
860
861 LayoutInfo loadLayout;
862 xegpu::DistributeLayoutAttr anchorLayout = load.getLayoutAttr();
863 if (hasParamsOfLayoutKind(anchorLayout)) {
864 loadLayout = LayoutInfo(anchorLayout);
865 } else {
866
867 LayoutInfo valueLayout = results[0]->getValue();
868 // Need the layout of the value to propagate to the tensor descriptor.
869 if (!valueLayout.isAssigned())
870 return;
871 loadLayout = valueLayout;
872 // LoadNdOp has the transpose effect. However, at the stage of this analysis
873 // this effect is not expected and should be abstracted away. Emit a
874 // warning.
875 if (auto transpose = load.getTranspose()) {
876 load.emitWarning("Transpose effect is not expected for LoadNdOp at "
877 "LayoutInfoPropagation stage.");
878 loadLayout = valueLayout.transpose(transpose.value());
879 }
880 load.setLayoutAttr(dyn_cast<xegpu::DistributeLayoutAttr>(loadLayout.get()));
881 }
882 // Propagate the new layout to the tensor descriptor operand.
883 propagateIfChanged(operands[0], operands[0]->meet(loadLayout));
884}
885
886/// For vector::TransposeOp, the layout of the result is transposed and
887/// propagated to the operand.
888void LayoutInfoPropagation::visitTransposeOp(
889 vector::TransposeOp transpose, ArrayRef<LayoutInfoLattice *> operands,
890 ArrayRef<const LayoutInfoLattice *> results) {
891 // Need the layout of transpose result to propagate to the operands.
892 LayoutInfo resultLayout = results[0]->getValue();
893 if (!resultLayout.isAssigned())
894 return;
895 LayoutInfo newLayout = resultLayout.transpose(transpose.getPermutation());
896 // Propagate the new layout to the vector operand.
897 propagateIfChanged(operands[0], operands[0]->meet(newLayout));
898}
899
900/// For vector::BitCastOp, the lane_data of the source layout is changed based
901/// on the bit width of the source and result types.
902void LayoutInfoPropagation::visitVectorBitcastOp(
903 vector::BitCastOp bitcast, ArrayRef<LayoutInfoLattice *> operands,
904 ArrayRef<const LayoutInfoLattice *> results) {
905 // Need the layout of bitcast result to propagate to the operands.
906 LayoutInfo resultLayout = results[0]->getValue();
907 if (!resultLayout.isAssigned())
908 return;
909 int inElemTyBitWidth =
910 bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
911 int outElemTyBitWidth =
912 bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
913 // If the element bit widths are the same, then the layout does not change.
914 if (inElemTyBitWidth == outElemTyBitWidth) {
915 propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
916 return;
917 }
918 // Check if the result layout is valid. i.e. result vector can be distributed.
919 auto resultLaneLayout = resultLayout.getLaneLayout();
920 auto resultLaneData = resultLayout.getLaneData();
922 bitcast.getResultVectorType(),
923 xegpu::LayoutAttr::get(bitcast->getContext(), resultLaneLayout,
924 resultLaneData)))) {
925 bitcast.emitWarning(
926 "Result vector type can not be evenly distributed across lanes.");
927 return;
928 }
929 int64_t rank = bitcast.getSourceVectorType().getRank();
930 // Bitcast is a `narrowing` if the input element type bit width larger than
931 // the output element type bit width. eg. f32 -> f16 is a narrowing bitcast.
932 bool isNarrowing = inElemTyBitWidth > outElemTyBitWidth;
933 int bitCastRatio = isNarrowing ? inElemTyBitWidth / outElemTyBitWidth
934 : outElemTyBitWidth / inElemTyBitWidth;
935 SmallVector<int> sourceLaneLayout =
936 resultLayout.getLaneLayout(); // Lane layout does not change for bitcast.
937 SmallVector<int> outData = resultLayout.getLaneData();
938
939 // TODO: Currently we assume that bitcasts does not require cross lane
940 // communication. So each lane must own the required number of elements to
941 // perform the bitcast locally without cross-lane communication.
942 int outInnerBitsPerLane = outData[rank - 1] * outElemTyBitWidth;
943 if (outInnerBitsPerLane < inElemTyBitWidth) {
944 bitcast.emitWarning(
945 "Narrowing bitcast with cross lane communication is not supported.");
946 return;
947 }
948 // Check if each lane owns a single element in all dimensions except the
949 // innermost dimension.
950 SmallVector<int> sourceLaneData(outData.begin(), outData.end() - 1);
951 if (llvm::any_of(sourceLaneData, [](int64_t d) { return d != 1; })) {
952 bitcast.emitWarning("Each lane must not own multiple elements in any "
953 "dimension other than "
954 "the innermost dimension.");
955 return;
956 }
957 // Decide lane data based on whether the bitcast is narrowing or widening.
958 int64_t innerMostLaneData = isNarrowing ? outData[rank - 1] / bitCastRatio
959 : outData[rank - 1] * bitCastRatio;
960 sourceLaneData.push_back(innerMostLaneData);
961
962 propagateIfChanged(
963 operands[0],
964 operands[0]->meet(LayoutInfo(xegpu::LayoutAttr::get(
965 bitcast->getContext(), sourceLaneLayout, sourceLaneData))));
966}
967
968/// Propagate the layout of the result to the tensor descriptor, mask and offset
969/// operands in LoadGatherOp.
970void LayoutInfoPropagation::visitLoadGatherOp(
971 xegpu::LoadGatherOp load, ArrayRef<LayoutInfoLattice *> operands,
972 ArrayRef<const LayoutInfoLattice *> results) {
973
974 LayoutInfo loadLayout;
975 LayoutInfo maskLayout;
976 xegpu::DistributeLayoutAttr anchorLayout = load.getLayoutAttr();
977 if (hasParamsOfLayoutKind(anchorLayout)) {
978 loadLayout = LayoutInfo(anchorLayout);
979 maskLayout = loadLayout;
980 } else {
981
982 // The layout is strictly determined by the payload type.
983 VectorType payloadTy = load.getValueType();
984 if (!payloadTy) {
985 load.emitWarning("Not propagating, non-vector payload supplied.");
986 return;
987 }
988 auto uArch = getUArch(getChipStr(load).value_or(""));
989 const int subgroupSize = uArch->getSubgroupSize();
990 SmallVector<int> instData{subgroupSize};
991 if (auto chunkSize = load.getChunkSize().value_or(0); chunkSize > 1)
992 instData.push_back(chunkSize);
993 else if (auto srcTdescTy =
994 dyn_cast<xegpu::TensorDescType>(load.getSourceType())) {
995 if (srcTdescTy.getChunkSizeAsInt() > 1)
996 instData.push_back(chunkSize);
997 }
998
999 if (layoutKind == LayoutKind::InstData)
1000 loadLayout =
1001 LayoutInfo(xegpu::LayoutAttr::get(load.getContext(), instData));
1002 else
1003 loadLayout = getDefaultSIMTLayoutInfo(
1004 payloadTy, uArch, uArch->getGeneralPackedFormatBitSize(),
1005 /*scattered*/ true);
1006
1007 // Mask operand should have 1D default layout.
1008 maskLayout = getDefaultSIMTLayoutInfo(load->getContext(), 1, subgroupSize);
1009
1010 load.setLayoutAttr(dyn_cast<xegpu::DistributeLayoutAttr>(loadLayout.get()));
1011 }
1012 // Propagate the new layout to the tensor descriptor operand.
1013 if (isa<xegpu::TensorDescType>(load.getSourceType()))
1014 propagateIfChanged(operands[0], operands[0]->meet(loadLayout));
1015 // Propagate the new layout to the mask and optional offset operand.
1016 propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
1017 if (load.getOffsets())
1018 propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
1019}
1020
1021/// Propagate the layout of the descriptor to the vector offset operand in
1022/// CreateDescOp.
1023void LayoutInfoPropagation::visitCreateDescOp(
1024 xegpu::CreateDescOp createDesc, ArrayRef<LayoutInfoLattice *> operands,
1025 ArrayRef<const LayoutInfoLattice *> results) {
1026 LayoutInfo descLayout = results[0]->getValue();
1027 // Need the layout of the descriptor to propagate to the operands.
1028 if (!descLayout.isAssigned())
1029 return;
1030 auto uArch = getUArch(getChipStr(createDesc).value_or(""));
1031 // For offset operand propagate 1D default layout.
1032 LayoutInfo layout = getDefaultSIMTLayoutInfo(createDesc->getContext(), 1,
1033 uArch->getSubgroupSize());
1034 propagateIfChanged(operands[1], operands[1]->meet(layout));
1035}
1036
1037/// Set the layout for the value, tensor descriptor, offset and mask operands in
1038/// the StoreScatterOp.
1039void LayoutInfoPropagation::visitStoreScatterOp(
1040 xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands,
1041 ArrayRef<const LayoutInfoLattice *> results) {
1042
1043 LayoutInfo payloadLayout;
1044 LayoutInfo maskLayout;
1045 xegpu::DistributeLayoutAttr anchorLayout = storeScatter.getLayoutAttr();
1046 if (hasParamsOfLayoutKind(anchorLayout)) {
1047 payloadLayout = LayoutInfo(anchorLayout);
1048 maskLayout = payloadLayout;
1049 } else {
1050 // Currently, for 2D StoreScatterOp we expect that the height dimension of
1051 // the tensor descriptor is equal to the subgroup size. This is ensured by
1052 // the op verifier.
1053 VectorType payloadTy = storeScatter.getValueType();
1054 if (!payloadTy) {
1055 storeScatter.emitWarning("Not propagating, non-vector payload supplied.");
1056 return;
1057 }
1058
1059 auto uArch = getUArch(getChipStr(storeScatter).value_or(""));
1060 const int subgroupSize = uArch->getSubgroupSize();
1061
1062 if (layoutKind == LayoutKind::InstData) {
1063 SmallVector<int> instData{subgroupSize};
1064 if (auto chunkSize = storeScatter.getChunkSize().value_or(0);
1065 chunkSize > 1)
1066 instData.push_back(chunkSize);
1067 else if (auto dstTdescTy = dyn_cast<xegpu::TensorDescType>(
1068 storeScatter.getDestType())) {
1069 if (dstTdescTy.getChunkSizeAsInt() > 1)
1070 instData.push_back(chunkSize);
1071 }
1072 payloadLayout = LayoutInfo(
1073 xegpu::LayoutAttr::get(storeScatter.getContext(), instData));
1074 } else {
1075 auto payloadShape = payloadTy.getShape();
1076 if (payloadShape.size() > 1)
1077 assert(payloadShape[0] == subgroupSize &&
1078 "Expected the first dimension of 2D tensor descriptor to be "
1079 "equal to "
1080 "subgroup size.");
1081 payloadLayout = getDefaultSIMTLayoutInfo(
1082 payloadTy, uArch, uArch->getGeneralPackedFormatBitSize(),
1083 /*scattered=*/true);
1084 }
1085
1086 maskLayout =
1087 getDefaultSIMTLayoutInfo(storeScatter->getContext(), 1, subgroupSize);
1088
1089 storeScatter.setLayoutAttr(
1090 dyn_cast<xegpu::DistributeLayoutAttr>(payloadLayout.get()));
1091 }
1092 // Propagate the payload operand layout
1093 propagateIfChanged(operands[0], operands[0]->meet(payloadLayout));
1094 // Propagate the destination (if tdesc) operand layout
1095 if (isa<xegpu::TensorDescType>(storeScatter.getDestType()))
1096 propagateIfChanged(operands[1], operands[1]->meet(payloadLayout));
1097 // Propagate the new layout to the mask and optional offset operand.
1098 propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
1099 if (storeScatter.getOffsets())
1100 propagateIfChanged(operands[3], operands[3]->meet(maskLayout));
1101}
1102
1103namespace {
1104//===----------------------------------------------------------------------===//
1105// RunLayoutInfoPropagation
1106//===----------------------------------------------------------------------===//
1107
1108/// Driver class for running the LayoutInfoPropagation analysis.
1109class RunLayoutInfoPropagation {
1110public:
1111 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RunLayoutInfoPropagation)
1112
1113 RunLayoutInfoPropagation(Operation *op, LayoutKind layoutKind) : target(op) {
1114 SymbolTableCollection symbolTable;
1115 loadBaselineAnalyses(solver);
1116 solver.load<LayoutInfoPropagation>(symbolTable, layoutKind);
1117 (void)solver.initializeAndRun(op);
1118 }
1119
1120 LayoutInfo getLayoutInfo(Value val);
1121
1122 void printAnalysisResult(llvm::raw_ostream &os);
1123
1124private:
1125 DataFlowSolver solver;
1126 const Operation *target;
1127};
1128} // namespace
1129
1130LayoutInfo RunLayoutInfoPropagation::getLayoutInfo(Value val) {
1131 auto *state = solver.lookupState<LayoutInfoLattice>(val);
1132 if (!state)
1133 return {};
1134 return state->getValue();
1135}
1136
1137// Print the analysis result for debugging purposes.
1138void RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) {
1139 auto printFunctionResult = [&](FunctionOpInterface funcOp) {
1140 os << "function: " << funcOp.getName() << ":\n";
1141 // Function arguments
1142 for (BlockArgument arg : funcOp.getArguments()) {
1143 LayoutInfo layout = getLayoutInfo(arg);
1144 os << "argument: " << arg << "\n";
1145 os << "layout : ";
1146 layout.print(os);
1147 os << "\n";
1148 }
1149 // Function ops
1150 funcOp.walk([&](Operation *op) {
1151 // Skip ops that do not have results
1152 if (op->getResults().empty())
1153 return;
1154 os << "op : ";
1155 // For control-flow ops, print the op name only.
1156 if (isa<BranchOpInterface>(op) || isa<RegionBranchOpInterface>(op))
1157 os << op->getName();
1158 else
1159 op->print(os);
1160 os << "\n";
1161 // Print the layout for each result.
1162 for (auto [i, r] : llvm::enumerate(op->getResults())) {
1163 LayoutInfo layout = getLayoutInfo(r);
1164 os << "layout for result #" << i << ": ";
1165 layout.print(os);
1166 os << "\n";
1167 }
1168 });
1169 };
1170
1171 SmallVector<FunctionOpInterface> funcOps;
1172 if (auto modOp = dyn_cast<ModuleOp>(target)) {
1173 for (auto funcOp : modOp.getOps<FunctionOpInterface>())
1174 funcOps.push_back(funcOp);
1175
1176 // Collect all GpuFuncOps in the module.
1177 for (auto gpuModOp : modOp.getOps<gpu::GPUModuleOp>()) {
1178 for (auto gpuFuncOp : gpuModOp.getOps<FunctionOpInterface>())
1179 funcOps.push_back(gpuFuncOp);
1180 }
1181 }
1182 // Print the analysis result for each function.
1183 for (FunctionOpInterface funcOp : funcOps)
1184 printFunctionResult(funcOp);
1185}
1186
1187using GetLayoutFnTy = function_ref<xegpu::DistributeLayoutAttr(Value)>;
1188/// Update an operation with the layout of its results. If the result type is
1189/// a vector type, a temporary layout attribute is added to the operation. If
1190/// the result type is a tensor descriptor type, the type is updated with the
1191/// layout attribute. The users of the result are also updated with the layout
1192/// attribute.
1193static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
1194 GetLayoutFnTy getLayoutOfValue) {
1195 // Region ops (like scf.for) are already handled by the
1196 // updateControlFlowOps.
1197 if (mlir::isa<mlir::RegionBranchOpInterface>(op))
1198 return success();
1199
1200 // Iterate over all the results.
1201 for (OpResult result : op->getResults()) {
1202 Type resultType = result.getType();
1203 // Layouts are needed only for vector and tensor descriptor types.
1204 if (!isa<VectorType, xegpu::TensorDescType>(resultType))
1205 continue;
1206 // If the result has no layout but has users, emit a warning and continue.
1207 xegpu::DistributeLayoutAttr layout = getLayoutOfValue(result);
1208 if (!layout && result.getNumUses() > 0) {
1209 op->emitWarning("op has users but no layout assigned for its result");
1210 continue;
1211 }
1212 // If the result is a tensor descriptor type, update the tensor desc type
1213 // with layout.
1214 if (auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType)) {
1215 auto typeWithLayout = xegpu::TensorDescType::get(
1216 tensorDescTy.getContext(), tensorDescTy.getShape(),
1217 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
1218 result.setType(typeWithLayout);
1219 continue;
1220 }
1221 // If the result is a vector type, add a temporary layout attribute to the
1222 // op.
1224 }
1225 return success();
1226}
1227
1228/// Region ops like scf.for need special handling because they have blocks
1229/// inside. If the blocks have tensor descriptor type as block arguments,
1230/// thier types must be updated. Also region op can have results that may not
1231/// have any users (e.g. A and B tiles). They are not assigned a layout by
1232/// layout analysis because they have no users. However inside the region op
1233/// corresponding block arguments for these results do have layouts.
1234/// Therefore, in this case we still need to update the result types with the
1235/// layout attribute. This function function updates the internal block
1236/// arguments and the result types of the region op with the assigned layouts.
1237/// clang-format off
1238/// Example: scf.for ... iter_args(...) -> (out types) {
1239/// ^bb0(block types):
1240/// ...
1241/// scf.yield ... : (yield types)
1242/// }
1243/// clang-format on
1244/// In this example, at scf.yield, control-flow can transfer to two successor
1245/// regions. One is the ^bb0 (for loop body) and the other is the scf.for op
1246/// itself (yield the results). So we update both the block arguments of the
1247/// successor region (i.e. block types) and the result types of the scf.for op
1248/// (i.e. out types). Note that yield types are updated by respective
1249/// producers inside bb0.
1250static LogicalResult
1252 mlir::RegionBranchTerminatorOpInterface terminator,
1253 GetLayoutFnTy getLayoutOfValue) {
1254 // Only process if the terminator is inside a region branch op.
1255 auto branchOp = dyn_cast<RegionBranchOpInterface>(terminator->getParentOp());
1256 if (!branchOp)
1257 return success();
1258
1260 branchOp.getSuccessorOperandInputMapping(mapping,
1261 RegionBranchPoint(terminator));
1262 for (const auto &[successorOperand, successorInputs] : mapping) {
1263 for (Value successorInput : successorInputs) {
1264 Type inputType = successorInput.getType();
1265 // We only need to operate on tensor descriptor or vector types.
1266 if (!isa<xegpu::TensorDescType, VectorType>(inputType))
1267 continue;
1268 xegpu::DistributeLayoutAttr successorInputLayout =
1269 getLayoutOfValue(successorInput);
1270 xegpu::DistributeLayoutAttr successorOperandLayout =
1271 getLayoutOfValue(successorOperand->get());
1272
1273 // If either of the layouts is not assigned, we cannot proceed.
1274 if (!successorOperandLayout) {
1275 LLVM_DEBUG(DBGS() << "No layout assigned for forwarded operand in "
1276 "branch terminator: "
1277 << successorOperand->get() << "\n");
1278 return failure();
1279 }
1280 // We expect the layouts to match.
1281 if (successorInputLayout &&
1282 successorInputLayout != successorOperandLayout) {
1283 LLVM_DEBUG(DBGS() << "Conflicting layouts for region argument and "
1284 "operand forwarded as the argument: "
1285 << successorInputLayout << " vs "
1286 << successorOperandLayout << "\n");
1287 return failure();
1288 }
1289 // Get tensor descriptor type with the layout.
1290 if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(inputType)) {
1291 auto newTdescTy = xegpu::TensorDescType::get(
1292 tdescTy.getContext(), tdescTy.getShape(), tdescTy.getElementType(),
1293 tdescTy.getEncoding(), successorOperandLayout);
1294 successorInput.setType(newTdescTy);
1295 continue;
1296 }
1297 // If the type is a vector type and this region argument is an OpResult,
1298 // set the layout attribute on the OpResult.
1299 if (auto result = dyn_cast<OpResult>(successorInput))
1300 xegpu::setDistributeLayoutAttr(result, successorOperandLayout);
1301 }
1302 }
1303 return success();
1304}
1305
1306/// Update the function arguments and results with the layouts.
1307static LogicalResult updateFunctionOpInterface(mlir::OpBuilder &builder,
1308 mlir::FunctionOpInterface funcOp,
1309 GetLayoutFnTy getLayoutOfValue) {
1310 SmallVector<Type> newArgTypes;
1311 // Update the function arguments.
1312 for (BlockArgument arg : funcOp.getArguments()) {
1313 Type argType = arg.getType();
1314 newArgTypes.push_back(argType);
1315 if (!isa<VectorType, xegpu::TensorDescType>(argType))
1316 continue;
1317 xegpu::DistributeLayoutAttr layout = getLayoutOfValue(arg);
1318 if (!layout) {
1319 LLVM_DEBUG(DBGS() << "Expecting layout for function argument: " << arg
1320 << " but got none.\n");
1321 return failure();
1322 }
1323 if (auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(argType)) {
1324 auto newTdescTy = xegpu::TensorDescType::get(
1325 tensorDescTy.getContext(), tensorDescTy.getShape(),
1326 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
1327 arg.setType(newTdescTy);
1328 newArgTypes.back() = newTdescTy;
1329 }
1330 }
1331 // Update the function type with the new argument types.
1332 // NOTE: We assume that function results are not expected to have layouts.
1333 funcOp.setType(FunctionType::get(funcOp.getContext(), newArgTypes,
1334 funcOp.getResultTypes()));
1335 return success();
1336}
1337
1338namespace {
1339struct XeGPUPropagateLayoutPass final
1340 : public xegpu::impl::XeGPUPropagateLayoutBase<XeGPUPropagateLayoutPass> {
1341 XeGPUPropagateLayoutPass() = default;
1342 XeGPUPropagateLayoutPass(const XeGPUPropagateLayoutPass &other) = default;
1343 XeGPUPropagateLayoutPass(xegpu::XeGPUPropagateLayoutOptions options)
1344 : XeGPUPropagateLayoutBase(options) {}
1345 void runOnOperation() override;
1346};
1347
1348} // namespace
1349
1350void XeGPUPropagateLayoutPass::runOnOperation() {
1351 LayoutKind layoutKind;
1352 if (this->layoutKind == "lane") {
1353 layoutKind = LayoutKind::Lane;
1354 } else if (this->layoutKind == "inst") {
1355 layoutKind = LayoutKind::InstData;
1356 } else if (this->layoutKind == "subgroup") {
1357 layoutKind = LayoutKind::Subgroup;
1358 } else {
1359 getOperation()->emitError("Unsupported layout kind option: " +
1360 this->layoutKind);
1361 signalPassFailure();
1362 return;
1363 }
1364 RunLayoutInfoPropagation analysis(getOperation(), layoutKind);
1365 // Print the analysis result and exit. (for debugging purposes)
1366 if (printOnly) {
1367 auto &os = llvm::outs();
1368 analysis.printAnalysisResult(os);
1369 return;
1370 }
1371 // Helper to convert LayoutInfo to xegpu::LayoutAttr.
1372 auto getXeGPULayoutForValue = [&](Value val) -> xegpu::DistributeLayoutAttr {
1373 LayoutInfo layout = analysis.getLayoutInfo(val);
1374 if (!layout.isAssigned())
1375 return {};
1376 xegpu::DistributeLayoutAttr layoutAttr =
1377 cast<xegpu::DistributeLayoutAttr>(layout.get());
1378 if (layout.isSliceLayout())
1379 return cast<xegpu::SliceAttr>(layoutAttr);
1380 return cast<xegpu::LayoutAttr>(layoutAttr);
1381 };
1382
1383 mlir::OpBuilder builder(&getContext());
1384 Operation *op = getOperation();
1385 auto walkResult = op->walk([&](mlir::Block *block) -> WalkResult {
1386 for (mlir::Operation &op : llvm::reverse(block->getOperations())) {
1387 LogicalResult r = success();
1389 .Case<mlir::RegionBranchTerminatorOpInterface>(
1390 [&](mlir::RegionBranchTerminatorOpInterface branchTermOp) {
1391 r = updateControlFlowOps(builder, branchTermOp,
1392 getXeGPULayoutForValue);
1393 })
1394 .Case<mlir::FunctionOpInterface>(
1395 [&](mlir::FunctionOpInterface funcOp) {
1396 r = updateFunctionOpInterface(builder, funcOp,
1397 getXeGPULayoutForValue);
1398 })
1399 .Default([&](Operation *op) {
1400 r = updateOp(builder, op, getXeGPULayoutForValue);
1401 });
1402 if (failed(r)) {
1403 op.emitError("Failed to update operation with the layout.");
1404 return WalkResult::interrupt();
1405 }
1406 }
1407 return WalkResult::advance();
1408 });
1409 if (walkResult.wasInterrupted()) {
1410 signalPassFailure();
1411 return;
1412 }
1413}
return success()
#define DBGS()
Definition Hoisting.cpp:32
lhs
b getContext())
auto load
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition TypeID.h:331
static LogicalResult updateControlFlowOps(mlir::OpBuilder &builder, mlir::RegionBranchTerminatorOpInterface terminator, GetLayoutFnTy getLayoutOfValue)
Region ops like scf.for need special handling because they have blocks inside.
function_ref< xegpu::DistributeLayoutAttr(Value)> GetLayoutFnTy
static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op, GetLayoutFnTy getLayoutOfValue)
Update an operation with the layout of its results.
static LogicalResult updateFunctionOpInterface(mlir::OpBuilder &builder, mlir::FunctionOpInterface funcOp, GetLayoutFnTy getLayoutOfValue)
Update the function arguments and results with the layouts.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:309
OpListType & getOperations()
Definition Block.h:147
The general data-flow analysis solver.
const StateT * lookupState(AnchorT anchor) const
Lookup an analysis state for the given lattice anchor.
AnalysisT * load(Args &&...args)
Load an analysis into the solver. Return the analysis instance.
LogicalResult initializeAndRun(Operation *top)
Initialize the children analyses starting from the provided top-level operation and run the analysis ...
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:207
This class represents an operand of an operation.
Definition Value.h:257
This is a value defined by a result of an operation.
Definition Value.h:457
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:383
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
void print(raw_ostream &os, const OpPrintingFlags &flags={})
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:797
result_range getResults()
Definition Operation.h:415
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
This class represents a successor of a region.
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition Value.h:108
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
This class represents a lattice holding a specific value of type ValueT.
A sparse (backward) data-flow analysis for propagating SSA value lattices backwards across the IR by ...
SparseBackwardDataFlowAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
void loadBaselineAnalyses(DataFlowSolver &solver)
Populates a DataFlowSolver with analyses that are required to ensure user-defined analyses are run pr...
Definition Utils.h:29
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
detail::InFlightRemark analysis(Location loc, RemarkOpts opts)
Report an optimization analysis remark.
Definition Remarks.h:579
const uArch * getUArch(llvm::StringRef archName)
void setDistributeLayoutAttr(const OpResult &Result, const DistributeLayoutAttr layout)
[to-be-deprecated] Sets the DistributeLayoutAttr for a given OpResult user should use setAnchorLayout...
int getLargestDivisor(T dim, ArrayRef< T > candidates, ArrayRef< T > candidateMultiples={})
Helper Function to find a proper instruction multiple for the user-supplied sg-level data shape (dive...
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)
If tensor descriptor has a layout attribute it is used in SIMT mode.
Include the generated interface declarations.
DenseMap< OpOperand *, SmallVector< Value > > RegionBranchSuccessorMapping
A mapping from successor operands to successor inputs.
bool operator==(StringAttr lhs, std::nullptr_t)
Define comparisons for StringAttr against nullptr and itself to avoid the StringRef overloads from be...
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:144
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
virtual int getSubgroupSize() const =0