MLIR 23.0.0git
XeGPUPropagateLayout.cpp
Go to the documentation of this file.
1//===- XeGPUPropagateLayout.cpp - XeGPU Layout Propagation ------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
19#include "mlir/IR/Attributes.h"
20#include "mlir/IR/Builders.h"
23#include "mlir/IR/Operation.h"
24#include "mlir/IR/Value.h"
25#include "mlir/IR/Visitors.h"
28#include "mlir/Support/LLVM.h"
29#include "llvm/ADT/ArrayRef.h"
30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/SmallSet.h"
32#include "llvm/ADT/SmallVector.h"
33#include "llvm/ADT/TypeSwitch.h"
34#include "llvm/Support/Casting.h"
35#include "llvm/Support/Debug.h"
36#include "llvm/Support/LogicalResult.h"
37#include "llvm/Support/raw_ostream.h"
38
40
41namespace mlir {
42namespace xegpu {
43#define GEN_PASS_DEF_XEGPUPROPAGATELAYOUT
44#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
45} // namespace xegpu
46} // namespace mlir
47
48#define DEBUG_TYPE "xegpu-propagate-layout"
49#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
50
51using namespace mlir;
52using namespace mlir::dataflow;
53
54namespace {
55
56enum class LayoutKind { Lane, InstData, Subgroup };
57
58//===----------------------------------------------------------------------===//
59// LayoutInfo
60//===----------------------------------------------------------------------===//
61
62/// Helper class for tracking the analysis state of an mlir value. For layout
63/// propagation, the analysis state is simply the distribution layout of
64/// each value. The distribution layout information is encapsulated using
65/// xegpu::DistributeLayoutAttr class which can hold information about any type
66/// of distribution layout that XeGPU dialect supports. Purpose of this analysis
67/// to propagate some unique distribution layout for each value in the program
68/// starting from a set of anchor operations (like DPAS, StoreNd, etc.). Note
69/// that analysis will reach a fixed point when all values are reached some
70/// layout and, analysis does not try to modify any already assigned layouts.
71///
72/// Given this, LayoutInfo satisifies the following properties:
73/// 1) A LayoutInfo value can be in one of two states - `assigned` or `not
74/// assigned`.
75/// 2) Two LayoutInfo values are equal if they are both assigned or
76/// both not assigned. The concrete value of assigned state does not matter.
77/// 3) The meet operator works as follows:
78/// - If current state is assigned, return the current state. (already
79/// a unique layout is assigned. don't change it)
80/// - Otherwise, return the other state.
81
82struct LayoutInfo {
83private:
84 xegpu::DistributeLayoutAttr storage = nullptr;
85
86public:
87 LayoutInfo() = default;
88 LayoutInfo(const xegpu::DistributeLayoutAttr &layout) : storage(layout) {}
89
90 // Two lattice values are equal if they have `some` layout. The actual
91 // content of the layout does not matter.
92 bool operator==(const LayoutInfo &other) const {
93 return this->isAssigned() == other.isAssigned();
94 }
95
96 static LayoutInfo meet(const LayoutInfo &lhs, const LayoutInfo &rhs);
97
98 static LayoutInfo join(const LayoutInfo &lhs, const LayoutInfo &rhs);
99
100 void print(raw_ostream &os) const;
101
102 bool isAssigned() const { return storage != nullptr; }
103
104 LayoutInfo transpose(ArrayRef<int64_t> permutation) const;
105
106 SmallVector<int> getLaneLayout() const;
107
108 SmallVector<int> getLaneData() const;
109
110 SmallVector<int> getInstData() const;
111
112 SmallVector<int> getSgLayout() const;
113
114 SmallVector<int> getSgData() const;
115
116 SmallVector<int> getOrder() const;
117
118 bool isSliceLayout() const {
119 if (!isAssigned())
120 return false;
121 return isa<xegpu::SliceAttr>(storage);
122 }
123
124 int64_t getRank() const {
125 if (!isAssigned())
126 return -1;
127 return storage.getRank();
128 }
129
130 Attribute get() { return storage; }
131};
132
133SmallVector<int> LayoutInfo::getLaneLayout() const {
134 if (!isAssigned())
135 return {};
136 return llvm::map_to_vector(storage.getEffectiveLaneLayoutAsInt(),
137 [](int64_t val) { return static_cast<int>(val); });
138}
139
140SmallVector<int> LayoutInfo::getLaneData() const {
141 if (!isAssigned())
142 return {};
143 return llvm::map_to_vector(storage.getEffectiveLaneDataAsInt(),
144 [](int64_t val) { return static_cast<int>(val); });
145}
146
147SmallVector<int> LayoutInfo::getInstData() const {
148 if (!isAssigned())
149 return {};
150 return llvm::map_to_vector(storage.getEffectiveInstDataAsInt(),
151 [](int64_t val) { return static_cast<int>(val); });
152}
153
154SmallVector<int> LayoutInfo::getSgLayout() const {
155 if (!isAssigned())
156 return {};
157 return llvm::map_to_vector(storage.getEffectiveSgLayoutAsInt(),
158 [](int64_t val) { return static_cast<int>(val); });
159}
160
161SmallVector<int> LayoutInfo::getSgData() const {
162 if (!isAssigned())
163 return {};
164 return llvm::map_to_vector(storage.getEffectiveSgDataAsInt(),
165 [](int64_t val) { return static_cast<int>(val); });
166}
167
168SmallVector<int> LayoutInfo::getOrder() const {
169 if (!isAssigned() || !storage.getOrder())
170 return {};
171 return llvm::map_to_vector(storage.getOrder().asArrayRef(),
172 [](int64_t val) { return static_cast<int>(val); });
173}
174
175void LayoutInfo::print(raw_ostream &os) const {
176 if (isAssigned()) {
177 os << storage;
178 } else {
179 os << "Not assigned.";
180 }
181}
182
183LayoutInfo LayoutInfo::meet(const LayoutInfo &lhs, const LayoutInfo &rhs) {
184 if (!lhs.isAssigned())
185 return rhs;
186 return lhs;
187}
188
189/// Since this is a backward analysis, join method is not used.
190LayoutInfo LayoutInfo::join(const LayoutInfo &lhs, const LayoutInfo &rhs) {
191 llvm_unreachable("Join should not be triggered by layout propagation.");
192}
193
194/// Construct a new layout with the transposed inst_data or lane_layout,
195/// lane_data.
196LayoutInfo LayoutInfo::transpose(ArrayRef<int64_t> permutation) const {
197 if (!isAssigned())
198 return {};
199 // Check if the permutation is valid.
200 llvm::SmallSet<int64_t, 4> seen(permutation.begin(), permutation.end());
201 bool hasDuplicates = seen.size() != permutation.size();
202 bool withinRange = llvm::all_of(permutation, [&](int64_t idx) {
203 return idx >= 0 && idx < static_cast<int64_t>(permutation.size());
204 });
205
206 if (!withinRange || hasDuplicates) {
207 assert(false && "Invalid permutation for transpose.");
208 return {};
209 }
210
211 SmallVector<int32_t> laneLayout;
212 SmallVector<int32_t> laneData;
213 SmallVector<int32_t> instData;
214 SmallVector<int32_t> sgLayout;
217
218 for (int64_t idx : permutation) {
219 if (getLaneLayout().size()) {
220 laneLayout.push_back(static_cast<int32_t>(getLaneLayout()[idx]));
221 laneData.push_back(static_cast<int32_t>(getLaneData()[idx]));
222 }
223 if (getInstData().size())
224 instData.push_back(static_cast<int32_t>(getInstData()[idx]));
225 if (getSgData().size()) {
226 sgLayout.push_back(static_cast<int32_t>(getSgLayout()[idx]));
227 sgData.push_back(static_cast<int32_t>(getSgData()[idx]));
228 }
229 if (getOrder().size()) {
230 order.push_back(static_cast<int32_t>(getOrder()[idx]));
231 }
232 }
233 auto orderAttr = order.size()
234 ? DenseI32ArrayAttr::get(storage.getContext(), order)
235 : nullptr;
236 xegpu::LayoutAttr layoutAttr;
237 if (getLaneLayout().size())
238 layoutAttr =
239 xegpu::LayoutAttr::get(storage.getContext(), laneLayout, laneData);
240 if (getInstData().size())
241 layoutAttr = xegpu::LayoutAttr::get(storage.getContext(), instData);
242 if (getSgData().size())
243 layoutAttr = xegpu::LayoutAttr::get(
244 storage.getContext(),
245 DenseI32ArrayAttr::get(storage.getContext(), sgLayout),
246 DenseI32ArrayAttr::get(storage.getContext(), sgData),
247 /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
248 /*lane_data =*/nullptr, orderAttr);
249 return LayoutInfo(layoutAttr);
250}
251
252//===----------------------------------------------------------------------===//
253// LayoutInfoLattice
254//===----------------------------------------------------------------------===//
255
256/// Lattice holding the LayoutInfo for each value.
257struct LayoutInfoLattice : public Lattice<LayoutInfo> {
259 using Lattice::Lattice;
260};
261
262/// Helper Functions to get default layouts. A `default layout` is a layout that
263/// is assigned to a value when the layout is not fixed by some anchor operation
264/// (like DPAS).
265
266/// Helper Function to get the default layout for uniform values like constants.
267/// For 1D vector, lane_layout is [subgroupSize] and lane_data is [1].
268/// For 2D vector, lane_layout is [1, subgroupSize] and lane_data is [1, 1].
269static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
270 unsigned rank,
271 const xegpu::uArch::uArch *uArch) {
272 assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
273 if (rank == 1) {
274 return LayoutInfo(
275 xegpu::LayoutAttr::get(ctx, {uArch->getSubgroupSize()}, {1}));
276 }
277 return LayoutInfo(
278 xegpu::LayoutAttr::get(ctx, {1, uArch->getSubgroupSize()}, {1, 1}));
279}
280
281static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
282 unsigned rank, int subgroupSize) {
283 assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
284 if (rank == 1) {
285 return LayoutInfo(xegpu::LayoutAttr::get(ctx, {subgroupSize}, {1}));
286 }
287 return LayoutInfo(xegpu::LayoutAttr::get(ctx, {1, subgroupSize}, {1, 1}));
288}
289
290/// Helper to get the default layout for 2D block operations.
291template <typename Ty>
292static LayoutInfo getSIMTLayoutInforForBlockIO(Ty ty,
294 unsigned packingSize) {
295 // Expecting a 1D or 2D vector.
296 assert((ty.getRank() == 1 || ty.getRank() == 2) &&
297 "Expected 1D or 2D vector.");
298 // Expecting int or float element type.
299 assert(ty.getElementType().isIntOrFloat() &&
300 "Expected int or float element type.");
301 // If the rank is 1, then return default layout for 1D vector.
302 if (ty.getRank() == 1)
303 return getDefaultSIMTLayoutInfo(ty.getContext(), 1, uArch);
304 // Packing factor is determined by the element type bitwidth.
305 unsigned bitwidth = ty.getElementType().getIntOrFloatBitWidth();
306 int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
307 return LayoutInfo(xegpu::LayoutAttr::get(
308 ty.getContext(), {1, uArch->getSubgroupSize()}, {1, packingFactor}));
309}
310
311/// Helper to get the default layout for a vector type.
312static LayoutInfo
313getSIMTLayoutInforForScatterIO(VectorType vectorTy,
315 unsigned packingSize) {
316 // Expecting a 1D or 2D vector.
317 assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
318 "Expected 1D or 2D vector.");
319 // Expecting int or float element type.
320 assert(vectorTy.getElementType().isIntOrFloat() &&
321 "Expected int or float element type.");
322 // If the rank is 1, then return default layout for 1D vector.
323 if (vectorTy.getRank() == 1)
324 return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1, uArch);
325 // Packing factor is determined by the element type bitwidth.
326 unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
327 int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
328 return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(),
329 {uArch->getSubgroupSize(), 1},
330 {1, packingFactor}));
331}
332
333/// Helper Function to get the expected layouts for DPAS operands. `lane_data`
334/// is set according to the following criteria:
335/// * For A operand, the data must be packed in minimum
336/// `packedSizeInBitsForDefault`
337/// * For B operand, the data must be packed in minimum
338/// `packedSizeInBitsForDpasB`
339static LayoutInfo
340getSIMTLayoutInfoForDPASOperand(VectorType vectorTy, unsigned operandNum,
342 unsigned packingSize) {
343 Type elementTy = vectorTy.getElementType();
344 assert(elementTy.isIntOrFloat() &&
345 "Expected int or float type in DPAS operands");
347 // For B operand, data must be packed in minimum `packedDpasBSizeInBits` and
348 // must have the VNNI format.
349 if (operandNum == 1 && elementTy.getIntOrFloatBitWidth() < packingSize) {
351 {static_cast<int32_t>(packingSize / elementTy.getIntOrFloatBitWidth()),
352 1});
353 return LayoutInfo(
354 xegpu::LayoutAttr::get(vectorTy.getContext(), layout, data));
355 }
356 // Otherwise, return the default layout for the vector type.
357 return getSIMTLayoutInforForBlockIO(vectorTy, uArch, packingSize);
358}
359
360//===----------------------------------------------------------------------===//
361// LayoutInfoPropagation
362//===----------------------------------------------------------------------===//
363
364/// Backward data flow analysis to propagate the lane_layout and lane_data of
365/// each value in the program. Currently, the layouts for operands DPAS,
366/// StoreNd, and StoreScatter are fixed (known before propagation). Purpose of
367/// this analysis is to propagate those known layouts to all their producers and
368/// (other) consumers.
369class LayoutInfoPropagation
370 : public SparseBackwardDataFlowAnalysis<LayoutInfoLattice> {
371private:
372 LayoutKind layoutKind;
373 void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
375
376 void visitStoreNdOp(xegpu::StoreNdOp store,
379
380 void visitStoreScatterOp(xegpu::StoreScatterOp storeScatter,
383
384 void visitLoadNdOp(xegpu::LoadNdOp load,
387
388 void visitLoadGatherOp(xegpu::LoadGatherOp load,
391
392 void visitTransposeOp(vector::TransposeOp transpose,
395
396 void visitVectorBitcastOp(vector::BitCastOp bitcast,
399
400 void visitCreateDescOp(xegpu::CreateDescOp createDesc,
403
404 void visitUpdateNdOffsetOp(xegpu::UpdateNdOffsetOp updateNdOffset,
407
408 void visitPrefetchNdOp(xegpu::PrefetchNdOp prefetch,
411
412 void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction,
415
416 void visitVectorBroadCastOp(vector::BroadcastOp broadcast,
419 void visitShapeCastOp(vector::ShapeCastOp shapeCast,
422
423 bool hasParamsOfLayoutKind(xegpu::DistributeLayoutAttr anchorLayout);
424
425public:
426 LayoutInfoPropagation(DataFlowSolver &solver,
427 SymbolTableCollection &symbolTable,
428 LayoutKind layoutKind)
429 : SparseBackwardDataFlowAnalysis(solver, symbolTable),
430 layoutKind(layoutKind) {}
432
433 LogicalResult
434 visitOperation(Operation *op, ArrayRef<LayoutInfoLattice *> operands,
435 ArrayRef<const LayoutInfoLattice *> results) override;
436
437 void visitBranchOperand(OpOperand &operand) override {};
438
439 void visitCallOperand(OpOperand &operand) override {};
440
441 void
442 visitNonControlFlowArguments(RegionSuccessor &successor,
443 ArrayRef<BlockArgument> arguments) override {};
444
445 void visitExternalCall(CallOpInterface call,
447 ArrayRef<const LayoutInfoLattice *> results) override {
448 };
449
450 void setToExitState(LayoutInfoLattice *lattice) override {
451 (void)lattice->meet(LayoutInfo());
452 }
453};
454} // namespace
455
456LogicalResult LayoutInfoPropagation::visitOperation(
457 Operation *op, ArrayRef<LayoutInfoLattice *> operands,
458 ArrayRef<const LayoutInfoLattice *> results) {
460 .Case<xegpu::DpasOp>(
461 [&](auto dpasOp) { visitDpasOp(dpasOp, operands, results); })
462 .Case<xegpu::StoreNdOp>(
463 [&](auto storeNdOp) { visitStoreNdOp(storeNdOp, operands, results); })
464 .Case<xegpu::StoreScatterOp>([&](auto storeScatterOp) {
465 visitStoreScatterOp(storeScatterOp, operands, results);
466 })
467 .Case<xegpu::LoadNdOp>(
468 [&](auto loadNdOp) { visitLoadNdOp(loadNdOp, operands, results); })
469 .Case<xegpu::LoadGatherOp>([&](auto loadGatherOp) {
470 visitLoadGatherOp(loadGatherOp, operands, results);
471 })
472 .Case<xegpu::CreateDescOp>([&](auto createDescOp) {
473 visitCreateDescOp(createDescOp, operands, results);
474 })
475 .Case<xegpu::UpdateNdOffsetOp>([&](auto updateNdOffsetOp) {
476 visitUpdateNdOffsetOp(updateNdOffsetOp, operands, results);
477 })
478 .Case<xegpu::PrefetchNdOp>([&](auto prefetchNdOp) {
479 visitPrefetchNdOp(prefetchNdOp, operands, results);
480 })
481 .Case<vector::TransposeOp>([&](auto transposeOp) {
482 visitTransposeOp(transposeOp, operands, results);
483 })
484 .Case<vector::BitCastOp>([&](auto bitcastOp) {
485 visitVectorBitcastOp(bitcastOp, operands, results);
486 })
487 .Case<vector::MultiDimReductionOp>([&](auto reductionOp) {
488 visitVectorMultiReductionOp(reductionOp, operands, results);
489 })
490 .Case<vector::BroadcastOp>([&](auto broadcastOp) {
491 visitVectorBroadCastOp(broadcastOp, operands, results);
492 })
493 .Case<vector::ShapeCastOp>([&](auto shapeCastOp) {
494 visitShapeCastOp(shapeCastOp, operands, results);
495 })
496 // All other ops.
497 .Default([&](Operation *op) {
498 for (const LayoutInfoLattice *resultInfo : results) {
499 if (!resultInfo->getValue().isAssigned())
500 continue;
501 for (auto [operandInfo, operand] :
502 llvm::zip(operands, op->getOpOperands())) {
503 // If the operand type is not a vector or tensor descriptor, skip
504 // it.
505 if (!isa<xegpu::TensorDescType, VectorType>(
506 operand.get().getType()))
507 continue;
508 // Propagate the result layout to the operand.
509 meet(operandInfo, *resultInfo);
510 }
511 }
512 });
513
514 return success();
515}
516
517bool LayoutInfoPropagation::hasParamsOfLayoutKind(
518 xegpu::DistributeLayoutAttr anchorLayout) {
519 if (anchorLayout == nullptr) {
520 return false;
521 }
522 if (layoutKind == LayoutKind::InstData) {
523 return !(anchorLayout.getEffectiveInstDataAsInt().empty());
524 } else if (layoutKind == LayoutKind::Lane) {
525 return !(anchorLayout.getEffectiveLaneLayoutAsInt().empty() ||
526 anchorLayout.getEffectiveLaneDataAsInt().empty());
527 } else if (layoutKind == LayoutKind::Subgroup) {
528 return !(anchorLayout.getEffectiveSgLayoutAsInt().empty() ||
529 anchorLayout.getEffectiveSgDataAsInt().empty());
530 }
531 return false;
532}
533
534void LayoutInfoPropagation::visitPrefetchNdOp(
535 xegpu::PrefetchNdOp prefetch, ArrayRef<LayoutInfoLattice *> operands,
536 ArrayRef<const LayoutInfoLattice *> results) {
537
538 LayoutInfo prefetchLayout;
539 xegpu::DistributeLayoutAttr anchorLayout = prefetch.getLayoutAttr();
540 if (hasParamsOfLayoutKind(anchorLayout)) {
541 prefetchLayout = LayoutInfo(anchorLayout);
542 } else {
543 // Here we assign the default layout to the tensor descriptor operand of
544 // prefetch.
545 auto tdescTy = prefetch.getTensorDescType();
546
547 auto uArch = getUArch(getChipStr(prefetch).value_or(""));
548 const auto *uArchInstruction =
549 dyn_cast<xegpu::uArch::Subgroup2DBlockPrefetchInstruction>(
550 uArch->getInstruction(
551 xegpu::uArch::InstructionKind::Subgroup2DBlockPrefetch));
552
553 auto blockWHC =
554 uArchInstruction->getBlockWidthHeightCount(tdescTy.getElementType());
555 if (!blockWHC)
556 prefetch.emitWarning("No known block params found for the element type.");
557 auto [bWidth, bHeight, bCount] = blockWHC.value();
558 SmallVector<int> instData;
559 int instWidth = xegpu::getLargestDivisor(
560 static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 1)), bWidth);
561 if (instWidth == -1)
562 prefetch.emitWarning(
563 "No suitable instruction multiple found for the given shape.");
564 if (tdescTy.getRank() == 1)
565 instData = {instWidth};
566 else {
567 int instHeight = xegpu::getLargestDivisor(
568 static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 2)), bHeight);
569 if (instHeight == -1)
570 prefetch.emitWarning(
571 "No suitable instruction multiple found for the given shape.");
572 instData = {instHeight, instWidth};
573 }
574
575 if (layoutKind == LayoutKind::InstData)
576 prefetchLayout =
577 LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(), instData));
578 else
579 prefetchLayout = getSIMTLayoutInforForBlockIO(
580 tdescTy, uArch, uArchInstruction->getPackedFormatBitSize());
581
582 prefetch.setLayoutAttr(
583 dyn_cast<xegpu::DistributeLayoutAttr>(prefetchLayout.get()));
584 }
585 // Propagate the layout to the source tensor descriptor.
586 propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout));
587}
588
589void LayoutInfoPropagation::visitVectorMultiReductionOp(
590 vector::MultiDimReductionOp reduction,
591 ArrayRef<LayoutInfoLattice *> operands,
592 ArrayRef<const LayoutInfoLattice *> results) {
593 // The layout of the result must be present.
594 LayoutInfo resultLayout = results[0]->getValue();
595 if (!resultLayout.isAssigned())
596 return;
597 // We only consider 2D -> 1D reductions at this point.
598 VectorType resultTy = llvm::dyn_cast<VectorType>(reduction.getDestType());
599 if (!resultTy || resultTy.getRank() != 1) {
600 reduction.emitWarning("Expecting output type to be 1D vector.");
601 return;
602 }
603 auto uArch = getUArch(xegpu::getChipStr(reduction).value_or(""));
604 // Given that the result is 1D, the layout of the operand should be 2D with
605 // default layout.
606 LayoutInfo operandLayout = getDefaultSIMTLayoutInfo(
607 reduction->getContext(), 2, uArch->getSubgroupSize());
608 propagateIfChanged(operands[0], operands[0]->meet(operandLayout));
609 // Accumulator should have the same layout as the result.
610 propagateIfChanged(operands[1], operands[1]->meet(resultLayout));
611}
612
613void LayoutInfoPropagation::visitVectorBroadCastOp(
614 vector::BroadcastOp broadcast, ArrayRef<LayoutInfoLattice *> operands,
615 ArrayRef<const LayoutInfoLattice *> results) {
616 // The layout of the result must be present.
617 LayoutInfo resultLayout = results[0]->getValue();
618 if (!resultLayout.isAssigned())
619 return;
620 // Only consider vector to vector broadcasts for now.
621 VectorType resultTy = broadcast.getResultVectorType();
622 VectorType sourceTy = dyn_cast<VectorType>(broadcast.getSourceType());
623 // skip layout propagation for non-vector source operand.
624 if (!sourceTy)
625 return;
626
627 // Hanlding broadcast from low-rank to high-rank (e.g., 1D to 2D) case.
628 if (sourceTy.getRank() != resultTy.getRank()) {
629 auto sourceDims = sourceTy.getShape();
630 auto resultDims = resultTy.getShape();
631 SmallVector<int64_t> bcastDims;
632 auto dimDiff = resultTy.getRank() - sourceTy.getRank();
633 // adding the missing leading dims
634 for (int i = 0; i < dimDiff; i++)
635 bcastDims.push_back(i);
636
637 // for the rest dims in the resultTy, if sourceTy dim is 1, then it's
638 // broadcasted dim
639 for (size_t i = 0; i < sourceDims.size(); i++)
640 if ((sourceDims[i] == 1) && (resultDims[i + dimDiff] != 1))
641 bcastDims.push_back(i + dimDiff);
642
643 // create a slice layout for the source
644 xegpu::SliceAttr sliceLayout = xegpu::SliceAttr::get(
646 cast<xegpu::DistributeLayoutAttr>(resultLayout.get()),
648
649 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(sliceLayout)));
650 return;
651 }
652 propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
653}
654
655void LayoutInfoPropagation::visitShapeCastOp(
656 vector::ShapeCastOp shapeCast, ArrayRef<LayoutInfoLattice *> operands,
657 ArrayRef<const LayoutInfoLattice *> results) {
658 // The layout of the result must be present.
659 LayoutInfo resultLayout = results[0]->getValue();
660 if (!resultLayout.isAssigned())
661 return;
662 VectorType sourceTy = shapeCast.getSourceVectorType();
663 VectorType resultTy = shapeCast.getResultVectorType();
664 // Shape cast layout propagation only supports 1D -> 2D shape casts.
665 // TODO: Support kD -> nD shape casts (k < n, n >= 2) where expanded dims are
666 // unit dimensions and non-unit dims match.
667 if (sourceTy.getRank() != 1 || resultTy.getRank() != 2) {
668 shapeCast.emitWarning("Expecting shape cast to be 1D -> 2D.");
669 return;
670 }
671 int64_t slicedDim = resultTy.getShape()[0] == 1 ? 0 : 1;
672 xegpu::SliceAttr sliceLayout = xegpu::SliceAttr::get(
673 shapeCast->getContext(), cast<xegpu::LayoutAttr>(resultLayout.get()),
674 DenseI64ArrayAttr::get(shapeCast->getContext(), {slicedDim}));
675 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(sliceLayout)));
676}
677
678/// Propagate the layout of the result tensor to the source tensor descriptor
679/// in UpdateNdOffsetOp.
680void LayoutInfoPropagation::visitUpdateNdOffsetOp(
681 xegpu::UpdateNdOffsetOp updateNdOffset,
682 ArrayRef<LayoutInfoLattice *> operands,
683 ArrayRef<const LayoutInfoLattice *> results) {
684 // The layout of the result must be present.
685 LayoutInfo resultLayout = results[0]->getValue();
686 if (!resultLayout.isAssigned())
687 return;
688 // Propagate the layout to the source operand.
689 propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
690}
691
692/// Set the layouts for DPAS A, B, and C operands.
693void LayoutInfoPropagation::visitDpasOp(
694 xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
695 ArrayRef<const LayoutInfoLattice *> results) {
696
697 LayoutInfo dpasALayout;
698 LayoutInfo dpasBLayout;
699 LayoutInfo dpasCDLayout;
700
701 xegpu::DistributeLayoutAttr anchorLayoutCD = dpas.getLayoutCdAttr();
702 if (hasParamsOfLayoutKind(anchorLayoutCD)) {
703 xegpu::DistributeLayoutAttr anchorLayoutA = dpas.getLayoutAAttr();
704 xegpu::DistributeLayoutAttr anchorLayoutB = dpas.getLayoutBAttr();
705 assert(hasParamsOfLayoutKind(anchorLayoutA) &&
706 "Expected anchor layout for DPAS A operand.");
707 assert(hasParamsOfLayoutKind(anchorLayoutB) &&
708 "Expected anchor layout for DPAS B operand.");
709 dpasALayout = LayoutInfo(anchorLayoutA);
710 dpasBLayout = LayoutInfo(anchorLayoutB);
711 dpasCDLayout = LayoutInfo(anchorLayoutCD);
712
713 } else {
714
715 VectorType aTy = dpas.getLhsType();
716 VectorType bTy = dpas.getRhsType();
717
718 auto uArch = getUArch(getChipStr(dpas).value_or(""));
719 const int subgroupSize = uArch->getSubgroupSize();
720 const auto *uArchInstruction =
721 dyn_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc>(uArch->getInstruction(
722 xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc));
723
724 const unsigned dataALen = aTy.getShape().front();
725 auto supportedALen = uArchInstruction->getSupportedM(aTy.getElementType());
726 const int maxALen =
727 xegpu::getLargestDivisor(dataALen, ArrayRef<unsigned>(supportedALen));
728 if (maxALen == -1)
729 dpas.emitWarning(
730 "No suitable instruction multiple found for the given shape.");
731
732 const unsigned dataBLen = bTy.getShape().back();
733 auto supportedBLen = uArchInstruction->getSupportedN(bTy.getElementType());
734
735 const int maxBLen =
736 xegpu::getLargestDivisor(dataBLen, ArrayRef<unsigned>(supportedBLen));
737
738 if (maxBLen == -1)
739 dpas.emitWarning(
740 "No suitable instruction multiple found for the given shape.");
741 SmallVector<int> instDataA = {maxALen, subgroupSize};
742 SmallVector<int> instDataB = {subgroupSize, maxBLen};
743
744 if (layoutKind == LayoutKind::InstData) {
745 dpasALayout =
746 LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataA));
747 dpasBLayout =
748 LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataB));
749 } else {
750 dpasALayout = getSIMTLayoutInfoForDPASOperand(
751 aTy, 0, uArch, uArchInstruction->getPackedFormatBitSizeA());
752 dpasBLayout = getSIMTLayoutInfoForDPASOperand(
753 bTy, 1, uArch, uArchInstruction->getPackedFormatBitSizeB());
754 }
755
756 if (operands.size() > 2) {
757 VectorType cTy = dpas.getAccType();
758 if (layoutKind == LayoutKind::InstData) {
759 const unsigned dataCLen = bTy.getShape().back();
760 auto supportedCLen =
761 uArchInstruction->getSupportedN(bTy.getElementType());
762 const int maxCLen = xegpu::getLargestDivisor(
763 dataCLen, ArrayRef<unsigned>(supportedCLen));
764 if (maxCLen == -1)
765 dpas.emitWarning(
766 "No suitable instruction multiple found for the given shape.");
767 SmallVector<int> instDataC = {maxALen, maxCLen};
768 dpasCDLayout =
769 LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataC));
770 } else
771 dpasCDLayout = getSIMTLayoutInfoForDPASOperand(
772 cTy, 2, uArch, uArchInstruction->getPackedFormatBitSizeB());
773
774 dpas.setLayoutCdAttr(
775 dyn_cast<xegpu::DistributeLayoutAttr>(dpasCDLayout.get()));
776 }
777 dpas.setLayoutAAttr(
778 dyn_cast<xegpu::DistributeLayoutAttr>(dpasALayout.get()));
779 dpas.setLayoutBAttr(
780 dyn_cast<xegpu::DistributeLayoutAttr>(dpasBLayout.get()));
781 }
782
783 propagateIfChanged(operands[0], operands[0]->meet(dpasALayout));
784 propagateIfChanged(operands[1], operands[1]->meet(dpasBLayout));
785 if (operands.size() > 2) {
786 propagateIfChanged(operands[2], operands[2]->meet(dpasCDLayout));
787 }
788}
789
790/// Set the layout for the value and tensor descriptor operands in StoreNdOp.
791void LayoutInfoPropagation::visitStoreNdOp(
792 xegpu::StoreNdOp store, ArrayRef<LayoutInfoLattice *> operands,
793 ArrayRef<const LayoutInfoLattice *> results) {
794
795 LayoutInfo storeLayout;
796 xegpu::DistributeLayoutAttr anchorLayout = store.getLayoutAttr();
797 if (hasParamsOfLayoutKind(anchorLayout)) {
798 storeLayout = LayoutInfo(anchorLayout);
799 } else {
800 auto uArch = getUArch(getChipStr(store).value_or(""));
801 const auto *uArchInstruction =
802 dyn_cast<xegpu::uArch::Subgroup2DBlockStoreInstruction>(
803 uArch->getInstruction(
804 xegpu::uArch::InstructionKind::Subgroup2DBlockStore));
805 VectorType dataTy = store.getValueType();
806 auto blockWHC = uArchInstruction->getBlockWidthHeightCount(
807 store.getValueType().getElementType());
808 if (!blockWHC)
809 store.emitWarning("No known block params found for the element type.");
810 auto [bWidth, bHeight, bCount] = blockWHC.value();
811 SmallVector<int> instData;
812 int instWidth = xegpu::getLargestDivisor(
813 static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 1)), bWidth);
814 if (instWidth == -1)
815 store.emitWarning(
816 "No suitable instruction multiple found for the given shape.");
817 if (dataTy.getRank() == 1)
818 instData = {instWidth};
819 else {
820 int instHeight = xegpu::getLargestDivisor(
821 static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 2)), bHeight);
822 if (instHeight == -1)
823 store.emitWarning(
824 "No suitable instruction multiple found for the given shape.");
825 instData = {instHeight, instWidth};
826 }
827
828 if (layoutKind == LayoutKind::InstData)
829 storeLayout =
830 LayoutInfo(xegpu::LayoutAttr::get(dataTy.getContext(), instData));
831 else
832 storeLayout = getSIMTLayoutInforForBlockIO(
833 store.getValueType(), uArch,
834 uArchInstruction->getPackedFormatBitSize());
835 store.setLayoutAttr(
836 dyn_cast<xegpu::DistributeLayoutAttr>(storeLayout.get()));
837 }
838 // Propagate the layout to the value operand.
839 // Both operands should have the same layout
840 for (LayoutInfoLattice *operand : operands)
841 propagateIfChanged(operand, operand->meet(storeLayout));
842}
843
844/// Propagate the layout of the value to the tensor descriptor operand in
845/// LoadNdOp.
846void LayoutInfoPropagation::visitLoadNdOp(
847 xegpu::LoadNdOp load, ArrayRef<LayoutInfoLattice *> operands,
848 ArrayRef<const LayoutInfoLattice *> results) {
849
850 LayoutInfo loadLayout;
851 xegpu::DistributeLayoutAttr anchorLayout = load.getLayoutAttr();
852 if (hasParamsOfLayoutKind(anchorLayout)) {
853 loadLayout = LayoutInfo(anchorLayout);
854 } else {
855
856 LayoutInfo valueLayout = results[0]->getValue();
857 // Need the layout of the value to propagate to the tensor descriptor.
858 if (!valueLayout.isAssigned())
859 return;
860 loadLayout = valueLayout;
861 // LoadNdOp has the transpose effect. However, at the stage of this analysis
862 // this effect is not expected and should be abstracted away. Emit a
863 // warning.
864 if (auto transpose = load.getTranspose()) {
865 load.emitWarning("Transpose effect is not expected for LoadNdOp at "
866 "LayoutInfoPropagation stage.");
867 loadLayout = valueLayout.transpose(transpose.value());
868 }
869 load.setLayoutAttr(dyn_cast<xegpu::DistributeLayoutAttr>(loadLayout.get()));
870 }
871 // Propagate the new layout to the tensor descriptor operand.
872 propagateIfChanged(operands[0], operands[0]->meet(loadLayout));
873}
874
875/// For vector::TransposeOp, the layout of the result is transposed and
876/// propagated to the operand.
877void LayoutInfoPropagation::visitTransposeOp(
878 vector::TransposeOp transpose, ArrayRef<LayoutInfoLattice *> operands,
879 ArrayRef<const LayoutInfoLattice *> results) {
880 // Need the layout of transpose result to propagate to the operands.
881 LayoutInfo resultLayout = results[0]->getValue();
882 if (!resultLayout.isAssigned())
883 return;
884 LayoutInfo newLayout = resultLayout.transpose(transpose.getPermutation());
885 // Propagate the new layout to the vector operand.
886 propagateIfChanged(operands[0], operands[0]->meet(newLayout));
887}
888
889/// For vector::BitCastOp, the lane_data of the source layout is changed based
890/// on the bit width of the source and result types.
891void LayoutInfoPropagation::visitVectorBitcastOp(
892 vector::BitCastOp bitcast, ArrayRef<LayoutInfoLattice *> operands,
893 ArrayRef<const LayoutInfoLattice *> results) {
894 // Need the layout of bitcast result to propagate to the operands.
895 LayoutInfo resultLayout = results[0]->getValue();
896 if (!resultLayout.isAssigned())
897 return;
898 int inElemTyBitWidth =
899 bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
900 int outElemTyBitWidth =
901 bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
902 // If the element bit widths are the same, then the layout does not change.
903 if (inElemTyBitWidth == outElemTyBitWidth) {
904 propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
905 return;
906 }
907 // Check if the result layout is valid. i.e. result vector can be distributed.
908 auto resultLaneLayout = resultLayout.getLaneLayout();
909 auto resultLaneData = resultLayout.getLaneData();
911 bitcast.getResultVectorType(),
912 xegpu::LayoutAttr::get(bitcast->getContext(), resultLaneLayout,
913 resultLaneData)))) {
914 bitcast.emitWarning(
915 "Result vector type can not be evenly distributed across lanes.");
916 return;
917 }
918 int64_t rank = bitcast.getSourceVectorType().getRank();
919 // Bitcast is a `narrowing` if the input element type bit width larger than
920 // the output element type bit width. eg. f32 -> f16 is a narrowing bitcast.
921 bool isNarrowing = inElemTyBitWidth > outElemTyBitWidth;
922 int bitCastRatio = isNarrowing ? inElemTyBitWidth / outElemTyBitWidth
923 : outElemTyBitWidth / inElemTyBitWidth;
924 SmallVector<int> sourceLaneLayout =
925 resultLayout.getLaneLayout(); // Lane layout does not change for bitcast.
926 SmallVector<int> outData = resultLayout.getLaneData();
927
928 // TODO: Currently we assume that bitcasts does not require cross lane
929 // communication. So each lane must own the required number of elements to
930 // perform the bitcast locally without cross-lane communication.
931 int outInnerBitsPerLane = outData[rank - 1] * outElemTyBitWidth;
932 if (outInnerBitsPerLane < inElemTyBitWidth) {
933 bitcast.emitWarning(
934 "Narrowing bitcast with cross lane communication is not supported.");
935 return;
936 }
937 // Check if each lane owns a single element in all dimensions except the
938 // innermost dimension.
939 SmallVector<int> sourceLaneData(outData.begin(), outData.end() - 1);
940 if (llvm::any_of(sourceLaneData, [](int64_t d) { return d != 1; })) {
941 bitcast.emitWarning("Each lane must not own multiple elements in any "
942 "dimension other than "
943 "the innermost dimension.");
944 return;
945 }
946 // Decide lane data based on whether the bitcast is narrowing or widening.
947 int64_t innerMostLaneData = isNarrowing ? outData[rank - 1] / bitCastRatio
948 : outData[rank - 1] * bitCastRatio;
949 sourceLaneData.push_back(innerMostLaneData);
950
951 propagateIfChanged(
952 operands[0],
953 operands[0]->meet(LayoutInfo(xegpu::LayoutAttr::get(
954 bitcast->getContext(), sourceLaneLayout, sourceLaneData))));
955}
956
957/// Propagate the layout of the result to the tensor descriptor, mask and offset
958/// operands in LoadGatherOp.
959void LayoutInfoPropagation::visitLoadGatherOp(
960 xegpu::LoadGatherOp load, ArrayRef<LayoutInfoLattice *> operands,
961 ArrayRef<const LayoutInfoLattice *> results) {
962
963 LayoutInfo loadLayout;
964 LayoutInfo maskLayout;
965 xegpu::DistributeLayoutAttr anchorLayout = load.getLayoutAttr();
966 if (hasParamsOfLayoutKind(anchorLayout)) {
967 loadLayout = LayoutInfo(anchorLayout);
968 maskLayout = loadLayout;
969 } else {
970
971 // The layout is strictly determined by the payload type.
972 VectorType payloadTy = load.getValueType();
973 if (!payloadTy) {
974 load.emitWarning("Not propagating, non-vector payload supplied.");
975 return;
976 }
977 auto uArch = getUArch(getChipStr(load).value_or(""));
978 const int subgroupSize = uArch->getSubgroupSize();
979 SmallVector<int> instData{subgroupSize};
980 if (auto chunkSize = load.getChunkSize().value_or(0); chunkSize > 1)
981 instData.push_back(chunkSize);
982 else if (auto srcTdescTy =
983 dyn_cast<xegpu::TensorDescType>(load.getSourceType())) {
984 if (srcTdescTy.getChunkSizeAsInt() > 1)
985 instData.push_back(chunkSize);
986 }
987
988 if (layoutKind == LayoutKind::InstData)
989 loadLayout =
990 LayoutInfo(xegpu::LayoutAttr::get(load.getContext(), instData));
991 else
992 loadLayout = getSIMTLayoutInforForScatterIO(
993 payloadTy, uArch, uArch->getGeneralPackedFormatBitSize());
994
995 // Mask operand should have 1D default layout.
996 maskLayout = getDefaultSIMTLayoutInfo(load->getContext(), 1, subgroupSize);
997
998 load.setLayoutAttr(dyn_cast<xegpu::DistributeLayoutAttr>(loadLayout.get()));
999 }
1000 // Propagate the new layout to the tensor descriptor operand.
1001 if (isa<xegpu::TensorDescType>(load.getSourceType()))
1002 propagateIfChanged(operands[0], operands[0]->meet(loadLayout));
1003 // Propagate the new layout to the mask and optional offset operand.
1004 propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
1005 if (load.getOffsets())
1006 propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
1007}
1008
1009/// Propagate the layout of the descriptor to the vector offset operand in
1010/// CreateDescOp.
1011void LayoutInfoPropagation::visitCreateDescOp(
1012 xegpu::CreateDescOp createDesc, ArrayRef<LayoutInfoLattice *> operands,
1013 ArrayRef<const LayoutInfoLattice *> results) {
1014 LayoutInfo descLayout = results[0]->getValue();
1015 // Need the layout of the descriptor to propagate to the operands.
1016 if (!descLayout.isAssigned())
1017 return;
1018 auto uArch = getUArch(getChipStr(createDesc).value_or(""));
1019 // For offset operand propagate 1D default layout.
1020 LayoutInfo layout = getDefaultSIMTLayoutInfo(createDesc->getContext(), 1,
1021 uArch->getSubgroupSize());
1022 propagateIfChanged(operands[1], operands[1]->meet(layout));
1023}
1024
1025/// Set the layout for the value, tensor descriptor, offset and mask operands in
1026/// the StoreScatterOp.
1027void LayoutInfoPropagation::visitStoreScatterOp(
1028 xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands,
1029 ArrayRef<const LayoutInfoLattice *> results) {
1030
1031 LayoutInfo payloadLayout;
1032 LayoutInfo maskLayout;
1033 xegpu::DistributeLayoutAttr anchorLayout = storeScatter.getLayoutAttr();
1034 if (hasParamsOfLayoutKind(anchorLayout)) {
1035 payloadLayout = LayoutInfo(anchorLayout);
1036 maskLayout = payloadLayout;
1037 } else {
1038 // Currently, for 2D StoreScatterOp we expect that the height dimension of
1039 // the tensor descriptor is equal to the subgroup size. This is ensured by
1040 // the op verifier.
1041 VectorType payloadTy = storeScatter.getValueType();
1042 if (!payloadTy) {
1043 storeScatter.emitWarning("Not propagating, non-vector payload supplied.");
1044 return;
1045 }
1046
1047 auto uArch = getUArch(getChipStr(storeScatter).value_or(""));
1048 const int subgroupSize = uArch->getSubgroupSize();
1049
1050 if (layoutKind == LayoutKind::InstData) {
1051 SmallVector<int> instData{subgroupSize};
1052 if (auto chunkSize = storeScatter.getChunkSize().value_or(0);
1053 chunkSize > 1)
1054 instData.push_back(chunkSize);
1055 else if (auto dstTdescTy = dyn_cast<xegpu::TensorDescType>(
1056 storeScatter.getDestType())) {
1057 if (dstTdescTy.getChunkSizeAsInt() > 1)
1058 instData.push_back(chunkSize);
1059 }
1060 payloadLayout = LayoutInfo(
1061 xegpu::LayoutAttr::get(storeScatter.getContext(), instData));
1062 } else {
1063 auto payloadShape = payloadTy.getShape();
1064 if (payloadShape.size() > 1)
1065 assert(payloadShape[0] == subgroupSize &&
1066 "Expected the first dimension of 2D tensor descriptor to be "
1067 "equal to "
1068 "subgroup size.");
1069 payloadLayout = getSIMTLayoutInforForScatterIO(
1070 payloadTy, uArch, uArch->getGeneralPackedFormatBitSize());
1071 }
1072
1073 maskLayout =
1074 getDefaultSIMTLayoutInfo(storeScatter->getContext(), 1, subgroupSize);
1075
1076 storeScatter.setLayoutAttr(
1077 dyn_cast<xegpu::DistributeLayoutAttr>(payloadLayout.get()));
1078 }
1079 // Propagate the payload operand layout
1080 propagateIfChanged(operands[0], operands[0]->meet(payloadLayout));
1081 // Propagate the destination (if tdesc) operand layout
1082 if (isa<xegpu::TensorDescType>(storeScatter.getDestType()))
1083 propagateIfChanged(operands[1], operands[1]->meet(payloadLayout));
1084 // Propagate the new layout to the mask and optional offset operand.
1085 propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
1086 if (storeScatter.getOffsets())
1087 propagateIfChanged(operands[3], operands[3]->meet(maskLayout));
1088}
1089
1090namespace {
1091//===----------------------------------------------------------------------===//
1092// RunLayoutInfoPropagation
1093//===----------------------------------------------------------------------===//
1094
1095/// Driver class for running the LayoutInfoPropagation analysis.
1096class RunLayoutInfoPropagation {
1097public:
1098 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RunLayoutInfoPropagation)
1099
1100 RunLayoutInfoPropagation(Operation *op, LayoutKind layoutKind) : target(op) {
1101 SymbolTableCollection symbolTable;
1102 loadBaselineAnalyses(solver);
1103 solver.load<LayoutInfoPropagation>(symbolTable, layoutKind);
1104 (void)solver.initializeAndRun(op);
1105 }
1106
1107 LayoutInfo getLayoutInfo(Value val);
1108
1109 void printAnalysisResult(llvm::raw_ostream &os);
1110
1111private:
1112 DataFlowSolver solver;
1113 const Operation *target;
1114};
1115} // namespace
1116
1117LayoutInfo RunLayoutInfoPropagation::getLayoutInfo(Value val) {
1118 auto *state = solver.lookupState<LayoutInfoLattice>(val);
1119 if (!state)
1120 return {};
1121 return state->getValue();
1122}
1123
1124// Print the analysis result for debugging purposes.
1125void RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) {
1126 auto printFunctionResult = [&](FunctionOpInterface funcOp) {
1127 os << "function: " << funcOp.getName() << ":\n";
1128 // Function arguments
1129 for (BlockArgument arg : funcOp.getArguments()) {
1130 LayoutInfo layout = getLayoutInfo(arg);
1131 os << "argument: " << arg << "\n";
1132 os << "layout : ";
1133 layout.print(os);
1134 os << "\n";
1135 }
1136 // Function ops
1137 funcOp.walk([&](Operation *op) {
1138 // Skip ops that do not have results
1139 if (op->getResults().empty())
1140 return;
1141 os << "op : ";
1142 // For control-flow ops, print the op name only.
1143 if (isa<BranchOpInterface>(op) || isa<RegionBranchOpInterface>(op))
1144 os << op->getName();
1145 else
1146 op->print(os);
1147 os << "\n";
1148 // Print the layout for each result.
1149 for (auto [i, r] : llvm::enumerate(op->getResults())) {
1150 LayoutInfo layout = getLayoutInfo(r);
1151 os << "layout for result #" << i << ": ";
1152 layout.print(os);
1153 os << "\n";
1154 }
1155 });
1156 };
1157
1158 SmallVector<FunctionOpInterface> funcOps;
1159 if (auto modOp = dyn_cast<ModuleOp>(target)) {
1160 for (auto funcOp : modOp.getOps<FunctionOpInterface>())
1161 funcOps.push_back(funcOp);
1162
1163 // Collect all GpuFuncOps in the module.
1164 for (auto gpuModOp : modOp.getOps<gpu::GPUModuleOp>()) {
1165 for (auto gpuFuncOp : gpuModOp.getOps<FunctionOpInterface>())
1166 funcOps.push_back(gpuFuncOp);
1167 }
1168 }
1169 // Print the analysis result for each function.
1170 for (FunctionOpInterface funcOp : funcOps)
1171 printFunctionResult(funcOp);
1172}
1173
1174using GetLayoutFnTy = function_ref<xegpu::DistributeLayoutAttr(Value)>;
1175/// Update an operation with the layout of its results. If the result type is
1176/// a vector type, a temporary layout attribute is added to the operation. If
1177/// the result type is a tensor descriptor type, the type is updated with the
1178/// layout attribute. The users of the result are also updated with the layout
1179/// attribute.
1180static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
1181 GetLayoutFnTy getLayoutOfValue) {
1182 // Region ops (like scf.for) are already handled by the
1183 // updateControlFlowOps.
1184 if (mlir::isa<mlir::RegionBranchOpInterface>(op))
1185 return success();
1186
1187 // Iterate over all the results.
1188 for (OpResult result : op->getResults()) {
1189 Type resultType = result.getType();
1190 // Layouts are needed only for vector and tensor descriptor types.
1191 if (!isa<VectorType, xegpu::TensorDescType>(resultType))
1192 continue;
1193 // If the result has no layout but has users, emit a warning and continue.
1194 xegpu::DistributeLayoutAttr layout = getLayoutOfValue(result);
1195 if (!layout && result.getNumUses() > 0) {
1196 op->emitWarning("op has users but no layout assigned for its result");
1197 continue;
1198 }
1199 // If the result is a tensor descriptor type, update the tensor desc type
1200 // with layout.
1201 if (auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType)) {
1202 auto typeWithLayout = xegpu::TensorDescType::get(
1203 tensorDescTy.getContext(), tensorDescTy.getShape(),
1204 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
1205 result.setType(typeWithLayout);
1206 continue;
1207 }
1208 // If the result is a vector type, add a temporary layout attribute to the
1209 // op.
1211 }
1212 return success();
1213}
1214
1215/// Region ops like scf.for need special handling because they have blocks
1216/// inside. If the blocks have tensor descriptor type as block arguments,
1217/// thier types must be updated. Also region op can have results that may not
1218/// have any users (e.g. A and B tiles). They are not assigned a layout by
1219/// layout analysis because they have no users. However inside the region op
1220/// corresponding block arguments for these results do have layouts.
1221/// Therefore, in this case we still need to update the result types with the
1222/// layout attribute. This function function updates the internal block
1223/// arguments and the result types of the region op with the assigned layouts.
1224/// clang-format off
1225/// Example: scf.for ... iter_args(...) -> (out types) {
1226/// ^bb0(block types):
1227/// ...
1228/// scf.yield ... : (yield types)
1229/// }
1230/// clang-format on
1231/// In this example, at scf.yield, control-flow can transfer to two successor
1232/// regions. One is the ^bb0 (for loop body) and the other is the scf.for op
1233/// itself (yield the results). So we update both the block arguments of the
1234/// successor region (i.e. block types) and the result types of the scf.for op
1235/// (i.e. out types). Note that yield types are updated by respective
1236/// producers inside bb0.
1237static LogicalResult
1239 mlir::RegionBranchTerminatorOpInterface terminator,
1240 GetLayoutFnTy getLayoutOfValue) {
1241 // Only process if the terminator is inside a region branch op.
1242 auto branchOp = dyn_cast<RegionBranchOpInterface>(terminator->getParentOp());
1243 if (!branchOp)
1244 return success();
1245
1247 branchOp.getSuccessorOperandInputMapping(mapping,
1248 RegionBranchPoint(terminator));
1249 for (const auto &[successorOperand, successorInputs] : mapping) {
1250 for (Value successorInput : successorInputs) {
1251 Type inputType = successorInput.getType();
1252 // We only need to operate on tensor descriptor or vector types.
1253 if (!isa<xegpu::TensorDescType, VectorType>(inputType))
1254 continue;
1255 xegpu::DistributeLayoutAttr successorInputLayout =
1256 getLayoutOfValue(successorInput);
1257 xegpu::DistributeLayoutAttr successorOperandLayout =
1258 getLayoutOfValue(successorOperand->get());
1259
1260 // If either of the layouts is not assigned, we cannot proceed.
1261 if (!successorOperandLayout) {
1262 LLVM_DEBUG(DBGS() << "No layout assigned for forwarded operand in "
1263 "branch terminator: "
1264 << successorOperand->get() << "\n");
1265 return failure();
1266 }
1267 // We expect the layouts to match.
1268 if (successorInputLayout &&
1269 successorInputLayout != successorOperandLayout) {
1270 LLVM_DEBUG(DBGS() << "Conflicting layouts for region argument and "
1271 "operand forwarded as the argument: "
1272 << successorInputLayout << " vs "
1273 << successorOperandLayout << "\n");
1274 return failure();
1275 }
1276 // Get tensor descriptor type with the layout.
1277 if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(inputType)) {
1278 auto newTdescTy = xegpu::TensorDescType::get(
1279 tdescTy.getContext(), tdescTy.getShape(), tdescTy.getElementType(),
1280 tdescTy.getEncoding(), successorOperandLayout);
1281 successorInput.setType(newTdescTy);
1282 continue;
1283 }
1284 // If the type is a vector type and this region argument is an OpResult,
1285 // set the layout attribute on the OpResult.
1286 if (auto result = dyn_cast<OpResult>(successorInput))
1287 xegpu::setDistributeLayoutAttr(result, successorOperandLayout);
1288 }
1289 }
1290 return success();
1291}
1292
1293/// Update the function arguments and results with the layouts.
1294static LogicalResult updateFunctionOpInterface(mlir::OpBuilder &builder,
1295 mlir::FunctionOpInterface funcOp,
1296 GetLayoutFnTy getLayoutOfValue) {
1297 SmallVector<Type> newArgTypes;
1298 // Update the function arguments.
1299 for (BlockArgument arg : funcOp.getArguments()) {
1300 Type argType = arg.getType();
1301 newArgTypes.push_back(argType);
1302 if (!isa<VectorType, xegpu::TensorDescType>(argType))
1303 continue;
1304 xegpu::DistributeLayoutAttr layout = getLayoutOfValue(arg);
1305 if (!layout) {
1306 LLVM_DEBUG(DBGS() << "Expecting layout for function argument: " << arg
1307 << " but got none.\n");
1308 return failure();
1309 }
1310 if (auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(argType)) {
1311 auto newTdescTy = xegpu::TensorDescType::get(
1312 tensorDescTy.getContext(), tensorDescTy.getShape(),
1313 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
1314 arg.setType(newTdescTy);
1315 newArgTypes.back() = newTdescTy;
1316 }
1317 }
1318 // Update the function type with the new argument types.
1319 // NOTE: We assume that function results are not expected to have layouts.
1320 funcOp.setType(FunctionType::get(funcOp.getContext(), newArgTypes,
1321 funcOp.getResultTypes()));
1322 return success();
1323}
1324
1325namespace {
1326struct XeGPUPropagateLayoutPass final
1327 : public xegpu::impl::XeGPUPropagateLayoutBase<XeGPUPropagateLayoutPass> {
1328 XeGPUPropagateLayoutPass() = default;
1329 XeGPUPropagateLayoutPass(const XeGPUPropagateLayoutPass &other) = default;
1330 XeGPUPropagateLayoutPass(xegpu::XeGPUPropagateLayoutOptions options)
1331 : XeGPUPropagateLayoutBase(options) {}
1332 void runOnOperation() override;
1333};
1334
1335} // namespace
1336
1337void XeGPUPropagateLayoutPass::runOnOperation() {
1338 LayoutKind layoutKind;
1339 if (this->layoutKind == "lane") {
1340 layoutKind = LayoutKind::Lane;
1341 } else if (this->layoutKind == "inst") {
1342 layoutKind = LayoutKind::InstData;
1343 } else if (this->layoutKind == "subgroup") {
1344 layoutKind = LayoutKind::Subgroup;
1345 } else {
1346 getOperation()->emitError("Unsupported layout kind option: " +
1347 this->layoutKind);
1348 signalPassFailure();
1349 return;
1350 }
1351 RunLayoutInfoPropagation analysis(getOperation(), layoutKind);
1352 // Print the analysis result and exit. (for debugging purposes)
1353 if (printOnly) {
1354 auto &os = llvm::outs();
1355 analysis.printAnalysisResult(os);
1356 return;
1357 }
1358 // Helper to convert LayoutInfo to xegpu::LayoutAttr.
1359 auto getXeGPULayoutForValue = [&](Value val) -> xegpu::DistributeLayoutAttr {
1360 LayoutInfo layout = analysis.getLayoutInfo(val);
1361 if (!layout.isAssigned())
1362 return {};
1363 xegpu::DistributeLayoutAttr layoutAttr =
1364 cast<xegpu::DistributeLayoutAttr>(layout.get());
1365 if (layout.isSliceLayout())
1366 return cast<xegpu::SliceAttr>(layoutAttr);
1367 return cast<xegpu::LayoutAttr>(layoutAttr);
1368 };
1369
1370 mlir::OpBuilder builder(&getContext());
1371 Operation *op = getOperation();
1372 auto walkResult = op->walk([&](mlir::Block *block) -> WalkResult {
1373 for (mlir::Operation &op : llvm::reverse(block->getOperations())) {
1374 LogicalResult r = success();
1376 .Case<mlir::RegionBranchTerminatorOpInterface>(
1377 [&](mlir::RegionBranchTerminatorOpInterface branchTermOp) {
1378 r = updateControlFlowOps(builder, branchTermOp,
1379 getXeGPULayoutForValue);
1380 })
1381 .Case<mlir::FunctionOpInterface>(
1382 [&](mlir::FunctionOpInterface funcOp) {
1383 r = updateFunctionOpInterface(builder, funcOp,
1384 getXeGPULayoutForValue);
1385 })
1386 .Default([&](Operation *op) {
1387 r = updateOp(builder, op, getXeGPULayoutForValue);
1388 });
1389 if (failed(r)) {
1390 op.emitError("Failed to update operation with the layout.");
1391 return WalkResult::interrupt();
1392 }
1393 }
1394 return WalkResult::advance();
1395 });
1396 if (walkResult.wasInterrupted()) {
1397 signalPassFailure();
1398 return;
1399 }
1400}
return success()
#define DBGS()
Definition Hoisting.cpp:32
lhs
b getContext())
auto load
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition TypeID.h:331
static LogicalResult updateControlFlowOps(mlir::OpBuilder &builder, mlir::RegionBranchTerminatorOpInterface terminator, GetLayoutFnTy getLayoutOfValue)
Region ops like scf.for need special handling because they have blocks inside.
function_ref< xegpu::DistributeLayoutAttr(Value)> GetLayoutFnTy
static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op, GetLayoutFnTy getLayoutOfValue)
Update an operation with the layout of its results.
static LogicalResult updateFunctionOpInterface(mlir::OpBuilder &builder, mlir::FunctionOpInterface funcOp, GetLayoutFnTy getLayoutOfValue)
Update the function arguments and results with the layouts.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:309
OpListType & getOperations()
Definition Block.h:147
The general data-flow analysis solver.
const StateT * lookupState(AnchorT anchor) const
Lookup an analysis state for the given lattice anchor.
AnalysisT * load(Args &&...args)
Load an analysis into the solver. Return the analysis instance.
LogicalResult initializeAndRun(Operation *top)
Initialize the children analyses starting from the provided top-level operation and run the analysis ...
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:207
This class represents an operand of an operation.
Definition Value.h:257
This is a value defined by a result of an operation.
Definition Value.h:457
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:383
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
void print(raw_ostream &os, const OpPrintingFlags &flags={})
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:797
result_range getResults()
Definition Operation.h:415
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
This class represents a successor of a region.
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition Value.h:108
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
This class represents a lattice holding a specific value of type ValueT.
A sparse (backward) data-flow analysis for propagating SSA value lattices backwards across the IR by ...
SparseBackwardDataFlowAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
void loadBaselineAnalyses(DataFlowSolver &solver)
Populates a DataFlowSolver with analyses that are required to ensure user-defined analyses are run pr...
Definition Utils.h:29
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
detail::InFlightRemark analysis(Location loc, RemarkOpts opts)
Report an optimization analysis remark.
Definition Remarks.h:579
const uArch * getUArch(llvm::StringRef archName)
void setDistributeLayoutAttr(const OpResult &Result, const DistributeLayoutAttr layout)
[to-be-deprecated] Sets the DistributeLayoutAttr for a given OpResult user should use setAnchorLayout...
int getLargestDivisor(T dim, ArrayRef< T > candidates, ArrayRef< T > candidateMultiples={})
Helper Function to find a proper instruction multiple for the user-supplied sg-level data shape (dive...
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)
If tensor descriptor has a layout attribute it is used in SIMT mode.
Include the generated interface declarations.
DenseMap< OpOperand *, SmallVector< Value > > RegionBranchSuccessorMapping
A mapping from successor operands to successor inputs.
bool operator==(StringAttr lhs, std::nullptr_t)
Define comparisons for StringAttr against nullptr and itself to avoid the StringRef overloads from be...
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:144
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
virtual int getSubgroupSize() const =0