31 #include "llvm/ADT/ArrayRef.h"
32 #include "llvm/ADT/STLExtras.h"
33 #include "llvm/ADT/SmallVector.h"
34 #include "llvm/ADT/TypeSwitch.h"
35 #include "llvm/Support/Casting.h"
36 #include "llvm/Support/Debug.h"
37 #include "llvm/Support/InterleavedRange.h"
38 #include "llvm/Support/LogicalResult.h"
39 #include "llvm/Support/raw_ostream.h"
43 #define GEN_PASS_DEF_XEGPUPROPAGATELAYOUT
44 #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
48 #define DEBUG_TYPE "xegpu-propagate-layout"
49 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
65 Layout(std::initializer_list<int64_t> list) : layout(list) {}
66 void print(llvm::raw_ostream &os)
const;
67 size_t size()
const {
return layout.size(); }
71 os << llvm::interleaved_array(layout);
77 using LaneLayout = Layout;
78 using LaneData = Layout;
102 LaneLayout laneLayout;
104 xegpu::LayoutAttr layoutAttr;
107 LayoutInfo() =
default;
108 LayoutInfo(
const LaneLayout &layout,
const LaneData &data)
109 : laneLayout(layout), laneData(data) {}
113 bool operator==(
const LayoutInfo &other)
const {
114 return this->isAssigned() == other.isAssigned();
117 static LayoutInfo meet(
const LayoutInfo &lhs,
const LayoutInfo &rhs);
119 static LayoutInfo join(
const LayoutInfo &lhs,
const LayoutInfo &rhs);
121 void print(raw_ostream &os)
const;
123 bool isAssigned()
const {
124 return laneLayout.size() > 0 && laneData.size() > 0;
129 const LaneLayout &getLayout()
const {
return laneLayout; }
130 const LaneData &getData()
const {
return laneData; }
137 os <<
"lane_layout: ";
138 laneLayout.print(os);
139 os <<
", lane_data: ";
142 os <<
"Not assigned.";
146 LayoutInfo LayoutInfo::meet(
const LayoutInfo &lhs,
const LayoutInfo &rhs) {
147 if (!lhs.isAssigned())
153 LayoutInfo LayoutInfo::join(
const LayoutInfo &lhs,
const LayoutInfo &rhs) {
154 llvm_unreachable(
"Join should not be triggered by layout propagation.");
162 LaneLayout newLayout;
164 for (int64_t idx : permutation) {
165 newLayout.layout.push_back(laneLayout.layout[idx]);
166 newData.layout.push_back(laneData.layout[idx]);
168 return LayoutInfo(newLayout, newData);
176 struct LayoutInfoLattice :
public Lattice<LayoutInfo> {
178 using Lattice::Lattice;
188 static LayoutInfo getDefaultSIMTLayoutInfo(
unsigned rank) {
189 assert((rank == 1 || rank == 2) &&
"Expected 1D or 2D vector.");
198 static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy) {
200 assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
201 "Expected 1D or 2D vector.");
203 assert(vectorTy.getElementType().isIntOrFloat() &&
204 "Expected int or float element type.");
206 if (vectorTy.getRank() == 1)
207 return getDefaultSIMTLayoutInfo(1);
209 int packingFactor = 1;
210 unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
214 LaneData({1, packingFactor}));
218 static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy) {
220 assert((tdescTy.getRank() == 1 || tdescTy.getRank() == 2) &&
221 "Expected 1D or 2D TensorDesc.");
223 assert(tdescTy.getElementType().isIntOrFloat() &&
224 "Expected int or float element type.");
226 if (tdescTy.getRank() == 1)
227 return getDefaultSIMTLayoutInfo(1);
229 unsigned bitwidth = tdescTy.getElementType().getIntOrFloatBitWidth();
231 if (tdescTy.isScattered()) {
237 LaneData({1, packingFactor}));
245 LaneData({1, packingFactor}));
254 static LayoutInfo getSIMTLayoutInfoForDPASOperand(VectorType vectorTy,
255 unsigned operandNum) {
256 Type elementTy = vectorTy.getElementType();
258 "Expected int or float type in DPAS operands");
267 return LayoutInfo(layout, data);
270 return getDefaultSIMTLayoutInfo(vectorTy);
282 class LayoutInfoPropagation
288 void visitStoreNdOp(xegpu::StoreNdOp store,
292 void visitStoreScatterOp(xegpu::StoreScatterOp storeScatter,
296 void visitLoadNdOp(xegpu::LoadNdOp load,
300 void visitLoadGatherOp(xegpu::LoadGatherOp load,
304 void visitTransposeOp(vector::TransposeOp transpose,
308 void visitVectorBitcastOp(vector::BitCastOp bitcast,
312 void visitCreateDescOp(xegpu::CreateDescOp createDesc,
316 void visitUpdateNdOffsetOp(xegpu::UpdateNdOffsetOp updateNdOffset,
320 void visitPrefetchNdOp(xegpu::PrefetchNdOp prefetch,
324 void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction,
338 void visitBranchOperand(
OpOperand &operand)
override {};
340 void visitCallOperand(
OpOperand &operand)
override {};
342 void visitExternalCall(CallOpInterface call,
347 void setToExitState(LayoutInfoLattice *lattice)
override {
348 (void)lattice->meet(LayoutInfo());
353 LogicalResult LayoutInfoPropagation::visitOperation(
357 .Case<xegpu::DpasOp>(
358 [&](
auto dpasOp) { visitDpasOp(dpasOp, operands, results); })
359 .Case<xegpu::StoreNdOp>(
360 [&](
auto storeNdOp) { visitStoreNdOp(storeNdOp, operands, results); })
361 .Case<xegpu::StoreScatterOp>([&](
auto storeScatterOp) {
362 visitStoreScatterOp(storeScatterOp, operands, results);
364 .Case<xegpu::LoadNdOp>(
365 [&](
auto loadNdOp) { visitLoadNdOp(loadNdOp, operands, results); })
366 .Case<xegpu::LoadGatherOp>([&](
auto loadGatherOp) {
367 visitLoadGatherOp(loadGatherOp, operands, results);
369 .Case<xegpu::CreateDescOp>([&](
auto createDescOp) {
370 visitCreateDescOp(createDescOp, operands, results);
372 .Case<xegpu::UpdateNdOffsetOp>([&](
auto updateNdOffsetOp) {
373 visitUpdateNdOffsetOp(updateNdOffsetOp, operands, results);
375 .Case<xegpu::PrefetchNdOp>([&](
auto prefetchNdOp) {
376 visitPrefetchNdOp(prefetchNdOp, operands, results);
378 .Case<vector::TransposeOp>([&](
auto transposeOp) {
379 visitTransposeOp(transposeOp, operands, results);
381 .Case<vector::BitCastOp>([&](
auto bitcastOp) {
382 visitVectorBitcastOp(bitcastOp, operands, results);
384 .Case<vector::MultiDimReductionOp>([&](
auto reductionOp) {
385 visitVectorMultiReductionOp(reductionOp, operands, results);
389 for (
const LayoutInfoLattice *resultInfo : results) {
390 if (!resultInfo->getValue().isAssigned())
392 for (
auto [operandInfo, operand] :
396 if (!isa<xegpu::TensorDescType, VectorType>(
397 operand.get().getType()))
400 meet(operandInfo, *resultInfo);
408 void LayoutInfoPropagation::visitPrefetchNdOp(
413 auto tdescTy = prefetch.getTensorDescType();
414 auto prefetchLayout = getDefaultSIMTLayoutInfo(tdescTy);
416 propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout));
419 void LayoutInfoPropagation::visitVectorMultiReductionOp(
420 vector::MultiDimReductionOp reduction,
424 LayoutInfo resultLayout = results[0]->getValue();
425 if (!resultLayout.isAssigned())
428 VectorType resultTy = llvm::dyn_cast<VectorType>(reduction.getDestType());
429 if (!resultTy || resultTy.getRank() != 1) {
430 reduction.emitWarning(
"Expecting output type to be 1D vector.");
435 LayoutInfo operandLayout = getDefaultSIMTLayoutInfo(2);
436 propagateIfChanged(operands[0], operands[0]->meet(operandLayout));
438 propagateIfChanged(operands[1], operands[1]->meet(resultLayout));
443 void LayoutInfoPropagation::visitUpdateNdOffsetOp(
444 xegpu::UpdateNdOffsetOp updateNdOffset,
448 LayoutInfo resultLayout = results[0]->getValue();
449 if (!resultLayout.isAssigned())
452 propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
456 void LayoutInfoPropagation::visitDpasOp(
459 VectorType aTy = dpas.getLhsType();
460 VectorType bTy = dpas.getRhsType();
462 operands[0], operands[0]->meet(getSIMTLayoutInfoForDPASOperand(aTy, 0)));
464 operands[1], operands[1]->meet(getSIMTLayoutInfoForDPASOperand(bTy, 1)));
465 if (operands.size() > 2) {
466 VectorType cTy = dpas.getAccType();
469 operands[2]->meet(getSIMTLayoutInfoForDPASOperand(cTy, 2)));
474 void LayoutInfoPropagation::visitStoreNdOp(
477 LayoutInfo storeLayout = getDefaultSIMTLayoutInfo(store.getValueType());
479 for (LayoutInfoLattice *operand : operands)
480 propagateIfChanged(operand, operand->meet(storeLayout));
485 void LayoutInfoPropagation::visitLoadNdOp(
488 LayoutInfo valueLayout = results[0]->getValue();
490 if (!valueLayout.isAssigned())
492 LayoutInfo tensorDescLayout = valueLayout;
496 if (
auto transpose = load.getTranspose()) {
497 load.emitWarning(
"Transpose effect is not expected for LoadNdOp at "
498 "LayoutInfoPropagation stage.");
499 tensorDescLayout = valueLayout.getTransposedLayout(transpose.value());
502 propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
507 void LayoutInfoPropagation::visitTransposeOp(
511 LayoutInfo resultLayout = results[0]->getValue();
512 if (!resultLayout.isAssigned())
514 LayoutInfo newLayout =
515 resultLayout.getTransposedLayout(transpose.getPermutation());
517 propagateIfChanged(operands[0], operands[0]->meet(newLayout));
522 void LayoutInfoPropagation::visitVectorBitcastOp(
526 LayoutInfo resultLayout = results[0]->getValue();
527 if (!resultLayout.isAssigned())
529 int inElemTyBitWidth =
530 bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
531 int outElemTyBitWidth =
532 bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
536 if (inElemTyBitWidth != outElemTyBitWidth) {
537 bitcast.emitWarning(
"Widening or narrowing bitcasts are not expected at "
538 "layout propagation stage.");
542 propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
547 void LayoutInfoPropagation::visitLoadGatherOp(
551 LayoutInfo layout = getDefaultSIMTLayoutInfo(load.getTensorDescType());
554 LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(1);
557 propagateIfChanged(operands[0], operands[0]->meet(layout));
559 propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
564 void LayoutInfoPropagation::visitCreateDescOp(
567 LayoutInfo descLayout = results[0]->getValue();
569 if (!descLayout.isAssigned())
572 LayoutInfo layout = getDefaultSIMTLayoutInfo(1);
573 propagateIfChanged(operands[1], operands[1]->meet(layout));
578 void LayoutInfoPropagation::visitStoreScatterOp(
585 if (tdescShape.size() > 1)
588 "Expected the first dimension of 2D tensor descriptor to be equal to "
592 getDefaultSIMTLayoutInfo(storeScatter.getTensorDescType());
595 propagateIfChanged(operands[0], operands[0]->meet(layout));
597 propagateIfChanged(operands[1], operands[1]->meet(layout));
599 LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(1);
600 propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
609 class RunLayoutInfoPropagation {
613 RunLayoutInfoPropagation(
Operation *op) : target(op) {
616 solver.load<LayoutInfoPropagation>(symbolTable);
617 (void)solver.initializeAndRun(op);
620 LayoutInfo getLayoutInfo(
Value val);
622 void printAnalysisResult(llvm::raw_ostream &os);
630 LayoutInfo RunLayoutInfoPropagation::getLayoutInfo(
Value val) {
631 auto *state = solver.lookupState<LayoutInfoLattice>(val);
634 return state->getValue();
638 void RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) {
639 auto printFunctionResult = [&](FunctionOpInterface funcOp) {
640 os <<
"function: " << funcOp.getName() <<
":\n";
643 LayoutInfo layout = getLayoutInfo(arg);
644 os <<
"argument: " << arg <<
"\n";
656 if (isa<BranchOpInterface>(op) || isa<RegionBranchOpInterface>(op))
663 LayoutInfo layout = getLayoutInfo(r);
664 os <<
"layout for result #" << i <<
": ";
672 if (
auto modOp = dyn_cast<ModuleOp>(target)) {
673 for (
auto funcOp : modOp.getOps<FunctionOpInterface>())
674 funcOps.push_back(funcOp);
677 for (
auto gpuModOp : modOp.getOps<gpu::GPUModuleOp>()) {
678 for (
auto gpuFuncOp : gpuModOp.getOps<FunctionOpInterface>())
679 funcOps.push_back(gpuFuncOp);
683 for (FunctionOpInterface funcOp : funcOps)
684 printFunctionResult(funcOp);
696 if (mlir::isa<mlir::RegionBranchOpInterface>(op))
701 Type resultType = result.getType();
703 if (!isa<VectorType, xegpu::TensorDescType>(resultType))
706 xegpu::LayoutAttr layout = getLayoutOfValue(result);
707 if (!layout && result.getNumUses() > 0) {
708 op->
emitWarning(
"op has users but no layout assigned for its result");
713 if (
auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType)) {
715 tensorDescTy.getContext(), tensorDescTy.getShape(),
716 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
717 result.setType(typeWithLayout);
751 mlir::RegionBranchTerminatorOpInterface terminator,
754 if (!mlir::isa<mlir::RegionBranchOpInterface>(terminator->getParentOp()))
760 terminator.getSuccessorRegions(operands, successors);
764 terminator.getSuccessorOperands(successor);
766 for (
auto [successorOperand, successorInput] :
767 llvm::zip(successorOperands, successorInputs)) {
768 Type inputType = successorInput.getType();
770 if (!isa<xegpu::TensorDescType, VectorType>(inputType))
772 xegpu::LayoutAttr successorInputLayout = getLayoutOfValue(successorInput);
773 xegpu::LayoutAttr successorOperandLayout =
774 getLayoutOfValue(successorOperand);
777 if (!successorOperandLayout) {
780 <<
"No layout assigned for forwarded operand in branch terminator: "
781 << successorOperand <<
"\n");
785 if (successorInputLayout &&
786 successorInputLayout != successorOperandLayout) {
787 LLVM_DEBUG(
DBGS() <<
"Conflicting layouts for region argument and "
788 "operand forwarded as the argument: "
789 << successorInputLayout <<
" vs "
790 << successorOperandLayout <<
"\n");
794 if (
auto tdescTy = dyn_cast<xegpu::TensorDescType>(inputType)) {
796 tdescTy.getContext(), tdescTy.getShape(), tdescTy.getElementType(),
797 tdescTy.getEncoding(), successorOperandLayout);
798 successorInput.setType(newTdescTy);
803 if (
auto result = dyn_cast<OpResult>(successorInput))
812 mlir::FunctionOpInterface funcOp,
817 Type argType = arg.getType();
818 newArgTypes.push_back(argType);
819 if (!isa<VectorType, xegpu::TensorDescType>(argType))
821 xegpu::LayoutAttr layout = getLayoutOfValue(arg);
823 LLVM_DEBUG(
DBGS() <<
"Expecting layout for function argument: " << arg
824 <<
" but got none.\n");
827 if (
auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(argType)) {
829 tensorDescTy.getContext(), tensorDescTy.getShape(),
830 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
831 arg.setType(newTdescTy);
832 newArgTypes.back() = newTdescTy;
838 funcOp.getResultTypes()));
843 struct XeGPUPropagateLayoutPass final
844 :
public xegpu::impl::XeGPUPropagateLayoutBase<XeGPUPropagateLayoutPass> {
845 XeGPUPropagateLayoutPass() =
default;
846 XeGPUPropagateLayoutPass(
const XeGPUPropagateLayoutPass &other) =
default;
847 XeGPUPropagateLayoutPass(xegpu::XeGPUPropagateLayoutOptions
options)
848 : XeGPUPropagateLayoutBase(
options) {}
849 void runOnOperation()
override;
854 void XeGPUPropagateLayoutPass::runOnOperation() {
855 auto &analysis = getAnalysis<RunLayoutInfoPropagation>();
858 auto &os = llvm::outs();
859 analysis.printAnalysisResult(os);
863 auto getXeGPULayoutForValue = [&](
Value val) -> xegpu::LayoutAttr {
864 LayoutInfo layout = analysis.getLayoutInfo(val);
865 if (!layout.isAssigned())
868 val.
getContext(), llvm::to_vector_of<int>(layout.getLayoutAsArrayRef()),
869 llvm::to_vector_of<int>(layout.getDataAsArrayRef()));
876 LogicalResult r = success();
877 TypeSwitch<Operation *>(&op)
878 .Case<mlir::RegionBranchTerminatorOpInterface>(
879 [&](mlir::RegionBranchTerminatorOpInterface branchTermOp) {
880 r = updateControlFlowOps(builder, branchTermOp,
881 getXeGPULayoutForValue);
883 .Case<mlir::FunctionOpInterface>(
884 [&](mlir::FunctionOpInterface funcOp) {
886 getXeGPULayoutForValue);
889 r =
updateOp(builder, op, getXeGPULayoutForValue);
892 op.
emitError(
"Failed to update operation with the layout.");
898 if (walkResult.wasInterrupted()) {
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
static LogicalResult updateControlFlowOps(mlir::OpBuilder &builder, mlir::RegionBranchTerminatorOpInterface terminator, GetLayoutFnTy getLayoutOfValue)
Region ops like scf.for need special handling because they have blocks inside.
static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op, GetLayoutFnTy getLayoutOfValue)
Update an operation with the layout of its results.
static LogicalResult updateFunctionOpInterface(mlir::OpBuilder &builder, mlir::FunctionOpInterface funcOp, GetLayoutFnTy getLayoutOfValue)
Update the function arguments and results with the layouts.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
OpListType & getOperations()
The general data-flow analysis solver.
This class helps build Operations.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
void print(raw_ostream &os, const OpPrintingFlags &flags=std::nullopt)
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
MutableArrayRef< OpOperand > getOpOperands()
result_range getResults()
This class represents a successor of a region.
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
static WalkResult interrupt()
This class represents a lattice holding a specific value of type ValueT.
A sparse (backward) data-flow analysis for propagating SSA value lattices backwards across the IR by ...
SparseBackwardDataFlowAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
void loadBaselineAnalyses(DataFlowSolver &solver)
Populates a DataFlowSolver with analyses that are required to ensure user-defined analyses are run pr...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
constexpr unsigned packedSizeInBitsForDpasB
constexpr unsigned subgroupSize
constexpr unsigned packedSizeInBitsForGatherScatter
constexpr unsigned packedSizeInBitsForDefault
If DPAS A or B operands have low precision element types they must be packed according to the followi...
void setLayoutAttr(const T &operandOrResult, const LayoutAttr layout)
Sets the LayoutAttr for a given OpOperand or OpResult by attaching it to the owner's dictionary attri...
Include the generated interface declarations.
bool operator==(StringAttr lhs, std::nullptr_t)
Define comparisons for StringAttr against nullptr and itself to avoid the StringRef overloads from be...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...