MLIR 22.0.0git
XeGPUPropagateLayout.cpp
Go to the documentation of this file.
1//===- XeGPUPropagateLayout.cpp - XeGPU Layout Propagation ------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
19#include "mlir/IR/Attributes.h"
20#include "mlir/IR/Builders.h"
23#include "mlir/IR/Operation.h"
24#include "mlir/IR/Value.h"
25#include "mlir/IR/Visitors.h"
28#include "mlir/Support/LLVM.h"
29#include "llvm/ADT/ArrayRef.h"
30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/SmallSet.h"
32#include "llvm/ADT/SmallVector.h"
33#include "llvm/ADT/TypeSwitch.h"
34#include "llvm/Support/Casting.h"
35#include "llvm/Support/Debug.h"
36#include "llvm/Support/LogicalResult.h"
37#include "llvm/Support/raw_ostream.h"
38
40
41namespace mlir {
42namespace xegpu {
43#define GEN_PASS_DEF_XEGPUPROPAGATELAYOUT
44#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
45} // namespace xegpu
46} // namespace mlir
47
48#define DEBUG_TYPE "xegpu-propagate-layout"
49#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
50
51using namespace mlir;
52using namespace mlir::dataflow;
53
54namespace {
55
56enum class LayoutKind { Lane, InstData };
57
58//===----------------------------------------------------------------------===//
59// LayoutInfo
60//===----------------------------------------------------------------------===//
61
62/// Helper class for tracking the analysis state of an mlir value. For layout
63/// propagation, the analysis state is simply the distribution layout of
64/// each value. The distribution layout information is encapsulated using
65/// xegpu::DistributeLayoutAttr class which can hold information about any type
66/// of distribution layout that XeGPU dialect supports. Purpose of this analysis
67/// to propagate some unique distribution layout for each value in the program
68/// starting from a set of anchor operations (like DPAS, StoreNd, etc.). Note
69/// that analysis will reach a fixed point when all values are reached some
70/// layout and, analysis does not try to modify any already assigned layouts.
71///
72/// Given this, LayoutInfo satisifies the following properties:
73/// 1) A LayoutInfo value can be in one of two states - `assigned` or `not
74/// assigned`.
75/// 2) Two LayoutInfo values are equal if they are both assigned or
76/// both not assigned. The concrete value of assigned state does not matter.
77/// 3) The meet operator works as follows:
78/// - If current state is assigned, return the current state. (already
79/// a unique layout is assigned. don't change it)
80/// - Otherwise, return the other state.
81
82struct LayoutInfo {
83private:
84 xegpu::DistributeLayoutAttr storage = nullptr;
85
86public:
87 LayoutInfo() = default;
88 LayoutInfo(const xegpu::DistributeLayoutAttr &layout) : storage(layout) {}
89
90 // Two lattice values are equal if they have `some` layout. The actual
91 // content of the layout does not matter.
92 bool operator==(const LayoutInfo &other) const {
93 return this->isAssigned() == other.isAssigned();
94 }
95
96 static LayoutInfo meet(const LayoutInfo &lhs, const LayoutInfo &rhs);
97
98 static LayoutInfo join(const LayoutInfo &lhs, const LayoutInfo &rhs);
99
100 void print(raw_ostream &os) const;
101
102 bool isAssigned() const { return storage != nullptr; }
103
104 LayoutInfo transpose(ArrayRef<int64_t> permutation) const;
105
106 SmallVector<int> getLaneLayout() const;
107
108 SmallVector<int> getLaneData() const;
109
110 SmallVector<int> getInstData() const;
111
112 bool isSliceLayout() const {
113 if (!isAssigned())
114 return false;
115 return isa<xegpu::SliceAttr>(storage);
116 }
117
118 int64_t getRank() const {
119 if (!isAssigned())
120 return -1;
121 return storage.getRank();
122 }
123
124 Attribute get() { return storage; }
125};
126
127SmallVector<int> LayoutInfo::getLaneLayout() const {
128 if (!isAssigned())
129 return {};
130 assert(storage.getEffectiveLaneLayoutAsInt().size() &&
131 "Expected lane layout to be assigned");
132 return llvm::map_to_vector(storage.getEffectiveLaneLayoutAsInt(),
133 [](int64_t val) { return static_cast<int>(val); });
134}
135
136SmallVector<int> LayoutInfo::getLaneData() const {
137 if (!isAssigned())
138 return {};
139 assert(storage.getEffectiveLaneDataAsInt().size() &&
140 "Expected lane data to be assigned");
141 return llvm::map_to_vector(storage.getEffectiveLaneDataAsInt(),
142 [](int64_t val) { return static_cast<int>(val); });
143}
144
145SmallVector<int> LayoutInfo::getInstData() const {
146 if (!isAssigned())
147 return {};
148 return llvm::map_to_vector(storage.getEffectiveInstDataAsInt(),
149 [](int64_t val) { return static_cast<int>(val); });
150}
151
152void LayoutInfo::print(raw_ostream &os) const {
153 if (isAssigned()) {
154 os << storage;
155 } else {
156 os << "Not assigned.";
157 }
158}
159
160LayoutInfo LayoutInfo::meet(const LayoutInfo &lhs, const LayoutInfo &rhs) {
161 if (!lhs.isAssigned())
162 return rhs;
163 return lhs;
164}
165
166/// Since this is a backward analysis, join method is not used.
167LayoutInfo LayoutInfo::join(const LayoutInfo &lhs, const LayoutInfo &rhs) {
168 llvm_unreachable("Join should not be triggered by layout propagation.");
169}
170
171/// Construct a new layout with the transposed inst_data or lane_layout,
172/// lane_data.
173LayoutInfo LayoutInfo::transpose(ArrayRef<int64_t> permutation) const {
174 if (!isAssigned())
175 return {};
176 // Check if the permutation is valid.
177 llvm::SmallSet<int64_t, 4> seen(permutation.begin(), permutation.end());
178 bool hasDuplicates = seen.size() != permutation.size();
179 bool withinRange = llvm::all_of(permutation, [&](int64_t idx) {
180 return idx >= 0 && idx < static_cast<int64_t>(permutation.size());
181 });
182
183 if (!withinRange || hasDuplicates) {
184 assert(false && "Invalid permutation for transpose.");
185 return {};
186 }
187
188 SmallVector<int32_t> laneLayout;
189 SmallVector<int32_t> laneData;
190 SmallVector<int32_t> instData;
191 for (int64_t idx : permutation) {
192 if (getLaneLayout().size()) {
193 laneLayout.push_back(static_cast<int32_t>(getLaneLayout()[idx]));
194 laneData.push_back(static_cast<int32_t>(getLaneData()[idx]));
195 }
196 if (getInstData().size())
197 instData.push_back(static_cast<int32_t>(getInstData()[idx]));
198 }
199 xegpu::LayoutAttr layoutAttr;
200 if (getLaneLayout().size())
201 layoutAttr =
202 xegpu::LayoutAttr::get(storage.getContext(), laneLayout, laneData);
203 if (getInstData().size())
204 layoutAttr = xegpu::LayoutAttr::get(storage.getContext(), instData);
205 return LayoutInfo(layoutAttr);
206}
207
208//===----------------------------------------------------------------------===//
209// LayoutInfoLattice
210//===----------------------------------------------------------------------===//
211
212/// Lattice holding the LayoutInfo for each value.
213struct LayoutInfoLattice : public Lattice<LayoutInfo> {
215 using Lattice::Lattice;
216};
217
218/// Helper Functions to get default layouts. A `default layout` is a layout that
219/// is assigned to a value when the layout is not fixed by some anchor operation
220/// (like DPAS).
221
222/// Helper Function to get the default layout for uniform values like constants.
223/// For 1D vector, lane_layout is [subgroupSize] and lane_data is [1].
224/// For 2D vector, lane_layout is [1, subgroupSize] and lane_data is [1, 1].
225static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
226 unsigned rank,
227 const xegpu::uArch::uArch *uArch) {
228 assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
229 if (rank == 1) {
230 return LayoutInfo(
231 xegpu::LayoutAttr::get(ctx, {uArch->getSubgroupSize()}, {1}));
232 }
233 return LayoutInfo(
234 xegpu::LayoutAttr::get(ctx, {1, uArch->getSubgroupSize()}, {1, 1}));
235}
236
237static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
238 unsigned rank, int subgroupSize) {
239 assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
240 if (rank == 1) {
241 return LayoutInfo(xegpu::LayoutAttr::get(ctx, {subgroupSize}, {1}));
242 }
243 return LayoutInfo(xegpu::LayoutAttr::get(ctx, {1, subgroupSize}, {1, 1}));
244}
245
246/// Helper to get the default layout for a vector type.
247static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy,
249 unsigned packingSize,
250 bool isScattered = false) {
251 // Expecting a 1D or 2D vector.
252 assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
253 "Expected 1D or 2D vector.");
254 // Expecting int or float element type.
255 assert(vectorTy.getElementType().isIntOrFloat() &&
256 "Expected int or float element type.");
257 // If the rank is 1, then return default layout for 1D vector.
258 if (vectorTy.getRank() == 1)
259 return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1, uArch);
260 // Packing factor is determined by the element type bitwidth.
261 unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
262 int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
263 if (isScattered) {
264 return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(),
265 {uArch->getSubgroupSize(), 1},
266 {1, packingFactor}));
267 }
268 return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(),
269 {1, uArch->getSubgroupSize()},
270 {1, packingFactor}));
272
273/// Helper to get the default layout for a vector type.
274static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy,
276 unsigned packingSize,
277 bool isScattered = false) {
278 // Expecting a 1D or 2D vector.
279 assert((tdescTy.getRank() == 1 || tdescTy.getRank() == 2) &&
280 "Expected 1D or 2D TensorDesc.");
281 // Expecting int or float element type.
282 assert(tdescTy.getElementType().isIntOrFloat() &&
283 "Expected int or float element type.");
284 // If the rank is 1, then return default layout for 1D vector.
285 if (tdescTy.getRank() == 1)
286 return getDefaultSIMTLayoutInfo(tdescTy.getContext(), 1, uArch);
287 // Packing factor is determined by the element type bitwidth.
288 unsigned bitwidth = tdescTy.getElementType().getIntOrFloatBitWidth();
289 int subgroupSize = uArch->getSubgroupSize();
290 int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
291 if (isScattered) {
292 return LayoutInfo(xegpu::LayoutAttr::get(
293 tdescTy.getContext(), {subgroupSize, 1}, {1, packingFactor}));
295
296 return LayoutInfo(xegpu::LayoutAttr::get(
297 tdescTy.getContext(), {1, subgroupSize}, {1, packingFactor}));
298}
299
300/// Helper Function to get the expected layouts for DPAS operands. `lane_data`
301/// is set according to the following criteria:
302/// * For A operand, the data must be packed in minimum
303/// `packedSizeInBitsForDefault`
304/// * For B operand, the data must be packed in minimum
305/// `packedSizeInBitsForDpasB`
306static LayoutInfo
307getSIMTLayoutInfoForDPASOperand(VectorType vectorTy, unsigned operandNum,
309 unsigned packingSize) {
310 Type elementTy = vectorTy.getElementType();
311 assert(elementTy.isIntOrFloat() &&
312 "Expected int or float type in DPAS operands");
313 SmallVector<int32_t, 2> layout({1, uArch->getSubgroupSize()});
314 // For B operand, data must be packed in minimum `packedDpasBSizeInBits` and
315 // must have the VNNI format.
316 if (operandNum == 1 && elementTy.getIntOrFloatBitWidth() < packingSize) {
318 {static_cast<int32_t>(packingSize / elementTy.getIntOrFloatBitWidth()),
319 1});
320 return LayoutInfo(
321 xegpu::LayoutAttr::get(vectorTy.getContext(), layout, data));
322 }
323 // Otherwise, return the default layout for the vector type.
324 return getDefaultSIMTLayoutInfo(vectorTy, uArch, packingSize);
325}
326
327//===----------------------------------------------------------------------===//
328// LayoutInfoPropagation
329//===----------------------------------------------------------------------===//
330
331/// Backward data flow analysis to propagate the lane_layout and lane_data of
332/// each value in the program. Currently, the layouts for operands DPAS,
333/// StoreNd, and StoreScatter are fixed (known before propagation). Purpose of
334/// this analysis is to propagate those known layouts to all their producers and
335/// (other) consumers.
336class LayoutInfoPropagation
337 : public SparseBackwardDataFlowAnalysis<LayoutInfoLattice> {
338private:
339 LayoutKind layoutKind;
340 void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
342
343 void visitStoreNdOp(xegpu::StoreNdOp store,
346
347 void visitStoreScatterOp(xegpu::StoreScatterOp storeScatter,
350
351 void visitLoadNdOp(xegpu::LoadNdOp load,
354
355 void visitLoadGatherOp(xegpu::LoadGatherOp load,
358
359 void visitTransposeOp(vector::TransposeOp transpose,
362
363 void visitVectorBitcastOp(vector::BitCastOp bitcast,
366
367 void visitCreateDescOp(xegpu::CreateDescOp createDesc,
370
371 void visitUpdateNdOffsetOp(xegpu::UpdateNdOffsetOp updateNdOffset,
374
375 void visitPrefetchNdOp(xegpu::PrefetchNdOp prefetch,
378
379 void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction,
382
383 void visitVectorBroadCastOp(vector::BroadcastOp broadcast,
386 void visitShapeCastOp(vector::ShapeCastOp shapeCast,
389
390public:
391 LayoutInfoPropagation(DataFlowSolver &solver,
392 SymbolTableCollection &symbolTable,
393 LayoutKind layoutKind)
394 : SparseBackwardDataFlowAnalysis(solver, symbolTable),
395 layoutKind(layoutKind) {}
397
398 LogicalResult
399 visitOperation(Operation *op, ArrayRef<LayoutInfoLattice *> operands,
400 ArrayRef<const LayoutInfoLattice *> results) override;
401
402 void visitBranchOperand(OpOperand &operand) override {};
403
404 void visitCallOperand(OpOperand &operand) override {};
405
406 void visitExternalCall(CallOpInterface call,
407 ArrayRef<LayoutInfoLattice *> operands,
408 ArrayRef<const LayoutInfoLattice *> results) override {
409 };
410
411 void setToExitState(LayoutInfoLattice *lattice) override {
412 (void)lattice->meet(LayoutInfo());
413 }
414};
415} // namespace
416
417LogicalResult LayoutInfoPropagation::visitOperation(
418 Operation *op, ArrayRef<LayoutInfoLattice *> operands,
419 ArrayRef<const LayoutInfoLattice *> results) {
421 .Case<xegpu::DpasOp>(
422 [&](auto dpasOp) { visitDpasOp(dpasOp, operands, results); })
423 .Case<xegpu::StoreNdOp>(
424 [&](auto storeNdOp) { visitStoreNdOp(storeNdOp, operands, results); })
425 .Case<xegpu::StoreScatterOp>([&](auto storeScatterOp) {
426 visitStoreScatterOp(storeScatterOp, operands, results);
427 })
428 .Case<xegpu::LoadNdOp>(
429 [&](auto loadNdOp) { visitLoadNdOp(loadNdOp, operands, results); })
430 .Case<xegpu::LoadGatherOp>([&](auto loadGatherOp) {
431 visitLoadGatherOp(loadGatherOp, operands, results);
432 })
433 .Case<xegpu::CreateDescOp>([&](auto createDescOp) {
434 visitCreateDescOp(createDescOp, operands, results);
435 })
436 .Case<xegpu::UpdateNdOffsetOp>([&](auto updateNdOffsetOp) {
437 visitUpdateNdOffsetOp(updateNdOffsetOp, operands, results);
438 })
439 .Case<xegpu::PrefetchNdOp>([&](auto prefetchNdOp) {
440 visitPrefetchNdOp(prefetchNdOp, operands, results);
441 })
442 .Case<vector::TransposeOp>([&](auto transposeOp) {
443 visitTransposeOp(transposeOp, operands, results);
444 })
445 .Case<vector::BitCastOp>([&](auto bitcastOp) {
446 visitVectorBitcastOp(bitcastOp, operands, results);
447 })
448 .Case<vector::MultiDimReductionOp>([&](auto reductionOp) {
449 visitVectorMultiReductionOp(reductionOp, operands, results);
450 })
451 .Case<vector::BroadcastOp>([&](auto broadcastOp) {
452 visitVectorBroadCastOp(broadcastOp, operands, results);
453 })
454 .Case<vector::ShapeCastOp>([&](auto shapeCastOp) {
455 visitShapeCastOp(shapeCastOp, operands, results);
456 })
457 // All other ops.
458 .Default([&](Operation *op) {
459 for (const LayoutInfoLattice *resultInfo : results) {
460 if (!resultInfo->getValue().isAssigned())
461 continue;
462 for (auto [operandInfo, operand] :
463 llvm::zip(operands, op->getOpOperands())) {
464 // If the operand type is not a vector or tensor descriptor, skip
465 // it.
466 if (!isa<xegpu::TensorDescType, VectorType>(
467 operand.get().getType()))
468 continue;
469 // Propagate the result layout to the operand.
470 meet(operandInfo, *resultInfo);
471 }
472 }
473 });
474
475 return success();
476}
477
478void LayoutInfoPropagation::visitPrefetchNdOp(
479 xegpu::PrefetchNdOp prefetch, ArrayRef<LayoutInfoLattice *> operands,
480 ArrayRef<const LayoutInfoLattice *> results) {
481 // Here we assign the default layout to the tensor descriptor operand of
482 // prefetch.
483 auto tdescTy = prefetch.getTensorDescType();
484
485 auto uArch = getUArch(getChipStr(prefetch).value_or(""));
486 const auto *uArchInstruction =
487 dyn_cast<xegpu::uArch::Subgroup2DBlockPrefetchInstruction>(
488 uArch->getInstruction(
489 xegpu::uArch::InstructionKind::Subgroup2DBlockPrefetch));
490
491 auto blockWHC =
492 uArchInstruction->getBlockWidthHeightCount(tdescTy.getElementType());
493 if (!blockWHC)
494 prefetch.emitWarning("No known block params found for the element type.");
495 auto [bWidth, bHeight, bCount] = blockWHC.value();
496 SmallVector<int> instData;
497 int instWidth = xegpu::getLargestDivisor(
498 static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 1)), bWidth);
499 if (instWidth == -1)
500 prefetch.emitWarning(
501 "No suitable instruction multiple found for the given shape.");
502 if (tdescTy.getRank() == 1)
503 instData = {instWidth};
504 else {
505 int instHeight = xegpu::getLargestDivisor(
506 static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 2)), bHeight);
507 if (instHeight == -1)
508 prefetch.emitWarning(
509 "No suitable instruction multiple found for the given shape.");
510 instData = {instHeight, instWidth};
511 }
512 LayoutInfo prefetchLayout;
513 if (layoutKind == LayoutKind::InstData)
514 prefetchLayout =
515 LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(), instData));
516 else
517 prefetchLayout = getDefaultSIMTLayoutInfo(
518 tdescTy, uArch, uArchInstruction->getPackedFormatBitSize());
519
520 // Propagate the layout to the source tensor descriptor.
521 propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout));
522}
523
524void LayoutInfoPropagation::visitVectorMultiReductionOp(
525 vector::MultiDimReductionOp reduction,
526 ArrayRef<LayoutInfoLattice *> operands,
527 ArrayRef<const LayoutInfoLattice *> results) {
528 // The layout of the result must be present.
529 LayoutInfo resultLayout = results[0]->getValue();
530 if (!resultLayout.isAssigned())
531 return;
532 // We only consider 2D -> 1D reductions at this point.
533 VectorType resultTy = llvm::dyn_cast<VectorType>(reduction.getDestType());
534 if (!resultTy || resultTy.getRank() != 1) {
535 reduction.emitWarning("Expecting output type to be 1D vector.");
536 return;
537 }
538 auto uArch = getUArch(xegpu::getChipStr(reduction).value_or(""));
539 // Given that the result is 1D, the layout of the operand should be 2D with
540 // default layout.
541 LayoutInfo operandLayout = getDefaultSIMTLayoutInfo(
542 reduction->getContext(), 2, uArch->getSubgroupSize());
543 propagateIfChanged(operands[0], operands[0]->meet(operandLayout));
544 // Accumulator should have the same layout as the result.
545 propagateIfChanged(operands[1], operands[1]->meet(resultLayout));
546}
547
548void LayoutInfoPropagation::visitVectorBroadCastOp(
549 vector::BroadcastOp broadcast, ArrayRef<LayoutInfoLattice *> operands,
550 ArrayRef<const LayoutInfoLattice *> results) {
551 // The layout of the result must be present.
552 LayoutInfo resultLayout = results[0]->getValue();
553 if (!resultLayout.isAssigned())
554 return;
555 // Only consider vector to vector broadcasts for now.
556 VectorType resultTy = broadcast.getResultVectorType();
557 VectorType sourceTy = dyn_cast<VectorType>(broadcast.getSourceType());
558 if (!sourceTy) {
559 broadcast.emitWarning("Expecting source type to be a vector type.");
560 return;
561 }
562
563 // Only consider nD -> nD broadcast.
564 if (sourceTy.getRank() != resultTy.getRank()) {
565 broadcast.emitWarning("Expecting source and result to have same rank.");
566 return;
567 }
568 SetVector<int64_t> broadcastUnitDims = broadcast.computeBroadcastedUnitDims();
569 if (broadcastUnitDims.size() != 1) {
570 broadcast.emitWarning("Expecting source type to be nD vector only with "
571 "one broadcasted dimension.");
572 return;
573 }
574 // Propagate the result layout to the source operand.
575 propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
576}
577
578void LayoutInfoPropagation::visitShapeCastOp(
579 vector::ShapeCastOp shapeCast, ArrayRef<LayoutInfoLattice *> operands,
580 ArrayRef<const LayoutInfoLattice *> results) {
581 // The layout of the result must be present.
582 LayoutInfo resultLayout = results[0]->getValue();
583 if (!resultLayout.isAssigned())
584 return;
585 VectorType sourceTy = shapeCast.getSourceVectorType();
586 VectorType resultTy = shapeCast.getResultVectorType();
587 // Shape cast layout propagation only supports 1D -> 2D shape casts.
588 // TODO: Support kD -> nD shape casts (k < n, n >= 2) where expanded dims are
589 // unit dimensions and non-unit dims match.
590 if (sourceTy.getRank() != 1 || resultTy.getRank() != 2) {
591 shapeCast.emitWarning("Expecting shape cast to be 1D -> 2D.");
592 return;
593 }
594 int64_t slicedDim = resultTy.getShape()[0] == 1 ? 0 : 1;
595 xegpu::SliceAttr sliceLayout = xegpu::SliceAttr::get(
596 shapeCast->getContext(), cast<xegpu::LayoutAttr>(resultLayout.get()),
597 DenseI64ArrayAttr::get(shapeCast->getContext(), {slicedDim}));
598 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(sliceLayout)));
599}
600
601/// Propagate the layout of the result tensor to the source tensor descriptor
602/// in UpdateNdOffsetOp.
603void LayoutInfoPropagation::visitUpdateNdOffsetOp(
604 xegpu::UpdateNdOffsetOp updateNdOffset,
605 ArrayRef<LayoutInfoLattice *> operands,
606 ArrayRef<const LayoutInfoLattice *> results) {
607 // The layout of the result must be present.
608 LayoutInfo resultLayout = results[0]->getValue();
609 if (!resultLayout.isAssigned())
610 return;
611 // Propagate the layout to the source operand.
612 propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
613}
614
615/// Set the layouts for DPAS A, B, and C operands.
616void LayoutInfoPropagation::visitDpasOp(
617 xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
618 ArrayRef<const LayoutInfoLattice *> results) {
619 VectorType aTy = dpas.getLhsType();
620 VectorType bTy = dpas.getRhsType();
621
622 auto uArch = getUArch(getChipStr(dpas).value_or(""));
623 const int subgroupSize = uArch->getSubgroupSize();
624 const auto *uArchInstruction =
625 dyn_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc>(uArch->getInstruction(
626 xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc));
627
628 const unsigned dataALen = aTy.getShape().front();
629 auto supportedALen = uArchInstruction->getSupportedM(aTy.getElementType());
630 const int maxALen =
631 xegpu::getLargestDivisor(dataALen, ArrayRef<unsigned>(supportedALen));
632 if (maxALen == -1)
633 dpas.emitWarning(
634 "No suitable instruction multiple found for the given shape.");
635
636 const unsigned dataBLen = bTy.getShape().back();
637 auto supportedBLen = uArchInstruction->getSupportedK(bTy.getElementType());
638 const int maxBLen =
639 xegpu::getLargestDivisor(dataBLen, ArrayRef<unsigned>(supportedBLen));
640 if (maxBLen == -1)
641 dpas.emitWarning(
642 "No suitable instruction multiple found for the given shape.");
643 SmallVector<int> instDataA = {maxALen, subgroupSize};
644 SmallVector<int> instDataB = {subgroupSize, maxBLen};
645
646 LayoutInfo dpasALayout;
647 LayoutInfo dpasBLayout;
648 LayoutInfo dpasCLayout;
649
650 if (layoutKind == LayoutKind::InstData) {
651 dpasALayout =
652 LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataA));
653 dpasBLayout =
654 LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataB));
655 } else {
656 dpasALayout = getSIMTLayoutInfoForDPASOperand(
657 aTy, 0, uArch, uArchInstruction->getPackedFormatBitSizeA());
658 dpasBLayout = getSIMTLayoutInfoForDPASOperand(
659 bTy, 1, uArch, uArchInstruction->getPackedFormatBitSizeB());
660 }
661
662 propagateIfChanged(operands[0], operands[0]->meet(dpasALayout));
663 propagateIfChanged(operands[1], operands[1]->meet(dpasBLayout));
664 if (operands.size() > 2) {
665 VectorType cTy = dpas.getAccType();
666 const unsigned dataCLen = bTy.getShape().back();
667 auto supportedCLen = uArchInstruction->getSupportedN(bTy.getElementType());
668 const int maxCLen =
669 xegpu::getLargestDivisor(dataCLen, ArrayRef<unsigned>(supportedCLen));
670 if (maxCLen == -1)
671 dpas.emitWarning(
672 "No suitable instruction multiple found for the given shape.");
673 SmallVector<int> instDataC = {maxALen, maxCLen};
674
675 if (layoutKind == LayoutKind::InstData)
676 dpasCLayout =
677 LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataC));
678 else
679 dpasCLayout = getSIMTLayoutInfoForDPASOperand(
680 cTy, 2, uArch, uArchInstruction->getPackedFormatBitSizeB());
681
682 propagateIfChanged(operands[2], operands[2]->meet(dpasCLayout));
683 }
684}
685
686/// Set the layout for the value and tensor descriptor operands in StoreNdOp.
687void LayoutInfoPropagation::visitStoreNdOp(
688 xegpu::StoreNdOp store, ArrayRef<LayoutInfoLattice *> operands,
689 ArrayRef<const LayoutInfoLattice *> results) {
690
691 auto uArch = getUArch(getChipStr(store).value_or(""));
692 const auto *uArchInstruction =
693 dyn_cast<xegpu::uArch::Subgroup2DBlockStoreInstruction>(
694 uArch->getInstruction(
695 xegpu::uArch::InstructionKind::Subgroup2DBlockStore));
696 VectorType dataTy = store.getValueType();
697 auto blockWHC = uArchInstruction->getBlockWidthHeightCount(
698 store.getValueType().getElementType());
699 if (!blockWHC)
700 store.emitWarning("No known block params found for the element type.");
701 auto [bWidth, bHeight, bCount] = blockWHC.value();
702 SmallVector<int> instData;
703 int instWidth = xegpu::getLargestDivisor(
704 static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 1)), bWidth);
705 if (instWidth == -1)
706 store.emitWarning(
707 "No suitable instruction multiple found for the given shape.");
708 if (dataTy.getRank() == 1)
709 instData = {instWidth};
710 else {
711 int instHeight = xegpu::getLargestDivisor(
712 static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 2)), bHeight);
713 if (instHeight == -1)
714 store.emitWarning(
715 "No suitable instruction multiple found for the given shape.");
716 instData = {instHeight, instWidth};
717 }
718
719 LayoutInfo storeLayout;
720 if (layoutKind == LayoutKind::InstData)
721 storeLayout =
722 LayoutInfo(xegpu::LayoutAttr::get(dataTy.getContext(), instData));
723 else
724 storeLayout =
725 getDefaultSIMTLayoutInfo(store.getValueType(), uArch,
726 uArchInstruction->getPackedFormatBitSize());
727 // Both operands should have the same layout
728 for (LayoutInfoLattice *operand : operands)
729 propagateIfChanged(operand, operand->meet(storeLayout));
730}
731
732/// Propagate the layout of the value to the tensor descriptor operand in
733/// LoadNdOp.
734void LayoutInfoPropagation::visitLoadNdOp(
735 xegpu::LoadNdOp load, ArrayRef<LayoutInfoLattice *> operands,
736 ArrayRef<const LayoutInfoLattice *> results) {
737 LayoutInfo valueLayout = results[0]->getValue();
738 // Need the layout of the value to propagate to the tensor descriptor.
739 if (!valueLayout.isAssigned())
740 return;
741 LayoutInfo tensorDescLayout = valueLayout;
742 // LoadNdOp has the transpose effect. However, at the stage of this analysis
743 // this effect is not expected and should be abstracted away. Emit a
744 // warning.
745 if (auto transpose = load.getTranspose()) {
746 load.emitWarning("Transpose effect is not expected for LoadNdOp at "
747 "LayoutInfoPropagation stage.");
748 tensorDescLayout = valueLayout.transpose(transpose.value());
749 }
750 // Propagate the new layout to the tensor descriptor operand.
751 propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
752}
753
754/// For vector::TransposeOp, the layout of the result is transposed and
755/// propagated to the operand.
756void LayoutInfoPropagation::visitTransposeOp(
757 vector::TransposeOp transpose, ArrayRef<LayoutInfoLattice *> operands,
758 ArrayRef<const LayoutInfoLattice *> results) {
759 // Need the layout of transpose result to propagate to the operands.
760 LayoutInfo resultLayout = results[0]->getValue();
761 if (!resultLayout.isAssigned())
762 return;
763 LayoutInfo newLayout = resultLayout.transpose(transpose.getPermutation());
764 // Propagate the new layout to the vector operand.
765 propagateIfChanged(operands[0], operands[0]->meet(newLayout));
766}
767
768/// For vector::BitCastOp, the lane_data of the source layout is changed based
769/// on the bit width of the source and result types.
770void LayoutInfoPropagation::visitVectorBitcastOp(
771 vector::BitCastOp bitcast, ArrayRef<LayoutInfoLattice *> operands,
772 ArrayRef<const LayoutInfoLattice *> results) {
773 // Need the layout of bitcast result to propagate to the operands.
774 LayoutInfo resultLayout = results[0]->getValue();
775 if (!resultLayout.isAssigned())
776 return;
777 int inElemTyBitWidth =
778 bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
779 int outElemTyBitWidth =
780 bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
781 // If the element bit widths are the same, then the layout does not change.
782 if (inElemTyBitWidth == outElemTyBitWidth) {
783 propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
784 return;
785 }
786 // Check if the result layout is valid. i.e. result vector can be distributed.
787 auto resultLaneLayout = resultLayout.getLaneLayout();
788 auto resultLaneData = resultLayout.getLaneData();
790 bitcast.getResultVectorType(),
791 xegpu::LayoutAttr::get(bitcast->getContext(), resultLaneLayout,
792 resultLaneData)))) {
793 bitcast.emitWarning(
794 "Result vector type can not be evenly distributed across lanes.");
795 return;
796 }
797 int64_t rank = bitcast.getSourceVectorType().getRank();
798 // Bitcast is a `narrowing` if the input element type bit width larger than
799 // the output element type bit width. eg. f32 -> f16 is a narrowing bitcast.
800 bool isNarrowing = inElemTyBitWidth > outElemTyBitWidth;
801 int bitCastRatio = isNarrowing ? inElemTyBitWidth / outElemTyBitWidth
802 : outElemTyBitWidth / inElemTyBitWidth;
803 SmallVector<int> sourceLaneLayout =
804 resultLayout.getLaneLayout(); // Lane layout does not change for bitcast.
805 SmallVector<int> outData = resultLayout.getLaneData();
806
807 // TODO: Currently we assume that bitcasts does not require cross lane
808 // communication. So each lane must own the required number of elements to
809 // perform the bitcast locally without cross-lane communication.
810 int outInnerBitsPerLane = outData[rank - 1] * outElemTyBitWidth;
811 if (outInnerBitsPerLane < inElemTyBitWidth) {
812 bitcast.emitWarning(
813 "Narrowing bitcast with cross lane communication is not supported.");
814 return;
815 }
816 // Check if each lane owns a single element in all dimensions except the
817 // innermost dimension.
818 SmallVector<int> sourceLaneData(outData.begin(), outData.end() - 1);
819 if (llvm::any_of(sourceLaneData, [](int64_t d) { return d != 1; })) {
820 bitcast.emitWarning("Each lane must not own multiple elements in any "
821 "dimension other than "
822 "the innermost dimension.");
823 return;
824 }
825 // Decide lane data based on whether the bitcast is narrowing or widening.
826 int64_t innerMostLaneData = isNarrowing ? outData[rank - 1] / bitCastRatio
827 : outData[rank - 1] * bitCastRatio;
828 sourceLaneData.push_back(innerMostLaneData);
829
830 propagateIfChanged(
831 operands[0],
832 operands[0]->meet(LayoutInfo(xegpu::LayoutAttr::get(
833 bitcast->getContext(), sourceLaneLayout, sourceLaneData))));
834}
835
836/// Propagate the layout of the result to the tensor descriptor, mask and offset
837/// operands in LoadGatherOp.
838void LayoutInfoPropagation::visitLoadGatherOp(
839 xegpu::LoadGatherOp load, ArrayRef<LayoutInfoLattice *> operands,
840 ArrayRef<const LayoutInfoLattice *> results) {
841 // The layout is strictly determined by the payload type.
842 auto payloadTy = dyn_cast<VectorType>(load.getValueType());
843 if (!payloadTy) {
844 load.emitWarning("Not propagating, non-vector payload supplied.");
845 return;
846 }
847 auto uArch = getUArch(getChipStr(load).value_or(""));
848 const int subgroupSize = uArch->getSubgroupSize();
849 SmallVector<int> instData{subgroupSize};
850 if (auto chunkSize = load.getChunkSize().value_or(0); chunkSize > 1)
851 instData.push_back(chunkSize);
852 else if (auto srcTdescTy =
853 dyn_cast<xegpu::TensorDescType>(load.getSourceType())) {
854 if (srcTdescTy.getChunkSizeAsInt() > 1)
855 instData.push_back(chunkSize);
856 }
857 LayoutInfo layout;
858 if (layoutKind == LayoutKind::InstData)
859 layout = LayoutInfo(xegpu::LayoutAttr::get(load.getContext(), instData));
860 else
861 layout = getDefaultSIMTLayoutInfo(payloadTy, uArch,
863 /*scattered*/ true);
864
865 // Mask operand should have 1D default layout.
866 LayoutInfo maskLayout =
867 getDefaultSIMTLayoutInfo(load->getContext(), 1, subgroupSize);
868
869 // Propagate the new layout to the tensor descriptor operand.
870 if (isa<xegpu::TensorDescType>(load.getSourceType()))
871 propagateIfChanged(operands[0], operands[0]->meet(layout));
872 // Propagate the new layout to the mask and optional offset operand.
873 propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
874 if (load.getOffsets())
875 propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
876}
877
878/// Propagate the layout of the descriptor to the vector offset operand in
879/// CreateDescOp.
880void LayoutInfoPropagation::visitCreateDescOp(
881 xegpu::CreateDescOp createDesc, ArrayRef<LayoutInfoLattice *> operands,
882 ArrayRef<const LayoutInfoLattice *> results) {
883 LayoutInfo descLayout = results[0]->getValue();
884 // Need the layout of the descriptor to propagate to the operands.
885 if (!descLayout.isAssigned())
886 return;
887 auto uArch = getUArch(getChipStr(createDesc).value_or(""));
888 // For offset operand propagate 1D default layout.
889 LayoutInfo layout = getDefaultSIMTLayoutInfo(createDesc->getContext(), 1,
890 uArch->getSubgroupSize());
891 propagateIfChanged(operands[1], operands[1]->meet(layout));
892}
893
894/// Set the layout for the value, tensor descriptor, offset and mask operands in
895/// the StoreScatterOp.
896void LayoutInfoPropagation::visitStoreScatterOp(
897 xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands,
898 ArrayRef<const LayoutInfoLattice *> results) {
899 // Currently, for 2D StoreScatterOp we expect that the height dimension of
900 // the tensor descriptor is equal to the subgroup size. This is ensured by
901 // the op verifier.
902 auto payloadTy = dyn_cast<VectorType>(storeScatter.getValueType());
903 if (!payloadTy) {
904 storeScatter.emitWarning("Not propagating, non-vector payload supplied.");
905 return;
906 }
907 LayoutInfo payloadLayout;
908 auto uArch = getUArch(getChipStr(storeScatter).value_or(""));
909 const int subgroupSize = uArch->getSubgroupSize();
910
911 if (auto layout = storeScatter.getLayoutAttr()) {
912 payloadLayout = LayoutInfo(layout);
913 } else {
914 if (layoutKind == LayoutKind::InstData) {
915 SmallVector<int> instData{subgroupSize};
916 if (auto chunkSize = storeScatter.getChunkSize().value_or(0);
917 chunkSize > 1)
918 instData.push_back(chunkSize);
919 else if (auto dstTdescTy = dyn_cast<xegpu::TensorDescType>(
920 storeScatter.getDestType())) {
921 if (dstTdescTy.getChunkSizeAsInt() > 1)
922 instData.push_back(chunkSize);
923 }
924 payloadLayout = LayoutInfo(
925 xegpu::LayoutAttr::get(storeScatter.getContext(), instData));
926 } else {
927 auto payloadShape = payloadTy.getShape();
928 if (payloadShape.size() > 1)
929 assert(payloadShape[0] == subgroupSize &&
930 "Expected the first dimension of 2D tensor descriptor to be "
931 "equal to "
932 "subgroup size.");
933 payloadLayout = getDefaultSIMTLayoutInfo(
934 payloadTy, uArch, uArch->getGeneralPackedFormatBitSize(),
935 /*scattered=*/true);
936 }
937 }
938
939 LayoutInfo maskLayout =
940 getDefaultSIMTLayoutInfo(storeScatter->getContext(), 1, subgroupSize);
941 // Propagate the payload operand layout
942 propagateIfChanged(operands[0], operands[0]->meet(payloadLayout));
943 // Propagate the destination (if tdesc) operand layout
944 if (isa<xegpu::TensorDescType>(storeScatter.getDestType()))
945 propagateIfChanged(operands[1], operands[1]->meet(payloadLayout));
946 // Propagate the new layout to the mask and optional offset operand.
947 propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
948 if (storeScatter.getOffsets())
949 propagateIfChanged(operands[3], operands[3]->meet(maskLayout));
950}
951
952namespace {
953//===----------------------------------------------------------------------===//
954// RunLayoutInfoPropagation
955//===----------------------------------------------------------------------===//
956
957/// Driver class for running the LayoutInfoPropagation analysis.
958class RunLayoutInfoPropagation {
959public:
960 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RunLayoutInfoPropagation)
961
962 RunLayoutInfoPropagation(Operation *op, LayoutKind layoutKind) : target(op) {
963 SymbolTableCollection symbolTable;
964 loadBaselineAnalyses(solver);
965 solver.load<LayoutInfoPropagation>(symbolTable, layoutKind);
966 (void)solver.initializeAndRun(op);
967 }
968
969 LayoutInfo getLayoutInfo(Value val);
970
971 void printAnalysisResult(llvm::raw_ostream &os);
972
973private:
974 DataFlowSolver solver;
975 const Operation *target;
976};
977} // namespace
978
979LayoutInfo RunLayoutInfoPropagation::getLayoutInfo(Value val) {
980 auto *state = solver.lookupState<LayoutInfoLattice>(val);
981 if (!state)
982 return {};
983 return state->getValue();
984}
985
986// Print the analysis result for debugging purposes.
987void RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) {
988 auto printFunctionResult = [&](FunctionOpInterface funcOp) {
989 os << "function: " << funcOp.getName() << ":\n";
990 // Function arguments
991 for (BlockArgument arg : funcOp.getArguments()) {
992 LayoutInfo layout = getLayoutInfo(arg);
993 os << "argument: " << arg << "\n";
994 os << "layout : ";
995 layout.print(os);
996 os << "\n";
997 }
998 // Function ops
999 funcOp.walk([&](Operation *op) {
1000 // Skip ops that do not have results
1001 if (op->getResults().empty())
1002 return;
1003 os << "op : ";
1004 // For control-flow ops, print the op name only.
1005 if (isa<BranchOpInterface>(op) || isa<RegionBranchOpInterface>(op))
1006 os << op->getName();
1007 else
1008 op->print(os);
1009 os << "\n";
1010 // Print the layout for each result.
1011 for (auto [i, r] : llvm::enumerate(op->getResults())) {
1012 LayoutInfo layout = getLayoutInfo(r);
1013 os << "layout for result #" << i << ": ";
1014 layout.print(os);
1015 os << "\n";
1016 }
1017 });
1018 };
1019
1020 SmallVector<FunctionOpInterface> funcOps;
1021 if (auto modOp = dyn_cast<ModuleOp>(target)) {
1022 for (auto funcOp : modOp.getOps<FunctionOpInterface>())
1023 funcOps.push_back(funcOp);
1024
1025 // Collect all GpuFuncOps in the module.
1026 for (auto gpuModOp : modOp.getOps<gpu::GPUModuleOp>()) {
1027 for (auto gpuFuncOp : gpuModOp.getOps<FunctionOpInterface>())
1028 funcOps.push_back(gpuFuncOp);
1029 }
1030 }
1031 // Print the analysis result for each function.
1032 for (FunctionOpInterface funcOp : funcOps)
1033 printFunctionResult(funcOp);
1034}
1035
1036using GetLayoutFnTy = function_ref<xegpu::DistributeLayoutAttr(Value)>;
1037/// Update an operation with the layout of its results. If the result type is
1038/// a vector type, a temporary layout attribute is added to the operation. If
1039/// the result type is a tensor descriptor type, the type is updated with the
1040/// layout attribute. The users of the result are also updated with the layout
1041/// attribute.
1042static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
1043 GetLayoutFnTy getLayoutOfValue) {
1044 // Region ops (like scf.for) are already handled by the
1045 // updateControlFlowOps.
1046 if (mlir::isa<mlir::RegionBranchOpInterface>(op))
1047 return success();
1048
1049 // Iterate over all the results.
1050 for (OpResult result : op->getResults()) {
1051 Type resultType = result.getType();
1052 // Layouts are needed only for vector and tensor descriptor types.
1053 if (!isa<VectorType, xegpu::TensorDescType>(resultType))
1054 continue;
1055 // If the result has no layout but has users, emit a warning and continue.
1056 xegpu::DistributeLayoutAttr layout = getLayoutOfValue(result);
1057 if (!layout && result.getNumUses() > 0) {
1058 op->emitWarning("op has users but no layout assigned for its result");
1059 continue;
1060 }
1061 // If the result is a tensor descriptor type, update the tensor desc type
1062 // with layout.
1063 if (auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType)) {
1064 auto typeWithLayout = xegpu::TensorDescType::get(
1065 tensorDescTy.getContext(), tensorDescTy.getShape(),
1066 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
1067 result.setType(typeWithLayout);
1068 continue;
1069 }
1070 // If the result is a vector type, add a temporary layout attribute to the
1071 // op.
1072 xegpu::setDistributeLayoutAttr(result, layout, /*respectPermLayout*/ true);
1073 }
1074 return success();
1075}
1076
1077/// Region ops like scf.for need special handling because they have blocks
1078/// inside. If the blocks have tensor descriptor type as block arguments,
1079/// thier types must be updated. Also region op can have results that may not
1080/// have any users (e.g. A and B tiles). They are not assigned a layout by
1081/// layout analysis because they have no users. However inside the region op
1082/// corresponding block arguments for these results do have layouts.
1083/// Therefore, in this case we still need to update the result types with the
1084/// layout attribute. This function function updates the internal block
1085/// arguments and the result types of the region op with the assigned layouts.
1086/// clang-format off
1087/// Example: scf.for ... iter_args(...) -> (out types) {
1088/// ^bb0(block types):
1089/// ...
1090/// scf.yield ... : (yield types)
1091/// }
1092/// clang-format on
1093/// In this example, at scf.yield, control-flow can transfer to two successor
1094/// regions. One is the ^bb0 (for loop body) and the other is the scf.for op
1095/// itself (yield the results). So we update both the block arguments of the
1096/// successor region (i.e. block types) and the result types of the scf.for op
1097/// (i.e. out types). Note that yield types are updated by respective
1098/// producers inside bb0.
1099static LogicalResult
1101 mlir::RegionBranchTerminatorOpInterface terminator,
1102 GetLayoutFnTy getLayoutOfValue) {
1103 // Only process if the terminator is inside a region branch op.
1104 if (!mlir::isa<mlir::RegionBranchOpInterface>(terminator->getParentOp()))
1105 return success();
1106
1108 llvm::SmallVector<mlir::Attribute> operands(terminator->getNumOperands(),
1109 nullptr);
1110 terminator.getSuccessorRegions(operands, successors);
1111
1112 for (mlir::RegionSuccessor &successor : successors) {
1113 mlir::OperandRange successorOperands =
1114 terminator.getSuccessorOperands(successor);
1115 mlir::ValueRange successorInputs = successor.getSuccessorInputs();
1116 for (auto [successorOperand, successorInput] :
1117 llvm::zip(successorOperands, successorInputs)) {
1118 Type inputType = successorInput.getType();
1119 // We only need to operate on tensor descriptor or vector types.
1120 if (!isa<xegpu::TensorDescType, VectorType>(inputType))
1121 continue;
1122 xegpu::DistributeLayoutAttr successorInputLayout =
1123 getLayoutOfValue(successorInput);
1124 xegpu::DistributeLayoutAttr successorOperandLayout =
1125 getLayoutOfValue(successorOperand);
1126
1127 // If either of the layouts is not assigned, we cannot proceed.
1128 if (!successorOperandLayout) {
1129 LLVM_DEBUG(DBGS() << "No layout assigned for forwarded operand in "
1130 "branch terminator: "
1131 << successorOperand << "\n");
1132 return failure();
1133 }
1134 // We expect the layouts to match.
1135 if (successorInputLayout &&
1136 successorInputLayout != successorOperandLayout) {
1137 LLVM_DEBUG(DBGS() << "Conflicting layouts for region argument and "
1138 "operand forwarded as the argument: "
1139 << successorInputLayout << " vs "
1140 << successorOperandLayout << "\n");
1141 return failure();
1142 }
1143 // Get tensor descriptor type with the layout.
1144 if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(inputType)) {
1145 auto newTdescTy = xegpu::TensorDescType::get(
1146 tdescTy.getContext(), tdescTy.getShape(), tdescTy.getElementType(),
1147 tdescTy.getEncoding(), successorOperandLayout);
1148 successorInput.setType(newTdescTy);
1149 continue;
1150 }
1151 // If the type is a vector type and this region argument is an OpResult,
1152 // set the layout attribute on the OpResult.
1153 if (auto result = dyn_cast<OpResult>(successorInput))
1154 xegpu::setDistributeLayoutAttr(result, successorOperandLayout);
1155 }
1156 }
1157 return success();
1158}
1159
1160/// Update the function arguments and results with the layouts.
1161static LogicalResult updateFunctionOpInterface(mlir::OpBuilder &builder,
1162 mlir::FunctionOpInterface funcOp,
1163 GetLayoutFnTy getLayoutOfValue) {
1164 SmallVector<Type> newArgTypes;
1165 // Update the function arguments.
1166 for (BlockArgument arg : funcOp.getArguments()) {
1167 Type argType = arg.getType();
1168 newArgTypes.push_back(argType);
1169 if (!isa<VectorType, xegpu::TensorDescType>(argType))
1170 continue;
1171 xegpu::DistributeLayoutAttr layout = getLayoutOfValue(arg);
1172 if (!layout) {
1173 LLVM_DEBUG(DBGS() << "Expecting layout for function argument: " << arg
1174 << " but got none.\n");
1175 return failure();
1176 }
1177 if (auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(argType)) {
1178 auto newTdescTy = xegpu::TensorDescType::get(
1179 tensorDescTy.getContext(), tensorDescTy.getShape(),
1180 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
1181 arg.setType(newTdescTy);
1182 newArgTypes.back() = newTdescTy;
1183 }
1184 }
1185 // Update the function type with the new argument types.
1186 // NOTE: We assume that function results are not expected to have layouts.
1187 funcOp.setType(FunctionType::get(funcOp.getContext(), newArgTypes,
1188 funcOp.getResultTypes()));
1189 return success();
1190}
1191
1192namespace {
1193struct XeGPUPropagateLayoutPass final
1194 : public xegpu::impl::XeGPUPropagateLayoutBase<XeGPUPropagateLayoutPass> {
1195 XeGPUPropagateLayoutPass() = default;
1196 XeGPUPropagateLayoutPass(const XeGPUPropagateLayoutPass &other) = default;
1197 XeGPUPropagateLayoutPass(xegpu::XeGPUPropagateLayoutOptions options)
1198 : XeGPUPropagateLayoutBase(options) {}
1199 void runOnOperation() override;
1200};
1201
1202} // namespace
1203
1204void XeGPUPropagateLayoutPass::runOnOperation() {
1205 LayoutKind layoutKind;
1206 if (this->layoutKind == "lane") {
1207 layoutKind = LayoutKind::Lane;
1208 } else if (this->layoutKind == "inst") {
1209 layoutKind = LayoutKind::InstData;
1210 } else {
1211 getOperation()->emitError("Unsupported layout kind option: " +
1212 this->layoutKind);
1213 signalPassFailure();
1214 return;
1215 }
1216 RunLayoutInfoPropagation analysis(getOperation(), layoutKind);
1217 // Print the analysis result and exit. (for debugging purposes)
1218 if (printOnly) {
1219 auto &os = llvm::outs();
1220 analysis.printAnalysisResult(os);
1221 return;
1222 }
1223 // Helper to convert LayoutInfo to xegpu::LayoutAttr.
1224 auto getXeGPULayoutForValue = [&](Value val) -> xegpu::DistributeLayoutAttr {
1225 LayoutInfo layout = analysis.getLayoutInfo(val);
1226 if (!layout.isAssigned())
1227 return {};
1228 xegpu::DistributeLayoutAttr layoutAttr =
1229 cast<xegpu::DistributeLayoutAttr>(layout.get());
1230 if (layout.isSliceLayout())
1231 return cast<xegpu::SliceAttr>(layoutAttr);
1232 return cast<xegpu::LayoutAttr>(layoutAttr);
1233 };
1234
1235 mlir::OpBuilder builder(&getContext());
1236 Operation *op = getOperation();
1237 auto walkResult = op->walk([&](mlir::Block *block) -> WalkResult {
1238 for (mlir::Operation &op : llvm::reverse(block->getOperations())) {
1239 LogicalResult r = success();
1241 .Case<mlir::RegionBranchTerminatorOpInterface>(
1242 [&](mlir::RegionBranchTerminatorOpInterface branchTermOp) {
1243 r = updateControlFlowOps(builder, branchTermOp,
1244 getXeGPULayoutForValue);
1245 })
1246 .Case<mlir::FunctionOpInterface>(
1247 [&](mlir::FunctionOpInterface funcOp) {
1248 r = updateFunctionOpInterface(builder, funcOp,
1249 getXeGPULayoutForValue);
1250 })
1251 .Default([&](Operation *op) {
1252 r = updateOp(builder, op, getXeGPULayoutForValue);
1253 });
1254 if (failed(r)) {
1255 op.emitError("Failed to update operation with the layout.");
1256 return WalkResult::interrupt();
1257 }
1258 }
1259 return WalkResult::advance();
1260 });
1261 if (walkResult.wasInterrupted()) {
1262 signalPassFailure();
1263 return;
1264 }
1265}
return success()
#define DBGS()
Definition Hoisting.cpp:32
lhs
b getContext())
auto load
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition TypeID.h:331
static LogicalResult updateControlFlowOps(mlir::OpBuilder &builder, mlir::RegionBranchTerminatorOpInterface terminator, GetLayoutFnTy getLayoutOfValue)
Region ops like scf.for need special handling because they have blocks inside.
function_ref< xegpu::DistributeLayoutAttr(Value)> GetLayoutFnTy
static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op, GetLayoutFnTy getLayoutOfValue)
Update an operation with the layout of its results.
static LogicalResult updateFunctionOpInterface(mlir::OpBuilder &builder, mlir::FunctionOpInterface funcOp, GetLayoutFnTy getLayoutOfValue)
Update the function arguments and results with the layouts.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:309
OpListType & getOperations()
Definition Block.h:137
The general data-flow analysis solver.
const StateT * lookupState(AnchorT anchor) const
Lookup an analysis state for the given lattice anchor.
AnalysisT * load(Args &&...args)
Load an analysis into the solver. Return the analysis instance.
LogicalResult initializeAndRun(Operation *top)
Initialize the children analyses starting from the provided top-level operation and run the analysis ...
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:207
This class represents an operand of an operation.
Definition Value.h:257
This is a value defined by a result of an operation.
Definition Value.h:457
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:383
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
void print(raw_ostream &os, const OpPrintingFlags &flags={})
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:797
result_range getResults()
Definition Operation.h:415
This class represents a successor of a region.
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
This class represents a lattice holding a specific value of type ValueT.
A sparse (backward) data-flow analysis for propagating SSA value lattices backwards across the IR by ...
SparseBackwardDataFlowAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int64_t > content)
void loadBaselineAnalyses(DataFlowSolver &solver)
Populates a DataFlowSolver with analyses that are required to ensure user-defined analyses are run pr...
Definition Utils.h:29
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
detail::InFlightRemark analysis(Location loc, RemarkOpts opts)
Report an optimization analysis remark.
Definition Remarks.h:567
const uArch * getUArch(llvm::StringRef archName)
int getLargestDivisor(T dim, ArrayRef< T > candidates, ArrayRef< T > candidateMultiples={})
Helper Function to find a proper instruction multiple for the user-supplied sg-level data shape (dive...
void setDistributeLayoutAttr(const T &operandOrResult, const DistributeLayoutAttr layout, bool respectPermLayout=false)
Sets the DistributeLayoutAttr for a given OpOperand or OpResult by attaching it to the owner's dictio...
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)
If tensor descriptor has a layout attribute it is used in SIMT mode.
Include the generated interface declarations.
bool operator==(StringAttr lhs, std::nullptr_t)
Define comparisons for StringAttr against nullptr and itself to avoid the StringRef overloads from be...
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:144
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
virtual unsigned getGeneralPackedFormatBitSize() const =0
virtual int getSubgroupSize() const =0
const Instruction * getInstruction(InstructionKind instKind) const
Definition uArchBase.h:157