20 #include "llvm/ADT/TypeSwitch.h"
21 #include "llvm/Support/raw_ostream.h"
25 #define GEN_PASS_DEF_XEGPUSUBGROUPDISTRIBUTE
26 #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
54 Layout(std::initializer_list<int64_t> list) : layout(list) {}
55 void print(llvm::raw_ostream &os)
const;
56 size_t size()
const {
return layout.size(); }
57 int64_t operator[](
size_t idx)
const;
62 llvm::interleaveComma(layout, os);
66 int64_t Layout::operator[](
size_t idx)
const {
67 assert(idx < layout.size() &&
"Index out of bounds.");
74 using WiLayout = Layout;
75 using WiData = Layout;
102 SGMap(
const WiLayout &layout,
const WiData &data)
103 : wiLayout(layout), wiData(data) {}
108 return this->isAssigned() == other.isAssigned();
111 static SGMap meet(
const SGMap &lhs,
const SGMap &rhs);
113 static SGMap join(
const SGMap &lhs,
const SGMap &rhs);
115 void print(raw_ostream &os)
const;
117 bool isAssigned()
const {
return wiLayout.size() > 0 && wiData.size() > 0; }
121 const WiLayout &getLayout()
const {
return wiLayout; }
122 const WiData &getData()
const {
return wiData; }
132 os <<
"Not assigned.";
135 SGMap SGMap::meet(
const SGMap &lhs,
const SGMap &rhs) {
136 if (!lhs.isAssigned())
142 SGMap SGMap::join(
const SGMap &lhs,
const SGMap &rhs) {
143 llvm_unreachable(
"Join should not be triggered by SGMapPropagation.");
152 for (
auto idx : permutation) {
153 newLayout.layout.push_back(wiLayout.layout[idx]);
154 newData.layout.push_back(wiData.layout[idx]);
156 return SGMap(newLayout, newData);
164 struct SGMapLattice :
public Lattice<SGMap> {
166 using Lattice::Lattice;
177 static SGMap getDefaultSgMap(
unsigned rank) {
178 assert((rank == 1 || rank == 2) &&
"Expected 1D or 2D vector.");
181 return SGMap(WiLayout({1,
subgroupSize}), WiData({1, 1}));
185 static SGMap getDefaultSgMap(VectorType vectorTy) {
187 assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
188 "Expected 1D or 2D vector.");
190 assert(vectorTy.getElementType().isIntOrFloat() &&
191 "Expected int or float element type.");
193 if (vectorTy.getRank() == 1)
194 return getDefaultSgMap(1);
196 int packingFactor = 1;
197 auto bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
200 return SGMap(WiLayout({1,
subgroupSize}), WiData({1, packingFactor}));
209 static SGMap getSGMapForDPASOperand(VectorType vectorTy,
unsigned operandNum) {
210 auto elementTy = vectorTy.getElementType();
211 assert(elementTy.isIntOrFloat() &&
212 "Expected int or float type in DPAS operands");
216 if (operandNum == 1 &&
220 return SGMap(layout, data);
223 return getDefaultSgMap(vectorTy);
243 void visitStoreScatterOp(xegpu::StoreScatterOp storeScatter,
250 void visitLoadGatherOp(xegpu::LoadGatherOp load,
254 void visitTransposeOp(vector::TransposeOp
transpose,
258 void visitVectorBitcastOp(vector::BitCastOp bitcast,
262 void visitCreateDescOp(xegpu::CreateDescOp createDesc,
266 void visitUpdateNdOffsetOp(xegpu::UpdateNdOffsetOp updateNdOffset,
270 void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction,
282 void visitBranchOperand(
OpOperand &operand)
override {};
284 void visitCallOperand(
OpOperand &operand)
override {};
286 void visitExternalCall(CallOpInterface call,
290 void setToExitState(SGMapLattice *lattice)
override {
291 (void)lattice->meet(SGMap());
297 SGMapPropagation::visitOperation(
Operation *op,
301 .Case<xegpu::DpasOp>(
302 [&](
auto dpasOp) { visitDpasOp(dpasOp, operands, results); })
303 .Case<xegpu::StoreNdOp>(
304 [&](
auto storeNdOp) { visitStoreNdOp(storeNdOp, operands, results); })
305 .Case<xegpu::StoreScatterOp>([&](
auto storeScatterOp) {
306 visitStoreScatterOp(storeScatterOp, operands, results);
308 .Case<xegpu::LoadNdOp>(
309 [&](
auto loadNdOp) { visitLoadNdOp(loadNdOp, operands, results); })
310 .Case<xegpu::LoadGatherOp>([&](
auto loadGatherOp) {
311 visitLoadGatherOp(loadGatherOp, operands, results);
313 .Case<xegpu::CreateDescOp>([&](
auto createDescOp) {
314 visitCreateDescOp(createDescOp, operands, results);
316 .Case<xegpu::UpdateNdOffsetOp>([&](
auto updateNdOffsetOp) {
317 visitUpdateNdOffsetOp(updateNdOffsetOp, operands, results);
321 .Case<xegpu::CreateNdDescOp>([&](
auto createNdDescOp) {})
322 .Case<vector::TransposeOp>([&](
auto transposeOp) {
323 visitTransposeOp(transposeOp, operands, results);
325 .Case<vector::BitCastOp>([&](
auto bitcastOp) {
326 visitVectorBitcastOp(bitcastOp, operands, results);
328 .Case<vector::MultiDimReductionOp>([&](
auto reductionOp) {
329 visitVectorMultiReductionOp(reductionOp, operands, results);
333 for (
const SGMapLattice *r : results) {
334 for (SGMapLattice *operand : operands) {
336 if (r->getValue().isAssigned())
342 for (
const SGMapLattice *r : results) {
343 addDependency(
const_cast<SGMapLattice *
>(r), getProgramPointAfter(op));
348 void SGMapPropagation::visitVectorMultiReductionOp(
352 auto resultLayout = results[0]->getValue();
353 if (!resultLayout.isAssigned())
356 assert(resultLayout.getLayout().size() == 1 &&
357 "Expected 1D layout for reduction result.");
360 auto operandLayout = getDefaultSgMap(2);
361 propagateIfChanged(operands[0], operands[0]->meet(operandLayout));
363 propagateIfChanged(operands[1], operands[1]->meet(resultLayout));
368 void SGMapPropagation::visitUpdateNdOffsetOp(
372 auto resultLayout = results[0]->getValue();
373 if (!resultLayout.isAssigned())
376 propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
380 void SGMapPropagation::visitDpasOp(xegpu::DpasOp dpas,
383 auto aTy = dpas.getLhsType();
384 auto bTy = dpas.getRhsType();
385 propagateIfChanged(operands[0],
386 operands[0]->meet(getSGMapForDPASOperand(aTy, 0)));
387 propagateIfChanged(operands[1],
388 operands[1]->meet(getSGMapForDPASOperand(bTy, 1)));
389 if (operands.size() > 2) {
390 auto cTy = dpas.getAccType();
391 propagateIfChanged(operands[2],
392 operands[2]->meet(getSGMapForDPASOperand(cTy, 2)));
397 void SGMapPropagation::visitStoreNdOp(xegpu::StoreNdOp store,
400 auto storeLayout = getDefaultSgMap(store.getValueType());
402 for (SGMapLattice *operand : operands) {
403 propagateIfChanged(operand, operand->meet(storeLayout));
409 void SGMapPropagation::visitLoadNdOp(xegpu::LoadNdOp load,
412 auto valueLayout = results[0]->getValue();
414 if (!valueLayout.isAssigned())
416 SGMap tensorDescLayout = valueLayout;
419 if (
auto transpose = load.getTranspose()) {
420 load.emitWarning(
"Transpose effect is not expected for LoadNdOp at "
421 "SGMapPropagation stage.");
422 tensorDescLayout = valueLayout.getTransposedLayout(
transpose.value());
425 propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
430 void SGMapPropagation::visitTransposeOp(
434 auto resultLayout = results[0]->getValue();
435 if (!resultLayout.isAssigned())
437 auto newLayout = resultLayout.getTransposedLayout(
transpose.getPermutation());
439 propagateIfChanged(operands[0], operands[0]->meet(newLayout));
444 void SGMapPropagation::visitVectorBitcastOp(
448 auto resultLayout = results[0]->getValue();
449 if (!resultLayout.isAssigned())
451 auto inElemTyBitWidth =
452 bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
453 auto outElemTyBitWidth =
454 bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
457 const WiLayout &newWiLayout = resultLayout.getLayout();
458 const WiData &currData = resultLayout.getData();
461 if (inElemTyBitWidth < outElemTyBitWidth) {
462 auto ratio = outElemTyBitWidth / inElemTyBitWidth;
463 newWiData = resultLayout.getData()[0] == 1
464 ? WiData({1, currData[1] * ratio})
465 : WiData({currData[0] * ratio, 1});
468 auto ratio = inElemTyBitWidth / outElemTyBitWidth;
469 newWiData = resultLayout.getData()[0] == 1
470 ? WiData({1, currData[1] / ratio})
471 : WiData({currData[0] / ratio, 1});
474 propagateIfChanged(operands[0],
475 operands[0]->meet(SGMap(newWiLayout, newWiData)));
480 void SGMapPropagation::visitLoadGatherOp(
483 auto valueLayout = results[0]->getValue();
485 if (!valueLayout.isAssigned())
488 SGMap tensorDescLayout = valueLayout;
489 if (load.getTranspose()) {
493 load.emitWarning(
"Transpose effect is not expected for LoadGatherOp at "
494 "SGMapPropagation stage.");
495 tensorDescLayout = valueLayout.getTransposedLayout({1, 0});
498 auto maskLayout = getDefaultSgMap(1);
500 propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
502 propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
507 void SGMapPropagation::visitCreateDescOp(
510 auto descLayout = results[0]->getValue();
512 if (!descLayout.isAssigned())
515 SGMap layout = getDefaultSgMap(1);
516 propagateIfChanged(operands[1], operands[1]->meet(layout));
521 void SGMapPropagation::visitStoreScatterOp(
527 auto tdescShape = storeScatter.getTensorDescType().getShape();
528 if (tdescShape.size() > 1 && tdescShape[0] %
subgroupSize != 0) {
529 storeScatter.emitError(
"Height dimension of the tensor descriptor should "
530 "be evenly divisible by the subgroup size.");
533 auto valueLayout = getDefaultSgMap(storeScatter.getValueType());
534 SGMap storeScatterLayout = valueLayout;
535 if (storeScatter.getTranspose()) {
539 storeScatter.emitWarning(
"Transpose effect is not expected for "
540 "StoreScatterOp at SGMapPropagation stage.");
541 storeScatterLayout = valueLayout.getTransposedLayout({1, 0});
544 propagateIfChanged(operands[0], operands[0]->meet(valueLayout));
546 propagateIfChanged(operands[1], operands[1]->meet(storeScatterLayout));
548 auto maskLayout = getDefaultSgMap(1);
549 propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
559 class RunSGMapPropagation {
561 RunSGMapPropagation(
Operation *op) : target(op) {
565 solver.load<SGMapPropagation>(symbolTable);
566 (void)solver.initializeAndRun(op);
569 SGMap getSGMap(
Value val);
571 void printAnalysisResult(llvm::raw_ostream &os);
579 SGMap RunSGMapPropagation::getSGMap(
Value val) {
580 auto *state = solver.lookupState<SGMapLattice>(val);
583 return state->getValue();
586 void RunSGMapPropagation::printAnalysisResult(llvm::raw_ostream &os) {
587 auto printFunctionResult = [&](FunctionOpInterface funcOp) {
588 os <<
"function: " << funcOp.getName() <<
":\n";
590 for (
auto arg : funcOp.getArguments()) {
591 auto layout = getSGMap(arg);
592 os <<
"argument: " << arg <<
"\n";
604 if (isa<BranchOpInterface>(op) || isa<RegionBranchOpInterface>(op))
611 auto layout = getSGMap(r);
612 os <<
"sg_map for result #" << i <<
": ";
620 if (
auto modOp = dyn_cast<ModuleOp>(target)) {
621 for (
auto funcOp : modOp.getOps<FunctionOpInterface>()) {
622 funcOps.push_back(funcOp);
625 for (
auto gpuModOp : modOp.getOps<gpu::GPUModuleOp>()) {
626 for (
auto gpuFuncOp : gpuModOp.getOps<FunctionOpInterface>()) {
627 funcOps.push_back(gpuFuncOp);
632 for (
auto funcOp : funcOps) {
633 printFunctionResult(funcOp);
638 struct XeGPUSubgroupDistributePass final
639 :
public xegpu::impl::XeGPUSubgroupDistributeBase<
640 XeGPUSubgroupDistributePass> {
641 XeGPUSubgroupDistributePass() =
default;
642 XeGPUSubgroupDistributePass(
const XeGPUSubgroupDistributePass &other) =
644 XeGPUSubgroupDistributePass(xegpu::XeGPUSubgroupDistributeOptions
options)
645 : XeGPUSubgroupDistributeBase(
options) {}
646 void runOnOperation()
override;
650 void XeGPUSubgroupDistributePass::runOnOperation() {
652 RunSGMapPropagation solver(op);
656 auto &os = llvm::outs();
657 solver.printAnalysisResult(os);
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
constexpr unsigned subgroupSize
HW dependent constants.
constexpr unsigned packedSizeInBitsForDpasB
constexpr unsigned packedSizeInBitsForDefault
If DPAS A or B operands have low precision element types they must be packed according to the followi...
The general data-flow analysis solver.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
void print(raw_ostream &os, const OpPrintingFlags &flags=std::nullopt)
OperationName getName()
The name of an operation is the key identifier for it.
result_range getResults()
This class represents a collection of SymbolTables.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Dead code analysis analyzes control-flow, as understood by RegionBranchOpInterface and BranchOpInterf...
This class represents a lattice holding a specific value of type ValueT.
A sparse (backward) data-flow analysis for propagating SSA value lattices backwards across the IR by ...
SparseBackwardDataFlowAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
This analysis implements sparse constant propagation, which attempts to determine constant-valued res...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Include the generated interface declarations.
bool operator==(StringAttr lhs, std::nullptr_t)
Define comparisons for StringAttr against nullptr and itself to avoid the StringRef overloads from be...