30 #include "llvm/ADT/ArrayRef.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/SmallSet.h"
33 #include "llvm/ADT/SmallVector.h"
34 #include "llvm/ADT/TypeSwitch.h"
35 #include "llvm/Support/Casting.h"
36 #include "llvm/Support/Debug.h"
37 #include "llvm/Support/LogicalResult.h"
38 #include "llvm/Support/raw_ostream.h"
42 #define GEN_PASS_DEF_XEGPUPROPAGATELAYOUT
43 #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
47 #define DEBUG_TYPE "xegpu-propagate-layout"
48 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
81 xegpu::DistributeLayoutAttr storage =
nullptr;
84 LayoutInfo() =
default;
85 LayoutInfo(
const xegpu::DistributeLayoutAttr &layout) : storage(layout) {}
89 bool operator==(
const LayoutInfo &other)
const {
90 return this->isAssigned() == other.isAssigned();
93 static LayoutInfo meet(
const LayoutInfo &lhs,
const LayoutInfo &rhs);
95 static LayoutInfo join(
const LayoutInfo &lhs,
const LayoutInfo &rhs);
97 void print(raw_ostream &os)
const;
99 bool isAssigned()
const {
return storage !=
nullptr; }
107 bool isSliceLayout()
const {
110 return isa<xegpu::SliceAttr>(storage);
113 int64_t getRank()
const {
116 return storage.getRank();
125 assert(storage.getEffectiveLaneLayoutAsInt().size() &&
126 "Expected lane layout to be assigned");
127 return llvm::map_to_vector(storage.getEffectiveLaneLayoutAsInt(),
128 [](int64_t val) { return static_cast<int>(val); });
134 assert(storage.getEffectiveLaneDataAsInt().size() &&
135 "Expected lane data to be assigned");
136 return llvm::map_to_vector(storage.getEffectiveLaneDataAsInt(),
137 [](int64_t val) { return static_cast<int>(val); });
144 os <<
"Not assigned.";
148 LayoutInfo LayoutInfo::meet(
const LayoutInfo &lhs,
const LayoutInfo &rhs) {
149 if (!lhs.isAssigned())
155 LayoutInfo LayoutInfo::join(
const LayoutInfo &lhs,
const LayoutInfo &rhs) {
156 llvm_unreachable(
"Join should not be triggered by layout propagation.");
164 llvm::SmallSet<int64_t, 4> seen(permutation.begin(), permutation.end());
165 bool hasDuplicates = seen.size() != permutation.size();
166 bool withinRange = llvm::all_of(permutation, [&](int64_t idx) {
167 return idx >= 0 && idx < static_cast<int64_t>(permutation.size());
170 if (!withinRange || hasDuplicates) {
171 assert(
false &&
"Invalid permutation for transpose.");
177 for (int64_t idx : permutation) {
178 laneLayout.push_back(
static_cast<int32_t
>(getLaneLayout()[idx]));
179 laneData.push_back(
static_cast<int32_t
>(getLaneData()[idx]));
190 struct LayoutInfoLattice :
public Lattice<LayoutInfo> {
192 using Lattice::Lattice;
204 assert((rank == 1 || rank == 2) &&
"Expected 1D or 2D vector.");
214 static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy,
215 bool isScattered =
false) {
217 assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
218 "Expected 1D or 2D vector.");
220 assert(vectorTy.getElementType().isIntOrFloat() &&
221 "Expected int or float element type.");
223 if (vectorTy.getRank() == 1)
224 return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1);
226 int packingFactor = 1;
227 unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
234 vectorTy.getContext(), {xegpu::targetinfo::subgroupSize, 1},
235 {1, packingFactor}));
240 {1, xegpu::targetinfo::subgroupSize},
241 {1, packingFactor}));
245 static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy,
246 bool isScattered =
false) {
248 assert((tdescTy.getRank() == 1 || tdescTy.getRank() == 2) &&
249 "Expected 1D or 2D TensorDesc.");
251 assert(tdescTy.getElementType().isIntOrFloat() &&
252 "Expected int or float element type.");
254 if (tdescTy.getRank() == 1)
255 return getDefaultSIMTLayoutInfo(tdescTy.getContext(), 1);
257 unsigned bitwidth = tdescTy.getElementType().getIntOrFloatBitWidth();
265 tdescTy.getContext(), {xegpu::targetinfo::subgroupSize, 1},
266 {1, packingFactor}));
274 {1, xegpu::targetinfo::subgroupSize},
275 {1, packingFactor}));
284 static LayoutInfo getSIMTLayoutInfoForDPASOperand(VectorType vectorTy,
285 unsigned operandNum) {
286 Type elementTy = vectorTy.getElementType();
288 "Expected int or float type in DPAS operands");
302 return getDefaultSIMTLayoutInfo(vectorTy);
314 class LayoutInfoPropagation
320 void visitStoreNdOp(xegpu::StoreNdOp store,
324 void visitStoreScatterOp(xegpu::StoreScatterOp storeScatter,
328 void visitLoadNdOp(xegpu::LoadNdOp load,
332 void visitLoadGatherOp(xegpu::LoadGatherOp load,
336 void visitTransposeOp(vector::TransposeOp transpose,
340 void visitVectorBitcastOp(vector::BitCastOp bitcast,
344 void visitCreateDescOp(xegpu::CreateDescOp createDesc,
348 void visitUpdateNdOffsetOp(xegpu::UpdateNdOffsetOp updateNdOffset,
352 void visitPrefetchNdOp(xegpu::PrefetchNdOp prefetch,
356 void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction,
360 void visitVectorBroadCastOp(vector::BroadcastOp
broadcast,
363 void visitShapeCastOp(vector::ShapeCastOp shapeCast,
377 void visitBranchOperand(
OpOperand &operand)
override {};
379 void visitCallOperand(
OpOperand &operand)
override {};
381 void visitExternalCall(CallOpInterface call,
386 void setToExitState(LayoutInfoLattice *lattice)
override {
387 (void)lattice->meet(LayoutInfo());
392 LogicalResult LayoutInfoPropagation::visitOperation(
396 .Case<xegpu::DpasOp>(
397 [&](
auto dpasOp) { visitDpasOp(dpasOp, operands, results); })
398 .Case<xegpu::StoreNdOp>(
399 [&](
auto storeNdOp) { visitStoreNdOp(storeNdOp, operands, results); })
400 .Case<xegpu::StoreScatterOp>([&](
auto storeScatterOp) {
401 visitStoreScatterOp(storeScatterOp, operands, results);
403 .Case<xegpu::LoadNdOp>(
404 [&](
auto loadNdOp) { visitLoadNdOp(loadNdOp, operands, results); })
405 .Case<xegpu::LoadGatherOp>([&](
auto loadGatherOp) {
406 visitLoadGatherOp(loadGatherOp, operands, results);
408 .Case<xegpu::CreateDescOp>([&](
auto createDescOp) {
409 visitCreateDescOp(createDescOp, operands, results);
411 .Case<xegpu::UpdateNdOffsetOp>([&](
auto updateNdOffsetOp) {
412 visitUpdateNdOffsetOp(updateNdOffsetOp, operands, results);
414 .Case<xegpu::PrefetchNdOp>([&](
auto prefetchNdOp) {
415 visitPrefetchNdOp(prefetchNdOp, operands, results);
417 .Case<vector::TransposeOp>([&](
auto transposeOp) {
418 visitTransposeOp(transposeOp, operands, results);
420 .Case<vector::BitCastOp>([&](
auto bitcastOp) {
421 visitVectorBitcastOp(bitcastOp, operands, results);
423 .Case<vector::MultiDimReductionOp>([&](
auto reductionOp) {
424 visitVectorMultiReductionOp(reductionOp, operands, results);
426 .Case<vector::BroadcastOp>([&](
auto broadcastOp) {
427 visitVectorBroadCastOp(broadcastOp, operands, results);
429 .Case<vector::ShapeCastOp>([&](
auto shapeCastOp) {
430 visitShapeCastOp(shapeCastOp, operands, results);
434 for (
const LayoutInfoLattice *resultInfo : results) {
435 if (!resultInfo->getValue().isAssigned())
437 for (
auto [operandInfo, operand] :
441 if (!isa<xegpu::TensorDescType, VectorType>(
442 operand.get().getType()))
445 meet(operandInfo, *resultInfo);
453 void LayoutInfoPropagation::visitPrefetchNdOp(
458 auto tdescTy = prefetch.getTensorDescType();
459 auto prefetchLayout = getDefaultSIMTLayoutInfo(tdescTy);
461 propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout));
464 void LayoutInfoPropagation::visitVectorMultiReductionOp(
465 vector::MultiDimReductionOp reduction,
469 LayoutInfo resultLayout = results[0]->getValue();
470 if (!resultLayout.isAssigned())
473 VectorType resultTy = llvm::dyn_cast<VectorType>(reduction.getDestType());
474 if (!resultTy || resultTy.getRank() != 1) {
475 reduction.emitWarning(
"Expecting output type to be 1D vector.");
480 LayoutInfo operandLayout =
481 getDefaultSIMTLayoutInfo(reduction->getContext(), 2);
482 propagateIfChanged(operands[0], operands[0]->meet(operandLayout));
484 propagateIfChanged(operands[1], operands[1]->meet(resultLayout));
487 void LayoutInfoPropagation::visitVectorBroadCastOp(
491 LayoutInfo resultLayout = results[0]->getValue();
492 if (!resultLayout.isAssigned())
495 VectorType resultTy =
broadcast.getResultVectorType();
496 VectorType sourceTy = dyn_cast<VectorType>(
broadcast.getSourceType());
498 broadcast.emitWarning(
"Expecting source type to be a vector type.");
503 if (sourceTy.getRank() != resultTy.getRank()) {
504 broadcast.emitWarning(
"Expecting source and result to have same rank.");
508 if (broadcastUnitDims.size() != 1) {
509 broadcast.emitWarning(
"Expecting source type to be nD vector only with "
510 "one broadcasted dimension.");
514 propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
517 void LayoutInfoPropagation::visitShapeCastOp(
521 LayoutInfo resultLayout = results[0]->getValue();
522 if (!resultLayout.isAssigned())
524 VectorType sourceTy = shapeCast.getSourceVectorType();
525 VectorType resultTy = shapeCast.getResultVectorType();
529 if (sourceTy.getRank() != 1 || resultTy.getRank() != 2) {
530 shapeCast.emitWarning(
"Expecting shape cast to be 1D -> 2D.");
533 int64_t slicedDim = resultTy.getShape()[0] == 1 ? 0 : 1;
535 shapeCast->getContext(), cast<xegpu::LayoutAttr>(resultLayout.get()),
537 propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(sliceLayout)));
542 void LayoutInfoPropagation::visitUpdateNdOffsetOp(
543 xegpu::UpdateNdOffsetOp updateNdOffset,
547 LayoutInfo resultLayout = results[0]->getValue();
548 if (!resultLayout.isAssigned())
551 propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
555 void LayoutInfoPropagation::visitDpasOp(
558 VectorType aTy = dpas.getLhsType();
559 VectorType bTy = dpas.getRhsType();
561 operands[0], operands[0]->meet(getSIMTLayoutInfoForDPASOperand(aTy, 0)));
563 operands[1], operands[1]->meet(getSIMTLayoutInfoForDPASOperand(bTy, 1)));
564 if (operands.size() > 2) {
565 VectorType cTy = dpas.getAccType();
568 operands[2]->meet(getSIMTLayoutInfoForDPASOperand(cTy, 2)));
573 void LayoutInfoPropagation::visitStoreNdOp(
576 LayoutInfo storeLayout = getDefaultSIMTLayoutInfo(store.getValueType());
578 for (LayoutInfoLattice *operand : operands)
579 propagateIfChanged(operand, operand->meet(storeLayout));
584 void LayoutInfoPropagation::visitLoadNdOp(
587 LayoutInfo valueLayout = results[0]->getValue();
589 if (!valueLayout.isAssigned())
591 LayoutInfo tensorDescLayout = valueLayout;
595 if (
auto transpose = load.getTranspose()) {
596 load.emitWarning(
"Transpose effect is not expected for LoadNdOp at "
597 "LayoutInfoPropagation stage.");
598 tensorDescLayout = valueLayout.transpose(transpose.value());
601 propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
606 void LayoutInfoPropagation::visitTransposeOp(
610 LayoutInfo resultLayout = results[0]->getValue();
611 if (!resultLayout.isAssigned())
613 LayoutInfo newLayout = resultLayout.transpose(transpose.getPermutation());
615 propagateIfChanged(operands[0], operands[0]->meet(newLayout));
620 void LayoutInfoPropagation::visitVectorBitcastOp(
624 LayoutInfo resultLayout = results[0]->getValue();
625 if (!resultLayout.isAssigned())
627 int inElemTyBitWidth =
628 bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
629 int outElemTyBitWidth =
630 bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
632 if (inElemTyBitWidth == outElemTyBitWidth) {
633 propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
637 auto resultLaneLayout = resultLayout.getLaneLayout();
638 auto resultLaneData = resultLayout.getLaneData();
640 bitcast.getResultVectorType(),
644 "Result vector type can not be evenly distributed across lanes.");
647 int64_t rank = bitcast.getSourceVectorType().getRank();
650 bool isNarrowing = inElemTyBitWidth > outElemTyBitWidth;
651 int bitCastRatio = isNarrowing ? inElemTyBitWidth / outElemTyBitWidth
652 : outElemTyBitWidth / inElemTyBitWidth;
654 resultLayout.getLaneLayout();
660 int outInnerBitsPerLane = outData[rank - 1] * outElemTyBitWidth;
661 if (outInnerBitsPerLane < inElemTyBitWidth) {
663 "Narrowing bitcast with cross lane communication is not supported.");
669 if (llvm::any_of(sourceLaneData, [](int64_t d) {
return d != 1; })) {
670 bitcast.emitWarning(
"Each lane must not own multiple elements in any "
671 "dimension other than "
672 "the innermost dimension.");
676 int64_t innerMostLaneData = isNarrowing ? outData[rank - 1] / bitCastRatio
677 : outData[rank - 1] * bitCastRatio;
678 sourceLaneData.push_back(innerMostLaneData);
683 bitcast->getContext(), sourceLaneLayout, sourceLaneData))));
688 void LayoutInfoPropagation::visitLoadGatherOp(
692 auto payloadTy = dyn_cast<VectorType>(load.getValueType());
694 load.emitWarning(
"Not propagating, non-vector payload supplied.");
697 LayoutInfo layout = getDefaultSIMTLayoutInfo(payloadTy,
true);
700 LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(load->getContext(), 1);
703 if (isa<xegpu::TensorDescType>(load.getSourceType()))
704 propagateIfChanged(operands[0], operands[0]->meet(layout));
706 propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
707 if (load.getOffsets())
708 propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
713 void LayoutInfoPropagation::visitCreateDescOp(
716 LayoutInfo descLayout = results[0]->getValue();
718 if (!descLayout.isAssigned())
721 LayoutInfo layout = getDefaultSIMTLayoutInfo(createDesc->getContext(), 1);
722 propagateIfChanged(operands[1], operands[1]->meet(layout));
727 void LayoutInfoPropagation::visitStoreScatterOp(
733 auto payloadTy = dyn_cast<VectorType>(storeScatter.getValueType());
735 storeScatter.emitWarning(
"Not propagating, non-vector payload supplied.");
738 auto payloadShape = payloadTy.getShape();
739 if (payloadShape.size() > 1)
742 "Expected the first dimension of 2D tensor descriptor to be equal to "
745 LayoutInfo payloadLayout =
746 getDefaultSIMTLayoutInfo(payloadTy,
true);
748 LayoutInfo maskLayout =
749 getDefaultSIMTLayoutInfo(storeScatter->getContext(), 1);
751 propagateIfChanged(operands[0], operands[0]->meet(payloadLayout));
753 if (isa<xegpu::TensorDescType>(storeScatter.getDestType()))
754 propagateIfChanged(operands[1], operands[1]->meet(payloadLayout));
756 propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
757 if (storeScatter.getOffsets())
758 propagateIfChanged(operands[3], operands[3]->meet(maskLayout));
767 class RunLayoutInfoPropagation {
771 RunLayoutInfoPropagation(
Operation *op) : target(op) {
774 solver.load<LayoutInfoPropagation>(symbolTable);
775 (void)solver.initializeAndRun(op);
778 LayoutInfo getLayoutInfo(
Value val);
780 void printAnalysisResult(llvm::raw_ostream &os);
788 LayoutInfo RunLayoutInfoPropagation::getLayoutInfo(
Value val) {
789 auto *state = solver.lookupState<LayoutInfoLattice>(val);
792 return state->getValue();
796 void RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) {
797 auto printFunctionResult = [&](FunctionOpInterface funcOp) {
798 os <<
"function: " << funcOp.getName() <<
":\n";
801 LayoutInfo layout = getLayoutInfo(arg);
802 os <<
"argument: " << arg <<
"\n";
814 if (isa<BranchOpInterface>(op) || isa<RegionBranchOpInterface>(op))
821 LayoutInfo layout = getLayoutInfo(r);
822 os <<
"layout for result #" << i <<
": ";
830 if (
auto modOp = dyn_cast<ModuleOp>(target)) {
831 for (
auto funcOp : modOp.getOps<FunctionOpInterface>())
832 funcOps.push_back(funcOp);
835 for (
auto gpuModOp : modOp.getOps<gpu::GPUModuleOp>()) {
836 for (
auto gpuFuncOp : gpuModOp.getOps<FunctionOpInterface>())
837 funcOps.push_back(gpuFuncOp);
841 for (FunctionOpInterface funcOp : funcOps)
842 printFunctionResult(funcOp);
855 if (mlir::isa<mlir::RegionBranchOpInterface>(op))
860 Type resultType = result.getType();
862 if (!isa<VectorType, xegpu::TensorDescType>(resultType))
865 xegpu::DistributeLayoutAttr layout = getLayoutOfValue(result);
866 if (!layout && result.getNumUses() > 0) {
867 op->
emitWarning(
"op has users but no layout assigned for its result");
872 if (
auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType)) {
874 tensorDescTy.getContext(), tensorDescTy.getShape(),
875 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
876 result.setType(typeWithLayout);
910 mlir::RegionBranchTerminatorOpInterface terminator,
913 if (!mlir::isa<mlir::RegionBranchOpInterface>(terminator->getParentOp()))
919 terminator.getSuccessorRegions(operands, successors);
923 terminator.getSuccessorOperands(successor);
925 for (
auto [successorOperand, successorInput] :
926 llvm::zip(successorOperands, successorInputs)) {
927 Type inputType = successorInput.getType();
929 if (!isa<xegpu::TensorDescType, VectorType>(inputType))
931 xegpu::DistributeLayoutAttr successorInputLayout =
932 getLayoutOfValue(successorInput);
933 xegpu::DistributeLayoutAttr successorOperandLayout =
934 getLayoutOfValue(successorOperand);
937 if (!successorOperandLayout) {
938 LLVM_DEBUG(
DBGS() <<
"No layout assigned for forwarded operand in "
939 "branch terminator: "
940 << successorOperand <<
"\n");
944 if (successorInputLayout &&
945 successorInputLayout != successorOperandLayout) {
946 LLVM_DEBUG(
DBGS() <<
"Conflicting layouts for region argument and "
947 "operand forwarded as the argument: "
948 << successorInputLayout <<
" vs "
949 << successorOperandLayout <<
"\n");
953 if (
auto tdescTy = dyn_cast<xegpu::TensorDescType>(inputType)) {
955 tdescTy.getContext(), tdescTy.getShape(), tdescTy.getElementType(),
956 tdescTy.getEncoding(), successorOperandLayout);
957 successorInput.setType(newTdescTy);
962 if (
auto result = dyn_cast<OpResult>(successorInput))
971 mlir::FunctionOpInterface funcOp,
976 Type argType = arg.getType();
977 newArgTypes.push_back(argType);
978 if (!isa<VectorType, xegpu::TensorDescType>(argType))
980 xegpu::DistributeLayoutAttr layout = getLayoutOfValue(arg);
982 LLVM_DEBUG(
DBGS() <<
"Expecting layout for function argument: " << arg
983 <<
" but got none.\n");
986 if (
auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(argType)) {
988 tensorDescTy.getContext(), tensorDescTy.getShape(),
989 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
990 arg.setType(newTdescTy);
991 newArgTypes.back() = newTdescTy;
997 funcOp.getResultTypes()));
1002 struct XeGPUPropagateLayoutPass final
1003 :
public xegpu::impl::XeGPUPropagateLayoutBase<XeGPUPropagateLayoutPass> {
1004 XeGPUPropagateLayoutPass() =
default;
1005 XeGPUPropagateLayoutPass(
const XeGPUPropagateLayoutPass &other) =
default;
1006 XeGPUPropagateLayoutPass(xegpu::XeGPUPropagateLayoutOptions
options)
1007 : XeGPUPropagateLayoutBase(
options) {}
1008 void runOnOperation()
override;
1013 void XeGPUPropagateLayoutPass::runOnOperation() {
1014 auto &
analysis = getAnalysis<RunLayoutInfoPropagation>();
1017 auto &os = llvm::outs();
1022 auto getXeGPULayoutForValue = [&](
Value val) -> xegpu::DistributeLayoutAttr {
1023 LayoutInfo layout =
analysis.getLayoutInfo(val);
1024 if (!layout.isAssigned())
1026 if (layout.isSliceLayout())
1027 return cast<xegpu::SliceAttr>(layout.get());
1028 return cast<xegpu::LayoutAttr>(layout.get());
1035 LogicalResult r = success();
1036 TypeSwitch<Operation *>(&op)
1037 .Case<mlir::RegionBranchTerminatorOpInterface>(
1038 [&](mlir::RegionBranchTerminatorOpInterface branchTermOp) {
1039 r = updateControlFlowOps(builder, branchTermOp,
1040 getXeGPULayoutForValue);
1042 .Case<mlir::FunctionOpInterface>(
1043 [&](mlir::FunctionOpInterface funcOp) {
1045 getXeGPULayoutForValue);
1048 r =
updateOp(builder, op, getXeGPULayoutForValue);
1051 op.
emitError(
"Failed to update operation with the layout.");
1057 if (walkResult.wasInterrupted()) {
1058 signalPassFailure();
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
static LogicalResult updateControlFlowOps(mlir::OpBuilder &builder, mlir::RegionBranchTerminatorOpInterface terminator, GetLayoutFnTy getLayoutOfValue)
Region ops like scf.for need special handling because they have blocks inside.
static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op, GetLayoutFnTy getLayoutOfValue)
Update an operation with the layout of its results.
static LogicalResult updateFunctionOpInterface(mlir::OpBuilder &builder, mlir::FunctionOpInterface funcOp, GetLayoutFnTy getLayoutOfValue)
Update the function arguments and results with the layouts.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
OpListType & getOperations()
The general data-flow analysis solver.
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
void print(raw_ostream &os, const OpPrintingFlags &flags={})
MutableArrayRef< OpOperand > getOpOperands()
result_range getResults()
This class represents a successor of a region.
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
static WalkResult interrupt()
This class represents a lattice holding a specific value of type ValueT.
A sparse (backward) data-flow analysis for propagating SSA value lattices backwards across the IR by ...
SparseBackwardDataFlowAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
void loadBaselineAnalyses(DataFlowSolver &solver)
Populates a DataFlowSolver with analyses that are required to ensure user-defined analyses are run pr...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
constexpr unsigned packedSizeInBitsForDpasB
constexpr unsigned subgroupSize
constexpr unsigned packedSizeInBitsForGatherScatter
constexpr unsigned packedSizeInBitsForDefault
If DPAS A or B operands have low precision element types they must be packed according to the followi...
void setDistributeLayoutAttr(const T &operandOrResult, const DistributeLayoutAttr layout)
Sets the DistributeLayoutAttr for a given OpOperand or OpResult by attaching it to the owner's dictio...
FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)
If tensor descriptor has a layout attribute it is used in SIMT mode.
Include the generated interface declarations.
bool operator==(StringAttr lhs, std::nullptr_t)
Define comparisons for StringAttr against nullptr and itself to avoid the StringRef overloads from be...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...