MLIR 22.0.0git
XeGPUPropagateLayout.cpp
Go to the documentation of this file.
1//===- XeGPUPropagateLayout.cpp - XeGPU Layout Propagation ------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
19#include "mlir/IR/Attributes.h"
20#include "mlir/IR/Builders.h"
23#include "mlir/IR/Operation.h"
24#include "mlir/IR/Value.h"
25#include "mlir/IR/Visitors.h"
28#include "mlir/Support/LLVM.h"
29#include "llvm/ADT/ArrayRef.h"
30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/SmallSet.h"
32#include "llvm/ADT/SmallVector.h"
33#include "llvm/ADT/TypeSwitch.h"
34#include "llvm/Support/Casting.h"
35#include "llvm/Support/Debug.h"
36#include "llvm/Support/LogicalResult.h"
37#include "llvm/Support/raw_ostream.h"
38
40
41namespace mlir {
42namespace xegpu {
43#define GEN_PASS_DEF_XEGPUPROPAGATELAYOUT
44#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
45} // namespace xegpu
46} // namespace mlir
47
48#define DEBUG_TYPE "xegpu-propagate-layout"
49#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
50
51using namespace mlir;
52using namespace mlir::dataflow;
53
54namespace {
55
56enum class LayoutKind { Lane, InstData };
57
58//===----------------------------------------------------------------------===//
59// LayoutInfo
60//===----------------------------------------------------------------------===//
61
62/// Helper class for tracking the analysis state of an mlir value. For layout
63/// propagation, the analysis state is simply the distribution layout of
64/// each value. The distribution layout information is encapsulated using
65/// xegpu::DistributeLayoutAttr class which can hold information about any type
66/// of distribution layout that XeGPU dialect supports. Purpose of this analysis
67/// to propagate some unique distribution layout for each value in the program
68/// starting from a set of anchor operations (like DPAS, StoreNd, etc.). Note
69/// that analysis will reach a fixed point when all values are reached some
70/// layout and, analysis does not try to modify any already assigned layouts.
71///
72/// Given this, LayoutInfo satisifies the following properties:
73/// 1) A LayoutInfo value can be in one of two states - `assigned` or `not
74/// assigned`.
75/// 2) Two LayoutInfo values are equal if they are both assigned or
76/// both not assigned. The concrete value of assigned state does not matter.
77/// 3) The meet operator works as follows:
78/// - If current state is assigned, return the current state. (already
79/// a unique layout is assigned. don't change it)
80/// - Otherwise, return the other state.
81
82struct LayoutInfo {
83private:
84 xegpu::DistributeLayoutAttr storage = nullptr;
85
86public:
87 LayoutInfo() = default;
88 LayoutInfo(const xegpu::DistributeLayoutAttr &layout) : storage(layout) {}
89
90 // Two lattice values are equal if they have `some` layout. The actual
91 // content of the layout does not matter.
92 bool operator==(const LayoutInfo &other) const {
93 return this->isAssigned() == other.isAssigned();
94 }
95
96 static LayoutInfo meet(const LayoutInfo &lhs, const LayoutInfo &rhs);
97
98 static LayoutInfo join(const LayoutInfo &lhs, const LayoutInfo &rhs);
99
100 void print(raw_ostream &os) const;
101
102 bool isAssigned() const { return storage != nullptr; }
103
104 LayoutInfo transpose(ArrayRef<int64_t> permutation) const;
105
106 SmallVector<int> getLaneLayout() const;
107
108 SmallVector<int> getLaneData() const;
109
110 SmallVector<int> getInstData() const;
111
112 bool isSliceLayout() const {
113 if (!isAssigned())
114 return false;
115 return isa<xegpu::SliceAttr>(storage);
116 }
117
118 int64_t getRank() const {
119 if (!isAssigned())
120 return -1;
121 return storage.getRank();
122 }
123
124 Attribute get() { return storage; }
125};
126
127SmallVector<int> LayoutInfo::getLaneLayout() const {
128 if (!isAssigned())
129 return {};
130 assert(storage.getEffectiveLaneLayoutAsInt().size() &&
131 "Expected lane layout to be assigned");
132 return llvm::map_to_vector(storage.getEffectiveLaneLayoutAsInt(),
133 [](int64_t val) { return static_cast<int>(val); });
134}
135
136SmallVector<int> LayoutInfo::getLaneData() const {
137 if (!isAssigned())
138 return {};
139 assert(storage.getEffectiveLaneDataAsInt().size() &&
140 "Expected lane data to be assigned");
141 return llvm::map_to_vector(storage.getEffectiveLaneDataAsInt(),
142 [](int64_t val) { return static_cast<int>(val); });
143}
144
145SmallVector<int> LayoutInfo::getInstData() const {
146 if (!isAssigned())
147 return {};
148 return llvm::map_to_vector(storage.getEffectiveInstDataAsInt(),
149 [](int64_t val) { return static_cast<int>(val); });
150}
151
152void LayoutInfo::print(raw_ostream &os) const {
153 if (isAssigned()) {
154 os << storage;
155 } else {
156 os << "Not assigned.";
157 }
158}
159
160LayoutInfo LayoutInfo::meet(const LayoutInfo &lhs, const LayoutInfo &rhs) {
161 if (!lhs.isAssigned())
162 return rhs;
163 return lhs;
164}
165
166/// Since this is a backward analysis, join method is not used.
167LayoutInfo LayoutInfo::join(const LayoutInfo &lhs, const LayoutInfo &rhs) {
168 llvm_unreachable("Join should not be triggered by layout propagation.");
169}
170
171/// Construct a new layout with the transposed inst_data or lane_layout,
172/// lane_data.
173LayoutInfo LayoutInfo::transpose(ArrayRef<int64_t> permutation) const {
174 if (!isAssigned())
175 return {};
176 // Check if the permutation is valid.
177 llvm::SmallSet<int64_t, 4> seen(permutation.begin(), permutation.end());
178 bool hasDuplicates = seen.size() != permutation.size();
179 bool withinRange = llvm::all_of(permutation, [&](int64_t idx) {
180 return idx >= 0 && idx < static_cast<int64_t>(permutation.size());
181 });
182
183 if (!withinRange || hasDuplicates) {
184 assert(false && "Invalid permutation for transpose.");
185 return {};
186 }
187
188 SmallVector<int32_t> laneLayout;
189 SmallVector<int32_t> laneData;
190 SmallVector<int32_t> instData;
191 for (int64_t idx : permutation) {
192 if (getLaneLayout().size()) {
193 laneLayout.push_back(static_cast<int32_t>(getLaneLayout()[idx]));
194 laneData.push_back(static_cast<int32_t>(getLaneData()[idx]));
195 }
196 if (getInstData().size())
197 instData.push_back(static_cast<int32_t>(getInstData()[idx]));
198 }
199 xegpu::LayoutAttr layoutAttr;
200 if (getLaneLayout().size())
201 layoutAttr =
202 xegpu::LayoutAttr::get(storage.getContext(), laneLayout, laneData);
203 if (getInstData().size())
204 layoutAttr = xegpu::LayoutAttr::get(storage.getContext(), instData);
205 return LayoutInfo(layoutAttr);
206}
207
208//===----------------------------------------------------------------------===//
209// LayoutInfoLattice
210//===----------------------------------------------------------------------===//
211
212/// Lattice holding the LayoutInfo for each value.
213struct LayoutInfoLattice : public Lattice<LayoutInfo> {
215 using Lattice::Lattice;
216};
217
218/// Helper Functions to get default layouts. A `default layout` is a layout that
219/// is assigned to a value when the layout is not fixed by some anchor operation
220/// (like DPAS).
221
222/// Helper Function to get the default layout for uniform values like constants.
223/// For 1D vector, lane_layout is [subgroupSize] and lane_data is [1].
224/// For 2D vector, lane_layout is [1, subgroupSize] and lane_data is [1, 1].
225static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
226 unsigned rank,
227 const xegpu::uArch::uArch *uArch) {
228 assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
229 if (rank == 1) {
230 return LayoutInfo(
231 xegpu::LayoutAttr::get(ctx, {uArch->getSubgroupSize()}, {1}));
232 }
233 return LayoutInfo(
234 xegpu::LayoutAttr::get(ctx, {1, uArch->getSubgroupSize()}, {1, 1}));
235}
236
237static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
238 unsigned rank, int subgroupSize) {
239 assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
240 if (rank == 1) {
241 return LayoutInfo(xegpu::LayoutAttr::get(ctx, {subgroupSize}, {1}));
242 }
243 return LayoutInfo(xegpu::LayoutAttr::get(ctx, {1, subgroupSize}, {1, 1}));
244}
245
246/// Helper to get the default layout for a vector type.
247static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy,
249 unsigned packingSize,
250 bool isScattered = false) {
251 // Expecting a 1D or 2D vector.
252 assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
253 "Expected 1D or 2D vector.");
254 // Expecting int or float element type.
255 assert(vectorTy.getElementType().isIntOrFloat() &&
256 "Expected int or float element type.");
257 // If the rank is 1, then return default layout for 1D vector.
258 if (vectorTy.getRank() == 1)
259 return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1, uArch);
260 // Packing factor is determined by the element type bitwidth.
261 unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
262 int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
263 if (isScattered) {
264 return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(),
265 {uArch->getSubgroupSize(), 1},
266 {1, packingFactor}));
267 }
268 return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(),
269 {1, uArch->getSubgroupSize()},
270 {1, packingFactor}));
272
273/// Helper to get the default layout for a vector type.
274static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy,
276 unsigned packingSize,
277 bool isScattered = false) {
278 // Expecting a 1D or 2D vector.
279 assert((tdescTy.getRank() == 1 || tdescTy.getRank() == 2) &&
280 "Expected 1D or 2D TensorDesc.");
281 // Expecting int or float element type.
282 assert(tdescTy.getElementType().isIntOrFloat() &&
283 "Expected int or float element type.");
284 // If the rank is 1, then return default layout for 1D vector.
285 if (tdescTy.getRank() == 1)
286 return getDefaultSIMTLayoutInfo(tdescTy.getContext(), 1, uArch);
287 // Packing factor is determined by the element type bitwidth.
288 unsigned bitwidth = tdescTy.getElementType().getIntOrFloatBitWidth();
289 int subgroupSize = uArch->getSubgroupSize();
290 int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
291 if (isScattered) {
292 return LayoutInfo(xegpu::LayoutAttr::get(
293 tdescTy.getContext(), {subgroupSize, 1}, {1, packingFactor}));
295
296 return LayoutInfo(xegpu::LayoutAttr::get(
297 tdescTy.getContext(), {1, subgroupSize}, {1, packingFactor}));
298}
299
300/// Helper Function to get the expected layouts for DPAS operands. `lane_data`
301/// is set according to the following criteria:
302/// * For A operand, the data must be packed in minimum
303/// `packedSizeInBitsForDefault`
304/// * For B operand, the data must be packed in minimum
305/// `packedSizeInBitsForDpasB`
306static LayoutInfo
307getSIMTLayoutInfoForDPASOperand(VectorType vectorTy, unsigned operandNum,
309 unsigned packingSize) {
310 Type elementTy = vectorTy.getElementType();
311 assert(elementTy.isIntOrFloat() &&
312 "Expected int or float type in DPAS operands");
313 SmallVector<int32_t, 2> layout({1, uArch->getSubgroupSize()});
314 // For B operand, data must be packed in minimum `packedDpasBSizeInBits` and
315 // must have the VNNI format.
316 if (operandNum == 1 && elementTy.getIntOrFloatBitWidth() < packingSize) {
318 {static_cast<int32_t>(packingSize / elementTy.getIntOrFloatBitWidth()),
319 1});
320 return LayoutInfo(
321 xegpu::LayoutAttr::get(vectorTy.getContext(), layout, data));
322 }
323 // Otherwise, return the default layout for the vector type.
324 return getDefaultSIMTLayoutInfo(vectorTy, uArch, packingSize);
325}
326
327//===----------------------------------------------------------------------===//
328// LayoutInfoPropagation
329//===----------------------------------------------------------------------===//
330
331/// Backward data flow analysis to propagate the lane_layout and lane_data of
332/// each value in the program. Currently, the layouts for operands DPAS,
333/// StoreNd, and StoreScatter are fixed (known before propagation). Purpose of
334/// this analysis is to propagate those known layouts to all their producers and
335/// (other) consumers.
336class LayoutInfoPropagation
337 : public SparseBackwardDataFlowAnalysis<LayoutInfoLattice> {
338private:
339 LayoutKind layoutKind;
340 void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
342
343 void visitStoreNdOp(xegpu::StoreNdOp store,
346
347 void visitStoreScatterOp(xegpu::StoreScatterOp storeScatter,
350
351 void visitLoadNdOp(xegpu::LoadNdOp load,
354
355 void visitLoadGatherOp(xegpu::LoadGatherOp load,
358
359 void visitTransposeOp(vector::TransposeOp transpose,
362
363 void visitVectorBitcastOp(vector::BitCastOp bitcast,
366
367 void visitCreateDescOp(xegpu::CreateDescOp createDesc,
370
371 void visitUpdateNdOffsetOp(xegpu::UpdateNdOffsetOp updateNdOffset,
374
375 void visitPrefetchNdOp(xegpu::PrefetchNdOp prefetch,
378
379 void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction,
382
383 void visitVectorBroadCastOp(vector::BroadcastOp broadcast,
386 void visitShapeCastOp(vector::ShapeCastOp shapeCast,
389
390public:
391 LayoutInfoPropagation(DataFlowSolver &solver,
392 SymbolTableCollection &symbolTable,
393 LayoutKind layoutKind)
394 : SparseBackwardDataFlowAnalysis(solver, symbolTable),
395 layoutKind(layoutKind) {}
397
398 LogicalResult
399 visitOperation(Operation *op, ArrayRef<LayoutInfoLattice *> operands,
400 ArrayRef<const LayoutInfoLattice *> results) override;
401
402 void visitBranchOperand(OpOperand &operand) override {};
403
404 void visitCallOperand(OpOperand &operand) override {};
405
406 void visitExternalCall(CallOpInterface call,
407 ArrayRef<LayoutInfoLattice *> operands,
408 ArrayRef<const LayoutInfoLattice *> results) override {
409 };
410
411 void setToExitState(LayoutInfoLattice *lattice) override {
412 (void)lattice->meet(LayoutInfo());
413 }
414};
415} // namespace
416
417LogicalResult LayoutInfoPropagation::visitOperation(
418 Operation *op, ArrayRef<LayoutInfoLattice *> operands,
419 ArrayRef<const LayoutInfoLattice *> results) {
421 .Case<xegpu::DpasOp>(
422 [&](auto dpasOp) { visitDpasOp(dpasOp, operands, results); })
423 .Case<xegpu::StoreNdOp>(
424 [&](auto storeNdOp) { visitStoreNdOp(storeNdOp, operands, results); })
425 .Case<xegpu::StoreScatterOp>([&](auto storeScatterOp) {
426 visitStoreScatterOp(storeScatterOp, operands, results);
427 })
428 .Case<xegpu::LoadNdOp>(
429 [&](auto loadNdOp) { visitLoadNdOp(loadNdOp, operands, results); })
430 .Case<xegpu::LoadGatherOp>([&](auto loadGatherOp) {
431 visitLoadGatherOp(loadGatherOp, operands, results);
432 })
433 .Case<xegpu::CreateDescOp>([&](auto createDescOp) {
434 visitCreateDescOp(createDescOp, operands, results);
435 })
436 .Case<xegpu::UpdateNdOffsetOp>([&](auto updateNdOffsetOp) {
437 visitUpdateNdOffsetOp(updateNdOffsetOp, operands, results);
438 })
439 .Case<xegpu::PrefetchNdOp>([&](auto prefetchNdOp) {
440 visitPrefetchNdOp(prefetchNdOp, operands, results);
441 })
442 .Case<vector::TransposeOp>([&](auto transposeOp) {
443 visitTransposeOp(transposeOp, operands, results);
444 })
445 .Case<vector::BitCastOp>([&](auto bitcastOp) {
446 visitVectorBitcastOp(bitcastOp, operands, results);
447 })
448 .Case<vector::MultiDimReductionOp>([&](auto reductionOp) {
449 visitVectorMultiReductionOp(reductionOp, operands, results);
450 })
451 .Case<vector::BroadcastOp>([&](auto broadcastOp) {
452 visitVectorBroadCastOp(broadcastOp, operands, results);
453 })
454 .Case<vector::ShapeCastOp>([&](auto shapeCastOp) {
455 visitShapeCastOp(shapeCastOp, operands, results);
456 })
457 // All other ops.
458 .Default([&](Operation *op) {
459 for (const LayoutInfoLattice *resultInfo : results) {
460 if (!resultInfo->getValue().isAssigned())
461 continue;
462 for (auto [operandInfo, operand] :
463 llvm::zip(operands, op->getOpOperands())) {
464 // If the operand type is not a vector or tensor descriptor, skip
465 // it.
466 if (!isa<xegpu::TensorDescType, VectorType>(
467 operand.get().getType()))
468 continue;
469 // Propagate the result layout to the operand.
470 meet(operandInfo, *resultInfo);
471 }
472 }
473 });
474
475 return success();
476}
477
478void LayoutInfoPropagation::visitPrefetchNdOp(
479 xegpu::PrefetchNdOp prefetch, ArrayRef<LayoutInfoLattice *> operands,
480 ArrayRef<const LayoutInfoLattice *> results) {
481 // Here we assign the default layout to the tensor descriptor operand of
482 // prefetch.
483 auto tdescTy = prefetch.getTensorDescType();
484
485 auto uArch = getUArch(getChipStr(prefetch).value_or(""));
486 const auto *uArchInstruction =
487 dyn_cast<xegpu::uArch::Subgroup2DBlockPrefetchInstruction>(
488 uArch->getInstruction(
489 xegpu::uArch::InstructionKind::Subgroup2DBlockPrefetch));
490
491 auto blockWHC =
492 uArchInstruction->getBlockWidthHeightCount(tdescTy.getElementType());
493 if (!blockWHC)
494 prefetch.emitWarning("No known block params found for the element type.");
495 auto [bWidth, bHeight, bCount] = blockWHC.value();
496 SmallVector<int> instData;
497 int instWidth = xegpu::getLargestDivisor(
498 static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 1)), bWidth,
499 bCount);
500 if (instWidth == -1)
501 prefetch.emitWarning(
502 "No suitable instruction multiple found for the given shape.");
503 if (tdescTy.getRank() == 1)
504 instData = {instWidth};
505 else {
506 int instHeight = xegpu::getLargestDivisor(
507 static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 2)), bHeight);
508 if (instHeight == -1)
509 prefetch.emitWarning(
510 "No suitable instruction multiple found for the given shape.");
511 instData = {instHeight, instWidth};
512 }
513 LayoutInfo prefetchLayout;
514 if (layoutKind == LayoutKind::InstData)
515 prefetchLayout =
516 LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(), instData));
517 else
518 prefetchLayout = getDefaultSIMTLayoutInfo(
519 tdescTy, uArch, uArchInstruction->getPackedFormatBitSize());
520
521 // Propagate the layout to the source tensor descriptor.
522 propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout));
523}
524
525void LayoutInfoPropagation::visitVectorMultiReductionOp(
526 vector::MultiDimReductionOp reduction,
527 ArrayRef<LayoutInfoLattice *> operands,
528 ArrayRef<const LayoutInfoLattice *> results) {
529 // The layout of the result must be present.
530 LayoutInfo resultLayout = results[0]->getValue();
531 if (!resultLayout.isAssigned())
532 return;
533 // We only consider 2D -> 1D reductions at this point.
534 VectorType resultTy = llvm::dyn_cast<VectorType>(reduction.getDestType());
535 if (!resultTy || resultTy.getRank() != 1) {
536 reduction.emitWarning("Expecting output type to be 1D vector.");
537 return;
538 }
539 auto uArch = getUArch(xegpu::getChipStr(reduction).value_or(""));
540 // Given that the result is 1D, the layout of the operand should be 2D with
541 // default layout.
542 LayoutInfo operandLayout = getDefaultSIMTLayoutInfo(
543 reduction->getContext(), 2, uArch->getSubgroupSize());
544 propagateIfChanged(operands[0], operands[0]->meet(operandLayout));
545 // Accumulator should have the same layout as the result.
546 propagateIfChanged(operands[1], operands[1]->meet(resultLayout));
547}
548
549void LayoutInfoPropagation::visitVectorBroadCastOp(
550 vector::BroadcastOp broadcast, ArrayRef<LayoutInfoLattice *> operands,
551 ArrayRef<const LayoutInfoLattice *> results) {
552 // The layout of the result must be present.
553 LayoutInfo resultLayout = results[0]->getValue();
554 if (!resultLayout.isAssigned())
555 return;
556 // Only consider vector to vector broadcasts for now.
557 VectorType resultTy = broadcast.getResultVectorType();
558 VectorType sourceTy = dyn_cast<VectorType>(broadcast.getSourceType());
559 if (!sourceTy) {
560 broadcast.emitWarning("Expecting source type to be a vector type.");
561 return;
562 }
563
564 // Only consider nD -> nD broadcast.
565 if (sourceTy.getRank() != resultTy.getRank()) {
566 broadcast.emitWarning("Expecting source and result to have same rank.");
567 return;
568 }
569 SetVector<int64_t> broadcastUnitDims = broadcast.computeBroadcastedUnitDims();
570 if (broadcastUnitDims.size() != 1) {
571 broadcast.emitWarning("Expecting source type to be nD vector only with "
572 "one broadcasted dimension.");
573 return;
574 }
575 // Propagate the result layout to the source operand.
576 propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
577}
578
579void LayoutInfoPropagation::visitShapeCastOp(
580 vector::ShapeCastOp shapeCast, ArrayRef<LayoutInfoLattice *> operands,
581 ArrayRef<const LayoutInfoLattice *> results) {
582 // The layout of the result must be present.
583 LayoutInfo resultLayout = results[0]->getValue();
584 if (!resultLayout.isAssigned())
585 return;
586 VectorType sourceTy = shapeCast.getSourceVectorType();
587 VectorType resultTy = shapeCast.getResultVectorType();
588 // Shape cast layout propagation only supports 1D -> 2D shape casts.
589 // TODO: Support kD -> nD shape casts (k < n, n >= 2) where expanded dims are
590 // unit dimensions and non-unit dims match.
591 if (sourceTy.getRank() != 1 || resultTy.getRank() != 2) {
592 shapeCast.emitWarning("Expecting shape cast to be 1D -> 2D.");
593 return;
594 }
595 int64_t slicedDim = resultTy.getShape()[0] == 1 ? 0 : 1;
596 xegpu::SliceAttr sliceLayout = xegpu::SliceAttr::get(
597 shapeCast->getContext(), cast<xegpu::LayoutAttr>(resultLayout.get()),
598 DenseI64ArrayAttr::get(shapeCast->getContext(), {slicedDim}));
599 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(sliceLayout)));
600}
601
602/// Propagate the layout of the result tensor to the source tensor descriptor
603/// in UpdateNdOffsetOp.
604void LayoutInfoPropagation::visitUpdateNdOffsetOp(
605 xegpu::UpdateNdOffsetOp updateNdOffset,
606 ArrayRef<LayoutInfoLattice *> operands,
607 ArrayRef<const LayoutInfoLattice *> results) {
608 // The layout of the result must be present.
609 LayoutInfo resultLayout = results[0]->getValue();
610 if (!resultLayout.isAssigned())
611 return;
612 // Propagate the layout to the source operand.
613 propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
614}
615
616/// Set the layouts for DPAS A, B, and C operands.
617void LayoutInfoPropagation::visitDpasOp(
618 xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
619 ArrayRef<const LayoutInfoLattice *> results) {
620 VectorType aTy = dpas.getLhsType();
621 VectorType bTy = dpas.getRhsType();
622
623 auto uArch = getUArch(getChipStr(dpas).value_or(""));
624 const int subgroupSize = uArch->getSubgroupSize();
625 const auto *uArchInstruction =
626 dyn_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc>(uArch->getInstruction(
627 xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc));
628
629 const unsigned dataALen = aTy.getShape().front();
630 auto supportedALen = uArchInstruction->getSupportedM(aTy.getElementType());
631 const int maxALen =
632 xegpu::getLargestDivisor(dataALen, ArrayRef<unsigned>(supportedALen));
633 if (maxALen == -1)
634 dpas.emitWarning(
635 "No suitable instruction multiple found for the given shape.");
636
637 const unsigned dataBLen = bTy.getShape().back();
638 auto supportedBLen = uArchInstruction->getSupportedK(bTy.getElementType());
639 const int maxBLen =
640 xegpu::getLargestDivisor(dataBLen, ArrayRef<unsigned>(supportedBLen));
641 if (maxBLen == -1)
642 dpas.emitWarning(
643 "No suitable instruction multiple found for the given shape.");
644 SmallVector<int> instDataA = {maxALen, subgroupSize};
645 SmallVector<int> instDataB = {subgroupSize, maxBLen};
646
647 LayoutInfo dpasALayout;
648 LayoutInfo dpasBLayout;
649 LayoutInfo dpasCLayout;
650
651 if (layoutKind == LayoutKind::InstData) {
652 dpasALayout =
653 LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataA));
654 dpasBLayout =
655 LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataB));
656 } else {
657 dpasALayout = getSIMTLayoutInfoForDPASOperand(
658 aTy, 0, uArch, uArchInstruction->getPackedFormatBitSizeA());
659 dpasBLayout = getSIMTLayoutInfoForDPASOperand(
660 bTy, 1, uArch, uArchInstruction->getPackedFormatBitSizeB());
661 }
662
663 propagateIfChanged(operands[0], operands[0]->meet(dpasALayout));
664 propagateIfChanged(operands[1], operands[1]->meet(dpasBLayout));
665 if (operands.size() > 2) {
666 VectorType cTy = dpas.getAccType();
667 const unsigned dataCLen = bTy.getShape().back();
668 auto supportedCLen = uArchInstruction->getSupportedN(bTy.getElementType());
669 const int maxCLen =
670 xegpu::getLargestDivisor(dataCLen, ArrayRef<unsigned>(supportedCLen));
671 if (maxCLen == -1)
672 dpas.emitWarning(
673 "No suitable instruction multiple found for the given shape.");
674 SmallVector<int> instDataC = {maxALen, maxCLen};
675
676 if (layoutKind == LayoutKind::InstData)
677 dpasCLayout =
678 LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataC));
679 else
680 dpasCLayout = getSIMTLayoutInfoForDPASOperand(
681 cTy, 2, uArch, uArchInstruction->getPackedFormatBitSizeB());
682
683 propagateIfChanged(operands[2], operands[2]->meet(dpasCLayout));
684 }
685}
686
687/// Set the layout for the value and tensor descriptor operands in StoreNdOp.
688void LayoutInfoPropagation::visitStoreNdOp(
689 xegpu::StoreNdOp store, ArrayRef<LayoutInfoLattice *> operands,
690 ArrayRef<const LayoutInfoLattice *> results) {
691
692 auto uArch = getUArch(getChipStr(store).value_or(""));
693 const auto *uArchInstruction =
694 dyn_cast<xegpu::uArch::Subgroup2DBlockStoreInstruction>(
695 uArch->getInstruction(
696 xegpu::uArch::InstructionKind::Subgroup2DBlockStore));
697 VectorType dataTy = store.getValueType();
698 auto blockWHC = uArchInstruction->getBlockWidthHeightCount(
699 store.getValueType().getElementType());
700 if (!blockWHC)
701 store.emitWarning("No known block params found for the element type.");
702 auto [bWidth, bHeight, bCount] = blockWHC.value();
703 SmallVector<int> instData;
704 int instWidth = xegpu::getLargestDivisor(
705 static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 1)), bWidth,
706 bCount);
707 if (instWidth == -1)
708 store.emitWarning(
709 "No suitable instruction multiple found for the given shape.");
710 if (dataTy.getRank() == 1)
711 instData = {instWidth};
712 else {
713 int instHeight = xegpu::getLargestDivisor(
714 static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 2)), bHeight);
715 if (instHeight == -1)
716 store.emitWarning(
717 "No suitable instruction multiple found for the given shape.");
718 instData = {instHeight, instWidth};
719 }
720
721 LayoutInfo storeLayout;
722 if (layoutKind == LayoutKind::InstData)
723 storeLayout =
724 LayoutInfo(xegpu::LayoutAttr::get(dataTy.getContext(), instData));
725 else
726 storeLayout =
727 getDefaultSIMTLayoutInfo(store.getValueType(), uArch,
728 uArchInstruction->getPackedFormatBitSize());
729 // Both operands should have the same layout
730 for (LayoutInfoLattice *operand : operands)
731 propagateIfChanged(operand, operand->meet(storeLayout));
732}
733
734/// Propagate the layout of the value to the tensor descriptor operand in
735/// LoadNdOp.
736void LayoutInfoPropagation::visitLoadNdOp(
737 xegpu::LoadNdOp load, ArrayRef<LayoutInfoLattice *> operands,
738 ArrayRef<const LayoutInfoLattice *> results) {
739 LayoutInfo valueLayout = results[0]->getValue();
740 // Need the layout of the value to propagate to the tensor descriptor.
741 if (!valueLayout.isAssigned())
742 return;
743 LayoutInfo tensorDescLayout = valueLayout;
744 // LoadNdOp has the transpose effect. However, at the stage of this analysis
745 // this effect is not expected and should be abstracted away. Emit a
746 // warning.
747 if (auto transpose = load.getTranspose()) {
748 load.emitWarning("Transpose effect is not expected for LoadNdOp at "
749 "LayoutInfoPropagation stage.");
750 tensorDescLayout = valueLayout.transpose(transpose.value());
751 }
752 // Propagate the new layout to the tensor descriptor operand.
753 propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
754}
755
756/// For vector::TransposeOp, the layout of the result is transposed and
757/// propagated to the operand.
758void LayoutInfoPropagation::visitTransposeOp(
759 vector::TransposeOp transpose, ArrayRef<LayoutInfoLattice *> operands,
760 ArrayRef<const LayoutInfoLattice *> results) {
761 // Need the layout of transpose result to propagate to the operands.
762 LayoutInfo resultLayout = results[0]->getValue();
763 if (!resultLayout.isAssigned())
764 return;
765 LayoutInfo newLayout = resultLayout.transpose(transpose.getPermutation());
766 // Propagate the new layout to the vector operand.
767 propagateIfChanged(operands[0], operands[0]->meet(newLayout));
768}
769
770/// For vector::BitCastOp, the lane_data of the source layout is changed based
771/// on the bit width of the source and result types.
772void LayoutInfoPropagation::visitVectorBitcastOp(
773 vector::BitCastOp bitcast, ArrayRef<LayoutInfoLattice *> operands,
774 ArrayRef<const LayoutInfoLattice *> results) {
775 // Need the layout of bitcast result to propagate to the operands.
776 LayoutInfo resultLayout = results[0]->getValue();
777 if (!resultLayout.isAssigned())
778 return;
779 int inElemTyBitWidth =
780 bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
781 int outElemTyBitWidth =
782 bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
783 // If the element bit widths are the same, then the layout does not change.
784 if (inElemTyBitWidth == outElemTyBitWidth) {
785 propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
786 return;
787 }
788 // Check if the result layout is valid. i.e. result vector can be distributed.
789 auto resultLaneLayout = resultLayout.getLaneLayout();
790 auto resultLaneData = resultLayout.getLaneData();
792 bitcast.getResultVectorType(),
793 xegpu::LayoutAttr::get(bitcast->getContext(), resultLaneLayout,
794 resultLaneData)))) {
795 bitcast.emitWarning(
796 "Result vector type can not be evenly distributed across lanes.");
797 return;
798 }
799 int64_t rank = bitcast.getSourceVectorType().getRank();
800 // Bitcast is a `narrowing` if the input element type bit width larger than
801 // the output element type bit width. eg. f32 -> f16 is a narrowing bitcast.
802 bool isNarrowing = inElemTyBitWidth > outElemTyBitWidth;
803 int bitCastRatio = isNarrowing ? inElemTyBitWidth / outElemTyBitWidth
804 : outElemTyBitWidth / inElemTyBitWidth;
805 SmallVector<int> sourceLaneLayout =
806 resultLayout.getLaneLayout(); // Lane layout does not change for bitcast.
807 SmallVector<int> outData = resultLayout.getLaneData();
808
809 // TODO: Currently we assume that bitcasts does not require cross lane
810 // communication. So each lane must own the required number of elements to
811 // perform the bitcast locally without cross-lane communication.
812 int outInnerBitsPerLane = outData[rank - 1] * outElemTyBitWidth;
813 if (outInnerBitsPerLane < inElemTyBitWidth) {
814 bitcast.emitWarning(
815 "Narrowing bitcast with cross lane communication is not supported.");
816 return;
817 }
818 // Check if each lane owns a single element in all dimensions except the
819 // innermost dimension.
820 SmallVector<int> sourceLaneData(outData.begin(), outData.end() - 1);
821 if (llvm::any_of(sourceLaneData, [](int64_t d) { return d != 1; })) {
822 bitcast.emitWarning("Each lane must not own multiple elements in any "
823 "dimension other than "
824 "the innermost dimension.");
825 return;
826 }
827 // Decide lane data based on whether the bitcast is narrowing or widening.
828 int64_t innerMostLaneData = isNarrowing ? outData[rank - 1] / bitCastRatio
829 : outData[rank - 1] * bitCastRatio;
830 sourceLaneData.push_back(innerMostLaneData);
831
832 propagateIfChanged(
833 operands[0],
834 operands[0]->meet(LayoutInfo(xegpu::LayoutAttr::get(
835 bitcast->getContext(), sourceLaneLayout, sourceLaneData))));
836}
837
838/// Propagate the layout of the result to the tensor descriptor, mask and offset
839/// operands in LoadGatherOp.
840void LayoutInfoPropagation::visitLoadGatherOp(
841 xegpu::LoadGatherOp load, ArrayRef<LayoutInfoLattice *> operands,
842 ArrayRef<const LayoutInfoLattice *> results) {
843 // The layout is strictly determined by the payload type.
844 auto payloadTy = dyn_cast<VectorType>(load.getValueType());
845 if (!payloadTy) {
846 load.emitWarning("Not propagating, non-vector payload supplied.");
847 return;
848 }
849 auto uArch = getUArch(getChipStr(load).value_or(""));
850 const int subgroupSize = uArch->getSubgroupSize();
851 SmallVector<int> instData{subgroupSize};
852 if (auto chunkSize = load.getChunkSize().value_or(0); chunkSize > 1)
853 instData.push_back(chunkSize);
854 else if (auto srcTdescTy =
855 dyn_cast<xegpu::TensorDescType>(load.getSourceType())) {
856 if (srcTdescTy.getChunkSizeAsInt() > 1)
857 instData.push_back(chunkSize);
858 }
859 LayoutInfo layout;
860 if (layoutKind == LayoutKind::InstData)
861 layout = LayoutInfo(xegpu::LayoutAttr::get(load.getContext(), instData));
862 else
863 layout = getDefaultSIMTLayoutInfo(payloadTy, uArch,
865 /*scattered*/ true);
866
867 // Mask operand should have 1D default layout.
868 LayoutInfo maskLayout =
869 getDefaultSIMTLayoutInfo(load->getContext(), 1, subgroupSize);
870
871 // Propagate the new layout to the tensor descriptor operand.
872 if (isa<xegpu::TensorDescType>(load.getSourceType()))
873 propagateIfChanged(operands[0], operands[0]->meet(layout));
874 // Propagate the new layout to the mask and optional offset operand.
875 propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
876 if (load.getOffsets())
877 propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
878}
879
880/// Propagate the layout of the descriptor to the vector offset operand in
881/// CreateDescOp.
882void LayoutInfoPropagation::visitCreateDescOp(
883 xegpu::CreateDescOp createDesc, ArrayRef<LayoutInfoLattice *> operands,
884 ArrayRef<const LayoutInfoLattice *> results) {
885 LayoutInfo descLayout = results[0]->getValue();
886 // Need the layout of the descriptor to propagate to the operands.
887 if (!descLayout.isAssigned())
888 return;
889 auto uArch = getUArch(getChipStr(createDesc).value_or(""));
890 // For offset operand propagate 1D default layout.
891 LayoutInfo layout = getDefaultSIMTLayoutInfo(createDesc->getContext(), 1,
892 uArch->getSubgroupSize());
893 propagateIfChanged(operands[1], operands[1]->meet(layout));
894}
895
896/// Set the layout for the value, tensor descriptor, offset and mask operands in
897/// the StoreScatterOp.
898void LayoutInfoPropagation::visitStoreScatterOp(
899 xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands,
900 ArrayRef<const LayoutInfoLattice *> results) {
901 // Currently, for 2D StoreScatterOp we expect that the height dimension of
902 // the tensor descriptor is equal to the subgroup size. This is ensured by
903 // the op verifier.
904 auto payloadTy = dyn_cast<VectorType>(storeScatter.getValueType());
905 if (!payloadTy) {
906 storeScatter.emitWarning("Not propagating, non-vector payload supplied.");
907 return;
908 }
909 LayoutInfo payloadLayout;
910 auto uArch = getUArch(getChipStr(storeScatter).value_or(""));
911 const int subgroupSize = uArch->getSubgroupSize();
912
913 if (auto layout = storeScatter.getLayoutAttr()) {
914 payloadLayout = LayoutInfo(layout);
915 } else {
916 if (layoutKind == LayoutKind::InstData) {
917 SmallVector<int> instData{subgroupSize};
918 if (auto chunkSize = storeScatter.getChunkSize().value_or(0);
919 chunkSize > 1)
920 instData.push_back(chunkSize);
921 else if (auto dstTdescTy = dyn_cast<xegpu::TensorDescType>(
922 storeScatter.getDestType())) {
923 if (dstTdescTy.getChunkSizeAsInt() > 1)
924 instData.push_back(chunkSize);
925 }
926 payloadLayout = LayoutInfo(
927 xegpu::LayoutAttr::get(storeScatter.getContext(), instData));
928 } else {
929 auto payloadShape = payloadTy.getShape();
930 if (payloadShape.size() > 1)
931 assert(payloadShape[0] == subgroupSize &&
932 "Expected the first dimension of 2D tensor descriptor to be "
933 "equal to "
934 "subgroup size.");
935 payloadLayout = getDefaultSIMTLayoutInfo(
936 payloadTy, uArch, uArch->getGeneralPackedFormatBitSize(),
937 /*scattered=*/true);
938 }
939 }
940
941 LayoutInfo maskLayout =
942 getDefaultSIMTLayoutInfo(storeScatter->getContext(), 1, subgroupSize);
943 // Propagate the payload operand layout
944 propagateIfChanged(operands[0], operands[0]->meet(payloadLayout));
945 // Propagate the destination (if tdesc) operand layout
946 if (isa<xegpu::TensorDescType>(storeScatter.getDestType()))
947 propagateIfChanged(operands[1], operands[1]->meet(payloadLayout));
948 // Propagate the new layout to the mask and optional offset operand.
949 propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
950 if (storeScatter.getOffsets())
951 propagateIfChanged(operands[3], operands[3]->meet(maskLayout));
952}
953
954namespace {
955//===----------------------------------------------------------------------===//
956// RunLayoutInfoPropagation
957//===----------------------------------------------------------------------===//
958
959/// Driver class for running the LayoutInfoPropagation analysis.
960class RunLayoutInfoPropagation {
961public:
962 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RunLayoutInfoPropagation)
963
964 RunLayoutInfoPropagation(Operation *op, LayoutKind layoutKind) : target(op) {
965 SymbolTableCollection symbolTable;
966 loadBaselineAnalyses(solver);
967 solver.load<LayoutInfoPropagation>(symbolTable, layoutKind);
968 (void)solver.initializeAndRun(op);
969 }
970
971 LayoutInfo getLayoutInfo(Value val);
972
973 void printAnalysisResult(llvm::raw_ostream &os);
974
975private:
976 DataFlowSolver solver;
977 const Operation *target;
978};
979} // namespace
980
981LayoutInfo RunLayoutInfoPropagation::getLayoutInfo(Value val) {
982 auto *state = solver.lookupState<LayoutInfoLattice>(val);
983 if (!state)
984 return {};
985 return state->getValue();
986}
987
988// Print the analysis result for debugging purposes.
989void RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) {
990 auto printFunctionResult = [&](FunctionOpInterface funcOp) {
991 os << "function: " << funcOp.getName() << ":\n";
992 // Function arguments
993 for (BlockArgument arg : funcOp.getArguments()) {
994 LayoutInfo layout = getLayoutInfo(arg);
995 os << "argument: " << arg << "\n";
996 os << "layout : ";
997 layout.print(os);
998 os << "\n";
999 }
1000 // Function ops
1001 funcOp.walk([&](Operation *op) {
1002 // Skip ops that do not have results
1003 if (op->getResults().empty())
1004 return;
1005 os << "op : ";
1006 // For control-flow ops, print the op name only.
1007 if (isa<BranchOpInterface>(op) || isa<RegionBranchOpInterface>(op))
1008 os << op->getName();
1009 else
1010 op->print(os);
1011 os << "\n";
1012 // Print the layout for each result.
1013 for (auto [i, r] : llvm::enumerate(op->getResults())) {
1014 LayoutInfo layout = getLayoutInfo(r);
1015 os << "layout for result #" << i << ": ";
1016 layout.print(os);
1017 os << "\n";
1018 }
1019 });
1020 };
1021
1022 SmallVector<FunctionOpInterface> funcOps;
1023 if (auto modOp = dyn_cast<ModuleOp>(target)) {
1024 for (auto funcOp : modOp.getOps<FunctionOpInterface>())
1025 funcOps.push_back(funcOp);
1026
1027 // Collect all GpuFuncOps in the module.
1028 for (auto gpuModOp : modOp.getOps<gpu::GPUModuleOp>()) {
1029 for (auto gpuFuncOp : gpuModOp.getOps<FunctionOpInterface>())
1030 funcOps.push_back(gpuFuncOp);
1031 }
1032 }
1033 // Print the analysis result for each function.
1034 for (FunctionOpInterface funcOp : funcOps)
1035 printFunctionResult(funcOp);
1036}
1037
1038using GetLayoutFnTy = function_ref<xegpu::DistributeLayoutAttr(Value)>;
1039/// Update an operation with the layout of its results. If the result type is
1040/// a vector type, a temporary layout attribute is added to the operation. If
1041/// the result type is a tensor descriptor type, the type is updated with the
1042/// layout attribute. The users of the result are also updated with the layout
1043/// attribute.
1044static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
1045 GetLayoutFnTy getLayoutOfValue) {
1046 // Region ops (like scf.for) are already handled by the
1047 // updateControlFlowOps.
1048 if (mlir::isa<mlir::RegionBranchOpInterface>(op))
1049 return success();
1050
1051 // Iterate over all the results.
1052 for (OpResult result : op->getResults()) {
1053 Type resultType = result.getType();
1054 // Layouts are needed only for vector and tensor descriptor types.
1055 if (!isa<VectorType, xegpu::TensorDescType>(resultType))
1056 continue;
1057 // If the result has no layout but has users, emit a warning and continue.
1058 xegpu::DistributeLayoutAttr layout = getLayoutOfValue(result);
1059 if (!layout && result.getNumUses() > 0) {
1060 op->emitWarning("op has users but no layout assigned for its result");
1061 continue;
1062 }
1063 // If the result is a tensor descriptor type, update the tensor desc type
1064 // with layout.
1065 if (auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType)) {
1066 auto typeWithLayout = xegpu::TensorDescType::get(
1067 tensorDescTy.getContext(), tensorDescTy.getShape(),
1068 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
1069 result.setType(typeWithLayout);
1070 continue;
1071 }
1072 // If the result is a vector type, add a temporary layout attribute to the
1073 // op.
1074 xegpu::setDistributeLayoutAttr(result, layout, /*respectPermLayout*/ true);
1075 }
1076 return success();
1077}
1078
1079/// Region ops like scf.for need special handling because they have blocks
1080/// inside. If the blocks have tensor descriptor type as block arguments,
1081/// thier types must be updated. Also region op can have results that may not
1082/// have any users (e.g. A and B tiles). They are not assigned a layout by
1083/// layout analysis because they have no users. However inside the region op
1084/// corresponding block arguments for these results do have layouts.
1085/// Therefore, in this case we still need to update the result types with the
1086/// layout attribute. This function function updates the internal block
1087/// arguments and the result types of the region op with the assigned layouts.
1088/// clang-format off
1089/// Example: scf.for ... iter_args(...) -> (out types) {
1090/// ^bb0(block types):
1091/// ...
1092/// scf.yield ... : (yield types)
1093/// }
1094/// clang-format on
1095/// In this example, at scf.yield, control-flow can transfer to two successor
1096/// regions. One is the ^bb0 (for loop body) and the other is the scf.for op
1097/// itself (yield the results). So we update both the block arguments of the
1098/// successor region (i.e. block types) and the result types of the scf.for op
1099/// (i.e. out types). Note that yield types are updated by respective
1100/// producers inside bb0.
1101static LogicalResult
1103 mlir::RegionBranchTerminatorOpInterface terminator,
1104 GetLayoutFnTy getLayoutOfValue) {
1105 // Only process if the terminator is inside a region branch op.
1106 if (!mlir::isa<mlir::RegionBranchOpInterface>(terminator->getParentOp()))
1107 return success();
1108
1110 llvm::SmallVector<mlir::Attribute> operands(terminator->getNumOperands(),
1111 nullptr);
1112 terminator.getSuccessorRegions(operands, successors);
1113
1114 for (mlir::RegionSuccessor &successor : successors) {
1115 mlir::OperandRange successorOperands =
1116 terminator.getSuccessorOperands(successor);
1117 mlir::ValueRange successorInputs = successor.getSuccessorInputs();
1118 for (auto [successorOperand, successorInput] :
1119 llvm::zip(successorOperands, successorInputs)) {
1120 Type inputType = successorInput.getType();
1121 // We only need to operate on tensor descriptor or vector types.
1122 if (!isa<xegpu::TensorDescType, VectorType>(inputType))
1123 continue;
1124 xegpu::DistributeLayoutAttr successorInputLayout =
1125 getLayoutOfValue(successorInput);
1126 xegpu::DistributeLayoutAttr successorOperandLayout =
1127 getLayoutOfValue(successorOperand);
1128
1129 // If either of the layouts is not assigned, we cannot proceed.
1130 if (!successorOperandLayout) {
1131 LLVM_DEBUG(DBGS() << "No layout assigned for forwarded operand in "
1132 "branch terminator: "
1133 << successorOperand << "\n");
1134 return failure();
1135 }
1136 // We expect the layouts to match.
1137 if (successorInputLayout &&
1138 successorInputLayout != successorOperandLayout) {
1139 LLVM_DEBUG(DBGS() << "Conflicting layouts for region argument and "
1140 "operand forwarded as the argument: "
1141 << successorInputLayout << " vs "
1142 << successorOperandLayout << "\n");
1143 return failure();
1144 }
1145 // Get tensor descriptor type with the layout.
1146 if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(inputType)) {
1147 auto newTdescTy = xegpu::TensorDescType::get(
1148 tdescTy.getContext(), tdescTy.getShape(), tdescTy.getElementType(),
1149 tdescTy.getEncoding(), successorOperandLayout);
1150 successorInput.setType(newTdescTy);
1151 continue;
1152 }
1153 // If the type is a vector type and this region argument is an OpResult,
1154 // set the layout attribute on the OpResult.
1155 if (auto result = dyn_cast<OpResult>(successorInput))
1156 xegpu::setDistributeLayoutAttr(result, successorOperandLayout);
1157 }
1158 }
1159 return success();
1160}
1161
1162/// Update the function arguments and results with the layouts.
1163static LogicalResult updateFunctionOpInterface(mlir::OpBuilder &builder,
1164 mlir::FunctionOpInterface funcOp,
1165 GetLayoutFnTy getLayoutOfValue) {
1166 SmallVector<Type> newArgTypes;
1167 // Update the function arguments.
1168 for (BlockArgument arg : funcOp.getArguments()) {
1169 Type argType = arg.getType();
1170 newArgTypes.push_back(argType);
1171 if (!isa<VectorType, xegpu::TensorDescType>(argType))
1172 continue;
1173 xegpu::DistributeLayoutAttr layout = getLayoutOfValue(arg);
1174 if (!layout) {
1175 LLVM_DEBUG(DBGS() << "Expecting layout for function argument: " << arg
1176 << " but got none.\n");
1177 return failure();
1178 }
1179 if (auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(argType)) {
1180 auto newTdescTy = xegpu::TensorDescType::get(
1181 tensorDescTy.getContext(), tensorDescTy.getShape(),
1182 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
1183 arg.setType(newTdescTy);
1184 newArgTypes.back() = newTdescTy;
1185 }
1186 }
1187 // Update the function type with the new argument types.
1188 // NOTE: We assume that function results are not expected to have layouts.
1189 funcOp.setType(FunctionType::get(funcOp.getContext(), newArgTypes,
1190 funcOp.getResultTypes()));
1191 return success();
1192}
1193
1194namespace {
1195struct XeGPUPropagateLayoutPass final
1196 : public xegpu::impl::XeGPUPropagateLayoutBase<XeGPUPropagateLayoutPass> {
1197 XeGPUPropagateLayoutPass() = default;
1198 XeGPUPropagateLayoutPass(const XeGPUPropagateLayoutPass &other) = default;
1199 XeGPUPropagateLayoutPass(xegpu::XeGPUPropagateLayoutOptions options)
1200 : XeGPUPropagateLayoutBase(options) {}
1201 void runOnOperation() override;
1202};
1203
1204} // namespace
1205
1206void XeGPUPropagateLayoutPass::runOnOperation() {
1207 LayoutKind layoutKind;
1208 if (this->layoutKind == "lane") {
1209 layoutKind = LayoutKind::Lane;
1210 } else if (this->layoutKind == "inst") {
1211 layoutKind = LayoutKind::InstData;
1212 } else {
1213 getOperation()->emitError("Unsupported layout kind option: " +
1214 this->layoutKind);
1215 signalPassFailure();
1216 return;
1217 }
1218 RunLayoutInfoPropagation analysis(getOperation(), layoutKind);
1219 // Print the analysis result and exit. (for debugging purposes)
1220 if (printOnly) {
1221 auto &os = llvm::outs();
1222 analysis.printAnalysisResult(os);
1223 return;
1224 }
1225 // Helper to convert LayoutInfo to xegpu::LayoutAttr.
1226 auto getXeGPULayoutForValue = [&](Value val) -> xegpu::DistributeLayoutAttr {
1227 LayoutInfo layout = analysis.getLayoutInfo(val);
1228 if (!layout.isAssigned())
1229 return {};
1230 xegpu::DistributeLayoutAttr layoutAttr =
1231 cast<xegpu::DistributeLayoutAttr>(layout.get());
1232 if (layout.isSliceLayout())
1233 return cast<xegpu::SliceAttr>(layoutAttr);
1234 return cast<xegpu::LayoutAttr>(layoutAttr);
1235 };
1236
1237 mlir::OpBuilder builder(&getContext());
1238 Operation *op = getOperation();
1239 auto walkResult = op->walk([&](mlir::Block *block) -> WalkResult {
1240 for (mlir::Operation &op : llvm::reverse(block->getOperations())) {
1241 LogicalResult r = success();
1243 .Case<mlir::RegionBranchTerminatorOpInterface>(
1244 [&](mlir::RegionBranchTerminatorOpInterface branchTermOp) {
1245 r = updateControlFlowOps(builder, branchTermOp,
1246 getXeGPULayoutForValue);
1247 })
1248 .Case<mlir::FunctionOpInterface>(
1249 [&](mlir::FunctionOpInterface funcOp) {
1250 r = updateFunctionOpInterface(builder, funcOp,
1251 getXeGPULayoutForValue);
1252 })
1253 .Default([&](Operation *op) {
1254 r = updateOp(builder, op, getXeGPULayoutForValue);
1255 });
1256 if (failed(r)) {
1257 op.emitError("Failed to update operation with the layout.");
1258 return WalkResult::interrupt();
1259 }
1260 }
1261 return WalkResult::advance();
1262 });
1263 if (walkResult.wasInterrupted()) {
1264 signalPassFailure();
1265 return;
1266 }
1267}
return success()
#define DBGS()
Definition Hoisting.cpp:32
lhs
b getContext())
auto load
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition TypeID.h:331
static LogicalResult updateControlFlowOps(mlir::OpBuilder &builder, mlir::RegionBranchTerminatorOpInterface terminator, GetLayoutFnTy getLayoutOfValue)
Region ops like scf.for need special handling because they have blocks inside.
function_ref< xegpu::DistributeLayoutAttr(Value)> GetLayoutFnTy
static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op, GetLayoutFnTy getLayoutOfValue)
Update an operation with the layout of its results.
static LogicalResult updateFunctionOpInterface(mlir::OpBuilder &builder, mlir::FunctionOpInterface funcOp, GetLayoutFnTy getLayoutOfValue)
Update the function arguments and results with the layouts.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:309
OpListType & getOperations()
Definition Block.h:137
The general data-flow analysis solver.
const StateT * lookupState(AnchorT anchor) const
Lookup an analysis state for the given lattice anchor.
AnalysisT * load(Args &&...args)
Load an analysis into the solver. Return the analysis instance.
LogicalResult initializeAndRun(Operation *top)
Initialize the children analyses starting from the provided top-level operation and run the analysis ...
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:207
This class represents an operand of an operation.
Definition Value.h:257
This is a value defined by a result of an operation.
Definition Value.h:457
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:383
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
void print(raw_ostream &os, const OpPrintingFlags &flags={})
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:797
result_range getResults()
Definition Operation.h:415
This class represents a successor of a region.
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
This class represents a lattice holding a specific value of type ValueT.
A sparse (backward) data-flow analysis for propagating SSA value lattices backwards across the IR by ...
SparseBackwardDataFlowAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int64_t > content)
void loadBaselineAnalyses(DataFlowSolver &solver)
Populates a DataFlowSolver with analyses that are required to ensure user-defined analyses are run pr...
Definition Utils.h:29
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
detail::InFlightRemark analysis(Location loc, RemarkOpts opts)
Report an optimization analysis remark.
Definition Remarks.h:567
const uArch * getUArch(llvm::StringRef archName)
int getLargestDivisor(T dim, ArrayRef< T > candidates, ArrayRef< T > candidateMultiples={})
Helper Function to find a proper instruction multiple for the user-supplied sg-level data shape (dive...
void setDistributeLayoutAttr(const T &operandOrResult, const DistributeLayoutAttr layout, bool respectPermLayout=false)
Sets the DistributeLayoutAttr for a given OpOperand or OpResult by attaching it to the owner's dictio...
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)
If tensor descriptor has a layout attribute it is used in SIMT mode.
Include the generated interface declarations.
bool operator==(StringAttr lhs, std::nullptr_t)
Define comparisons for StringAttr against nullptr and itself to avoid the StringRef overloads from be...
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:144
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
virtual unsigned getGeneralPackedFormatBitSize() const =0
virtual int getSubgroupSize() const =0
const Instruction * getInstruction(InstructionKind instKind) const
Definition uArchBase.h:157