19#include "llvm/Support/Debug.h"
21#define DEBUG_TYPE "xegpu"
27 Attribute attr = memrefTy.getMemorySpace();
28 if (
auto intAttr = llvm::dyn_cast<IntegerAttr>(attr))
29 return intAttr.getInt() == 3;
30 if (
auto memrefSpace = llvm::dyn_cast<MemorySpaceAttr>(attr))
31 return memrefSpace.getValue() == MemorySpace::SLM;
32 if (
auto xevmSpace = llvm::dyn_cast<xevm::AddrSpaceAttr>(attr))
33 return xevmSpace.getValue() == xevm::AddrSpace::SHARED;
34 return gpu::GPUDialect::isWorkgroupMemoryAddressSpace(attr);
38static std::string
makeString(T array,
bool breakline =
false) {
41 llvm::raw_string_ostream os(buf);
43 for (
size_t i = 1; i < array.size(); i++) {
44 os << array[i - 1] <<
", ";
48 os << array.back() <<
"]";
54 if (
auto ty = llvm::dyn_cast<ShapedType>(type))
64 auto kind = attr.getValue();
65 return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED ||
66 kind == CachePolicy::STREAMING || kind == CachePolicy::READ_INVALIDATE;
72 auto kind = attr.getValue();
73 return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED ||
74 kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH;
79 TensorDescType tdescTy,
82 if (!tdescTy.isScattered())
83 return emitError() <<
"Expects a scattered TensorDesc.";
85 auto chunkSize = tdescTy.getChunkSizeAsInt();
88 return emitError() <<
"Expecting chunk size == 1 for scalar result";
89 if (dyn_cast<VectorType>(maskTy))
90 return emitError() <<
"Expecting a vector type result.";
98 if (valueTy.getElementType() != tdescTy.getElementType())
100 <<
"Value should have the same element type as TensorDesc.";
104 expectedMaskShape.pop_back();
105 if (expectedMaskShape != maskShape)
107 <<
"Mask should match TensorDesc except the chunk size dim.";
110 if (valueTy.getRank() == 1 && valueTy.getNumElements() == chunkSize) {
111 if (tdescTy.getLayoutAttr())
112 return emitError() <<
"TensorDesc doesn't need LayoutAttr for SIMT code";
116 if (tdescShape != valueShape)
118 <<
" is neither a valid distribution for SIMT nor "
119 "consistent with the tensor descriptor for SIMD "
126 VectorType valueTy,
int64_t chunkSize,
129 auto maskVecTy = dyn_cast<VectorType>(maskTy);
130 auto offsetsVecTy = dyn_cast<VectorType>(offsetsTy);
133 return emitError() <<
"Expecting chunk size == 1 for scalar result";
134 if (maskVecTy || offsetsVecTy)
135 return emitError() <<
"Expecting scalar mask and offsets.";
136 else if (maskVecTy && offsetsVecTy)
137 return emitError() <<
"Expecting a vector type result.";
141 auto valueSize = valueTy.getNumElements();
143 if (!maskVecTy && !offsetsVecTy) {
144 if (valueSize != chunkSize)
145 return emitError() <<
"value elements must match chunk size "
153 return emitError() <<
"Expecting a vector type mask.";
154 int64_t maskSize = maskVecTy.getNumElements();
157 if ((valueTy.getRank() == 1) && (valueSize != chunkSize))
158 return emitError() <<
"value elements must match chunk size "
161 if (valueSize != maskSize)
163 <<
"Mask should match value except the chunk size dim.";
169 expectedMaskShape.pop_back();
170 if (expectedMaskShape != maskShape)
171 return emitError() <<
"Mask should match value except the chunk size dim.";
178 UnitAttr subgroup_block_io, DistributeLayoutAttr layout,
182 if (subgroup_block_io)
183 return emitError() <<
"subgroup_block_io "
184 "are only allowed when result is a VectorType.";
189 if (mdescTy.getRank() != 2)
190 return emitError() <<
"mem_desc must be 2D.";
196 ArrayAttr strideAttr = mdescTy.getStrideAttr();
198 for (
Attribute attr : strideAttr.getValue()) {
199 strides.push_back(cast<IntegerAttr>(attr).getInt());
201 if (subgroup_block_io && layout) {
202 auto laneData = layout.getEffectiveLaneDataAsInt();
203 auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
204 if (!laneData.empty()) {
205 bool isLaneDataContiguous =
206 std::all_of(laneData.begin(), std::prev(laneData.end()),
207 [](
int x) { return x == 1; });
208 if (!isLaneDataContiguous)
209 return emitError() <<
"With subgroup_block_io, accessed data must be "
210 "contiguous and coalesced.";
211 for (
size_t i = 0; i < laneData.size(); ++i) {
212 if (laneLayout[i] != blockShape[i])
213 return emitError() <<
"With subgroup_block_io, the block shape must "
214 "match the lane layout.";
215 if (laneLayout[i] != 1 && strides[i] != 1)
216 return emitError() <<
"With subgroup_block_io, the distributed "
217 "dimensions must be contiguous.";
221 if (dataShape.size() == 2) {
222 if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
223 [](
auto p) {
return std::get<0>(p) > std::get<1>(p); }))
224 return emitError() <<
"data shape must not exceed mem_desc shape.";
228 if (subgroup_block_io && !blockShape.size())
229 return emitError() <<
"mem_desc must have block attribute when "
230 "subgroup_block_io is set.";
233 if (subgroup_block_io && mdescTy.isColMajor())
234 return emitError() <<
"mem_desc should be row major when "
235 "subgroup_block_io is set.";
247 [[maybe_unused]]
auto ty = source.getType();
248 assert(ty.hasStaticShape() &&
"expecting a memref with static shape");
250 build(builder, state, tdesc, source,
ValueRange({}) ,
263 assert((isa<IntegerType, MemRefType>(srcTy)) &&
264 "Source has to be either int or memref.");
278 if (
auto memrefTy = dyn_cast<MemRefType>(srcTy)) {
279 auto memrefShape = memrefTy.getShape();
280 auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
285 if (staticShape == memrefShape && staticStrides == memrefStrides &&
286 dynamicShape.empty() && dynamicStrides.empty()) {
292 build(builder, state, tdesc, source,
ValueRange({}), dynamicShape,
300 [[maybe_unused]]
auto ty = source.getType();
301 assert(ty.hasStaticShape() && offsets.size() == (
size_t)ty.getRank());
307 build(builder, state, tdesc, source, dynamicOffsets ,
319 assert(!
shape.empty() && !offsets.empty() && !strides.empty() &&
320 shape.size() == strides.size() &&
shape.size() == offsets.size());
323 assert((isa<IntegerType, MemRefType>(srcTy)) &&
324 "Source has to be either int or memref.");
342 if (
auto memrefTy = dyn_cast<MemRefType>(srcTy)) {
343 auto memrefShape = memrefTy.getShape();
344 auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
349 if (staticShape == memrefShape && staticStrides == memrefStrides &&
350 dynamicShape.empty() && dynamicStrides.empty()) {
356 build(builder, state, tdesc, source, dynamicOffsets, dynamicShape,
357 dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr);
360LogicalResult CreateNdDescOp::verify() {
362 bool invalidRank = rank != getMixedStrides().size();
363 bool invalidElemTy =
false;
369 auto srcMemorySpace = getSourceMemorySpace();
370 auto tdescMemorySpace =
static_cast<unsigned>(
getType().getMemorySpace());
371 if (srcMemorySpace != tdescMemorySpace)
373 <<
" Source: " << srcMemorySpace
374 <<
", TensorDesc: " << tdescMemorySpace;
376 if (
size_t offsetRank = getMixedOffsets().size())
377 invalidRank |= (offsetRank != rank);
381 if (
auto memrefTy = dyn_cast<MemRefType>(getSourceType()))
384 if (llvm::isa<IntegerType>(getSourceType())) {
387 return emitOpError(
"expecting strides and shape to be present for "
393 "Expecting the rank of shape, strides, offsets, and source (if source "
394 "is a memref) should match with each other.");
399 "Expecting the TensorDesc rank is not greater than the "
400 "ranks of shape, strides, offsets or the memref source.");
403 return emitOpError(
"TensorDesc should have the same element "
404 "type with the source if it is a memref.\n");
407 return emitOpError(
"Expects a non-scattered TensorDesc.\n");
419 auto parseIntegerOrValue = [&]() {
423 if (res.has_value() && succeeded(res.value())) {
424 values.push_back(operand);
425 integerVals.push_back(ShapedType::kDynamic);
426 if (valueTypes && parser.
parseColonType(valueTypes->emplace_back()))
432 integerVals.push_back(integer);
442 <<
"expected a list of SSA values or integers";
453 if (!integers || integers.empty())
463 Value tensorDesc, xegpu::CachePolicyAttr l1_hint,
464 xegpu::CachePolicyAttr l2_hint,
465 xegpu::CachePolicyAttr l3_hint) {
468 l1_hint, l2_hint, l3_hint,
nullptr);
473 xegpu::CachePolicyAttr l1_hint,
474 xegpu::CachePolicyAttr l2_hint,
475 xegpu::CachePolicyAttr l3_hint) {
482 build(builder, state, tensorDesc, dynamicOffsets, staticOffsetsAttr, l1_hint,
483 l2_hint, l3_hint,
nullptr);
486LogicalResult PrefetchNdOp::verify() {
487 auto tdescTy = getTensorDescType();
488 if (tdescTy.isScattered())
489 return emitOpError(
"Expects a non-scattered TensorDesc.\n");
492 return emitOpError(
"invalid l1_hint: ") << getL1HintAttr();
495 return emitOpError(
"invalid l2_hint: ") << getL2HintAttr();
498 return emitOpError(
"invalid l3_hint: ") << getL3HintAttr();
500 int64_t tDescRank = tdescTy.getRank();
501 int64_t offsetSize = getMixedOffsets().size();
502 if (offsetSize != 0 && offsetSize != tDescRank)
504 "Mismatched ranks between offsets and tensor descriptor");
514 Value tensorDesc, UnitAttr packed,
516 xegpu::CachePolicyAttr l1_hint,
517 xegpu::CachePolicyAttr l2_hint,
518 xegpu::CachePolicyAttr l3_hint) {
520 return build(builder, state, retType, tensorDesc,
ValueRange(),
528 xegpu::CachePolicyAttr l1_hint,
529 xegpu::CachePolicyAttr l2_hint,
530 xegpu::CachePolicyAttr l3_hint) {
537 build(builder, state, retType, tensorDesc, dynamicOffsets, staticOffsetsAttr,
538 packed, transpose, l1_hint, l2_hint, l3_hint,
542LogicalResult LoadNdOp::verify() {
543 auto tdescTy = getTensorDescType();
546 if (tdescTy.isScattered())
547 return emitOpError(
"Expects a non-scattered TensorDesc.\n");
549 if (tdescTy.getRank() > 2)
550 return emitOpError(
"Expects a 1D or 2D TensorDesc.\n");
553 return emitOpError(
"Invalid result, it should be a VectorType.\n");
556 return emitOpError(
"invalid l1_hint: ") << getL1HintAttr();
559 return emitOpError(
"invalid l2_hint: ") << getL2HintAttr();
562 return emitOpError(
"invalid l3_hint: ") << getL3HintAttr();
564 int tdescElems = tdescTy.getNumElements() * tdescTy.getArrayLength();
565 int valueElems = valueTy.getNumElements();
570 if (valueElems < tdescElems && valueTy.getRank() == 1) {
572 if (tdescTy.getLayoutAttr())
574 <<
"TensorDesc doesn't need LayoutAttr for SIMT code";
579 if (tdescElems % valueElems)
582 <<
" is not a valid distribution for tensor descriptor "
592 if (getTranspose()) {
593 auto trans = getTranspose().value();
595 if (llvm::all_of(trans, [&](
size_t s) {
return s < tdescShape.size(); }))
602 if (tdescTy.getRank() == 2) {
604 auto vnni_factor = valueShape.back();
605 tdescShape[axis] /= vnni_factor;
606 tdescShape.push_back(vnni_factor);
609 <<
"Invalid Packed Attr. It is ignored (available for 2D "
614 auto array_len = tdescTy.getArrayLength();
616 tdescShape.insert(tdescShape.begin(), array_len);
618 if (tdescShape != valueShape)
620 <<
" is not consistent with tensor descriptor "
623 int64_t tDescRank = tdescTy.getRank();
624 int64_t offsetSize = getMixedOffsets().size();
625 if (offsetSize != 0 && offsetSize != tDescRank)
627 "Mismatched ranks between offsets and tensor descriptor");
637 Value tensorDesc, xegpu::CachePolicyAttr l1_hint,
638 xegpu::CachePolicyAttr l2_hint,
639 xegpu::CachePolicyAttr l3_hint) {
641 return build(builder, state, value, tensorDesc,
ValueRange(),
648 xegpu::CachePolicyAttr l1_hint,
649 xegpu::CachePolicyAttr l2_hint,
650 xegpu::CachePolicyAttr l3_hint) {
657 build(builder, state, value, tensorDesc, dynamicOffsets, staticOffsetsAttr,
658 l1_hint, l2_hint, l3_hint,
nullptr);
661LogicalResult StoreNdOp::verify() {
662 auto dstTy = getTensorDescType();
665 if (dstTy.isScattered())
666 return emitOpError(
"Expects a non-scattered TensorDesc.\n");
668 if (dstTy.getRank() > 2)
669 return emitOpError(
"Expects a 1D or 2D TensorDesc.\n");
672 return emitOpError(
"Expecting a VectorType result.\n");
675 return emitOpError(
"invalid l1_hint: ") << getL1HintAttr();
678 return emitOpError(
"invalid l2_hint: ") << getL2HintAttr();
681 return emitOpError(
"invalid l3_hint: ") << getL3HintAttr();
683 auto array_len = dstTy.getArrayLength();
685 return emitOpError(
"array length is not supported by store_nd.\n");
687 auto tdescElems = dstTy.getNumElements();
688 auto valueElems = valTy.getNumElements();
693 if (valTy.getRank() == 1 && valueElems < tdescElems) {
695 if (dstTy.getLayoutAttr())
697 <<
"TensorDesc doesn't need LayoutAttr for SIMT code";
699 if (tdescElems % valueElems)
702 <<
" is not a valid distribution for tensor descriptor " << dstTy;
710 if (tdescShape != valueShape)
712 <<
" is not consistent with tensor descriptor "
715 int64_t tDescRank = dstTy.getRank();
716 int64_t offsetSize = getMixedOffsets().size();
717 if (offsetSize != 0 && offsetSize != tDescRank)
719 "Mismatched ranks between offsets and tensor descriptor");
727LogicalResult UpdateNdOffsetOp::verify() {
728 auto ty = getTensorDescType();
729 if (ty.isScattered())
730 return emitOpError(
"Expects a non-scattered TensorDesc.\n");
733 if (ty.getRank() != (
int64_t)getNumOffsets()) {
744 TensorDescType TensorDesc,
Value source,
746 auto loc = source.
getLoc();
748 auto type = VectorType::get(size, builder.
getIndexType());
750 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
751 build(builder, state, TensorDesc, source, offset);
755 TensorDescType TensorDesc,
Value source,
758 build(builder, state, TensorDesc, source, ofrs);
761LogicalResult CreateDescOp::verify() {
762 auto tdescTy = getTensorDescType();
764 if (!tdescTy.isScattered())
765 return emitOpError(
"Expects a scattered TensorDesc.\n");
771 auto srcMemorySpace = getSourceMemorySpace();
772 auto tdescMemorySpace =
static_cast<unsigned>(tdescTy.getMemorySpace());
773 if (srcMemorySpace != tdescMemorySpace)
775 <<
" Source: " << srcMemorySpace
776 <<
", TensorDesc: " << tdescMemorySpace;
779 auto chunkSize = tdescTy.getChunkSizeAsInt();
782 shape.push_back(chunkSize);
785 if (
shape != tdescShape)
795LogicalResult PrefetchOp::verify() {
796 auto tdescTy = getTensorDescType();
798 if (!tdescTy && !getOffsets())
801 if (tdescTy && getOffsets())
804 if (tdescTy && !tdescTy.isScattered())
805 return emitOpError(
"Expects a scattered TensorDesc.");
808 return emitOpError(
"invalid l1_hint: ") << getL1HintAttr();
811 return emitOpError(
"invalid l2_hint: ") << getL2HintAttr();
814 return emitOpError(
"invalid l3_hint: ") << getL3HintAttr();
816 auto srcTy = getSourceType();
817 if (srcTy.
isInteger() && !getOffsetAlignByteAttr())
818 return emitOpError(
"offset_align_byte is required with integer source.");
820 if (getOffsetAlignByteAttr() && !srcTy.
isInteger())
821 return emitOpError(
"offset_align_byte only allowed with integer source.");
827 xegpu::CachePolicyAttr l1_hint,
828 xegpu::CachePolicyAttr l2_hint,
829 xegpu::CachePolicyAttr l3_hint) {
830 build(builder, state, source,
Value(), l1_hint, l2_hint, l3_hint,
831 IntegerAttr{},
nullptr);
837LogicalResult LoadGatherOp::verify() {
838 auto tdescTy = getTensorDescType();
839 auto maskTy = getMaskType();
842 if (!tdescTy && !getOffsets())
845 if (tdescTy && getOffsets())
848 if (tdescTy && !tdescTy.isScattered())
849 return emitOpError(
"Expects a scattered TensorDesc.");
852 return emitOpError(
"invalid l1_hint: ") << getL1HintAttr();
855 return emitOpError(
"invalid l2_hint: ") << getL2HintAttr();
858 return emitOpError(
"invalid l3_hint: ") << getL3HintAttr();
863 auto srcTy = getSourceType();
864 uint64_t chunkSize =
static_cast<int64_t>(getChunkSize().value_or(1));
865 auto memTy = dyn_cast<MemRefType>(srcTy);
868 return emitError() <<
"Value should have the same element type as MemRef.";
870 auto offsetsTy = getOffsets().getType();
877 xegpu::CachePolicyAttr l1_hint,
878 xegpu::CachePolicyAttr l2_hint,
879 xegpu::CachePolicyAttr l3_hint) {
880 build(builder, state, valueType, source,
Value(), mask, IntegerAttr(),
881 l1_hint, l2_hint, l3_hint,
nullptr);
887 IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint,
888 xegpu::CachePolicyAttr l2_hint,
889 xegpu::CachePolicyAttr l3_hint) {
890 auto loc = source.
getLoc();
892 auto type = VectorType::get(size, builder.
getIndexType());
894 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
896 build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
897 l2_hint, l3_hint,
nullptr);
903 IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint,
904 xegpu::CachePolicyAttr l2_hint,
905 xegpu::CachePolicyAttr l3_hint,
906 DistributeLayoutAttr layout) {
907 auto loc = source.
getLoc();
909 auto type = VectorType::get(size, builder.
getIndexType());
911 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
913 build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
914 l2_hint, l3_hint, layout);
920LogicalResult StoreScatterOp::verify() {
921 auto tdescTy = getTensorDescType();
922 auto maskTy = getMaskType();
925 if (!tdescTy && !getOffsets())
928 if (tdescTy && getOffsets())
931 if (tdescTy && !tdescTy.isScattered())
932 return emitOpError(
"Expects a scattered TensorDesc.");
935 return emitOpError(
"invalid l1_hint: ") << getL1HintAttr();
938 return emitOpError(
"invalid l2_hint: ") << getL2HintAttr();
941 return emitOpError(
"invalid l3_hint: ") << getL3HintAttr();
947 auto destTy = getDestType();
948 uint64_t chunkSize =
static_cast<int64_t>(getChunkSize().value_or(1));
949 auto memTy = dyn_cast<MemRefType>(destTy);
952 return emitError() <<
"Value should have the same element type as MemRef.";
954 auto offsetsTy = getOffsets().getType();
961 xegpu::CachePolicyAttr l1_hint,
962 xegpu::CachePolicyAttr l2_hint,
963 xegpu::CachePolicyAttr l3_hint) {
964 build(builder, state, value, dest,
Value(), mask, IntegerAttr(), l1_hint,
965 l2_hint, l3_hint,
nullptr);
971 IntegerAttr chunk_size,
972 xegpu::CachePolicyAttr l1_hint,
973 xegpu::CachePolicyAttr l2_hint,
974 xegpu::CachePolicyAttr l3_hint) {
977 auto type = VectorType::get(size, builder.
getIndexType());
979 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
982 build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
986void StoreScatterOp::build(
989 xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint,
990 xegpu::CachePolicyAttr l3_hint, DistributeLayoutAttr layout) {
993 auto type = VectorType::get(size, builder.
getIndexType());
995 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
998 build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
1008 auto tdescTy = mlir::dyn_cast<TensorDescType>(tensorDesc.
getType());
1009 assert(tdescTy &&
"Expecting the source is a TensorDescType value.");
1010 auto loc = tensorDesc.
getLoc();
1012 auto type = VectorType::get({size}, builder.
getIndexType());
1014 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
1015 build(builder, state, tdescTy, tensorDesc, offset);
1021 build(builder, state, tensorDesc, ofrs);
1024LogicalResult UpdateOffsetOp::verify() {
1025 auto tdescTy = getTensorDescType();
1026 if (!tdescTy.isScattered())
1027 return emitOpError(
"Expects a scattered TensorDesc.\n");
1031 if (tdescTy.getChunkSizeAsInt() > 1)
1032 expectedOffsetShape.pop_back();
1034 if (expectedOffsetShape != offsetShape)
1036 "Offsets should match TensorDesc except the chunk size dim.");
1044LogicalResult DpasOp::verify() {
1045 int64_t lhsRank = getLhsType().getRank();
1046 int64_t rhsRank = getRhsType().getRank();
1047 int64_t resRank = getResultType().getRank();
1048 auto lhsShape = getLhsType().getShape();
1049 auto rhsShape = getRhsType().getShape();
1050 auto resShape = getResultType().getShape();
1052 if (getAcc() && getAcc().
getType() != getResultType())
1053 return emitOpError(
"Expecting the acc type to be the same as result.");
1058 if (lhsRank == 1 && rhsRank == 1 && resRank == 1) {
1059 auto numElems = getRhsType().getNumElements();
1060 auto elemTy = getRhsType().getElementType();
1061 auto factor = 32 / elemTy.getIntOrFloatBitWidth();
1062 if (numElems % factor != 0)
1063 return emitOpError(
"Expecting B operand to be a multiple of 32 bits.");
1068 if (lhsRank != 2 || (rhsRank != 2 && rhsRank != 3) || resRank != 2)
1070 "expecting lhs and result to be a 2D vector, and rhs to be either "
1071 "2D or 3D (packed) vector.");
1072 auto bK = rhsRank == 3 ? rhsShape[0] * rhsShape[2] : rhsShape[0];
1073 if (bK != lhsShape[1])
1075 if (lhsShape[0] != resShape[0])
1077 if (rhsShape[1] != resShape[1])
1086LogicalResult ConvertLayoutOp::verify() {
1087 auto srcLayout = getInputLayout();
1088 auto resLayout = getTargetLayout();
1096 if ((!srcLayout.isForWorkgroup() || !resLayout.isForWorkgroup()) &&
1097 (!srcLayout.isForSubgroup() || !resLayout.isForSubgroup()))
1098 return emitOpError(
"expected input layout and target layout be WgLayout or "
1099 "SgLayout at the same time.");
1101 auto shape = getSource().getType().getShape();
1102 if (!XeGPUDialect::isEvenlyDistributable(
shape, srcLayout))
1104 "invalid input layout, data cannot be evenly distributed.");
1106 if (!XeGPUDialect::isEvenlyDistributable(
shape, resLayout))
1108 "invalid target layout, data cannot be evenly distributed.");
1110 return mlir::success();
1113OpFoldResult ConvertLayoutOp::fold(FoldAdaptor adaptor) {
1114 if (getInputLayout() == getTargetLayout())
1123 if (op.getInputLayout() == op.getTargetLayout()) {
1142 DistributeLayoutAttr layout) {
1149 build(builder, state, res, memDesc, dynamicOffsets, staticOffsetsAttr,
1153LogicalResult LoadMatrixOp::verify() {
1155 auto resTy = dyn_cast<VectorType>(getRes().
getType());
1156 UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
1157 MemDescType mdescTy = getMemDesc().getType();
1160 getLayoutAttr(), [&]() {
return emitError(); });
1169 DistributeLayoutAttr layout) {
1174 build(builder, state, data, memDesc, dynamicOffsets, staticOffsetsAttr,
1178LogicalResult StoreMatrixOp::verify() {
1180 auto dataTy = dyn_cast<VectorType>(getData().
getType());
1181 UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
1182 MemDescType mdescTy = getMemDesc().getType();
1184 getLayoutAttr(), [&]() {
return emitError(); });
1188#include <mlir/Dialect/XeGPU/IR/XeGPUAttrInterface.cpp.inc>
1190#include <mlir/Dialect/XeGPU/IR/XeGPUEnums.cpp.inc>
1191#define GET_OP_CLASSES
1192#include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Type getElementType(Type type)
Determine the element type of type.
static Type getValueType(Attribute attr)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
static SmallVector< int64_t > getShapeOf(Type type)
LogicalResult IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy, UnitAttr subgroup_block_io, DistributeLayoutAttr layout, function_ref< InFlightDiagnostic()> emitError)
static std::string makeString(T array, bool breakline=false)
static bool isWriteHintOrNone(const CachePolicyAttr &attr)
static bool isReadHintOrNone(const CachePolicyAttr &attr)
static LogicalResult isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy, VectorType valueTy, int64_t chunkSize, function_ref< InFlightDiagnostic()> emitError)
static void printOptionalDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, DenseI64ArrayAttr integers)
static bool isSharedMemory(const MemRefType &memrefTy)
static ParseResult parseOptionalDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
static LogicalResult isValidGatherScatterParams(Type maskTy, VectorType valueTy, TensorDescType tdescTy, function_ref< InFlightDiagnostic()> emitError)
Delimiter
These are the supported delimiters around operand lists and region argument lists,...
@ Square
Square brackets surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
Attributes are known-constant values of operations.
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
MLIRContext * getContext() const
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class helps build Operations.
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
llvm::function_ref< Fn > function_ref
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
LogicalResult matchAndRewrite(xegpu::ConvertLayoutOp op, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.