30 #include "llvm/ADT/ArrayRef.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/SmallVector.h"
33 #include "llvm/ADT/TypeSwitch.h"
34 #include "llvm/Support/Casting.h"
35 #include "llvm/Support/Debug.h"
36 #include "llvm/Support/InterleavedRange.h"
37 #include "llvm/Support/LogicalResult.h"
38 #include "llvm/Support/raw_ostream.h"
42 #define GEN_PASS_DEF_XEGPUPROPAGATELAYOUT
43 #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
47 #define DEBUG_TYPE "xegpu-propagate-layout"
48 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
64 Layout(std::initializer_list<int64_t> list) : layout(list) {}
65 void print(llvm::raw_ostream &os)
const;
66 size_t size()
const {
return layout.size(); }
70 os << llvm::interleaved_array(layout);
76 using LaneLayout = Layout;
77 using LaneData = Layout;
101 LaneLayout laneLayout;
103 xegpu::LayoutAttr layoutAttr;
106 LayoutInfo() =
default;
107 LayoutInfo(
const LaneLayout &layout,
const LaneData &data)
108 : laneLayout(layout), laneData(data) {}
112 bool operator==(
const LayoutInfo &other)
const {
113 return this->isAssigned() == other.isAssigned();
116 static LayoutInfo meet(
const LayoutInfo &lhs,
const LayoutInfo &rhs);
118 static LayoutInfo join(
const LayoutInfo &lhs,
const LayoutInfo &rhs);
120 void print(raw_ostream &os)
const;
122 bool isAssigned()
const {
123 return laneLayout.size() > 0 && laneData.size() > 0;
128 const LaneLayout &getLayout()
const {
return laneLayout; }
129 const LaneData &getData()
const {
return laneData; }
136 os <<
"lane_layout: ";
137 laneLayout.print(os);
138 os <<
", lane_data: ";
141 os <<
"Not assigned.";
145 LayoutInfo LayoutInfo::meet(
const LayoutInfo &lhs,
const LayoutInfo &rhs) {
146 if (!lhs.isAssigned())
152 LayoutInfo LayoutInfo::join(
const LayoutInfo &lhs,
const LayoutInfo &rhs) {
153 llvm_unreachable(
"Join should not be triggered by layout propagation.");
161 LaneLayout newLayout;
163 for (int64_t idx : permutation) {
164 newLayout.layout.push_back(laneLayout.layout[idx]);
165 newData.layout.push_back(laneData.layout[idx]);
167 return LayoutInfo(newLayout, newData);
175 struct LayoutInfoLattice :
public Lattice<LayoutInfo> {
177 using Lattice::Lattice;
187 static LayoutInfo getDefaultSIMTLayoutInfo(
unsigned rank) {
188 assert((rank == 1 || rank == 2) &&
"Expected 1D or 2D vector.");
197 static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy) {
199 assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
200 "Expected 1D or 2D vector.");
202 assert(vectorTy.getElementType().isIntOrFloat() &&
203 "Expected int or float element type.");
205 if (vectorTy.getRank() == 1)
206 return getDefaultSIMTLayoutInfo(1);
208 int packingFactor = 1;
209 unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
213 LaneData({1, packingFactor}));
217 static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy) {
219 assert((tdescTy.getRank() == 1 || tdescTy.getRank() == 2) &&
220 "Expected 1D or 2D TensorDesc.");
222 assert(tdescTy.getElementType().isIntOrFloat() &&
223 "Expected int or float element type.");
225 if (tdescTy.getRank() == 1)
226 return getDefaultSIMTLayoutInfo(1);
228 unsigned bitwidth = tdescTy.getElementType().getIntOrFloatBitWidth();
230 if (tdescTy.isScattered()) {
236 LaneData({1, packingFactor}));
244 LaneData({1, packingFactor}));
253 static LayoutInfo getSIMTLayoutInfoForDPASOperand(VectorType vectorTy,
254 unsigned operandNum) {
255 Type elementTy = vectorTy.getElementType();
257 "Expected int or float type in DPAS operands");
266 return LayoutInfo(layout, data);
269 return getDefaultSIMTLayoutInfo(vectorTy);
281 class LayoutInfoPropagation
287 void visitStoreNdOp(xegpu::StoreNdOp store,
291 void visitStoreScatterOp(xegpu::StoreScatterOp storeScatter,
295 void visitLoadNdOp(xegpu::LoadNdOp load,
299 void visitLoadGatherOp(xegpu::LoadGatherOp load,
303 void visitTransposeOp(vector::TransposeOp transpose,
307 void visitVectorBitcastOp(vector::BitCastOp bitcast,
311 void visitCreateDescOp(xegpu::CreateDescOp createDesc,
315 void visitUpdateNdOffsetOp(xegpu::UpdateNdOffsetOp updateNdOffset,
319 void visitPrefetchNdOp(xegpu::PrefetchNdOp prefetch,
323 void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction,
337 void visitBranchOperand(
OpOperand &operand)
override {};
339 void visitCallOperand(
OpOperand &operand)
override {};
341 void visitExternalCall(CallOpInterface call,
346 void setToExitState(LayoutInfoLattice *lattice)
override {
347 (void)lattice->meet(LayoutInfo());
352 LogicalResult LayoutInfoPropagation::visitOperation(
356 .Case<xegpu::DpasOp>(
357 [&](
auto dpasOp) { visitDpasOp(dpasOp, operands, results); })
358 .Case<xegpu::StoreNdOp>(
359 [&](
auto storeNdOp) { visitStoreNdOp(storeNdOp, operands, results); })
360 .Case<xegpu::StoreScatterOp>([&](
auto storeScatterOp) {
361 visitStoreScatterOp(storeScatterOp, operands, results);
363 .Case<xegpu::LoadNdOp>(
364 [&](
auto loadNdOp) { visitLoadNdOp(loadNdOp, operands, results); })
365 .Case<xegpu::LoadGatherOp>([&](
auto loadGatherOp) {
366 visitLoadGatherOp(loadGatherOp, operands, results);
368 .Case<xegpu::CreateDescOp>([&](
auto createDescOp) {
369 visitCreateDescOp(createDescOp, operands, results);
371 .Case<xegpu::UpdateNdOffsetOp>([&](
auto updateNdOffsetOp) {
372 visitUpdateNdOffsetOp(updateNdOffsetOp, operands, results);
374 .Case<xegpu::PrefetchNdOp>([&](
auto prefetchNdOp) {
375 visitPrefetchNdOp(prefetchNdOp, operands, results);
377 .Case<vector::TransposeOp>([&](
auto transposeOp) {
378 visitTransposeOp(transposeOp, operands, results);
380 .Case<vector::BitCastOp>([&](
auto bitcastOp) {
381 visitVectorBitcastOp(bitcastOp, operands, results);
383 .Case<vector::MultiDimReductionOp>([&](
auto reductionOp) {
384 visitVectorMultiReductionOp(reductionOp, operands, results);
388 for (
const LayoutInfoLattice *resultInfo : results) {
389 if (!resultInfo->getValue().isAssigned())
391 for (
auto [operandInfo, operand] :
395 if (!isa<xegpu::TensorDescType, VectorType>(
396 operand.get().getType()))
399 meet(operandInfo, *resultInfo);
407 void LayoutInfoPropagation::visitPrefetchNdOp(
412 auto tdescTy = prefetch.getTensorDescType();
413 auto prefetchLayout = getDefaultSIMTLayoutInfo(tdescTy);
415 propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout));
418 void LayoutInfoPropagation::visitVectorMultiReductionOp(
419 vector::MultiDimReductionOp reduction,
423 LayoutInfo resultLayout = results[0]->getValue();
424 if (!resultLayout.isAssigned())
427 VectorType resultTy = llvm::dyn_cast<VectorType>(reduction.getDestType());
428 if (!resultTy || resultTy.getRank() != 1) {
429 reduction.emitWarning(
"Expecting output type to be 1D vector.");
434 LayoutInfo operandLayout = getDefaultSIMTLayoutInfo(2);
435 propagateIfChanged(operands[0], operands[0]->meet(operandLayout));
437 propagateIfChanged(operands[1], operands[1]->meet(resultLayout));
442 void LayoutInfoPropagation::visitUpdateNdOffsetOp(
443 xegpu::UpdateNdOffsetOp updateNdOffset,
447 LayoutInfo resultLayout = results[0]->getValue();
448 if (!resultLayout.isAssigned())
451 propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
455 void LayoutInfoPropagation::visitDpasOp(
458 VectorType aTy = dpas.getLhsType();
459 VectorType bTy = dpas.getRhsType();
461 operands[0], operands[0]->meet(getSIMTLayoutInfoForDPASOperand(aTy, 0)));
463 operands[1], operands[1]->meet(getSIMTLayoutInfoForDPASOperand(bTy, 1)));
464 if (operands.size() > 2) {
465 VectorType cTy = dpas.getAccType();
468 operands[2]->meet(getSIMTLayoutInfoForDPASOperand(cTy, 2)));
473 void LayoutInfoPropagation::visitStoreNdOp(
476 LayoutInfo storeLayout = getDefaultSIMTLayoutInfo(store.getValueType());
478 for (LayoutInfoLattice *operand : operands)
479 propagateIfChanged(operand, operand->meet(storeLayout));
484 void LayoutInfoPropagation::visitLoadNdOp(
487 LayoutInfo valueLayout = results[0]->getValue();
489 if (!valueLayout.isAssigned())
491 LayoutInfo tensorDescLayout = valueLayout;
495 if (
auto transpose = load.getTranspose()) {
496 load.emitWarning(
"Transpose effect is not expected for LoadNdOp at "
497 "LayoutInfoPropagation stage.");
498 tensorDescLayout = valueLayout.getTransposedLayout(transpose.value());
501 propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
506 void LayoutInfoPropagation::visitTransposeOp(
510 LayoutInfo resultLayout = results[0]->getValue();
511 if (!resultLayout.isAssigned())
513 LayoutInfo newLayout =
514 resultLayout.getTransposedLayout(transpose.getPermutation());
516 propagateIfChanged(operands[0], operands[0]->meet(newLayout));
521 void LayoutInfoPropagation::visitVectorBitcastOp(
525 LayoutInfo resultLayout = results[0]->getValue();
526 if (!resultLayout.isAssigned())
528 int inElemTyBitWidth =
529 bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
530 int outElemTyBitWidth =
531 bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
535 if (inElemTyBitWidth != outElemTyBitWidth) {
536 bitcast.emitWarning(
"Widening or narrowing bitcasts are not expected at "
537 "layout propagation stage.");
541 propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
546 void LayoutInfoPropagation::visitLoadGatherOp(
550 LayoutInfo layout = getDefaultSIMTLayoutInfo(load.getTensorDescType());
553 LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(1);
556 propagateIfChanged(operands[0], operands[0]->meet(layout));
558 propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
563 void LayoutInfoPropagation::visitCreateDescOp(
566 LayoutInfo descLayout = results[0]->getValue();
568 if (!descLayout.isAssigned())
571 LayoutInfo layout = getDefaultSIMTLayoutInfo(1);
572 propagateIfChanged(operands[1], operands[1]->meet(layout));
577 void LayoutInfoPropagation::visitStoreScatterOp(
584 if (tdescShape.size() > 1)
587 "Expected the first dimension of 2D tensor descriptor to be equal to "
591 getDefaultSIMTLayoutInfo(storeScatter.getTensorDescType());
594 propagateIfChanged(operands[0], operands[0]->meet(layout));
596 propagateIfChanged(operands[1], operands[1]->meet(layout));
598 LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(1);
599 propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
608 class RunLayoutInfoPropagation {
612 RunLayoutInfoPropagation(
Operation *op) : target(op) {
615 solver.load<LayoutInfoPropagation>(symbolTable);
616 (void)solver.initializeAndRun(op);
619 LayoutInfo getLayoutInfo(
Value val);
621 void printAnalysisResult(llvm::raw_ostream &os);
629 LayoutInfo RunLayoutInfoPropagation::getLayoutInfo(
Value val) {
630 auto *state = solver.lookupState<LayoutInfoLattice>(val);
633 return state->getValue();
637 void RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) {
638 auto printFunctionResult = [&](FunctionOpInterface funcOp) {
639 os <<
"function: " << funcOp.getName() <<
":\n";
642 LayoutInfo layout = getLayoutInfo(arg);
643 os <<
"argument: " << arg <<
"\n";
655 if (isa<BranchOpInterface>(op) || isa<RegionBranchOpInterface>(op))
662 LayoutInfo layout = getLayoutInfo(r);
663 os <<
"layout for result #" << i <<
": ";
671 if (
auto modOp = dyn_cast<ModuleOp>(target)) {
672 for (
auto funcOp : modOp.getOps<FunctionOpInterface>())
673 funcOps.push_back(funcOp);
676 for (
auto gpuModOp : modOp.getOps<gpu::GPUModuleOp>()) {
677 for (
auto gpuFuncOp : gpuModOp.getOps<FunctionOpInterface>())
678 funcOps.push_back(gpuFuncOp);
682 for (FunctionOpInterface funcOp : funcOps)
683 printFunctionResult(funcOp);
695 if (mlir::isa<mlir::RegionBranchOpInterface>(op))
700 Type resultType = result.getType();
702 if (!isa<VectorType, xegpu::TensorDescType>(resultType))
705 xegpu::LayoutAttr layout = getLayoutOfValue(result);
706 if (!layout && result.getNumUses() > 0) {
707 op->
emitWarning(
"op has users but no layout assigned for its result");
712 if (
auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType)) {
714 tensorDescTy.getContext(), tensorDescTy.getShape(),
715 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
716 result.setType(typeWithLayout);
750 mlir::RegionBranchTerminatorOpInterface terminator,
753 if (!mlir::isa<mlir::RegionBranchOpInterface>(terminator->getParentOp()))
759 terminator.getSuccessorRegions(operands, successors);
763 terminator.getSuccessorOperands(successor);
765 for (
auto [successorOperand, successorInput] :
766 llvm::zip(successorOperands, successorInputs)) {
767 Type inputType = successorInput.getType();
769 if (!isa<xegpu::TensorDescType, VectorType>(inputType))
771 xegpu::LayoutAttr successorInputLayout = getLayoutOfValue(successorInput);
772 xegpu::LayoutAttr successorOperandLayout =
773 getLayoutOfValue(successorOperand);
776 if (!successorOperandLayout) {
779 <<
"No layout assigned for forwarded operand in branch terminator: "
780 << successorOperand <<
"\n");
784 if (successorInputLayout &&
785 successorInputLayout != successorOperandLayout) {
786 LLVM_DEBUG(
DBGS() <<
"Conflicting layouts for region argument and "
787 "operand forwarded as the argument: "
788 << successorInputLayout <<
" vs "
789 << successorOperandLayout <<
"\n");
793 if (
auto tdescTy = dyn_cast<xegpu::TensorDescType>(inputType)) {
795 tdescTy.getContext(), tdescTy.getShape(), tdescTy.getElementType(),
796 tdescTy.getEncoding(), successorOperandLayout);
797 successorInput.setType(newTdescTy);
802 if (
auto result = dyn_cast<OpResult>(successorInput))
811 mlir::FunctionOpInterface funcOp,
816 Type argType = arg.getType();
817 newArgTypes.push_back(argType);
818 if (!isa<VectorType, xegpu::TensorDescType>(argType))
820 xegpu::LayoutAttr layout = getLayoutOfValue(arg);
822 LLVM_DEBUG(
DBGS() <<
"Expecting layout for function argument: " << arg
823 <<
" but got none.\n");
826 if (
auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(argType)) {
828 tensorDescTy.getContext(), tensorDescTy.getShape(),
829 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
830 arg.setType(newTdescTy);
831 newArgTypes.back() = newTdescTy;
837 funcOp.getResultTypes()));
842 struct XeGPUPropagateLayoutPass final
843 :
public xegpu::impl::XeGPUPropagateLayoutBase<XeGPUPropagateLayoutPass> {
844 XeGPUPropagateLayoutPass() =
default;
845 XeGPUPropagateLayoutPass(
const XeGPUPropagateLayoutPass &other) =
default;
846 XeGPUPropagateLayoutPass(xegpu::XeGPUPropagateLayoutOptions
options)
847 : XeGPUPropagateLayoutBase(
options) {}
848 void runOnOperation()
override;
853 void XeGPUPropagateLayoutPass::runOnOperation() {
854 auto &analysis = getAnalysis<RunLayoutInfoPropagation>();
857 auto &os = llvm::outs();
858 analysis.printAnalysisResult(os);
862 auto getXeGPULayoutForValue = [&](
Value val) -> xegpu::LayoutAttr {
863 LayoutInfo layout = analysis.getLayoutInfo(val);
864 if (!layout.isAssigned())
867 val.
getContext(), llvm::to_vector_of<int>(layout.getLayoutAsArrayRef()),
868 llvm::to_vector_of<int>(layout.getDataAsArrayRef()));
875 LogicalResult r = success();
876 TypeSwitch<Operation *>(&op)
877 .Case<mlir::RegionBranchTerminatorOpInterface>(
878 [&](mlir::RegionBranchTerminatorOpInterface branchTermOp) {
879 r = updateControlFlowOps(builder, branchTermOp,
880 getXeGPULayoutForValue);
882 .Case<mlir::FunctionOpInterface>(
883 [&](mlir::FunctionOpInterface funcOp) {
885 getXeGPULayoutForValue);
888 r =
updateOp(builder, op, getXeGPULayoutForValue);
891 op.
emitError(
"Failed to update operation with the layout.");
897 if (walkResult.wasInterrupted()) {
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
static LogicalResult updateControlFlowOps(mlir::OpBuilder &builder, mlir::RegionBranchTerminatorOpInterface terminator, GetLayoutFnTy getLayoutOfValue)
Region ops like scf.for need special handling because they have blocks inside.
static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op, GetLayoutFnTy getLayoutOfValue)
Update an operation with the layout of its results.
static LogicalResult updateFunctionOpInterface(mlir::OpBuilder &builder, mlir::FunctionOpInterface funcOp, GetLayoutFnTy getLayoutOfValue)
Update the function arguments and results with the layouts.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
OpListType & getOperations()
The general data-flow analysis solver.
This class helps build Operations.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
void print(raw_ostream &os, const OpPrintingFlags &flags={})
MutableArrayRef< OpOperand > getOpOperands()
result_range getResults()
This class represents a successor of a region.
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
static WalkResult interrupt()
This class represents a lattice holding a specific value of type ValueT.
A sparse (backward) data-flow analysis for propagating SSA value lattices backwards across the IR by ...
SparseBackwardDataFlowAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
void loadBaselineAnalyses(DataFlowSolver &solver)
Populates a DataFlowSolver with analyses that are required to ensure user-defined analyses are run pr...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
constexpr unsigned packedSizeInBitsForDpasB
constexpr unsigned subgroupSize
constexpr unsigned packedSizeInBitsForGatherScatter
constexpr unsigned packedSizeInBitsForDefault
If DPAS A or B operands have low precision element types they must be packed according to the followi...
void setLayoutAttr(const T &operandOrResult, const LayoutAttr layout)
Sets the LayoutAttr for a given OpOperand or OpResult by attaching it to the owner's dictionary attri...
Include the generated interface declarations.
bool operator==(StringAttr lhs, std::nullptr_t)
Define comparisons for StringAttr against nullptr and itself to avoid the StringRef overloads from be...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...