19#include "llvm/Support/Debug.h"
21#define DEBUG_TYPE "xegpu"
27 Attribute attr = memrefTy.getMemorySpace();
28 if (
auto intAttr = llvm::dyn_cast<IntegerAttr>(attr))
29 return intAttr.getInt() == 3;
30 if (
auto memrefSpace = llvm::dyn_cast<MemorySpaceAttr>(attr))
31 return memrefSpace.getValue() == MemorySpace::SLM;
32 if (
auto xevmSpace = llvm::dyn_cast<xevm::AddrSpaceAttr>(attr))
33 return xevmSpace.getValue() == xevm::AddrSpace::SHARED;
34 return gpu::GPUDialect::isWorkgroupMemoryAddressSpace(attr);
38static std::string
makeString(T array,
bool breakline =
false) {
41 llvm::raw_string_ostream os(buf);
43 for (
size_t i = 1; i < array.size(); i++) {
44 os << array[i - 1] <<
", ";
48 os << array.back() <<
"]";
54 if (
auto ty = llvm::dyn_cast<ShapedType>(type))
64 auto kind = attr.getValue();
65 return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED ||
66 kind == CachePolicy::STREAMING || kind == CachePolicy::READ_INVALIDATE;
72 auto kind = attr.getValue();
73 return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED ||
74 kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH;
79 TensorDescType tdescTy,
82 if (!tdescTy.isScattered())
83 return emitError() <<
"Expects a scattered TensorDesc.";
85 auto chunkSize = tdescTy.getChunkSizeAsInt();
88 return emitError() <<
"Expecting chunk size == 1 for scalar result";
89 if (dyn_cast<VectorType>(maskTy))
90 return emitError() <<
"Expecting a vector type result.";
98 if (valueTy.getElementType() != tdescTy.getElementType())
100 <<
"Value should have the same element type as TensorDesc.";
104 expectedMaskShape.pop_back();
105 if (expectedMaskShape != maskShape)
107 <<
"Mask should match TensorDesc except the chunk size dim.";
110 if (valueTy.getRank() == 1 && valueTy.getNumElements() == chunkSize) {
111 if (tdescTy.getLayoutAttr())
112 return emitError() <<
"TensorDesc doesn't need LayoutAttr for SIMT code";
116 if (tdescShape != valueShape)
118 <<
" is neither a valid distribution for SIMT nor "
119 "consistent with the tensor descriptor for SIMD "
126 VectorType valueTy,
int64_t chunkSize,
129 auto maskVecTy = dyn_cast<VectorType>(maskTy);
130 auto offsetsVecTy = dyn_cast<VectorType>(offsetsTy);
133 return emitError() <<
"Expecting chunk size == 1 for scalar result";
134 if (maskVecTy || offsetsVecTy)
135 return emitError() <<
"Expecting scalar mask and offsets.";
136 else if (maskVecTy && offsetsVecTy)
137 return emitError() <<
"Expecting a vector type result.";
141 auto valueSize = valueTy.getNumElements();
143 if (!maskVecTy && !offsetsVecTy) {
144 if (valueSize != chunkSize)
145 return emitError() <<
"value elements must match chunk size "
153 return emitError() <<
"Expecting a vector type mask.";
154 int64_t maskSize = maskVecTy.getNumElements();
157 if ((valueTy.getRank() == 1) && (valueSize != chunkSize))
158 return emitError() <<
"value elements must match chunk size "
161 if (valueSize != maskSize)
163 <<
"Mask should match value except the chunk size dim.";
169 expectedMaskShape.pop_back();
170 if (expectedMaskShape != maskShape)
171 return emitError() <<
"Mask should match value except the chunk size dim.";
178 UnitAttr subgroup_block_io, DistributeLayoutAttr layout,
182 if (subgroup_block_io)
183 return emitError() <<
"subgroup_block_io "
184 "are only allowed when result is a VectorType.";
189 if (mdescTy.getRank() != 2)
190 return emitError() <<
"mem_desc must be 2D.";
196 ArrayAttr strideAttr = mdescTy.getStrideAttr();
198 for (
Attribute attr : strideAttr.getValue()) {
199 strides.push_back(cast<IntegerAttr>(attr).getInt());
201 if (subgroup_block_io && layout) {
202 auto laneData = layout.getEffectiveLaneDataAsInt();
203 auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
204 if (!laneData.empty()) {
205 bool isLaneDataContiguous =
206 std::all_of(laneData.begin(), std::prev(laneData.end()),
207 [](
int x) { return x == 1; });
208 if (!isLaneDataContiguous)
209 return emitError() <<
"With subgroup_block_io, accessed data must be "
210 "contiguous and coalesced.";
211 for (
size_t i = 0; i < laneData.size(); ++i) {
212 if (laneLayout[i] != blockShape[i])
213 return emitError() <<
"With subgroup_block_io, the block shape must "
214 "match the lane layout.";
215 if (laneLayout[i] != 1 && strides[i] != 1)
216 return emitError() <<
"With subgroup_block_io, the distributed "
217 "dimensions must be contiguous.";
221 if (dataShape.size() == 2) {
222 if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
223 [](
auto p) {
return std::get<0>(p) > std::get<1>(p); }))
224 return emitError() <<
"data shape must not exceed mem_desc shape.";
228 if (subgroup_block_io && !blockShape.size())
229 return emitError() <<
"mem_desc must have block attribute when "
230 "subgroup_block_io is set.";
233 if (subgroup_block_io && mdescTy.isColMajor())
234 return emitError() <<
"mem_desc should be row major when "
235 "subgroup_block_io is set.";
247 [[maybe_unused]]
auto ty = source.getType();
248 assert(ty.hasStaticShape() &&
"expecting a memref with static shape");
250 build(builder, state, tdesc, source,
ValueRange({}) ,
263 assert((isa<IntegerType, MemRefType>(srcTy)) &&
264 "Source has to be either int or memref.");
278 if (
auto memrefTy = dyn_cast<MemRefType>(srcTy)) {
279 auto memrefShape = memrefTy.getShape();
280 auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
285 if (staticShape == memrefShape && staticStrides == memrefStrides &&
286 dynamicShape.empty() && dynamicStrides.empty()) {
292 build(builder, state, tdesc, source,
ValueRange({}), dynamicShape,
300 [[maybe_unused]]
auto ty = source.getType();
301 assert(ty.hasStaticShape() && offsets.size() == (
size_t)ty.getRank());
307 build(builder, state, tdesc, source, dynamicOffsets ,
319 assert(!
shape.empty() && !offsets.empty() && !strides.empty() &&
320 shape.size() == strides.size() &&
shape.size() == offsets.size());
323 assert((isa<IntegerType, MemRefType>(srcTy)) &&
324 "Source has to be either int or memref.");
342 if (
auto memrefTy = dyn_cast<MemRefType>(srcTy)) {
343 auto memrefShape = memrefTy.getShape();
344 auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
349 if (staticShape == memrefShape && staticStrides == memrefStrides &&
350 dynamicShape.empty() && dynamicStrides.empty()) {
356 build(builder, state, tdesc, source, dynamicOffsets, dynamicShape,
357 dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr);
360LogicalResult CreateNdDescOp::verify() {
362 bool invalidRank = rank != getMixedStrides().size();
363 bool invalidElemTy =
false;
369 auto srcMemorySpace = getSourceMemorySpace();
370 auto tdescMemorySpace =
static_cast<unsigned>(
getType().getMemorySpace());
371 if (srcMemorySpace != tdescMemorySpace)
373 <<
" Source: " << srcMemorySpace
374 <<
", TensorDesc: " << tdescMemorySpace;
376 if (
size_t offsetRank = getMixedOffsets().size())
377 invalidRank |= (offsetRank != rank);
381 if (
auto memrefTy = dyn_cast<MemRefType>(getSourceType()))
384 if (llvm::isa<IntegerType>(getSourceType())) {
387 return emitOpError(
"expecting strides and shape to be present for "
393 "Expecting the rank of shape, strides, offsets, and source (if source "
394 "is a memref) should match with each other.");
399 "Expecting the TensorDesc rank is not greater than the "
400 "ranks of shape, strides, offsets or the memref source.");
403 return emitOpError(
"TensorDesc should have the same element "
404 "type with the source if it is a memref.\n");
407 return emitOpError(
"Expects a non-scattered TensorDesc.\n");
419 auto parseIntegerOrValue = [&]() {
423 if (res.has_value() && succeeded(res.value())) {
424 values.push_back(operand);
425 integerVals.push_back(ShapedType::kDynamic);
426 if (valueTypes && parser.
parseColonType(valueTypes->emplace_back()))
432 integerVals.push_back(integer);
442 <<
"expected a list of SSA values or integers";
453 if (!integers || integers.empty())
463 Value tensorDesc, xegpu::CachePolicyAttr l1_hint,
464 xegpu::CachePolicyAttr l2_hint,
465 xegpu::CachePolicyAttr l3_hint) {
468 l1_hint, l2_hint, l3_hint,
nullptr);
473 xegpu::CachePolicyAttr l1_hint,
474 xegpu::CachePolicyAttr l2_hint,
475 xegpu::CachePolicyAttr l3_hint,
476 xegpu::DistributeLayoutAttr layout) {
483 build(builder, state, tensorDesc, dynamicOffsets, staticOffsetsAttr, l1_hint,
484 l2_hint, l3_hint, layout);
487LogicalResult PrefetchNdOp::verify() {
488 auto tdescTy = getTensorDescType();
489 if (tdescTy.isScattered())
490 return emitOpError(
"Expects a non-scattered TensorDesc.\n");
493 return emitOpError(
"invalid l1_hint: ") << getL1HintAttr();
496 return emitOpError(
"invalid l2_hint: ") << getL2HintAttr();
499 return emitOpError(
"invalid l3_hint: ") << getL3HintAttr();
501 int64_t tDescRank = tdescTy.getRank();
502 int64_t offsetSize = getMixedOffsets().size();
503 if (offsetSize != 0 && offsetSize != tDescRank)
505 "Mismatched ranks between offsets and tensor descriptor");
515 Value tensorDesc, UnitAttr packed,
517 xegpu::CachePolicyAttr l1_hint,
518 xegpu::CachePolicyAttr l2_hint,
519 xegpu::CachePolicyAttr l3_hint) {
521 return build(builder, state, retType, tensorDesc,
ValueRange(),
529 xegpu::CachePolicyAttr l1_hint,
530 xegpu::CachePolicyAttr l2_hint,
531 xegpu::CachePolicyAttr l3_hint,
532 xegpu::DistributeLayoutAttr layout) {
539 build(builder, state, retType, tensorDesc, dynamicOffsets, staticOffsetsAttr,
540 packed, transpose, l1_hint, l2_hint, l3_hint,
544LogicalResult LoadNdOp::verify() {
545 auto tdescTy = getTensorDescType();
548 if (tdescTy.isScattered())
549 return emitOpError(
"Expects a non-scattered TensorDesc.\n");
551 if (tdescTy.getRank() > 2)
552 return emitOpError(
"Expects a 1D or 2D TensorDesc.\n");
555 return emitOpError(
"Invalid result, it should be a VectorType.\n");
558 return emitOpError(
"invalid l1_hint: ") << getL1HintAttr();
561 return emitOpError(
"invalid l2_hint: ") << getL2HintAttr();
564 return emitOpError(
"invalid l3_hint: ") << getL3HintAttr();
566 int tdescElems = tdescTy.getNumElements() * tdescTy.getArrayLength();
567 int valueElems = valueTy.getNumElements();
572 if (valueElems < tdescElems && valueTy.getRank() == 1) {
574 if (tdescTy.getLayoutAttr())
576 <<
"TensorDesc doesn't need LayoutAttr for SIMT code";
581 if (tdescElems % valueElems)
584 <<
" is not a valid distribution for tensor descriptor "
594 if (getTranspose()) {
595 auto trans = getTranspose().value();
597 if (llvm::all_of(trans, [&](
size_t s) {
return s < tdescShape.size(); }))
604 if (tdescTy.getRank() == 2) {
606 auto vnni_factor = valueShape.back();
607 tdescShape[axis] /= vnni_factor;
608 tdescShape.push_back(vnni_factor);
611 <<
"Invalid Packed Attr. It is ignored (available for 2D "
616 auto array_len = tdescTy.getArrayLength();
618 tdescShape.insert(tdescShape.begin(), array_len);
620 if (tdescShape != valueShape)
622 <<
" is not consistent with tensor descriptor "
625 int64_t tDescRank = tdescTy.getRank();
626 int64_t offsetSize = getMixedOffsets().size();
627 if (offsetSize != 0 && offsetSize != tDescRank)
629 "Mismatched ranks between offsets and tensor descriptor");
639 Value tensorDesc, xegpu::CachePolicyAttr l1_hint,
640 xegpu::CachePolicyAttr l2_hint,
641 xegpu::CachePolicyAttr l3_hint) {
643 return build(builder, state, value, tensorDesc,
ValueRange(),
650 xegpu::CachePolicyAttr l1_hint,
651 xegpu::CachePolicyAttr l2_hint,
652 xegpu::CachePolicyAttr l3_hint,
653 xegpu::DistributeLayoutAttr layout) {
660 build(builder, state, value, tensorDesc, dynamicOffsets, staticOffsetsAttr,
661 l1_hint, l2_hint, l3_hint, layout);
664LogicalResult StoreNdOp::verify() {
665 auto dstTy = getTensorDescType();
668 if (dstTy.isScattered())
669 return emitOpError(
"Expects a non-scattered TensorDesc.\n");
671 if (dstTy.getRank() > 2)
672 return emitOpError(
"Expects a 1D or 2D TensorDesc.\n");
675 return emitOpError(
"Expecting a VectorType result.\n");
678 return emitOpError(
"invalid l1_hint: ") << getL1HintAttr();
681 return emitOpError(
"invalid l2_hint: ") << getL2HintAttr();
684 return emitOpError(
"invalid l3_hint: ") << getL3HintAttr();
686 auto array_len = dstTy.getArrayLength();
688 return emitOpError(
"array length is not supported by store_nd.\n");
690 auto tdescElems = dstTy.getNumElements();
691 auto valueElems = valTy.getNumElements();
696 if (valTy.getRank() == 1 && valueElems < tdescElems) {
698 if (dstTy.getLayoutAttr())
700 <<
"TensorDesc doesn't need LayoutAttr for SIMT code";
702 if (tdescElems % valueElems)
705 <<
" is not a valid distribution for tensor descriptor " << dstTy;
713 if (tdescShape != valueShape)
715 <<
" is not consistent with tensor descriptor "
718 int64_t tDescRank = dstTy.getRank();
719 int64_t offsetSize = getMixedOffsets().size();
720 if (offsetSize != 0 && offsetSize != tDescRank)
722 "Mismatched ranks between offsets and tensor descriptor");
730LogicalResult UpdateNdOffsetOp::verify() {
731 auto ty = getTensorDescType();
732 if (ty.isScattered())
733 return emitOpError(
"Expects a non-scattered TensorDesc.\n");
736 if (ty.getRank() != (
int64_t)getNumOffsets()) {
747 TensorDescType TensorDesc,
Value source,
749 auto loc = source.
getLoc();
751 auto type = VectorType::get(size, builder.
getIndexType());
753 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
754 build(builder, state, TensorDesc, source, offset);
758 TensorDescType TensorDesc,
Value source,
761 build(builder, state, TensorDesc, source, ofrs);
764LogicalResult CreateDescOp::verify() {
765 auto tdescTy = getTensorDescType();
767 if (!tdescTy.isScattered())
768 return emitOpError(
"Expects a scattered TensorDesc.\n");
774 auto srcMemorySpace = getSourceMemorySpace();
775 auto tdescMemorySpace =
static_cast<unsigned>(tdescTy.getMemorySpace());
776 if (srcMemorySpace != tdescMemorySpace)
778 <<
" Source: " << srcMemorySpace
779 <<
", TensorDesc: " << tdescMemorySpace;
782 auto chunkSize = tdescTy.getChunkSizeAsInt();
785 shape.push_back(chunkSize);
788 if (
shape != tdescShape)
798LogicalResult PrefetchOp::verify() {
799 auto tdescTy = getTensorDescType();
801 if (!tdescTy && !getOffsets())
804 if (tdescTy && getOffsets())
807 if (tdescTy && !tdescTy.isScattered())
808 return emitOpError(
"Expects a scattered TensorDesc.");
811 return emitOpError(
"invalid l1_hint: ") << getL1HintAttr();
814 return emitOpError(
"invalid l2_hint: ") << getL2HintAttr();
817 return emitOpError(
"invalid l3_hint: ") << getL3HintAttr();
819 auto srcTy = getSourceType();
820 if (srcTy.
isInteger() && !getOffsetAlignByteAttr())
821 return emitOpError(
"offset_align_byte is required with integer source.");
823 if (getOffsetAlignByteAttr() && !srcTy.
isInteger())
824 return emitOpError(
"offset_align_byte only allowed with integer source.");
830 xegpu::CachePolicyAttr l1_hint,
831 xegpu::CachePolicyAttr l2_hint,
832 xegpu::CachePolicyAttr l3_hint) {
833 build(builder, state, source,
Value(), l1_hint, l2_hint, l3_hint,
834 IntegerAttr{},
nullptr);
840LogicalResult LoadGatherOp::verify() {
841 auto tdescTy = getTensorDescType();
842 auto maskTy = getMaskType();
845 if (!tdescTy && !getOffsets())
848 if (tdescTy && getOffsets())
851 if (tdescTy && !tdescTy.isScattered())
852 return emitOpError(
"Expects a scattered TensorDesc.");
855 return emitOpError(
"invalid l1_hint: ") << getL1HintAttr();
858 return emitOpError(
"invalid l2_hint: ") << getL2HintAttr();
861 return emitOpError(
"invalid l3_hint: ") << getL3HintAttr();
866 auto srcTy = getSourceType();
867 uint64_t chunkSize =
static_cast<int64_t>(getChunkSize().value_or(1));
868 auto memTy = dyn_cast<MemRefType>(srcTy);
871 return emitError() <<
"Value should have the same element type as MemRef.";
873 auto offsetsTy = getOffsets().getType();
880 xegpu::CachePolicyAttr l1_hint,
881 xegpu::CachePolicyAttr l2_hint,
882 xegpu::CachePolicyAttr l3_hint) {
883 build(builder, state, valueType, source,
Value(), mask, IntegerAttr(),
884 l1_hint, l2_hint, l3_hint,
nullptr);
890 IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint,
891 xegpu::CachePolicyAttr l2_hint,
892 xegpu::CachePolicyAttr l3_hint) {
893 auto loc = source.
getLoc();
895 auto type = VectorType::get(size, builder.
getIndexType());
897 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
899 build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
900 l2_hint, l3_hint,
nullptr);
906 IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint,
907 xegpu::CachePolicyAttr l2_hint,
908 xegpu::CachePolicyAttr l3_hint,
909 DistributeLayoutAttr layout) {
910 auto loc = source.
getLoc();
912 auto type = VectorType::get(size, builder.
getIndexType());
914 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
916 build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
917 l2_hint, l3_hint, layout);
923LogicalResult StoreScatterOp::verify() {
924 auto tdescTy = getTensorDescType();
925 auto maskTy = getMaskType();
928 if (!tdescTy && !getOffsets())
931 if (tdescTy && getOffsets())
934 if (tdescTy && !tdescTy.isScattered())
935 return emitOpError(
"Expects a scattered TensorDesc.");
938 return emitOpError(
"invalid l1_hint: ") << getL1HintAttr();
941 return emitOpError(
"invalid l2_hint: ") << getL2HintAttr();
944 return emitOpError(
"invalid l3_hint: ") << getL3HintAttr();
950 auto destTy = getDestType();
951 uint64_t chunkSize =
static_cast<int64_t>(getChunkSize().value_or(1));
952 auto memTy = dyn_cast<MemRefType>(destTy);
955 return emitError() <<
"Value should have the same element type as MemRef.";
957 auto offsetsTy = getOffsets().getType();
964 xegpu::CachePolicyAttr l1_hint,
965 xegpu::CachePolicyAttr l2_hint,
966 xegpu::CachePolicyAttr l3_hint) {
967 build(builder, state, value, dest,
Value(), mask, IntegerAttr(), l1_hint,
968 l2_hint, l3_hint,
nullptr);
974 IntegerAttr chunk_size,
975 xegpu::CachePolicyAttr l1_hint,
976 xegpu::CachePolicyAttr l2_hint,
977 xegpu::CachePolicyAttr l3_hint) {
980 auto type = VectorType::get(size, builder.
getIndexType());
982 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
985 build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
989void StoreScatterOp::build(
992 xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint,
993 xegpu::CachePolicyAttr l3_hint, DistributeLayoutAttr layout) {
996 auto type = VectorType::get(size, builder.
getIndexType());
998 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
1001 build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
1011 auto tdescTy = mlir::dyn_cast<TensorDescType>(tensorDesc.
getType());
1012 assert(tdescTy &&
"Expecting the source is a TensorDescType value.");
1013 auto loc = tensorDesc.
getLoc();
1015 auto type = VectorType::get({size}, builder.
getIndexType());
1017 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
1018 build(builder, state, tdescTy, tensorDesc, offset);
1024 build(builder, state, tensorDesc, ofrs);
1027LogicalResult UpdateOffsetOp::verify() {
1028 auto tdescTy = getTensorDescType();
1029 if (!tdescTy.isScattered())
1030 return emitOpError(
"Expects a scattered TensorDesc.\n");
1034 if (tdescTy.getChunkSizeAsInt() > 1)
1035 expectedOffsetShape.pop_back();
1037 if (expectedOffsetShape != offsetShape)
1039 "Offsets should match TensorDesc except the chunk size dim.");
1047LogicalResult DpasOp::verify() {
1048 int64_t lhsRank = getLhsType().getRank();
1049 int64_t rhsRank = getRhsType().getRank();
1050 int64_t resRank = getResultType().getRank();
1051 auto lhsShape = getLhsType().getShape();
1052 auto rhsShape = getRhsType().getShape();
1053 auto resShape = getResultType().getShape();
1055 if (getAcc() && getAcc().
getType() != getResultType())
1056 return emitOpError(
"Expecting the acc type to be the same as result.");
1061 if (lhsRank == 1 && rhsRank == 1 && resRank == 1) {
1062 auto numElems = getRhsType().getNumElements();
1063 auto elemTy = getRhsType().getElementType();
1064 auto factor = 32 / elemTy.getIntOrFloatBitWidth();
1065 if (numElems % factor != 0)
1066 return emitOpError(
"Expecting B operand to be a multiple of 32 bits.");
1071 if (lhsRank != 2 || (rhsRank != 2 && rhsRank != 3) || resRank != 2)
1073 "expecting lhs and result to be a 2D vector, and rhs to be either "
1074 "2D or 3D (packed) vector.");
1075 auto bK = rhsRank == 3 ? rhsShape[0] * rhsShape[2] : rhsShape[0];
1076 if (bK != lhsShape[1])
1078 if (lhsShape[0] != resShape[0])
1080 if (rhsShape[1] != resShape[1])
1089LogicalResult ConvertLayoutOp::verify() {
1090 auto srcLayout = getInputLayout();
1091 auto resLayout = getTargetLayout();
1099 if ((!srcLayout.isForWorkgroup() || !resLayout.isForWorkgroup()) &&
1100 (!srcLayout.isForSubgroup() || !resLayout.isForSubgroup()))
1101 return emitOpError(
"expected input layout and target layout be WgLayout or "
1102 "SgLayout at the same time.");
1104 auto shape = getSource().getType().getShape();
1105 if (!XeGPUDialect::isEvenlyDistributable(
shape, srcLayout))
1107 "invalid input layout, data cannot be evenly distributed.");
1109 if (!XeGPUDialect::isEvenlyDistributable(
shape, resLayout))
1111 "invalid target layout, data cannot be evenly distributed.");
1113 return mlir::success();
1116OpFoldResult ConvertLayoutOp::fold(FoldAdaptor adaptor) {
1117 if (getInputLayout() == getTargetLayout())
1126 if (op.getInputLayout() == op.getTargetLayout()) {
1145 DistributeLayoutAttr layout) {
1152 build(builder, state, res, memDesc, dynamicOffsets, staticOffsetsAttr,
1156LogicalResult LoadMatrixOp::verify() {
1158 auto resTy = dyn_cast<VectorType>(getRes().
getType());
1159 UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
1160 MemDescType mdescTy = getMemDesc().getType();
1163 getLayoutAttr(), [&]() {
return emitError(); });
1172 DistributeLayoutAttr layout) {
1177 build(builder, state, data, memDesc, dynamicOffsets, staticOffsetsAttr,
1181LogicalResult StoreMatrixOp::verify() {
1183 auto dataTy = dyn_cast<VectorType>(getData().
getType());
1184 UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
1185 MemDescType mdescTy = getMemDesc().getType();
1187 getLayoutAttr(), [&]() {
return emitError(); });
1191#include <mlir/Dialect/XeGPU/IR/XeGPUAttrInterface.cpp.inc>
1193#include <mlir/Dialect/XeGPU/IR/XeGPUEnums.cpp.inc>
1194#define GET_OP_CLASSES
1195#include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Type getElementType(Type type)
Determine the element type of type.
static Type getValueType(Attribute attr)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
static SmallVector< int64_t > getShapeOf(Type type)
LogicalResult IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy, UnitAttr subgroup_block_io, DistributeLayoutAttr layout, function_ref< InFlightDiagnostic()> emitError)
static std::string makeString(T array, bool breakline=false)
static bool isWriteHintOrNone(const CachePolicyAttr &attr)
static bool isReadHintOrNone(const CachePolicyAttr &attr)
static LogicalResult isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy, VectorType valueTy, int64_t chunkSize, function_ref< InFlightDiagnostic()> emitError)
static void printOptionalDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, DenseI64ArrayAttr integers)
static bool isSharedMemory(const MemRefType &memrefTy)
static ParseResult parseOptionalDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
static LogicalResult isValidGatherScatterParams(Type maskTy, VectorType valueTy, TensorDescType tdescTy, function_ref< InFlightDiagnostic()> emitError)
Delimiter
These are the supported delimiters around operand lists and region argument lists,...
@ Square
Square brackets surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
Attributes are known-constant values of operations.
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
MLIRContext * getContext() const
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class helps build Operations.
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
llvm::function_ref< Fn > function_ref
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
LogicalResult matchAndRewrite(xegpu::ConvertLayoutOp op, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.