MLIR 23.0.0git
XeGPUPropagateLayout.cpp
Go to the documentation of this file.
1//===- XeGPUPropagateLayout.cpp - XeGPU Layout Propagation ------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
21#include "mlir/IR/Attributes.h"
22#include "mlir/IR/Builders.h"
25#include "mlir/IR/Operation.h"
26#include "mlir/IR/Value.h"
27#include "mlir/IR/Visitors.h"
31#include "mlir/Support/LLVM.h"
32#include "llvm/ADT/ArrayRef.h"
33#include "llvm/ADT/STLExtras.h"
34#include "llvm/ADT/SmallSet.h"
35#include "llvm/ADT/SmallVector.h"
36#include "llvm/ADT/TypeSwitch.h"
37#include "llvm/Support/Casting.h"
38#include "llvm/Support/Debug.h"
39#include "llvm/Support/LogicalResult.h"
40#include "llvm/Support/raw_ostream.h"
41
42namespace mlir {
43namespace xegpu {
44#define GEN_PASS_DEF_XEGPUPROPAGATELAYOUT
45#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
46} // namespace xegpu
47} // namespace mlir
48
49#define DEBUG_TYPE "xegpu-propagate-layout"
50#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
51
52using namespace mlir;
53using namespace mlir::dataflow;
54
55namespace {
56
57//===----------------------------------------------------------------------===//
58// LayoutInfo
59//===----------------------------------------------------------------------===//
60
61/// Helper class for tracking the analysis state of an mlir value. For layout
62/// propagation, the analysis state is simply the distribution layout of
63/// each value. The distribution layout information is encapsulated using
64/// xegpu::DistributeLayoutAttr class which can hold information about any type
65/// of distribution layout that XeGPU dialect supports. Purpose of this analysis
66/// to propagate some unique distribution layout for each value in the program
67/// starting from a set of anchor operations (like DPAS, StoreNd, etc.). Note
68/// that analysis will reach a fixed point when all values are reached some
69/// layout and, analysis does not try to modify any already assigned layouts.
70///
71/// Given this, LayoutInfo satisifies the following properties:
72/// 1) A LayoutInfo value can be in one of two states - `assigned` or `not
73/// assigned`.
74/// 2) Two LayoutInfo values are equal if they are both assigned or
75/// both not assigned. The concrete value of assigned state does not matter.
76/// 3) The meet operator works as follows:
77/// - If current state is assigned, return the current state. (already
78/// a unique layout is assigned. don't change it)
79/// - Otherwise, return the other state.
80
81struct LayoutInfo {
82private:
83 xegpu::DistributeLayoutAttr storage = nullptr;
84
85public:
86 LayoutInfo() = default;
87 LayoutInfo(const xegpu::DistributeLayoutAttr &layout) : storage(layout) {}
88
89 // Two lattice values are equal if they have `some` layout. The actual
90 // content of the layout does not matter.
91 bool operator==(const LayoutInfo &other) const {
92 return this->isAssigned() == other.isAssigned();
93 }
94
95 static LayoutInfo meet(const LayoutInfo &lhs, const LayoutInfo &rhs);
96
97 static LayoutInfo join(const LayoutInfo &lhs, const LayoutInfo &rhs);
98
99 void print(raw_ostream &os) const;
100
101 bool isAssigned() const { return storage != nullptr; }
102
103 LayoutInfo transpose(ArrayRef<int64_t> permutation) const;
104
105 SmallVector<int> getLaneLayout() const;
106
107 SmallVector<int> getLaneData() const;
108
109 SmallVector<int> getInstData() const;
110
111 SmallVector<int> getSgLayout() const;
112
113 SmallVector<int> getSgData() const;
114
115 SmallVector<int> getOrder() const;
116
117 bool isSliceLayout() const {
118 if (!isAssigned())
119 return false;
120 return isa<xegpu::SliceAttr>(storage);
121 }
122
123 int64_t getRank() const {
124 if (!isAssigned())
125 return -1;
126 return storage.getRank();
127 }
128
129 Attribute get() { return storage; }
130 void set(const xegpu::DistributeLayoutAttr &layout) { storage = layout; }
131};
132
133SmallVector<int> LayoutInfo::getLaneLayout() const {
134 if (!isAssigned())
135 return {};
136 return llvm::map_to_vector(storage.getEffectiveLaneLayoutAsInt(),
137 [](int64_t val) { return static_cast<int>(val); });
138}
139
140SmallVector<int> LayoutInfo::getLaneData() const {
141 if (!isAssigned())
142 return {};
143 return llvm::map_to_vector(storage.getEffectiveLaneDataAsInt(),
144 [](int64_t val) { return static_cast<int>(val); });
145}
146
147SmallVector<int> LayoutInfo::getInstData() const {
148 if (!isAssigned())
149 return {};
150 return llvm::map_to_vector(storage.getEffectiveInstDataAsInt(),
151 [](int64_t val) { return static_cast<int>(val); });
152}
153
154SmallVector<int> LayoutInfo::getSgLayout() const {
155 if (!isAssigned())
156 return {};
157 return llvm::map_to_vector(storage.getEffectiveSgLayoutAsInt(),
158 [](int64_t val) { return static_cast<int>(val); });
159}
160
161SmallVector<int> LayoutInfo::getSgData() const {
162 if (!isAssigned())
163 return {};
164 return llvm::map_to_vector(storage.getEffectiveSgDataAsInt(),
165 [](int64_t val) { return static_cast<int>(val); });
166}
167
168SmallVector<int> LayoutInfo::getOrder() const {
169 if (!isAssigned() || !storage.getOrder())
170 return {};
171 return llvm::map_to_vector(storage.getOrder().asArrayRef(),
172 [](int64_t val) { return static_cast<int>(val); });
173}
174
175void LayoutInfo::print(raw_ostream &os) const {
176 if (isAssigned()) {
177 os << storage;
178 } else {
179 os << "Not assigned.";
180 }
181}
182
183LayoutInfo LayoutInfo::meet(const LayoutInfo &lhs, const LayoutInfo &rhs) {
184 if (!lhs.isAssigned())
185 return rhs;
186 return lhs;
187}
188
189/// Since this is a backward analysis, join method is not used.
190LayoutInfo LayoutInfo::join(const LayoutInfo &lhs, const LayoutInfo &rhs) {
191 llvm_unreachable("Join should not be triggered by layout propagation.");
192}
193
194/// Construct a new layout with the transposed inst_data or lane_layout,
195/// lane_data.
196LayoutInfo LayoutInfo::transpose(ArrayRef<int64_t> permutation) const {
197 if (!isAssigned())
198 return {};
199 // Check if the permutation is valid.
200 llvm::SmallSet<int64_t, 4> seen(permutation.begin(), permutation.end());
201 bool hasDuplicates = seen.size() != permutation.size();
202 bool withinRange = llvm::all_of(permutation, [&](int64_t idx) {
203 return idx >= 0 && idx < static_cast<int64_t>(permutation.size());
204 });
205
206 if (!withinRange || hasDuplicates) {
207 assert(false && "Invalid permutation for transpose.");
208 return {};
209 }
210
211 SmallVector<int32_t> laneLayout;
212 SmallVector<int32_t> laneData;
213 SmallVector<int32_t> instData;
214 SmallVector<int32_t> sgLayout;
217
218 for (int64_t idx : permutation) {
219 if (getLaneLayout().size()) {
220 laneLayout.push_back(static_cast<int32_t>(getLaneLayout()[idx]));
221 laneData.push_back(static_cast<int32_t>(getLaneData()[idx]));
222 }
223 if (getInstData().size())
224 instData.push_back(static_cast<int32_t>(getInstData()[idx]));
225 if (getSgData().size()) {
226 sgLayout.push_back(static_cast<int32_t>(getSgLayout()[idx]));
227 sgData.push_back(static_cast<int32_t>(getSgData()[idx]));
228 }
229 if (getOrder().size()) {
230 order.push_back(static_cast<int32_t>(getOrder()[idx]));
231 }
232 }
233 auto orderAttr = order.size()
234 ? DenseI32ArrayAttr::get(storage.getContext(), order)
235 : nullptr;
236 xegpu::LayoutAttr layoutAttr;
237 if (getLaneLayout().size())
238 layoutAttr =
239 xegpu::LayoutAttr::get(storage.getContext(), laneLayout, laneData);
240 if (getInstData().size())
241 layoutAttr = xegpu::LayoutAttr::get(storage.getContext(), instData);
242 if (getSgData().size())
243 layoutAttr = xegpu::LayoutAttr::get(
244 storage.getContext(),
245 DenseI32ArrayAttr::get(storage.getContext(), sgLayout),
246 DenseI32ArrayAttr::get(storage.getContext(), sgData),
247 /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
248 /*lane_data =*/nullptr, orderAttr);
249 return LayoutInfo(layoutAttr);
250}
251
252//===----------------------------------------------------------------------===//
253// LayoutInfoLattice
254//===----------------------------------------------------------------------===//
255
256/// Lattice holding the LayoutInfo for each value.
257struct LayoutInfoLattice : public Lattice<LayoutInfo> {
259 using Lattice::Lattice;
260};
261
262/// Helper Functions to get default layouts. A `default layout` is a layout that
263/// is assigned to a value when the layout is not fixed by some anchor operation
264/// (like DPAS).
265
266/// Helper Function to get the default layout for uniform values like constants.
267/// For 1D vector, lane_layout is [subgroupSize] and lane_data is [1].
268/// For 2D vector, lane_layout is [1, subgroupSize] and lane_data is [1, 1].
269static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
270 unsigned rank,
271 const xegpu::uArch::uArch *uArch) {
272 assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
273 if (rank == 1) {
274 return LayoutInfo(
275 xegpu::LayoutAttr::get(ctx, {uArch->getSubgroupSize()}, {1}));
277 return LayoutInfo(
278 xegpu::LayoutAttr::get(ctx, {1, uArch->getSubgroupSize()}, {1, 1}));
281static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
282 unsigned rank, int subgroupSize) {
283 assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
284 if (rank == 1) {
285 return LayoutInfo(xegpu::LayoutAttr::get(ctx, {subgroupSize}, {1}));
286 }
287 return LayoutInfo(xegpu::LayoutAttr::get(ctx, {1, subgroupSize}, {1, 1}));
288}
290/// Helper to get the default layout for 2D block operations.
291template <typename Ty>
292static LayoutInfo getSIMTLayoutInfoBlockIO(Ty ty,
294 unsigned packingSize) {
295 // Expecting a 1D or 2D vector.
296 assert((ty.getRank() == 1 || ty.getRank() == 2) &&
297 "Expected 1D or 2D vector.");
298 // Expecting int or float element type.
299 assert(ty.getElementType().isIntOrFloat() &&
300 "Expected int or float element type.");
301 // If the rank is 1, then return default layout for 1D vector.
302 if (ty.getRank() == 1)
303 return getDefaultSIMTLayoutInfo(ty.getContext(), 1, uArch);
304 // Packing factor is determined by the element type bitwidth.
305 unsigned bitwidth = ty.getElementType().getIntOrFloatBitWidth();
306 int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
307 return LayoutInfo(xegpu::LayoutAttr::get(
308 ty.getContext(), {1, uArch->getSubgroupSize()}, {1, packingFactor}));
309}
310
311//===----------------------------------------------------------------------===//
312// LayoutInfoPropagation
313//===----------------------------------------------------------------------===//
314
315/// Backward data flow analysis to propagate the lane_layout and lane_data of
316/// each value in the program. Currently, the layouts for operands DPAS,
317/// StoreNd, and StoreScatter are fixed (known before propagation). Purpose of
318/// this analysis is to propagate those known layouts to all their producers and
319/// (other) consumers.
320class LayoutInfoPropagation
321 : public SparseBackwardDataFlowAnalysis<LayoutInfoLattice> {
322private:
324 void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
326
327 void visitStoreNdOp(xegpu::StoreNdOp store,
331 void visitStoreScatterOp(xegpu::StoreScatterOp storeScatter,
335 void visitLoadNdOp(xegpu::LoadNdOp load,
339 void visitLoadGatherOp(xegpu::LoadGatherOp load,
343 void visitTransposeOp(vector::TransposeOp transpose,
346
347 void visitVectorBitcastOp(vector::BitCastOp bitcast,
350
351 void visitCreateDescOp(xegpu::CreateDescOp createDesc,
354
355 void visitUpdateNdOffsetOp(xegpu::UpdateNdOffsetOp updateNdOffset,
358
359 void visitPrefetchNdOp(xegpu::PrefetchNdOp prefetch,
362
363 void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction,
366
367 void visitVectorBroadCastOp(vector::BroadcastOp broadcast,
370 void visitShapeCastOp(vector::ShapeCastOp shapeCast,
373 void
374 visitInsertStridedSliceOp(vector::InsertStridedSliceOp insertStridedSlice,
377
378 void visitLoadMatrixOp(xegpu::LoadMatrixOp load,
381
382 void visitStoreMatrixOp(xegpu::StoreMatrixOp store,
385
386 void visitLoadGatherOp(xegpu::LoadMatrixOp load,
389
390 void visitStoreScatterOp(xegpu::StoreMatrixOp store,
393
394 bool hasParamsOfLayoutKind(xegpu::DistributeLayoutAttr anchorLayout);
395
396public:
397 LayoutInfoPropagation(DataFlowSolver &solver,
398 SymbolTableCollection &symbolTable,
399 xegpu::LayoutKind layoutKind)
400 : SparseBackwardDataFlowAnalysis(solver, symbolTable),
401 layoutKind(layoutKind) {}
403
404 LogicalResult
405 visitOperation(Operation *op, ArrayRef<LayoutInfoLattice *> operands,
406 ArrayRef<const LayoutInfoLattice *> results) override;
407
408 void visitBranchOperand(OpOperand &operand) override {};
409
410 void visitCallOperand(OpOperand &operand) override {};
411
412 void
413 visitNonControlFlowArguments(RegionSuccessor &successor,
414 ArrayRef<BlockArgument> arguments) override {};
415
416 void visitExternalCall(CallOpInterface call,
417 ArrayRef<LayoutInfoLattice *> operands,
418 ArrayRef<const LayoutInfoLattice *> results) override {
419 };
420
421 void setToExitState(LayoutInfoLattice *lattice) override {
422 (void)lattice->meet(LayoutInfo());
423 }
424};
425} // namespace
426
427LogicalResult LayoutInfoPropagation::visitOperation(
428 Operation *op, ArrayRef<LayoutInfoLattice *> operands,
429 ArrayRef<const LayoutInfoLattice *> results) {
431 .Case(
432 [&](xegpu::DpasOp dpasOp) { visitDpasOp(dpasOp, operands, results); })
433 .Case([&](xegpu::StoreNdOp storeNdOp) {
434 visitStoreNdOp(storeNdOp, operands, results);
435 })
436 .Case([&](xegpu::StoreScatterOp storeScatterOp) {
437 visitStoreScatterOp(storeScatterOp, operands, results);
438 })
439 .Case([&](xegpu::LoadNdOp loadNdOp) {
440 visitLoadNdOp(loadNdOp, operands, results);
441 })
442 .Case([&](xegpu::LoadGatherOp loadGatherOp) {
443 visitLoadGatherOp(loadGatherOp, operands, results);
444 })
445 .Case([&](xegpu::CreateDescOp createDescOp) {
446 visitCreateDescOp(createDescOp, operands, results);
447 })
448 .Case([&](xegpu::UpdateNdOffsetOp updateNdOffsetOp) {
449 visitUpdateNdOffsetOp(updateNdOffsetOp, operands, results);
450 })
451 .Case([&](xegpu::PrefetchNdOp prefetchNdOp) {
452 visitPrefetchNdOp(prefetchNdOp, operands, results);
453 })
454 .Case([&](vector::TransposeOp transposeOp) {
455 visitTransposeOp(transposeOp, operands, results);
456 })
457 .Case([&](vector::BitCastOp bitcastOp) {
458 visitVectorBitcastOp(bitcastOp, operands, results);
459 })
460 .Case([&](vector::MultiDimReductionOp reductionOp) {
461 visitVectorMultiReductionOp(reductionOp, operands, results);
462 })
463 .Case([&](vector::BroadcastOp broadcastOp) {
464 visitVectorBroadCastOp(broadcastOp, operands, results);
465 })
466 .Case([&](vector::ShapeCastOp shapeCastOp) {
467 visitShapeCastOp(shapeCastOp, operands, results);
468 })
469 .Case([&](vector::InsertStridedSliceOp insertStridedSliceOp) {
470 visitInsertStridedSliceOp(insertStridedSliceOp, operands, results);
471 })
472 .Case([&](xegpu::LoadMatrixOp loadMatrixOp) {
473 visitLoadMatrixOp(loadMatrixOp, operands, results);
474 })
475 .Case([&](xegpu::StoreMatrixOp storeMatrixOp) {
476 visitStoreMatrixOp(storeMatrixOp, operands, results);
477 })
478 // All other ops.
479 .Default([&](Operation *op) {
480 for (const LayoutInfoLattice *resultInfo : results) {
481 if (!resultInfo->getValue().isAssigned())
482 continue;
483 for (auto [operandInfo, operand] :
484 llvm::zip(operands, op->getOpOperands())) {
485 // If the operand type is not a vector or tensor descriptor, skip
486 // it.
487 if (!isa<xegpu::TensorDescType, VectorType>(
488 operand.get().getType()))
489 continue;
490 // Propagate the result layout to the operand.
491 meet(operandInfo, *resultInfo);
492 }
493 }
494 });
495
496 return success();
497}
498
499bool LayoutInfoPropagation::hasParamsOfLayoutKind(
500 xegpu::DistributeLayoutAttr anchorLayout) {
501 if (anchorLayout == nullptr) {
502 return false;
503 }
504 if (layoutKind == xegpu::LayoutKind::InstData) {
505 return !(anchorLayout.getEffectiveInstDataAsInt().empty());
506 } else if (layoutKind == xegpu::LayoutKind::Lane) {
507 return !(anchorLayout.getEffectiveLaneLayoutAsInt().empty() ||
508 anchorLayout.getEffectiveLaneDataAsInt().empty());
509 } else if (layoutKind == xegpu::LayoutKind::Subgroup) {
510 return !(anchorLayout.getEffectiveSgLayoutAsInt().empty() ||
511 anchorLayout.getEffectiveSgDataAsInt().empty());
512 }
513 return false;
514}
515
516// This function returns all layouts for the given sgCount, whose sgData:
517// 1. Evenly divides the wgShape.
518// 2. Is a multiple of instData.
519// Example:
520// wgShape = [128, 64], instData = [8, 16], sgCount = 32
521// Returns layouts:
522// [(8,4), (16,2)], which correspond to sgData [16,16] and [8,32].
524 ArrayRef<int> instData,
525 int64_t sgCount) {
527 for (int sgLayout0 = 1; sgLayout0 <= sgCount; ++sgLayout0) {
528 if (sgCount % sgLayout0)
529 continue;
530 int sgLayout1 = sgCount / sgLayout0;
531 int sgData0 = wgShape[0] / sgLayout0;
532 int sgData1 = wgShape[1] / sgLayout1;
533 if ((wgShape[0] % sgLayout0 || wgShape[1] % sgLayout1) ||
534 (sgData0 % instData[0] || sgData1 % instData[1]))
535 continue;
536 candidates.emplace_back(sgLayout0, sgLayout1);
537 }
538 // Sort primarily by how balanced they are
539 // (i.e., minimize the absolute difference between the two dimensions), and
540 // secondarily by the first dimension in ascending order.
541 llvm::sort(candidates, [](const std::pair<int, int> &lhs,
542 const std::pair<int, int> &rhs) {
543 int diffLhs = std::abs(lhs.first - lhs.second);
544 int diffRhs = std::abs(rhs.first - rhs.second);
545 if (diffLhs != diffRhs)
546 return diffLhs < diffRhs;
547 return lhs.first < rhs.first;
548 });
549 return candidates;
550}
551
552FailureOr<int64_t> getNumSg(Operation *op, const int sgSize) {
553 // Oblivious to workitem layout, the total count matters.
554 auto gpuFunc = op->getParentOfType<gpu::GPUFuncOp>();
555 if (!gpuFunc)
556 return failure();
557 auto knownBlockSize = gpuFunc.getKnownBlockSize();
558 if (!knownBlockSize.has_value())
559 return failure();
560 const int flatBlockSize = llvm::product_of(knownBlockSize.value());
561 return flatBlockSize / sgSize;
562}
563
564void LayoutInfoPropagation::visitPrefetchNdOp(
565 xegpu::PrefetchNdOp prefetch, ArrayRef<LayoutInfoLattice *> operands,
566 ArrayRef<const LayoutInfoLattice *> results) {
567
568 LayoutInfo prefetchLayout;
569 xegpu::DistributeLayoutAttr anchorLayout = prefetch.getLayoutAttr();
570 if (hasParamsOfLayoutKind(anchorLayout)) {
571 prefetchLayout = LayoutInfo(anchorLayout);
572 } else {
573 // Here we assign the default layout to the tensor descriptor operand of
574 // prefetch.
575 auto tdescTy = prefetch.getTensorDescType();
576
577 auto uArch = getUArch(getChipStr(prefetch).value_or(""));
578 const auto *uArchInstruction =
579 dyn_cast<xegpu::uArch::Subgroup2DBlockPrefetchInstruction>(
580 uArch->getInstruction(
581 xegpu::uArch::InstructionKind::Subgroup2DBlockPrefetch));
582
583 auto blockWHC =
584 uArchInstruction->getBlockWidthHeightCount(tdescTy.getElementType());
585 if (!blockWHC)
586 prefetch.emitWarning("No known block params found for the element type.");
587 auto [bWidth, bHeight, bCount] = blockWHC.value();
588 SmallVector<int> instData;
589 int instWidth = xegpu::getLargestDivisor(
590 static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 1)), bWidth);
591 if (instWidth == -1)
592 prefetch.emitWarning(
593 "No suitable instruction multiple found for the given shape.");
594 if (tdescTy.getRank() == 1)
595 instData = {instWidth};
596 else {
597 int instHeight = xegpu::getLargestDivisor(
598 static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 2)), bHeight);
599 if (instHeight == -1)
600 prefetch.emitWarning(
601 "No suitable instruction multiple found for the given shape.");
602 instData = {instHeight, instWidth};
603 }
604
605 if (layoutKind == xegpu::LayoutKind::InstData)
606 prefetchLayout =
607 LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(), instData));
608 else
609 prefetchLayout = getSIMTLayoutInfoBlockIO(
610 tdescTy, uArch, uArchInstruction->getPackedFormatBitSize());
611
612 prefetch.setLayoutAttr(
613 dyn_cast<xegpu::DistributeLayoutAttr>(prefetchLayout.get()));
614 }
615 // Propagate the layout to the source tensor descriptor.
616 propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout));
617}
618
619void LayoutInfoPropagation::visitVectorMultiReductionOp(
620 vector::MultiDimReductionOp reduction,
621 ArrayRef<LayoutInfoLattice *> operands,
622 ArrayRef<const LayoutInfoLattice *> results) {
623 // The layout of the result must be present.
624 LayoutInfo resLayoutInfo = results[0]->getValue();
625 if (!resLayoutInfo.isAssigned())
626 return;
627
628 VectorType sourceTy = reduction.getSourceVectorType();
629 SmallVector<int64_t> reductionDims(reduction.getReductionDims());
630
631 auto uArch = getUArch(xegpu::getChipStr(reduction).value_or(""));
632 auto consumerLayoutAttr =
633 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
634
635 // The result layout represents the layout requirements of the operation.
636 // it is recorded to anchor layout or temporary layout.
637 // it must be honored for current op and may conflict with the layout
638 // propagated from consumer op, the conflict is resolved in later phase by
639 // converting the required result layout to the consumer layout
640 auto requiredResLayoutAttr = xegpu::setupMultiReductionResultLayout(
641 layoutKind, sourceTy, consumerLayoutAttr, reductionDims, uArch);
642
643 xegpu::setTemporaryLayout(reduction->getResult(0), requiredResLayoutAttr);
644
645 // derive the source layout from the dominant layout and reduction dims
646 auto srcLayoutAttr = xegpu::inferMultiReductionSourceLayout(
647 requiredResLayoutAttr, reductionDims);
648
649 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
650 // Accumulator should have the same layout as the result.
651 propagateIfChanged(operands[1],
652 operands[1]->meet(LayoutInfo(requiredResLayoutAttr)));
653}
654
655void LayoutInfoPropagation::visitVectorBroadCastOp(
656 vector::BroadcastOp broadcast, ArrayRef<LayoutInfoLattice *> operands,
657 ArrayRef<const LayoutInfoLattice *> results) {
658 // The layout of the result must be present.
659 LayoutInfo resLayoutInfo = results[0]->getValue();
660 if (!resLayoutInfo.isAssigned())
661 return;
662
663 // Only consider vector to vector broadcasts for now.
664 VectorType resultTy = broadcast.getResultVectorType();
665 VectorType sourceTy = dyn_cast<VectorType>(broadcast.getSourceType());
666 // skip layout propagation for non-vector source operand.
667 if (!sourceTy)
668 return;
669
670 auto srcShape = sourceTy.getShape();
671 auto resShape = resultTy.getShape();
672
673 size_t dimDiff = resultTy.getRank() - sourceTy.getRank();
674 for (size_t i = 0; i < srcShape.size(); i++)
675 if ((srcShape[i] == 1) && (resShape[i + dimDiff] != 1))
676 broadcast.emitWarning("broadcast must either from low-rank or same-rank "
677 "with unit-dim, mixed scenario is not supported!");
678
679 auto resultLayoutAttr =
680 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
681
682 xegpu::DistributeLayoutAttr srcLayoutAttr =
683 xegpu::inferBroadcastSourceLayout(resultLayoutAttr, resShape, srcShape);
684
685 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
686 return;
687}
688
689void LayoutInfoPropagation::visitShapeCastOp(
690 vector::ShapeCastOp shapeCast, ArrayRef<LayoutInfoLattice *> operands,
691 ArrayRef<const LayoutInfoLattice *> results) {
692 // The layout of the result must be present.
693 LayoutInfo resLayoutInfo = results[0]->getValue();
694 if (!resLayoutInfo.isAssigned())
695 return;
696 ArrayRef<int64_t> resShape = shapeCast.getResultVectorType().getShape();
697 ArrayRef<int64_t> srcShape = shapeCast.getSourceVectorType().getShape();
698 auto resultLayoutAttr =
699 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
700
701 xegpu::DistributeLayoutAttr srcLayoutAttr =
702 xegpu::inferShapeCastSourceLayout(resultLayoutAttr, resShape, srcShape);
703
704 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
705}
706
707/// Propagate the layout of the result tensor to the source tensor descriptor
708/// in UpdateNdOffsetOp.
709void LayoutInfoPropagation::visitUpdateNdOffsetOp(
710 xegpu::UpdateNdOffsetOp updateNdOffset,
711 ArrayRef<LayoutInfoLattice *> operands,
712 ArrayRef<const LayoutInfoLattice *> results) {
713 // The layout of the result must be present.
714 LayoutInfo resultLayout = results[0]->getValue();
715 if (!resultLayout.isAssigned())
716 return;
717 // Propagate the layout to the source operand.
718 propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
719}
720
721/// Set the layouts for DPAS A, B, and C operands.
722void LayoutInfoPropagation::visitDpasOp(
723 xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
724 ArrayRef<const LayoutInfoLattice *> results) {
725 LayoutInfo dpasALayout;
726 LayoutInfo dpasBLayout;
727 LayoutInfo dpasCDLayout;
728
729 xegpu::DistributeLayoutAttr anchorLayoutCD = dpas.getLayoutCdAttr();
730 if (hasParamsOfLayoutKind(anchorLayoutCD)) {
731 xegpu::DistributeLayoutAttr anchorLayoutA = dpas.getLayoutAAttr();
732 xegpu::DistributeLayoutAttr anchorLayoutB = dpas.getLayoutBAttr();
733 assert(hasParamsOfLayoutKind(anchorLayoutA) &&
734 "Expected anchor layout for DPAS A operand.");
735 assert(hasParamsOfLayoutKind(anchorLayoutB) &&
736 "Expected anchor layout for DPAS B operand.");
737 dpasALayout = LayoutInfo(anchorLayoutA);
738 dpasBLayout = LayoutInfo(anchorLayoutB);
739 dpasCDLayout = LayoutInfo(anchorLayoutCD);
740 } else {
741 auto uArch = getUArch(getChipStr(dpas).value_or(""));
742 VectorType aTy = dpas.getLhsType();
743 VectorType bTy = dpas.getRhsType();
744 VectorType cdTy = dpas.getResultType();
745
746 xegpu::DistributeLayoutAttr consumerLayoutAttr = nullptr;
747 xegpu::DistributeLayoutAttr requiredCDLayoutAttr, requiredALayout,
748 requiredBLayout;
749
750 int numSg = 0;
751 if (layoutKind == xegpu::LayoutKind::Subgroup) {
752 LayoutInfo consumerLayout = results[0]->getValue();
753 if (!consumerLayout.isAssigned())
754 return;
755 consumerLayoutAttr =
756 dyn_cast<xegpu::DistributeLayoutAttr>(consumerLayout.get());
757 auto numSgOrErr = getNumSg(dpas, uArch->getSubgroupSize());
758 if (failed(numSgOrErr)) {
759 dpas.emitWarning(
760 "Unable to determine the number of subgroups for the operation.");
761 return;
762 }
763 numSg = numSgOrErr.value();
764 }
765 auto layouts = xegpu::setupDpasLayout(layoutKind, aTy, bTy, cdTy,
766 consumerLayoutAttr, uArch, numSg);
767 if (!layouts.has_value()) {
768 dpas.emitWarning(
769 "Failed to determine required layouts for DPAS operands.");
770 return;
771 }
772
773 std::tie(requiredALayout, requiredBLayout, requiredCDLayoutAttr) = *layouts;
774
775 dpas.setLayoutAAttr(requiredALayout);
776 dpas.setLayoutBAttr(requiredBLayout);
777 dpas.setLayoutCdAttr(requiredCDLayoutAttr);
778 dpasALayout = LayoutInfo(requiredALayout);
779 dpasBLayout = LayoutInfo(requiredBLayout);
780 dpasCDLayout = LayoutInfo(requiredCDLayoutAttr);
781 }
782 propagateIfChanged(operands[0], operands[0]->meet(dpasALayout));
783 propagateIfChanged(operands[1], operands[1]->meet(dpasBLayout));
784 if (operands.size() > 2)
785 propagateIfChanged(operands[2], operands[2]->meet(dpasCDLayout));
786}
787
788/// Set the layout for the value and tensor descriptor operands in StoreNdOp.
789void LayoutInfoPropagation::visitStoreNdOp(
790 xegpu::StoreNdOp store, ArrayRef<LayoutInfoLattice *> operands,
791 ArrayRef<const LayoutInfoLattice *> results) {
792 LayoutInfo storeLayout;
793 xegpu::DistributeLayoutAttr anchorLayout = store.getLayoutAttr();
794 if (hasParamsOfLayoutKind(anchorLayout)) {
795 storeLayout = LayoutInfo(anchorLayout);
796 } else {
797 auto uArch = getUArch(getChipStr(store).value_or(""));
798 const auto *uArchInstruction =
799 dyn_cast<xegpu::uArch::Subgroup2DBlockStoreInstruction>(
800 uArch->getInstruction(
801 xegpu::uArch::InstructionKind::Subgroup2DBlockStore));
802 VectorType dataTy = store.getValueType();
803 auto blockWHC = uArchInstruction->getBlockWidthHeightCount(
804 store.getValueType().getElementType());
805 if (!blockWHC)
806 store.emitWarning("No known block params found for the element type.");
807 auto [bWidth, bHeight, bCount] = blockWHC.value();
808 SmallVector<int> instData;
809 int instWidth = xegpu::getLargestDivisor(
810 static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 1)), bWidth);
811 if (instWidth == -1)
812 store.emitWarning(
813 "No suitable instruction multiple found for the given shape.");
814 if (dataTy.getRank() == 1)
815 instData = {instWidth};
816 else {
817 int instHeight = xegpu::getLargestDivisor(
818 static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 2)), bHeight);
819 if (instHeight == -1)
820 store.emitWarning(
821 "No suitable instruction multiple found for the given shape.");
822 instData = {instHeight, instWidth};
823 }
824
825 if (layoutKind == xegpu::LayoutKind::InstData)
826 storeLayout =
827 LayoutInfo(xegpu::LayoutAttr::get(dataTy.getContext(), instData));
828 else if (layoutKind == xegpu::LayoutKind::Lane)
829 storeLayout =
830 getSIMTLayoutInfoBlockIO(store.getValueType(), uArch,
831 uArchInstruction->getPackedFormatBitSize());
832 else { // xegpu::LayoutKind::Subgroup
833 auto sgSize = uArch->getSubgroupSize();
834 auto numSgOrErr = getNumSg(store, sgSize);
835 if (failed(numSgOrErr)) {
836 store.emitWarning(
837 "Unable to determine the number of subgroups for the operation.");
838 return;
839 }
840 auto sgLayouts = getValidLayouts(store.getValueType().getShape(),
841 instData, numSgOrErr.value());
842 if (sgLayouts.empty()) {
843 store.emitWarning(
844 "Unable to determine suitable subgroup layout for store value.");
845 return;
846 }
847 SmallVector<int> sgLayout = {sgLayouts[0].first, sgLayouts[0].second};
848 SmallVector<int> sgData = {
849 static_cast<int>(dataTy.getShape()[0]) / sgLayout[0],
850 static_cast<int>(dataTy.getShape()[1]) / sgLayout[1]};
851 storeLayout = LayoutInfo(xegpu::LayoutAttr::get(
852 dataTy.getContext(),
853 DenseI32ArrayAttr::get(dataTy.getContext(), sgLayout),
854 DenseI32ArrayAttr::get(dataTy.getContext(), sgData),
855 /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
856 /*lane_data =*/nullptr, /*order =*/nullptr));
857 }
858 store.setLayoutAttr(
859 dyn_cast<xegpu::DistributeLayoutAttr>(storeLayout.get()));
860 }
861 // Propagate the layout to the value operand.
862 // Both operands should have the same layout
863 for (LayoutInfoLattice *operand : operands)
864 propagateIfChanged(operand, operand->meet(storeLayout));
865}
866
867/// Propagate the layout of the value to the tensor descriptor operand in
868/// LoadNdOp.
869void LayoutInfoPropagation::visitLoadNdOp(
870 xegpu::LoadNdOp load, ArrayRef<LayoutInfoLattice *> operands,
871 ArrayRef<const LayoutInfoLattice *> results) {
872 LayoutInfo loadLayout;
873 xegpu::DistributeLayoutAttr anchorLayout = load.getLayoutAttr();
874 if (hasParamsOfLayoutKind(anchorLayout)) {
875 loadLayout = LayoutInfo(anchorLayout);
876 } else {
877
878 LayoutInfo valueLayout = results[0]->getValue();
879 // Need the layout of the value to propagate to the tensor descriptor.
880 if (!valueLayout.isAssigned())
881 return;
882 loadLayout = valueLayout;
883 // LoadNdOp has the transpose effect. However, at the stage of this analysis
884 // this effect is not expected and should be abstracted away. Emit a
885 // warning.
886 if (auto transpose = load.getTranspose()) {
887 load.emitWarning("Transpose effect is not expected for LoadNdOp at "
888 "LayoutInfoPropagation stage.");
889 loadLayout = valueLayout.transpose(transpose.value());
890 }
891 load.setLayoutAttr(dyn_cast<xegpu::DistributeLayoutAttr>(loadLayout.get()));
892 }
893 // Propagate the new layout to the tensor descriptor operand.
894 propagateIfChanged(operands[0], operands[0]->meet(loadLayout));
895}
896
897/// For vector::TransposeOp, the layout of the result is transposed and
898/// propagated to the operand.
899void LayoutInfoPropagation::visitTransposeOp(
900 vector::TransposeOp transpose, ArrayRef<LayoutInfoLattice *> operands,
901 ArrayRef<const LayoutInfoLattice *> results) {
902 // Need the layout of transpose result to propagate to the operands.
903 LayoutInfo resultLayout = results[0]->getValue();
904 if (!resultLayout.isAssigned())
905 return;
906 LayoutInfo newLayout = resultLayout.transpose(transpose.getPermutation());
907 // Propagate the new layout to the vector operand.
908 propagateIfChanged(operands[0], operands[0]->meet(newLayout));
909}
910
911/// For vector::BitCastOp, the lane_data of the source layout is changed based
912/// on the bit width of the source and result types.
913void LayoutInfoPropagation::visitVectorBitcastOp(
914 vector::BitCastOp bitcast, ArrayRef<LayoutInfoLattice *> operands,
915 ArrayRef<const LayoutInfoLattice *> results) {
916 // Need the layout of bitcast result to propagate to the operands.
917 LayoutInfo resLayoutInfo = results[0]->getValue();
918 if (!resLayoutInfo.isAssigned())
919 return;
920
921 auto srcVecType = bitcast.getSourceVectorType();
922 auto resVecType = bitcast.getResultVectorType();
923
924 auto consumerLayoutAttr =
925 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
926 auto uArch = getUArch(xegpu::getChipStr(bitcast).value_or(""));
927 auto requiredResLayoutAttr = setupBitCastResultLayout(
928 layoutKind, srcVecType, resVecType, consumerLayoutAttr, uArch);
929
930 xegpu::setTemporaryLayout(bitcast->getResult(0), requiredResLayoutAttr);
931
932 int inElemTyBitWidth = srcVecType.getElementType().getIntOrFloatBitWidth();
933 int outElemTyBitWidth = resVecType.getElementType().getIntOrFloatBitWidth();
934
935 // derive the source layout from the dominant layout and reduction dims
936 auto srcLayoutAttr = xegpu::inferBitCastSourceLayout(
937 requiredResLayoutAttr, outElemTyBitWidth, inElemTyBitWidth);
938
939 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
940}
941
942void LayoutInfoPropagation::visitInsertStridedSliceOp(
943 vector::InsertStridedSliceOp insertStridedSlice,
944 ArrayRef<LayoutInfoLattice *> operands,
945 ArrayRef<const LayoutInfoLattice *> results) {
946 // The layout of the result must be present.
947 LayoutInfo resLayoutInfo = results[0]->getValue();
948 if (!resLayoutInfo.isAssigned())
949 return;
950
951 auto srcVecType = insertStridedSlice.getSourceVectorType();
952 auto resVecType = insertStridedSlice.getDestVectorType();
953
954 auto consumerLayoutAttr =
955 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
956 auto uArch = getUArch(xegpu::getChipStr(insertStridedSlice).value_or(""));
957
958 auto requiredResLayoutAttr = xegpu::setupInsertStridedSliceResultLayout(
959 layoutKind, srcVecType, resVecType, consumerLayoutAttr, uArch);
960
961 xegpu::setTemporaryLayout(insertStridedSlice->getResult(0),
962 requiredResLayoutAttr);
963
965 requiredResLayoutAttr, resVecType.getShape(), srcVecType.getShape());
966
967 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
968 propagateIfChanged(operands[1],
969 operands[1]->meet(LayoutInfo(requiredResLayoutAttr)));
970 return;
971}
972
973/// Propagate the layout of the result to the tensor descriptor, mask and offset
974/// operands in LoadGatherOp.
975void LayoutInfoPropagation::visitLoadGatherOp(
976 xegpu::LoadGatherOp load, ArrayRef<LayoutInfoLattice *> operands,
977 ArrayRef<const LayoutInfoLattice *> results) {
978 xegpu::DistributeLayoutAttr requiredAnchorLayoutAttr;
979 xegpu::DistributeLayoutAttr anchorLayoutAttr = load.getLayoutAttr();
980 auto uArch = getUArch(getChipStr(load).value_or(""));
981 auto subgroupSize = uArch->getSubgroupSize();
982 VectorType resVecTy = load.getValueType();
983 int chunkSize = load.getChunkSize().value_or(1);
984
985 LayoutInfo resLayoutInfo = results[0]->getValue();
986 if (!resLayoutInfo.isAssigned())
987 return;
988 auto consumerLayoutAttr =
989 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
990
991 if (hasParamsOfLayoutKind(anchorLayoutAttr)) {
992 requiredAnchorLayoutAttr = anchorLayoutAttr;
993 } else {
994 if (!resVecTy) {
995 load.emitWarning("Not propagating, non-vector payload supplied.");
996 return;
997 }
998 requiredAnchorLayoutAttr = xegpu::setupLoadGatherAnchorLayout(
999 layoutKind, resVecTy, chunkSize, consumerLayoutAttr, uArch);
1000 load.setLayoutAttr(requiredAnchorLayoutAttr);
1001 }
1002
1003 auto maskLayoutAttr = requiredAnchorLayoutAttr;
1004 // Special handling mask layout for chunked ops: Enforce the default xegpu 1D
1005 // layout for mask.
1006 if (chunkSize > 1) {
1007 if (layoutKind == xegpu::LayoutKind::InstData)
1008 maskLayoutAttr =
1009 xegpu::LayoutAttr::get(load->getContext(), {subgroupSize});
1010 else if (layoutKind == xegpu::LayoutKind::Lane)
1011 maskLayoutAttr =
1012 xegpu::LayoutAttr::get(load->getContext(), {subgroupSize}, {1});
1013 else
1014 assert(false &&
1015 "chunked StoreScatterOp should not be used at workgroup level");
1016 }
1017
1018 LayoutInfo maskLayoutInfo = LayoutInfo(maskLayoutAttr);
1019 auto loadLayoutInfo = LayoutInfo(requiredAnchorLayoutAttr);
1020
1021 // Propagate the new layout to the tensor descriptor operand.
1022 if (isa<xegpu::TensorDescType>(load.getSourceType()))
1023 propagateIfChanged(operands[0], operands[0]->meet(loadLayoutInfo));
1024 // Propagate the new layout to the mask and optional offset operand.
1025 propagateIfChanged(operands[1], operands[1]->meet(maskLayoutInfo));
1026 if (load.getOffsets())
1027 propagateIfChanged(operands[2], operands[2]->meet(maskLayoutInfo));
1028}
1029
1030/// Propagate the layout of the descriptor to the vector offset operand in
1031/// CreateDescOp.
1032void LayoutInfoPropagation::visitCreateDescOp(
1033 xegpu::CreateDescOp createDesc, ArrayRef<LayoutInfoLattice *> operands,
1034 ArrayRef<const LayoutInfoLattice *> results) {
1035 LayoutInfo descLayout = results[0]->getValue();
1036 // Need the layout of the descriptor to propagate to the operands.
1037 if (!descLayout.isAssigned())
1038 return;
1039 auto uArch = getUArch(getChipStr(createDesc).value_or(""));
1040 // For offset operand propagate 1D default layout.
1041 LayoutInfo layout = getDefaultSIMTLayoutInfo(createDesc->getContext(), 1,
1042 uArch->getSubgroupSize());
1043 propagateIfChanged(operands[1], operands[1]->meet(layout));
1044}
1045
1046/// Set the layout for the value, tensor descriptor, offset and mask operands in
1047/// the StoreScatterOp.
1048void LayoutInfoPropagation::visitStoreScatterOp(
1049 xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands,
1050 ArrayRef<const LayoutInfoLattice *> results) {
1051
1052 xegpu::DistributeLayoutAttr requiredAnchorLayoutAttr;
1053 xegpu::DistributeLayoutAttr anchorLayoutAttr = storeScatter.getLayoutAttr();
1054 auto uArch = getUArch(getChipStr(storeScatter).value_or(""));
1055 auto subgroupSize = uArch->getSubgroupSize();
1056 VectorType srcVecTy = storeScatter.getValueType();
1057 int chunkSize = storeScatter.getChunkSize().value_or(1);
1058
1059 if (hasParamsOfLayoutKind(anchorLayoutAttr)) {
1060 requiredAnchorLayoutAttr = anchorLayoutAttr;
1061 } else {
1062 if (!srcVecTy) {
1063 storeScatter.emitWarning("Not propagating, non-vector payload supplied.");
1064 return;
1065 }
1066 requiredAnchorLayoutAttr = xegpu::setupStoreScatterAnchorLayout(
1067 layoutKind, srcVecTy, chunkSize, uArch);
1068 storeScatter.setLayoutAttr(requiredAnchorLayoutAttr);
1069 }
1070
1071 LayoutInfo srcLayoutInfo = LayoutInfo(requiredAnchorLayoutAttr);
1072 auto maskLayoutAttr = requiredAnchorLayoutAttr;
1073 // Special handling mask layout for chunked ops: Enforce the default xegpu 1D
1074 // layout for mask.
1075 if (chunkSize > 1) {
1076 if (layoutKind == xegpu::LayoutKind::InstData)
1077 maskLayoutAttr =
1078 xegpu::LayoutAttr::get(storeScatter->getContext(), {subgroupSize});
1079 else if (layoutKind == xegpu::LayoutKind::Lane)
1080 maskLayoutAttr = xegpu::LayoutAttr::get(storeScatter->getContext(),
1081 {subgroupSize}, {1});
1082 else
1083 assert(false &&
1084 "chunked StoreScatterOp should not be used at workgroup level");
1085 }
1086
1087 LayoutInfo maskLayoutInfo = LayoutInfo(maskLayoutAttr);
1088
1089 // Propagate the payload operand layout
1090 propagateIfChanged(operands[0], operands[0]->meet(srcLayoutInfo));
1091 // Propagate the destination (if tdesc) operand layout
1092 if (isa<xegpu::TensorDescType>(storeScatter.getDestType()))
1093 propagateIfChanged(operands[1], operands[1]->meet(srcLayoutInfo));
1094 // Propagate the new layout to the mask and optional offset operand.
1095 propagateIfChanged(operands[2], operands[2]->meet(maskLayoutInfo));
1096 if (storeScatter.getOffsets())
1097 propagateIfChanged(operands[3], operands[3]->meet(maskLayoutInfo));
1098}
1099
1100void LayoutInfoPropagation::visitLoadMatrixOp(
1101 xegpu::LoadMatrixOp loadMatrixOp, ArrayRef<LayoutInfoLattice *> operands,
1102 ArrayRef<const LayoutInfoLattice *> results) {
1103
1104 LayoutInfo resLayoutInfo = results[0]->getValue();
1105 auto consumerLayoutAttr =
1106 dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
1107
1108 xegpu::DistributeLayoutAttr anchorLayout = loadMatrixOp.getLayoutAttr();
1109
1110 // only need to set anchor layout, no need to porpagate to memdesc and
1111 // offset
1112 if (!hasParamsOfLayoutKind(anchorLayout)) {
1113 VectorType resVecTy =
1114 llvm::cast<VectorType>(loadMatrixOp.getRes().getType());
1115 assert(resVecTy.getRank() == 2 && "Expecting 2D vector for store matrix.");
1116 auto uArch = getUArch(getChipStr(loadMatrixOp).value_or(""));
1117 auto requiredAnchorLayoutAttr = xegpu::setupLoadMatrixAnchorLayout(
1118 layoutKind, resVecTy, consumerLayoutAttr, uArch);
1119 loadMatrixOp.setLayoutAttr(requiredAnchorLayoutAttr);
1120 }
1121}
1122
1123// Store matrix is a flavor of scattered store for 2D shapes.
1124void LayoutInfoPropagation::visitStoreMatrixOp(
1125 xegpu::StoreMatrixOp storeMatrix, ArrayRef<LayoutInfoLattice *> operands,
1126 ArrayRef<const LayoutInfoLattice *> results) {
1127 xegpu::DistributeLayoutAttr anchorLayout = storeMatrix.getLayoutAttr();
1128 LayoutInfo layout;
1129 if (hasParamsOfLayoutKind(anchorLayout)) {
1130 layout = LayoutInfo(anchorLayout);
1131 } else {
1132 VectorType srcVecTy =
1133 llvm::cast<VectorType>(storeMatrix.getData().getType());
1134 assert(srcVecTy.getRank() == 2 && "Expecting 2D vector for store matrix.");
1135 auto uArch = getUArch(getChipStr(storeMatrix).value_or(""));
1136 auto requiredAnchorLayoutAttr =
1137 xegpu::setupStoreMatrixAnchorLayout(layoutKind, srcVecTy, uArch);
1138 storeMatrix.setLayoutAttr(requiredAnchorLayoutAttr);
1139 layout = LayoutInfo(requiredAnchorLayoutAttr);
1140 }
1141
1142 propagateIfChanged(operands[0], operands[0]->meet(layout));
1143}
1144
1145namespace {
1146//===----------------------------------------------------------------------===//
1147// RunLayoutInfoPropagation
1148//===----------------------------------------------------------------------===//
1149
1150/// Driver class for running the LayoutInfoPropagation analysis.
1151class RunLayoutInfoPropagation {
1152public:
1153 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RunLayoutInfoPropagation)
1154
1155 RunLayoutInfoPropagation(Operation *op, xegpu::LayoutKind layoutKind)
1156 : target(op) {
1157 SymbolTableCollection symbolTable;
1158 loadBaselineAnalyses(solver);
1159 solver.load<LayoutInfoPropagation>(symbolTable, layoutKind);
1160 (void)solver.initializeAndRun(op);
1161 }
1162
1163 LayoutInfo getLayoutInfo(Value val);
1164
1165 void printAnalysisResult(llvm::raw_ostream &os);
1166
1167private:
1168 DataFlowSolver solver;
1169 const Operation *target;
1170};
1171} // namespace
1172
1173LayoutInfo RunLayoutInfoPropagation::getLayoutInfo(Value val) {
1174 auto *state = solver.lookupState<LayoutInfoLattice>(val);
1175 if (!state)
1176 return {};
1177 return state->getValue();
1178}
1179
1180// Print the analysis result for debugging purposes.
1181void RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) {
1182 auto printFunctionResult = [&](FunctionOpInterface funcOp) {
1183 os << "function: " << funcOp.getName() << ":\n";
1184 // Function arguments
1185 for (BlockArgument arg : funcOp.getArguments()) {
1186 LayoutInfo layout = getLayoutInfo(arg);
1187 os << "argument: " << arg << "\n";
1188 os << "layout : ";
1189 layout.print(os);
1190 os << "\n";
1191 }
1192 // Function ops
1193 funcOp.walk([&](Operation *op) {
1194 // Skip ops that do not have results
1195 if (op->getResults().empty())
1196 return;
1197 os << "op : ";
1198 // For control-flow ops, print the op name only.
1199 if (isa<BranchOpInterface>(op) || isa<RegionBranchOpInterface>(op))
1200 os << op->getName();
1201 else
1202 op->print(os);
1203 os << "\n";
1204 // Print the layout for each result.
1205 for (auto [i, r] : llvm::enumerate(op->getResults())) {
1206 LayoutInfo layout = getLayoutInfo(r);
1207 os << "layout for result #" << i << ": ";
1208 layout.print(os);
1209 os << "\n";
1210 }
1211 });
1212 };
1213
1214 SmallVector<FunctionOpInterface> funcOps;
1215 if (auto modOp = dyn_cast<ModuleOp>(target)) {
1216 for (auto funcOp : modOp.getOps<FunctionOpInterface>())
1217 funcOps.push_back(funcOp);
1218
1219 // Collect all GpuFuncOps in the module.
1220 for (auto gpuModOp : modOp.getOps<gpu::GPUModuleOp>()) {
1221 for (auto gpuFuncOp : gpuModOp.getOps<FunctionOpInterface>())
1222 funcOps.push_back(gpuFuncOp);
1223 }
1224 }
1225 // Print the analysis result for each function.
1226 for (FunctionOpInterface funcOp : funcOps)
1227 printFunctionResult(funcOp);
1228}
1229
1230namespace {
1231
1232//===----------------------------------------------------------------------===//
1233// ResolveLayoutConflicts
1234//===----------------------------------------------------------------------===//
1235struct ResolveLayoutConflicts {
1236 ResolveLayoutConflicts(Operation *parentOp)
1237 : parentOp(parentOp), builder(parentOp->getContext()) {}
1238 LogicalResult run();
1239
1240private:
1241 Operation *parentOp;
1242 OpBuilder builder;
1243 LogicalResult resolveTensorDescConsumer(OpOperand &operand);
1244 LogicalResult resolveVectorConsumer(OpOperand &operand);
1245};
1246
1247} // namespace
1248
1249LogicalResult ResolveLayoutConflicts::run() {
1250 // Scan all operations in the parent op and resolve layout conflicts at
1251 // tensor descriptor and vector use points.
1252 auto r = parentOp->walk([&](Operation *op) -> WalkResult {
1253 for (OpOperand &operand : op->getOpOperands()) {
1254 // Handle conflicts in tensor descriptor operands.
1255 Type operandType = operand.get().getType();
1256 if (isa<xegpu::AnchorLayoutInterface>(op) &&
1257 isa<xegpu::TensorDescType>(operandType)) {
1258 auto res = resolveTensorDescConsumer(operand);
1259 return succeeded(res) ? WalkResult::advance() : WalkResult::interrupt();
1260 }
1261 // Handle conflicts in vector operands.
1262 if (isa<VectorType>(operandType)) {
1263 auto res = resolveVectorConsumer(operand);
1264 return succeeded(res) ? WalkResult::advance() : WalkResult::interrupt();
1265 }
1266 }
1267 return WalkResult::advance();
1268 });
1269
1270 return r.wasInterrupted() ? failure() : success();
1271}
1272
1273/// Helper to get the defining CreateNdDescOp of a tensor descriptor value. This
1274/// function tries to find the defining CreateNdDescOp recursively accross
1275/// control-flow boundaries.
1276static xegpu::CreateNdDescOp getDefiningCreateNdDescOp(Value tdescValue) {
1277 // Try to get the defining CreateNdDescOp of the tensor descriptor.
1278 auto definingOp = tdescValue.getDefiningOp<xegpu::CreateNdDescOp>();
1279 if (definingOp)
1280 return definingOp;
1281 // If tdescValue is an argument, try to get the tied init value from the
1282 // parent loop-like op.
1283 if (auto arg = dyn_cast<BlockArgument>(tdescValue)) {
1284 auto *parentOp = arg.getOwner()->getParentOp();
1285 if (auto loop = dyn_cast<LoopLikeOpInterface>(parentOp)) {
1286 OpOperand *tiedInit = loop.getTiedLoopInit(arg);
1287 if (tiedInit)
1288 return getDefiningCreateNdDescOp(tiedInit->get());
1289 }
1290 }
1291 // If not found, return null.
1292 return nullptr;
1293}
1294
1295LogicalResult
1296ResolveLayoutConflicts::resolveVectorConsumer(OpOperand &operand) {
1297 // TODO: Implement vector consumer layout conflict resolution. Requires layout
1298 // utilities.
1299 return success();
1300}
1301
1302LogicalResult
1303ResolveLayoutConflicts::resolveTensorDescConsumer(OpOperand &operand) {
1304 Operation *consumerOp = operand.getOwner();
1305 Value tdescValue = operand.get();
1306 auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(consumerOp);
1307 auto currTDescType = dyn_cast<xegpu::TensorDescType>(tdescValue.getType());
1308 assert(anchorOp && currTDescType &&
1309 "Expected anchor layout op and tensor descriptor consumer.");
1310 // TODO: Scattered tensor desc is not supported for now.
1311 if (currTDescType.isScattered()) {
1312 DBGS() << "Scattered tensor descriptor not supported: " << tdescValue
1313 << "\n";
1314 return failure();
1315 }
1316 Attribute currLayout = currTDescType.getLayout();
1317 Attribute expectedLayout = anchorOp.getAnchorLayout();
1318 // A conflict exists in tensor descriptor operand if tensor descriptor's
1319 // layout is different from the anchor layout expected by the consumer.
1320 if (expectedLayout && currLayout && expectedLayout != currLayout) {
1321 // Try to get the defining CreateNdDescOp of the tensor descriptor.
1322 auto conflictingCreateNdOp = getDefiningCreateNdDescOp(tdescValue);
1323 if (!conflictingCreateNdOp) {
1324 DBGS() << "Unable to find defining CreateNdDescOp for tensor descriptor: "
1325 << tdescValue << "\n";
1326 return failure();
1327 }
1328 // Duplicate the CreateNdDescOp with the expected layout.
1329 builder.setInsertionPointAfter(conflictingCreateNdOp);
1330 auto newTensorDescType = xegpu::TensorDescType::get(
1331 conflictingCreateNdOp.getContext(), currTDescType.getShape(),
1332 currTDescType.getElementType(), currTDescType.getEncoding(),
1333 expectedLayout);
1334 xegpu::CreateNdDescOp newOp = xegpu::CreateNdDescOp::create(
1335 builder, consumerOp->getLoc(), newTensorDescType,
1336 conflictingCreateNdOp->getOperands(),
1337 conflictingCreateNdOp->getAttrs());
1338 // Replace the tensor descriptor operand in the consumer op with the new
1339 // tensor descriptor.
1340 consumerOp->replaceUsesOfWith(tdescValue, newOp.getResult());
1341 }
1342 return success();
1343}
1344
1345using GetLayoutFnTy = function_ref<xegpu::DistributeLayoutAttr(Value)>;
1346/// Update an operation with the layout of its results. If the result type is
1347/// a vector type, a temporary layout attribute is added to the operation. If
1348/// the result type is a tensor descriptor type, the type is updated with the
1349/// layout attribute. The users of the result are also updated with the layout
1350/// attribute.
1351static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
1352 GetLayoutFnTy getLayoutOfValue) {
1353 // Region ops (like scf.for) are already handled by the
1354 // updateControlFlowOps.
1355 if (mlir::isa<mlir::RegionBranchOpInterface>(op))
1356 return success();
1357
1358 // Iterate over all the results.
1359 for (OpResult result : op->getResults()) {
1360 Type resultType = result.getType();
1361 // Layouts are needed only for vector and tensor descriptor types.
1362 if (!isa<VectorType, xegpu::TensorDescType>(resultType))
1363 continue;
1364 // If the result has no layout but has users, emit a warning and continue.
1365 xegpu::DistributeLayoutAttr layout = getLayoutOfValue(result);
1366 if (!layout && result.getNumUses() > 0) {
1367 op->emitWarning("op has users but no layout assigned for its result");
1368 continue;
1369 }
1370 // If the result is a tensor descriptor type, update the tensor desc type
1371 // with layout.
1372 if (auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType)) {
1373 auto typeWithLayout = xegpu::TensorDescType::get(
1374 tensorDescTy.getContext(), tensorDescTy.getShape(),
1375 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
1376 result.setType(typeWithLayout);
1377 continue;
1378 }
1379 // If the result is a vector type, add a temporary layout attribute to the
1380 // op.
1382 }
1383 return success();
1384}
1385
1386/// Region ops like scf.for need special handling because they have blocks
1387/// inside. If the blocks have tensor descriptor type as block arguments,
1388/// thier types must be updated. Also region op can have results that may not
1389/// have any users (e.g. A and B tiles). They are not assigned a layout by
1390/// layout analysis because they have no users. However inside the region op
1391/// corresponding block arguments for these results do have layouts.
1392/// Therefore, in this case we still need to update the result types with the
1393/// layout attribute. This function function updates the internal block
1394/// arguments and the result types of the region op with the assigned layouts.
1395/// clang-format off
1396/// Example: scf.for ... iter_args(...) -> (out types) {
1397/// ^bb0(block types):
1398/// ...
1399/// scf.yield ... : (yield types)
1400/// }
1401/// clang-format on
1402/// In this example, at scf.yield, control-flow can transfer to two successor
1403/// regions. One is the ^bb0 (for loop body) and the other is the scf.for op
1404/// itself (yield the results). So we update both the block arguments of the
1405/// successor region (i.e. block types) and the result types of the scf.for op
1406/// (i.e. out types). Note that yield types are updated by respective
1407/// producers inside bb0.
1408static LogicalResult
1410 mlir::RegionBranchTerminatorOpInterface terminator,
1411 GetLayoutFnTy getLayoutOfValue) {
1412 // Only process if the terminator is inside a region branch op.
1413 auto branchOp = dyn_cast<RegionBranchOpInterface>(terminator->getParentOp());
1414 if (!branchOp)
1415 return success();
1416
1418 branchOp.getSuccessorOperandInputMapping(mapping,
1419 RegionBranchPoint(terminator));
1420 for (const auto &[successorOperand, successorInputs] : mapping) {
1421 for (Value successorInput : successorInputs) {
1422 Type inputType = successorInput.getType();
1423 // We only need to operate on tensor descriptor or vector types.
1424 if (!isa<xegpu::TensorDescType, VectorType>(inputType))
1425 continue;
1426 xegpu::DistributeLayoutAttr successorInputLayout =
1427 getLayoutOfValue(successorInput);
1428 xegpu::DistributeLayoutAttr successorOperandLayout =
1429 getLayoutOfValue(successorOperand->get());
1430
1431 // If either of the layouts is not assigned, we cannot proceed.
1432 if (!successorOperandLayout) {
1433 LLVM_DEBUG(DBGS() << "No layout assigned for forwarded operand in "
1434 "branch terminator: "
1435 << successorOperand->get() << "\n");
1436 return failure();
1437 }
1438 // We expect the layouts to match.
1439 if (successorInputLayout &&
1440 successorInputLayout != successorOperandLayout) {
1441 LLVM_DEBUG(DBGS() << "Conflicting layouts for region argument and "
1442 "operand forwarded as the argument: "
1443 << successorInputLayout << " vs "
1444 << successorOperandLayout << "\n");
1445 return failure();
1446 }
1447 // Get tensor descriptor type with the layout.
1448 if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(inputType)) {
1449 auto newTdescTy = xegpu::TensorDescType::get(
1450 tdescTy.getContext(), tdescTy.getShape(), tdescTy.getElementType(),
1451 tdescTy.getEncoding(), successorOperandLayout);
1452 successorInput.setType(newTdescTy);
1453 continue;
1454 }
1455 // If the type is a vector type and this region argument is an OpResult,
1456 // set the layout attribute on the OpResult.
1457 if (auto result = dyn_cast<OpResult>(successorInput))
1458 xegpu::setDistributeLayoutAttr(result, successorOperandLayout);
1459 }
1460 }
1461 return success();
1462}
1463
1464/// Update the function arguments and results with the layouts.
1465static LogicalResult updateFunctionOpInterface(mlir::OpBuilder &builder,
1466 mlir::FunctionOpInterface funcOp,
1467 GetLayoutFnTy getLayoutOfValue) {
1468 SmallVector<Type> newArgTypes;
1469 // Update the function arguments.
1470 for (BlockArgument arg : funcOp.getArguments()) {
1471 Type argType = arg.getType();
1472 newArgTypes.push_back(argType);
1473 if (!isa<VectorType, xegpu::TensorDescType>(argType))
1474 continue;
1475 xegpu::DistributeLayoutAttr layout = getLayoutOfValue(arg);
1476 if (!layout) {
1477 LLVM_DEBUG(DBGS() << "Expecting layout for function argument: " << arg
1478 << " but got none.\n");
1479 return failure();
1480 }
1481 if (auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(argType)) {
1482 auto newTdescTy = xegpu::TensorDescType::get(
1483 tensorDescTy.getContext(), tensorDescTy.getShape(),
1484 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
1485 arg.setType(newTdescTy);
1486 newArgTypes.back() = newTdescTy;
1487 }
1488 }
1489 // Update the function type with the new argument types.
1490 // NOTE: We assume that function results are not expected to have layouts.
1491 funcOp.setType(FunctionType::get(funcOp.getContext(), newArgTypes,
1492 funcOp.getResultTypes()));
1493 return success();
1494}
1495
1496namespace {
1497struct XeGPUPropagateLayoutPass final
1498 : public xegpu::impl::XeGPUPropagateLayoutBase<XeGPUPropagateLayoutPass> {
1499 XeGPUPropagateLayoutPass() = default;
1500 XeGPUPropagateLayoutPass(const XeGPUPropagateLayoutPass &other) = default;
1501 XeGPUPropagateLayoutPass(xegpu::XeGPUPropagateLayoutOptions options)
1502 : XeGPUPropagateLayoutBase(options) {}
1503 void runOnOperation() override;
1504};
1505
1506} // namespace
1507
1509 LayoutKind layoutKind, bool printOnly) {
1510 RunLayoutInfoPropagation analysis(target, layoutKind);
1511 // Print the analysis result and exit. (for debugging purposes)
1512 if (printOnly) {
1513 auto &os = llvm::outs();
1514 analysis.printAnalysisResult(os);
1515 return success();
1516 }
1517 // Helper to convert LayoutInfo to xegpu::LayoutAttr.
1518 auto getXeGPULayoutForValue = [&](Value val) -> xegpu::DistributeLayoutAttr {
1519 LayoutInfo layout = analysis.getLayoutInfo(val);
1520 if (!layout.isAssigned())
1521 return {};
1522 if (auto opResult = dyn_cast<OpResult>(val)) {
1523
1524 Operation *defOp = opResult.getDefiningOp();
1525 if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(defOp)) {
1526 auto anchorLayout = anchorOp.getAnchorLayout();
1527 if (anchorLayout != nullptr)
1528 return anchorLayout;
1529 }
1530 xegpu::DistributeLayoutAttr requiredResLayoutAttr =
1531 xegpu::getTemporaryLayout(opResult);
1532 if (requiredResLayoutAttr != nullptr)
1533 return requiredResLayoutAttr;
1534 }
1535 xegpu::DistributeLayoutAttr layoutAttr =
1536 cast<xegpu::DistributeLayoutAttr>(layout.get());
1537 if (layout.isSliceLayout())
1538 return cast<xegpu::SliceAttr>(layoutAttr);
1539
1540 return cast<xegpu::LayoutAttr>(layoutAttr);
1541 };
1542
1543 Operation *op = target;
1544 auto walkResult = op->walk([&](mlir::Block *block) -> WalkResult {
1545 for (mlir::Operation &op : llvm::reverse(block->getOperations())) {
1546 LogicalResult r = success();
1548 .Case([&](mlir::RegionBranchTerminatorOpInterface branchTermOp) {
1549 r = updateControlFlowOps(builder, branchTermOp,
1550 getXeGPULayoutForValue);
1551 })
1552 .Case([&](mlir::FunctionOpInterface funcOp) {
1553 r = updateFunctionOpInterface(builder, funcOp,
1554 getXeGPULayoutForValue);
1555 })
1556 .Default([&](Operation *op) {
1557 r = updateOp(builder, op, getXeGPULayoutForValue);
1558 });
1559 if (failed(r)) {
1560 op.emitError("Failed to update operation with the layout.");
1561 return WalkResult::interrupt();
1562 }
1563 }
1564 return WalkResult::advance();
1565 });
1566 if (walkResult.wasInterrupted())
1567 return failure();
1568
1569 return success();
1570}
1571
1573 ResolveLayoutConflicts resolver(target);
1574 return resolver.run();
1575}
1576
1577void XeGPUPropagateLayoutPass::runOnOperation() {
1578 xegpu::LayoutKind layoutKind;
1579 if (this->layoutKind == "lane") {
1580 layoutKind = xegpu::LayoutKind::Lane;
1581 } else if (this->layoutKind == "inst") {
1582 layoutKind = xegpu::LayoutKind::InstData;
1583 } else if (this->layoutKind == "subgroup") {
1584 layoutKind = xegpu::LayoutKind::Subgroup;
1585 } else {
1586 getOperation()->emitError("Unsupported layout kind option: " +
1587 this->layoutKind);
1588 signalPassFailure();
1589 return;
1590 }
1591 OpBuilder builder(&getContext());
1592 if (failed(xegpu::propagateLayouts(builder, getOperation(), layoutKind,
1593 this->printOnly))) {
1594 signalPassFailure();
1595 return;
1596 }
1597 // Resolve layout conflicts if any.
1598 if (failed(xegpu::resolveLayoutConflicts(getOperation()))) {
1599 signalPassFailure();
1600 return;
1601 }
1602}
return success()
#define DBGS()
Definition Hoisting.cpp:32
std::string join(const Ts &...args)
Helper function to concatenate arguments into a std::string.
lhs
b getContext())
auto load
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition TypeID.h:331
static SmallVector< LayoutRepresentation > getValidLayouts(ArrayRef< int64_t > wgShape, ArrayRef< int64_t > instData, int64_t sgCount)
static LogicalResult updateControlFlowOps(mlir::OpBuilder &builder, mlir::RegionBranchTerminatorOpInterface terminator, GetLayoutFnTy getLayoutOfValue)
Region ops like scf.for need special handling because they have blocks inside.
static xegpu::CreateNdDescOp getDefiningCreateNdDescOp(Value tdescValue)
Helper to get the defining CreateNdDescOp of a tensor descriptor value.
function_ref< xegpu::DistributeLayoutAttr(Value)> GetLayoutFnTy
FailureOr< int64_t > getNumSg(Operation *op, const int sgSize)
static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op, GetLayoutFnTy getLayoutOfValue)
Update an operation with the layout of its results.
static LogicalResult updateFunctionOpInterface(mlir::OpBuilder &builder, mlir::FunctionOpInterface funcOp, GetLayoutFnTy getLayoutOfValue)
Update the function arguments and results with the layouts.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:309
Block represents an ordered list of Operations.
Definition Block.h:33
OpListType & getOperations()
Definition Block.h:147
The general data-flow analysis solver.
const StateT * lookupState(AnchorT anchor) const
Lookup an analysis state for the given lattice anchor.
AnalysisT * load(Args &&...args)
Load an analysis into the solver. Return the analysis instance.
LogicalResult initializeAndRun(Operation *top)
Initialize the children analyses starting from the provided top-level operation and run the analysis ...
IRValueT get() const
Return the current value being used by this operand.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:209
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
This class represents an operand of an operation.
Definition Value.h:257
This is a value defined by a result of an operation.
Definition Value.h:457
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:383
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:238
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
void print(raw_ostream &os, const OpPrintingFlags &flags={})
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:797
result_range getResults()
Definition Operation.h:415
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
This class represents a lattice holding a specific value of type ValueT.
A sparse (backward) data-flow analysis for propagating SSA value lattices backwards across the IR by ...
SparseBackwardDataFlowAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
::mlir::Pass::Option< std::string > layoutKind
void loadBaselineAnalyses(DataFlowSolver &solver)
Populates a DataFlowSolver with analyses that are required to ensure user-defined analyses are run pr...
Definition Utils.h:29
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
const uArch * getUArch(llvm::StringRef archName)
DistributeLayoutAttr inferShapeCastSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for a shape cast operation given the result layout attribute,...
LogicalResult propagateLayouts(OpBuilder &builder, Operation *target, LayoutKind layoutKind, bool printOnly=false)
SliceAttr setupMultiReductionResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, DistributeLayoutAttr consumerLayout, SmallVector< int64_t > reductionDims, const uArch::uArch *uArch)
Sets up layout for reduction operations by creating a SliceAttr for the result.
DistributeLayoutAttr inferInsertStridedSliceSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for an insert strided slice operation given the result layout attr...
void setTemporaryLayout(const T &operandOrResult, const DistributeLayoutAttr layout)
std::optional< std::tuple< DistributeLayoutAttr, DistributeLayoutAttr, DistributeLayoutAttr > > setupDpasLayout(LayoutKind layoutKind, VectorType aTy, VectorType bTy, VectorType cdTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch, int numSg)
Sets up the anchor layouts for a dpas operands (A, B, and C/D).
void setDistributeLayoutAttr(const OpResult &Result, const DistributeLayoutAttr layout)
[to-be-deprecated] Sets the DistributeLayoutAttr for a given OpResult user should use setAnchorLayout...
DistributeLayoutAttr setupLoadMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the anchor layout for load matrix operation.
int getLargestDivisor(T dim, ArrayRef< T > candidates, ArrayRef< T > candidateMultiples={})
Helper Function to find a proper instruction multiple for the user-supplied sg-level data shape (dive...
DistributeLayoutAttr inferBroadcastSourceLayout(DistributeLayoutAttr resLayout, ArrayRef< int64_t > resShape, ArrayRef< int64_t > srcShape)
Infers the source layout attribute for a broadcast operation given the result layout attribute,...
DistributeLayoutAttr setupStoreScatterAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, int chunkSize, const uArch::uArch *uArch)
Sets up the anchor layout for a store scatter operation.
DistributeLayoutAttr setupBitCastResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Setup the result layout attribute for a bitcast operation based on element type bitwidths.
LogicalResult resolveLayoutConflicts(Operation *target)
DistributeLayoutAttr inferBitCastSourceLayout(DistributeLayoutAttr resLayout, int resElemTyBitWidth, int srcElemTyBitWidth)
Infers the source layout attribute for a bitcast operation given the result layout attribute,...
DistributeLayoutAttr setupInsertStridedSliceResultLayout(LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the result layout for an insert strided slice operation.
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
DistributeLayoutAttr inferMultiReductionSourceLayout(DistributeLayoutAttr resLayout, SmallVector< int64_t > reduceDims)
Infers the source layout attribute for a reduction operation given the result layout attribute and re...
DistributeLayoutAttr setupLoadGatherAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, int chunkSize, DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch)
Sets up the anchor layout for a load gather operation.
DistributeLayoutAttr setupStoreMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy, const uArch::uArch *uArch)
Sets up the anchor layout for a store matrix operation.
Include the generated interface declarations.
DenseMap< OpOperand *, SmallVector< Value > > RegionBranchSuccessorMapping
A mapping from successor operands to successor inputs.
bool operator==(StringAttr lhs, std::nullptr_t)
Define comparisons for StringAttr against nullptr and itself to avoid the StringRef overloads from be...
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:136
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144
virtual int getSubgroupSize() const =0
const Instruction * getInstruction(InstructionKind instKind) const
Definition uArchBase.h:169