19#include "llvm/Support/Debug.h"
21#define DEBUG_TYPE "xegpu"
27 Attribute attr = memrefTy.getMemorySpace();
28 if (
auto intAttr = llvm::dyn_cast<IntegerAttr>(attr))
29 return intAttr.getInt() == 3;
30 if (
auto memrefSpace = llvm::dyn_cast<MemorySpaceAttr>(attr))
31 return memrefSpace.getValue() == MemorySpace::SLM;
32 if (
auto xevmSpace = llvm::dyn_cast<xevm::AddrSpaceAttr>(attr))
33 return xevmSpace.getValue() == xevm::AddrSpace::SHARED;
34 return gpu::GPUDialect::isWorkgroupMemoryAddressSpace(attr);
38static std::string
makeString(T array,
bool breakline =
false) {
41 llvm::raw_string_ostream os(buf);
43 for (
size_t i = 1; i < array.size(); i++) {
44 os << array[i - 1] <<
", ";
48 os << array.back() <<
"]";
54 if (
auto ty = llvm::dyn_cast<ShapedType>(type))
64 auto kind = attr.getValue();
65 return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED ||
66 kind == CachePolicy::STREAMING || kind == CachePolicy::READ_INVALIDATE;
72 auto kind = attr.getValue();
73 return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED ||
74 kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH;
79 TensorDescType tdescTy,
82 if (!tdescTy.isScattered())
83 return emitError() <<
"Expects a scattered TensorDesc.";
85 auto chunkSize = tdescTy.getChunkSizeAsInt();
88 return emitError() <<
"Expecting chunk size == 1 for scalar result";
89 if (dyn_cast<VectorType>(maskTy))
90 return emitError() <<
"Expecting a vector type result.";
98 if (valueTy.getElementType() != tdescTy.getElementType())
100 <<
"Value should have the same element type as TensorDesc.";
104 expectedMaskShape.pop_back();
105 if (expectedMaskShape != maskShape)
107 <<
"Mask should match TensorDesc except the chunk size dim.";
110 if (valueTy.getRank() == 1 && valueTy.getNumElements() == chunkSize) {
111 if (tdescTy.getLayoutAttr())
112 return emitError() <<
"TensorDesc doesn't need LayoutAttr for SIMT code";
116 if (tdescShape != valueShape)
118 <<
" is neither a valid distribution for SIMT nor "
119 "consistent with the tensor descriptor for SIMD "
126 VectorType valueTy,
int64_t chunkSize,
129 auto maskVecTy = dyn_cast<VectorType>(maskTy);
130 auto offsetsVecTy = dyn_cast<VectorType>(offsetsTy);
133 return emitError() <<
"Expecting chunk size == 1 for scalar result";
134 if (maskVecTy || offsetsVecTy)
135 return emitError() <<
"Expecting scalar mask and offsets.";
136 else if (maskVecTy && offsetsVecTy)
137 return emitError() <<
"Expecting a vector type result.";
141 auto valueSize = valueTy.getNumElements();
143 if (!maskVecTy && !offsetsVecTy) {
144 if (valueSize != chunkSize)
145 return emitError() <<
"value elements must match chunk size "
153 return emitError() <<
"Expecting a vector type mask.";
154 int64_t maskSize = maskVecTy.getNumElements();
157 if ((valueTy.getRank() == 1) && (valueSize != chunkSize))
158 return emitError() <<
"value elements must match chunk size "
161 if (valueSize != maskSize)
163 <<
"Mask should match value except the chunk size dim.";
169 expectedMaskShape.pop_back();
170 if (expectedMaskShape != maskShape)
171 return emitError() <<
"Mask should match value except the chunk size dim.";
178 UnitAttr subgroup_block_io, DistributeLayoutAttr layout,
182 if (subgroup_block_io)
183 return emitError() <<
"subgroup_block_io "
184 "are only allowed when result is a VectorType.";
189 if (mdescTy.getRank() != 2)
190 return emitError() <<
"mem_desc must be 2D.";
196 ArrayAttr strideAttr = mdescTy.getStrideAttr();
198 for (
Attribute attr : strideAttr.getValue()) {
199 strides.push_back(cast<IntegerAttr>(attr).getInt());
201 if (subgroup_block_io && layout) {
202 auto laneData = layout.getEffectiveLaneDataAsInt();
203 auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
204 if (!laneData.empty()) {
205 bool isLaneDataContiguous =
206 std::all_of(laneData.begin(), std::prev(laneData.end()),
207 [](
int x) { return x == 1; });
208 if (!isLaneDataContiguous)
209 return emitError() <<
"With subgroup_block_io, accessed data must be "
210 "contiguous and coalesced.";
211 for (
size_t i = 0; i < laneData.size(); ++i) {
212 if (laneLayout[i] != blockShape[i])
213 return emitError() <<
"With subgroup_block_io, the block shape must "
214 "match the lane layout.";
215 if (laneLayout[i] != 1 && strides[i] != 1)
216 return emitError() <<
"With subgroup_block_io, the distributed "
217 "dimensions must be contiguous.";
221 if (dataShape.size() == 2) {
222 if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
223 [](
auto p) {
return std::get<0>(p) > std::get<1>(p); }))
224 return emitError() <<
"data shape must not exceed mem_desc shape.";
228 if (subgroup_block_io && !blockShape.size())
229 return emitError() <<
"mem_desc must have block attribute when "
230 "subgroup_block_io is set.";
233 if (subgroup_block_io && mdescTy.isColMajor())
234 return emitError() <<
"mem_desc should be row major when "
235 "subgroup_block_io is set.";
247 [[maybe_unused]]
auto ty = source.getType();
248 assert(ty.hasStaticShape() &&
"expecting a memref with static shape");
250 build(builder, state, tdesc, source,
ValueRange({}) ,
263 assert((isa<IntegerType, MemRefType>(srcTy)) &&
264 "Source has to be either int or memref.");
278 if (
auto memrefTy = dyn_cast<MemRefType>(srcTy)) {
279 auto memrefShape = memrefTy.getShape();
280 auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
285 if (staticShape == memrefShape && staticStrides == memrefStrides &&
286 dynamicShape.empty() && dynamicStrides.empty()) {
292 build(builder, state, tdesc, source,
ValueRange({}), dynamicShape,
300 [[maybe_unused]]
auto ty = source.getType();
301 assert(ty.hasStaticShape() && offsets.size() == (
size_t)ty.getRank());
307 build(builder, state, tdesc, source, dynamicOffsets ,
319 assert(!
shape.empty() && !offsets.empty() && !strides.empty() &&
320 shape.size() == strides.size() &&
shape.size() == offsets.size());
323 assert((isa<IntegerType, MemRefType>(srcTy)) &&
324 "Source has to be either int or memref.");
342 if (
auto memrefTy = dyn_cast<MemRefType>(srcTy)) {
343 auto memrefShape = memrefTy.getShape();
344 auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
349 if (staticShape == memrefShape && staticStrides == memrefStrides &&
350 dynamicShape.empty() && dynamicStrides.empty()) {
356 build(builder, state, tdesc, source, dynamicOffsets, dynamicShape,
357 dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr);
360LogicalResult CreateNdDescOp::verify() {
362 bool invalidRank = rank != getMixedStrides().size();
363 bool invalidElemTy =
false;
369 auto srcMemorySpace = getSourceMemorySpace();
370 auto tdescMemorySpace =
static_cast<unsigned>(
getType().getMemorySpace());
371 if (srcMemorySpace != tdescMemorySpace)
373 <<
" Source: " << srcMemorySpace
374 <<
", TensorDesc: " << tdescMemorySpace;
376 if (
size_t offsetRank = getMixedOffsets().size())
377 invalidRank |= (offsetRank != rank);
381 if (
auto memrefTy = dyn_cast<MemRefType>(getSourceType()))
384 if (llvm::isa<IntegerType>(getSourceType())) {
387 return emitOpError(
"expecting strides and shape to be present for "
393 "Expecting the rank of shape, strides, offsets, and source (if source "
394 "is a memref) should match with each other.");
399 "Expecting the TensorDesc rank is not greater than the "
400 "ranks of shape, strides, offsets or the memref source.");
403 return emitOpError(
"TensorDesc should have the same element "
404 "type with the source if it is a memref.\n");
407 return emitOpError(
"Expects a non-scattered TensorDesc.\n");
419 auto parseIntegerOrValue = [&]() {
423 if (res.has_value() && succeeded(res.value())) {
424 values.push_back(operand);
425 integerVals.push_back(ShapedType::kDynamic);
426 if (valueTypes && parser.
parseColonType(valueTypes->emplace_back()))
432 integerVals.push_back(integer);
442 <<
"expected a list of SSA values or integers";
453 if (!integers || integers.empty())
463 Value tensorDesc, xegpu::CachePolicyAttr l1_hint,
464 xegpu::CachePolicyAttr l2_hint,
465 xegpu::CachePolicyAttr l3_hint) {
468 l1_hint, l2_hint, l3_hint);
473 xegpu::CachePolicyAttr l1_hint,
474 xegpu::CachePolicyAttr l2_hint,
475 xegpu::CachePolicyAttr l3_hint) {
482 build(builder, state, tensorDesc, dynamicOffsets, staticOffsetsAttr, l1_hint,
486LogicalResult PrefetchNdOp::verify() {
487 auto tdescTy = getTensorDescType();
488 if (tdescTy.isScattered())
489 return emitOpError(
"Expects a non-scattered TensorDesc.\n");
492 return emitOpError(
"invalid l1_hint: ") << getL1HintAttr();
495 return emitOpError(
"invalid l2_hint: ") << getL2HintAttr();
498 return emitOpError(
"invalid l3_hint: ") << getL3HintAttr();
500 int64_t tDescRank = tdescTy.getRank();
501 int64_t offsetSize = getMixedOffsets().size();
502 if (offsetSize != 0 && offsetSize != tDescRank)
504 "Mismatched ranks between offsets and tensor descriptor");
514 Value tensorDesc, UnitAttr packed,
516 xegpu::CachePolicyAttr l1_hint,
517 xegpu::CachePolicyAttr l2_hint,
518 xegpu::CachePolicyAttr l3_hint) {
520 return build(builder, state, retType, tensorDesc,
ValueRange(),
528 xegpu::CachePolicyAttr l1_hint,
529 xegpu::CachePolicyAttr l2_hint,
530 xegpu::CachePolicyAttr l3_hint) {
537 build(builder, state, retType, tensorDesc, dynamicOffsets, staticOffsetsAttr,
538 packed, transpose, l1_hint, l2_hint, l3_hint);
541LogicalResult LoadNdOp::verify() {
542 auto tdescTy = getTensorDescType();
545 if (tdescTy.isScattered())
546 return emitOpError(
"Expects a non-scattered TensorDesc.\n");
548 if (tdescTy.getRank() > 2)
549 return emitOpError(
"Expects a 1D or 2D TensorDesc.\n");
552 return emitOpError(
"Invalid result, it should be a VectorType.\n");
555 return emitOpError(
"invalid l1_hint: ") << getL1HintAttr();
558 return emitOpError(
"invalid l2_hint: ") << getL2HintAttr();
561 return emitOpError(
"invalid l3_hint: ") << getL3HintAttr();
563 int tdescElems = tdescTy.getNumElements() * tdescTy.getArrayLength();
564 int valueElems = valueTy.getNumElements();
569 if (valueElems < tdescElems && valueTy.getRank() == 1) {
571 if (tdescTy.getLayoutAttr())
573 <<
"TensorDesc doesn't need LayoutAttr for SIMT code";
578 if (tdescElems % valueElems)
581 <<
" is not a valid distribution for tensor descriptor "
591 if (getTranspose()) {
592 auto trans = getTranspose().value();
594 if (llvm::all_of(trans, [&](
size_t s) {
return s < tdescShape.size(); }))
601 if (tdescTy.getRank() == 2) {
603 auto vnni_factor = valueShape.back();
604 tdescShape[axis] /= vnni_factor;
605 tdescShape.push_back(vnni_factor);
608 <<
"Invalid Packed Attr. It is ignored (available for 2D "
613 auto array_len = tdescTy.getArrayLength();
615 tdescShape.insert(tdescShape.begin(), array_len);
617 if (tdescShape != valueShape)
619 <<
" is not consistent with tensor descriptor "
622 int64_t tDescRank = tdescTy.getRank();
623 int64_t offsetSize = getMixedOffsets().size();
624 if (offsetSize != 0 && offsetSize != tDescRank)
626 "Mismatched ranks between offsets and tensor descriptor");
636 Value tensorDesc, xegpu::CachePolicyAttr l1_hint,
637 xegpu::CachePolicyAttr l2_hint,
638 xegpu::CachePolicyAttr l3_hint) {
640 return build(builder, state, value, tensorDesc,
ValueRange(),
646 xegpu::CachePolicyAttr l1_hint,
647 xegpu::CachePolicyAttr l2_hint,
648 xegpu::CachePolicyAttr l3_hint) {
655 build(builder, state, value, tensorDesc, dynamicOffsets, staticOffsetsAttr,
656 l1_hint, l2_hint, l3_hint);
659LogicalResult StoreNdOp::verify() {
660 auto dstTy = getTensorDescType();
663 if (dstTy.isScattered())
664 return emitOpError(
"Expects a non-scattered TensorDesc.\n");
666 if (dstTy.getRank() > 2)
667 return emitOpError(
"Expects a 1D or 2D TensorDesc.\n");
670 return emitOpError(
"Expecting a VectorType result.\n");
673 return emitOpError(
"invalid l1_hint: ") << getL1HintAttr();
676 return emitOpError(
"invalid l2_hint: ") << getL2HintAttr();
679 return emitOpError(
"invalid l3_hint: ") << getL3HintAttr();
681 auto array_len = dstTy.getArrayLength();
683 return emitOpError(
"array length is not supported by store_nd.\n");
685 auto tdescElems = dstTy.getNumElements();
686 auto valueElems = valTy.getNumElements();
691 if (valTy.getRank() == 1 && valueElems < tdescElems) {
693 if (dstTy.getLayoutAttr())
695 <<
"TensorDesc doesn't need LayoutAttr for SIMT code";
697 if (tdescElems % valueElems)
700 <<
" is not a valid distribution for tensor descriptor " << dstTy;
708 if (tdescShape != valueShape)
710 <<
" is not consistent with tensor descriptor "
713 int64_t tDescRank = dstTy.getRank();
714 int64_t offsetSize = getMixedOffsets().size();
715 if (offsetSize != 0 && offsetSize != tDescRank)
717 "Mismatched ranks between offsets and tensor descriptor");
725LogicalResult UpdateNdOffsetOp::verify() {
726 auto ty = getTensorDescType();
727 if (ty.isScattered())
728 return emitOpError(
"Expects a non-scattered TensorDesc.\n");
731 if (ty.getRank() != (
int64_t)getNumOffsets()) {
742 TensorDescType TensorDesc,
Value source,
744 auto loc = source.
getLoc();
746 auto type = VectorType::get(size, builder.
getIndexType());
748 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
749 build(builder, state, TensorDesc, source, offset);
753 TensorDescType TensorDesc,
Value source,
756 build(builder, state, TensorDesc, source, ofrs);
759LogicalResult CreateDescOp::verify() {
760 auto tdescTy = getTensorDescType();
762 if (!tdescTy.isScattered())
763 return emitOpError(
"Expects a scattered TensorDesc.\n");
769 auto srcMemorySpace = getSourceMemorySpace();
770 auto tdescMemorySpace =
static_cast<unsigned>(tdescTy.getMemorySpace());
771 if (srcMemorySpace != tdescMemorySpace)
773 <<
" Source: " << srcMemorySpace
774 <<
", TensorDesc: " << tdescMemorySpace;
777 auto chunkSize = tdescTy.getChunkSizeAsInt();
780 shape.push_back(chunkSize);
783 if (
shape != tdescShape)
793LogicalResult PrefetchOp::verify() {
794 auto tdescTy = getTensorDescType();
796 if (!tdescTy && !getOffsets())
799 if (tdescTy && getOffsets())
802 if (tdescTy && !tdescTy.isScattered())
803 return emitOpError(
"Expects a scattered TensorDesc.");
806 return emitOpError(
"invalid l1_hint: ") << getL1HintAttr();
809 return emitOpError(
"invalid l2_hint: ") << getL2HintAttr();
812 return emitOpError(
"invalid l3_hint: ") << getL3HintAttr();
814 auto srcTy = getSourceType();
815 if (srcTy.
isInteger() && !getOffsetAlignByteAttr())
816 return emitOpError(
"offset_align_byte is required with integer source.");
818 if (getOffsetAlignByteAttr() && !srcTy.
isInteger())
819 return emitOpError(
"offset_align_byte only allowed with integer source.");
825 xegpu::CachePolicyAttr l1_hint,
826 xegpu::CachePolicyAttr l2_hint,
827 xegpu::CachePolicyAttr l3_hint) {
828 build(builder, state, source,
Value(), l1_hint, l2_hint, l3_hint,
835LogicalResult LoadGatherOp::verify() {
836 auto tdescTy = getTensorDescType();
837 auto maskTy = getMaskType();
840 if (!tdescTy && !getOffsets())
843 if (tdescTy && getOffsets())
846 if (tdescTy && !tdescTy.isScattered())
847 return emitOpError(
"Expects a scattered TensorDesc.");
850 return emitOpError(
"invalid l1_hint: ") << getL1HintAttr();
853 return emitOpError(
"invalid l2_hint: ") << getL2HintAttr();
856 return emitOpError(
"invalid l3_hint: ") << getL3HintAttr();
861 auto srcTy = getSourceType();
862 uint64_t chunkSize =
static_cast<int64_t>(getChunkSize().value_or(1));
863 auto memTy = dyn_cast<MemRefType>(srcTy);
866 return emitError() <<
"Value should have the same element type as MemRef.";
868 auto offsetsTy = getOffsets().getType();
875 xegpu::CachePolicyAttr l1_hint,
876 xegpu::CachePolicyAttr l2_hint,
877 xegpu::CachePolicyAttr l3_hint) {
878 build(builder, state, valueType, source,
Value(), mask, IntegerAttr(),
879 l1_hint, l2_hint, l3_hint,
nullptr);
885 IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint,
886 xegpu::CachePolicyAttr l2_hint,
887 xegpu::CachePolicyAttr l3_hint) {
888 auto loc = source.
getLoc();
890 auto type = VectorType::get(size, builder.
getIndexType());
892 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
894 build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
895 l2_hint, l3_hint,
nullptr);
901 IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint,
902 xegpu::CachePolicyAttr l2_hint,
903 xegpu::CachePolicyAttr l3_hint,
904 xegpu::LayoutAttr layout) {
905 auto loc = source.
getLoc();
907 auto type = VectorType::get(size, builder.
getIndexType());
909 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
911 build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
912 l2_hint, l3_hint, layout);
918LogicalResult StoreScatterOp::verify() {
919 auto tdescTy = getTensorDescType();
920 auto maskTy = getMaskType();
923 if (!tdescTy && !getOffsets())
926 if (tdescTy && getOffsets())
929 if (tdescTy && !tdescTy.isScattered())
930 return emitOpError(
"Expects a scattered TensorDesc.");
933 return emitOpError(
"invalid l1_hint: ") << getL1HintAttr();
936 return emitOpError(
"invalid l2_hint: ") << getL2HintAttr();
939 return emitOpError(
"invalid l3_hint: ") << getL3HintAttr();
945 auto destTy = getDestType();
946 uint64_t chunkSize =
static_cast<int64_t>(getChunkSize().value_or(1));
947 auto memTy = dyn_cast<MemRefType>(destTy);
950 return emitError() <<
"Value should have the same element type as MemRef.";
952 auto offsetsTy = getOffsets().getType();
959 xegpu::CachePolicyAttr l1_hint,
960 xegpu::CachePolicyAttr l2_hint,
961 xegpu::CachePolicyAttr l3_hint) {
962 build(builder, state, value, dest,
Value(), mask, IntegerAttr(), l1_hint,
963 l2_hint, l3_hint,
nullptr);
969 IntegerAttr chunk_size,
970 xegpu::CachePolicyAttr l1_hint,
971 xegpu::CachePolicyAttr l2_hint,
972 xegpu::CachePolicyAttr l3_hint) {
975 auto type = VectorType::get(size, builder.
getIndexType());
977 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
980 build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
984void StoreScatterOp::build(
987 xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint,
988 xegpu::CachePolicyAttr l3_hint, xegpu::LayoutAttr layout) {
991 auto type = VectorType::get(size, builder.
getIndexType());
993 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
996 build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
1006 auto tdescTy = mlir::dyn_cast<TensorDescType>(tensorDesc.
getType());
1007 assert(tdescTy &&
"Expecting the source is a TensorDescType value.");
1008 auto loc = tensorDesc.
getLoc();
1010 auto type = VectorType::get({size}, builder.
getIndexType());
1012 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
1013 build(builder, state, tdescTy, tensorDesc, offset);
1019 build(builder, state, tensorDesc, ofrs);
1022LogicalResult UpdateOffsetOp::verify() {
1023 auto tdescTy = getTensorDescType();
1024 if (!tdescTy.isScattered())
1025 return emitOpError(
"Expects a scattered TensorDesc.\n");
1029 if (tdescTy.getChunkSizeAsInt() > 1)
1030 expectedOffsetShape.pop_back();
1032 if (expectedOffsetShape != offsetShape)
1034 "Offsets should match TensorDesc except the chunk size dim.");
1042LogicalResult DpasOp::verify() {
1043 int64_t lhsRank = getLhsType().getRank();
1044 int64_t rhsRank = getRhsType().getRank();
1045 int64_t resRank = getResultType().getRank();
1046 auto lhsShape = getLhsType().getShape();
1047 auto rhsShape = getRhsType().getShape();
1048 auto resShape = getResultType().getShape();
1050 if (getAcc() && getAcc().
getType() != getResultType())
1051 return emitOpError(
"Expecting the acc type to be the same as result.");
1056 if (lhsRank == 1 && rhsRank == 1 && resRank == 1) {
1057 auto numElems = getRhsType().getNumElements();
1058 auto elemTy = getRhsType().getElementType();
1059 auto factor = 32 / elemTy.getIntOrFloatBitWidth();
1060 if (numElems % factor != 0)
1061 return emitOpError(
"Expecting B operand to be a multiple of 32 bits.");
1066 if (lhsRank != 2 || (rhsRank != 2 && rhsRank != 3) || resRank != 2)
1068 "expecting lhs and result to be a 2D vector, and rhs to be either "
1069 "2D or 3D (packed) vector.");
1070 auto bK = rhsRank == 3 ? rhsShape[0] * rhsShape[2] : rhsShape[0];
1071 if (bK != lhsShape[1])
1073 if (lhsShape[0] != resShape[0])
1075 if (rhsShape[1] != resShape[1])
1084LogicalResult ConvertLayoutOp::verify() {
1085 auto srcLayout = getInputLayout();
1086 auto resLayout = getTargetLayout();
1094 if ((!srcLayout.isForWorkgroup() || !resLayout.isForWorkgroup()) &&
1095 (!srcLayout.isForSubgroup() || !resLayout.isForSubgroup()))
1096 return emitOpError(
"expected input layout and target layout be WgLayout or "
1097 "SgLayout at the same time.");
1099 auto shape = getSource().getType().getShape();
1100 if (!XeGPUDialect::isEvenlyDistributable(
shape, srcLayout))
1102 "invalid input layout, data cannot be evenly distributed.");
1104 if (!XeGPUDialect::isEvenlyDistributable(
shape, resLayout))
1106 "invalid target layout, data cannot be evenly distributed.");
1108 return mlir::success();
1111OpFoldResult ConvertLayoutOp::fold(FoldAdaptor adaptor) {
1112 if (getInputLayout() == getTargetLayout())
1121 if (op.getInputLayout() == op.getTargetLayout()) {
1140 DistributeLayoutAttr layout) {
1147 build(builder, state, res, memDesc, dynamicOffsets, staticOffsetsAttr,
1151LogicalResult LoadMatrixOp::verify() {
1153 auto resTy = dyn_cast<VectorType>(getRes().
getType());
1154 UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
1155 MemDescType mdescTy = getMemDesc().getType();
1158 getLayoutAttr(), [&]() {
return emitError(); });
1167 DistributeLayoutAttr layout) {
1172 build(builder, state, data, memDesc, dynamicOffsets, staticOffsetsAttr,
1176LogicalResult StoreMatrixOp::verify() {
1178 auto dataTy = dyn_cast<VectorType>(getData().
getType());
1179 UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
1180 MemDescType mdescTy = getMemDesc().getType();
1182 getLayoutAttr(), [&]() {
return emitError(); });
1186#include <mlir/Dialect/XeGPU/IR/XeGPUAttrInterface.cpp.inc>
1188#include <mlir/Dialect/XeGPU/IR/XeGPUEnums.cpp.inc>
1189#define GET_OP_CLASSES
1190#include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Type getElementType(Type type)
Determine the element type of type.
static Type getValueType(Attribute attr)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
static SmallVector< int64_t > getShapeOf(Type type)
LogicalResult IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy, UnitAttr subgroup_block_io, DistributeLayoutAttr layout, function_ref< InFlightDiagnostic()> emitError)
static std::string makeString(T array, bool breakline=false)
static bool isWriteHintOrNone(const CachePolicyAttr &attr)
static bool isReadHintOrNone(const CachePolicyAttr &attr)
static LogicalResult isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy, VectorType valueTy, int64_t chunkSize, function_ref< InFlightDiagnostic()> emitError)
static void printOptionalDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, DenseI64ArrayAttr integers)
static bool isSharedMemory(const MemRefType &memrefTy)
static ParseResult parseOptionalDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
static LogicalResult isValidGatherScatterParams(Type maskTy, VectorType valueTy, TensorDescType tdescTy, function_ref< InFlightDiagnostic()> emitError)
Delimiter
These are the supported delimiters around operand lists and region argument lists,...
@ Square
Square brackets surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
Attributes are known-constant values of operations.
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
MLIRContext * getContext() const
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class helps build Operations.
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
llvm::function_ref< Fn > function_ref
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
LogicalResult matchAndRewrite(xegpu::ConvertLayoutOp op, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.