19 #include "llvm/Support/Debug.h"
21 #define DEBUG_TYPE "xegpu"
27 Attribute attr = memrefTy.getMemorySpace();
28 if (
auto intAttr = llvm::dyn_cast<IntegerAttr>(attr))
29 return intAttr.getInt() == 3;
30 if (
auto memrefSpace = llvm::dyn_cast<MemorySpaceAttr>(attr))
31 return memrefSpace.getValue() == MemorySpace::SLM;
32 if (
auto xevmSpace = llvm::dyn_cast<xevm::AddrSpaceAttr>(attr))
33 return xevmSpace.getValue() == xevm::AddrSpace::SHARED;
34 return gpu::GPUDialect::isWorkgroupMemoryAddressSpace(attr);
38 static std::string
makeString(T array,
bool breakline =
false) {
41 llvm::raw_string_ostream os(buf);
43 for (
size_t i = 1; i < array.size(); i++) {
44 os << array[i - 1] <<
", ";
48 os << array.back() <<
"]";
54 if (
auto ty = llvm::dyn_cast<ShapedType>(type))
64 auto kind = attr.getValue();
65 return kind == CachePolicy::CACHED ||
kind == CachePolicy::UNCACHED ||
66 kind == CachePolicy::STREAMING ||
kind == CachePolicy::READ_INVALIDATE;
72 auto kind = attr.getValue();
73 return kind == CachePolicy::CACHED ||
kind == CachePolicy::UNCACHED ||
74 kind == CachePolicy::WRITE_BACK ||
kind == CachePolicy::WRITE_THROUGH;
79 TensorDescType tdescTy,
82 if (!tdescTy.isScattered())
83 return emitError() <<
"Expects a scattered TensorDesc.";
85 auto chunkSize = tdescTy.getChunkSizeAsInt();
88 return emitError() <<
"Expecting chunk size == 1 for scalar result";
89 if (dyn_cast<VectorType>(maskTy))
90 return emitError() <<
"Expecting a vector type result.";
98 if (valueTy.getElementType() != tdescTy.getElementType())
100 <<
"Value should have the same element type as TensorDesc.";
104 expectedMaskShape.pop_back();
105 if (expectedMaskShape != maskShape)
107 <<
"Mask should match TensorDesc except the chunk size dim.";
110 if (valueTy.getRank() == 1 && valueTy.getNumElements() == chunkSize) {
111 if (tdescTy.getLayoutAttr())
112 return emitError() <<
"TensorDesc doesn't need LayoutAttr for SIMT code";
116 if (tdescShape != valueShape)
118 <<
" is neither a valid distribution for SIMT nor "
119 "consistent with the tensor descriptor for SIMD "
126 VectorType valueTy, int64_t chunkSize,
129 auto maskVecTy = dyn_cast<VectorType>(maskTy);
130 auto offsetsVecTy = dyn_cast<VectorType>(offsetsTy);
133 return emitError() <<
"Expecting chunk size == 1 for scalar result";
134 if (maskVecTy || offsetsVecTy)
135 return emitError() <<
"Expecting scalar mask and offsets.";
136 else if (maskVecTy && offsetsVecTy)
137 return emitError() <<
"Expecting a vector type result.";
141 auto valueSize = valueTy.getNumElements();
143 if (!maskVecTy && !offsetsVecTy) {
144 if (valueSize != chunkSize)
145 return emitError() <<
"value elements must match chunk size "
153 return emitError() <<
"Expecting a vector type mask.";
154 int64_t maskSize = maskVecTy.getNumElements();
157 if ((valueTy.getRank() == 1) && (valueSize != chunkSize))
158 return emitError() <<
"value elements must match chunk size "
161 if (valueSize != maskSize)
163 <<
"Mask should match value except the chunk size dim.";
169 expectedMaskShape.pop_back();
170 if (expectedMaskShape != maskShape)
171 return emitError() <<
"Mask should match value except the chunk size dim.";
182 [[maybe_unused]]
auto ty = source.getType();
183 assert(ty.hasStaticShape() &&
"expecting a memref with static shape");
185 build(builder, state, tdesc, source,
ValueRange({}) ,
193 void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
194 Type tdesc, Value source,
197 Type srcTy = source.getType();
198 assert((isa<IntegerType, MemRefType>(srcTy)) &&
199 "Source has to be either int or memref.");
210 auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
211 auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
213 if (
auto memrefTy = dyn_cast<MemRefType>(srcTy)) {
214 auto memrefShape = memrefTy.getShape();
215 auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
219 if (staticShape == memrefShape && staticStrides == memrefStrides) {
225 build(builder, state, tdesc, source, ValueRange({}), dynamicShape,
226 dynamicStrides, builder.getDenseI64ArrayAttr({}), staticShapeAttr,
230 void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
231 Type tdesc, TypedValue<MemRefType> source,
233 [[maybe_unused]]
auto ty = source.getType();
234 assert(ty.hasStaticShape() && offsets.size() == (
size_t)ty.getRank());
240 build(builder, state, tdesc, source, dynamicOffsets ,
243 builder.getDenseI64ArrayAttr(staticOffsets) ,
247 void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
248 Type tdesc, Value source,
252 assert(shape.size() && offsets.size() && strides.size() &&
253 shape.size() == strides.size() && shape.size() == offsets.size());
255 Type srcTy = source.getType();
256 assert((isa<IntegerType, MemRefType>(srcTy)) &&
257 "Source has to be either int or memref.");
271 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
272 auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
273 auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
275 if (
auto memrefTy = dyn_cast<MemRefType>(srcTy)) {
276 auto memrefShape = memrefTy.getShape();
277 auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
281 if (staticShape == memrefShape && staticStrides == memrefStrides) {
287 build(builder, state, tdesc, source, dynamicOffsets, dynamicShape,
288 dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr);
293 bool invalidRank = rank != getMixedStrides().size();
294 bool invalidElemTy =
false;
300 auto srcMemorySpace = getSourceMemorySpace();
301 auto tdescMemorySpace =
static_cast<unsigned>(
getType().getMemorySpace());
302 if (srcMemorySpace != tdescMemorySpace)
303 return emitOpError(
"Memory space mismatch.")
304 <<
" Source: " << srcMemorySpace
305 <<
", TensorDesc: " << tdescMemorySpace;
307 if (
size_t offsetRank = getMixedOffsets().size())
308 invalidRank |= (offsetRank != rank);
312 if (
auto memrefTy = dyn_cast<MemRefType>(getSourceType()))
315 if (llvm::isa<IntegerType>(getSourceType())) {
318 return emitOpError(
"expecting strides and shape to be present for "
324 "Expecting the rank of shape, strides, offsets, and source (if source "
325 "is a memref) should match with each other.");
328 if (
getType().getRank() > (int64_t)rank)
330 "Expecting the TensorDesc rank is not greater than the "
331 "ranks of shape, strides, offsets or the memref source.");
334 return emitOpError(
"TensorDesc should have the same element "
335 "type with the source if it is a memref.\n");
338 return emitOpError(
"Expects a non-scattered TensorDesc.\n");
350 auto parseIntegerOrValue = [&]() {
354 if (res.has_value() && succeeded(res.value())) {
355 values.push_back(operand);
356 integerVals.push_back(ShapedType::kDynamic);
357 if (valueTypes && parser.
parseColonType(valueTypes->emplace_back()))
363 integerVals.push_back(integer);
373 <<
"expected a list of SSA values or integers";
384 if (!integers || integers.empty())
394 Value tensorDesc, xegpu::CachePolicyAttr l1_hint,
395 xegpu::CachePolicyAttr l2_hint,
396 xegpu::CachePolicyAttr l3_hint) {
399 l1_hint, l2_hint, l3_hint);
402 void PrefetchNdOp::build(OpBuilder &builder, OperationState &state,
403 Value tensorDesc, ArrayRef<OpFoldResult> offsets,
404 xegpu::CachePolicyAttr l1_hint,
405 xegpu::CachePolicyAttr l2_hint,
406 xegpu::CachePolicyAttr l3_hint) {
407 SmallVector<Value> dynamicOffsets;
408 SmallVector<int64_t> staticOffsets;
411 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
413 build(builder, state, tensorDesc, dynamicOffsets, staticOffsetsAttr, l1_hint,
418 auto tdescTy = getTensorDescType();
419 if (tdescTy.isScattered())
420 return emitOpError(
"Expects a non-scattered TensorDesc.\n");
423 return emitOpError(
"invalid l1_hint: ") << getL1HintAttr();
426 return emitOpError(
"invalid l2_hint: ") << getL2HintAttr();
429 return emitOpError(
"invalid l3_hint: ") << getL3HintAttr();
431 int64_t tDescRank = tdescTy.getRank();
432 int64_t offsetSize =
static_cast<int64_t
>(getOffsets().size());
433 int64_t constOffsetSize =
434 getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
435 if (((offsetSize != 0) && (offsetSize != tDescRank)) ||
436 ((constOffsetSize != 0) && (constOffsetSize != tDescRank)))
438 "Mismatched ranks between offsets and tensor descriptor");
447 void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
448 Value tensorDesc, UnitAttr packed,
450 xegpu::CachePolicyAttr l1_hint,
451 xegpu::CachePolicyAttr l2_hint,
452 xegpu::CachePolicyAttr l3_hint) {
454 return build(builder, state, retType, tensorDesc, ValueRange(),
459 void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
460 Value tensorDesc, ArrayRef<OpFoldResult> offsets,
462 xegpu::CachePolicyAttr l1_hint,
463 xegpu::CachePolicyAttr l2_hint,
464 xegpu::CachePolicyAttr l3_hint) {
465 SmallVector<Value> dynamicOffsets;
466 SmallVector<int64_t> staticOffsets;
469 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
471 build(builder, state, retType, tensorDesc, dynamicOffsets, staticOffsetsAttr,
472 packed, transpose, l1_hint, l2_hint, l3_hint);
476 auto tdescTy = getTensorDescType();
479 if (tdescTy.isScattered())
480 return emitOpError(
"Expects a non-scattered TensorDesc.\n");
482 if (tdescTy.getRank() > 2)
483 return emitOpError(
"Expects a 1D or 2D TensorDesc.\n");
486 return emitOpError(
"Invalid result, it should be a VectorType.\n");
489 return emitOpError(
"invalid l1_hint: ") << getL1HintAttr();
492 return emitOpError(
"invalid l2_hint: ") << getL2HintAttr();
495 return emitOpError(
"invalid l3_hint: ") << getL3HintAttr();
497 int tdescElems = tdescTy.getNumElements() * tdescTy.getArrayLength();
498 int valueElems = valueTy.getNumElements();
503 if (valueElems < tdescElems && valueTy.getRank() == 1) {
505 if (tdescTy.getLayoutAttr())
507 <<
"TensorDesc doesn't need LayoutAttr for SIMT code";
512 if (tdescElems % valueElems)
515 <<
" is not a valid distribution for tensor descriptor "
525 if (getTranspose()) {
526 auto trans = getTranspose().value();
528 if (llvm::all_of(trans, [&](
size_t s) {
return s < tdescShape.size(); }))
535 if (tdescTy.getRank() == 2) {
537 auto vnni_factor = valueShape.back();
538 tdescShape[axis] /= vnni_factor;
539 tdescShape.push_back(vnni_factor);
542 <<
"Invalid Packed Attr. It is ignored (available for 2D "
547 auto array_len = tdescTy.getArrayLength();
549 tdescShape.insert(tdescShape.begin(), array_len);
551 if (tdescShape != valueShape)
552 return emitOpError() <<
"Result shape " <<
makeString(valueShape)
553 <<
" is not consistent with tensor descriptor "
556 int64_t tDescRank = tdescTy.getRank();
557 int64_t offsetSize =
static_cast<int64_t
>(getOffsets().size());
558 int64_t constOffsetSize =
559 getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
560 if (((offsetSize != 0) && (offsetSize != tDescRank)) ||
561 ((constOffsetSize != 0) && (constOffsetSize != tDescRank)))
563 "Mismatched ranks between offsets and tensor descriptor");
572 void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
573 Value tensorDesc, xegpu::CachePolicyAttr l1_hint,
574 xegpu::CachePolicyAttr l2_hint,
575 xegpu::CachePolicyAttr l3_hint) {
577 return build(builder, state, value, tensorDesc, ValueRange(),
581 void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
582 Value tensorDesc, ArrayRef<OpFoldResult> offsets,
583 xegpu::CachePolicyAttr l1_hint,
584 xegpu::CachePolicyAttr l2_hint,
585 xegpu::CachePolicyAttr l3_hint) {
586 SmallVector<Value> dynamicOffsets;
587 SmallVector<int64_t> staticOffsets;
590 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
592 build(builder, state, value, tensorDesc, dynamicOffsets, staticOffsetsAttr,
593 l1_hint, l2_hint, l3_hint);
597 auto dstTy = getTensorDescType();
600 if (dstTy.isScattered())
601 return emitOpError(
"Expects a non-scattered TensorDesc.\n");
603 if (dstTy.getRank() > 2)
604 return emitOpError(
"Expects a 1D or 2D TensorDesc.\n");
607 return emitOpError(
"Expecting a VectorType result.\n");
610 return emitOpError(
"invalid l1_hint: ") << getL1HintAttr();
613 return emitOpError(
"invalid l2_hint: ") << getL2HintAttr();
616 return emitOpError(
"invalid l3_hint: ") << getL3HintAttr();
618 auto array_len = dstTy.getArrayLength();
620 return emitOpError(
"array length is not supported by store_nd.\n");
622 auto tdescElems = dstTy.getNumElements();
623 auto valueElems = valTy.getNumElements();
628 if (valTy.getRank() == 1 && valueElems < tdescElems) {
630 if (dstTy.getLayoutAttr())
632 <<
"TensorDesc doesn't need LayoutAttr for SIMT code";
634 if (tdescElems % valueElems)
637 <<
" is not a valid distribution for tensor descriptor " << dstTy;
645 if (tdescShape != valueShape)
646 return emitOpError() <<
"Value shape " <<
makeString(valueShape)
647 <<
" is not consistent with tensor descriptor "
650 int64_t tDescRank = dstTy.getRank();
651 int64_t offsetSize =
static_cast<int64_t
>(getOffsets().size());
652 int64_t constOffsetSize =
653 getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
654 if (((offsetSize != 0) && (offsetSize != tDescRank)) ||
655 ((constOffsetSize != 0) && (constOffsetSize != tDescRank)))
657 "Mismatched ranks between offsets and tensor descriptor");
666 auto ty = getTensorDescType();
667 if (ty.isScattered())
668 return emitOpError(
"Expects a non-scattered TensorDesc.\n");
671 if (ty.getRank() != (int64_t)getNumOffsets()) {
672 return emitOpError(
"Invalid number of offsets.");
681 void CreateDescOp::build(OpBuilder &builder, OperationState &state,
682 TensorDescType TensorDesc, Value source,
684 auto loc = source.getLoc();
685 int64_t size =
static_cast<int64_t
>(offsets.size());
688 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
689 build(builder, state, TensorDesc, source, offset);
692 void CreateDescOp::build(OpBuilder &builder, OperationState &state,
693 TensorDescType TensorDesc, Value source,
696 build(builder, state, TensorDesc, source, ofrs);
700 auto tdescTy = getTensorDescType();
702 if (!tdescTy.isScattered())
703 return emitOpError(
"Expects a scattered TensorDesc.\n");
709 auto srcMemorySpace = getSourceMemorySpace();
710 auto tdescMemorySpace =
static_cast<unsigned>(tdescTy.getMemorySpace());
711 if (srcMemorySpace != tdescMemorySpace)
712 return emitOpError(
"Memory space mismatch.")
713 <<
" Source: " << srcMemorySpace
714 <<
", TensorDesc: " << tdescMemorySpace;
717 auto chunkSize = tdescTy.getChunkSizeAsInt();
718 SmallVector<int64_t> shape(getOffsetsType().
getShape());
720 shape.push_back(chunkSize);
723 if (shape != tdescShape)
724 return emitOpError(
"Incorrect TensorDesc shape. ")
725 <<
"Expected is " <<
makeString(shape) <<
"\n";
734 auto tdescTy = getTensorDescType();
736 if (!tdescTy && !getOffsets())
737 return emitOpError(
"Expects offsets.");
739 if (tdescTy && getOffsets())
740 return emitOpError(
"offsets not allowed.");
742 if (tdescTy && !tdescTy.isScattered())
743 return emitOpError(
"Expects a scattered TensorDesc.");
746 return emitOpError(
"invalid l1_hint: ") << getL1HintAttr();
749 return emitOpError(
"invalid l2_hint: ") << getL2HintAttr();
752 return emitOpError(
"invalid l3_hint: ") << getL3HintAttr();
754 auto srcTy = getSourceType();
755 if (srcTy.isInteger() && !getOffsetAlignByteAttr())
756 return emitOpError(
"offset_align_byte is required with integer source.");
758 if (getOffsetAlignByteAttr() && !srcTy.isInteger())
759 return emitOpError(
"offset_align_byte only allowed with integer source.");
764 void PrefetchOp::build(OpBuilder &builder, OperationState &state, Value source,
765 xegpu::CachePolicyAttr l1_hint,
766 xegpu::CachePolicyAttr l2_hint,
767 xegpu::CachePolicyAttr l3_hint) {
768 build(builder, state, source,
Value(), l1_hint, l2_hint, l3_hint,
776 auto tdescTy = getTensorDescType();
777 auto maskTy = getMaskType();
780 if (!tdescTy && !getOffsets())
781 return emitOpError(
"Expects offsets.");
783 if (tdescTy && getOffsets())
784 return emitOpError(
"offsets not allowed.");
786 if (tdescTy && !tdescTy.isScattered())
787 return emitOpError(
"Expects a scattered TensorDesc.");
790 return emitOpError(
"invalid l1_hint: ") << getL1HintAttr();
793 return emitOpError(
"invalid l2_hint: ") << getL2HintAttr();
796 return emitOpError(
"invalid l3_hint: ") << getL3HintAttr();
800 [&]() {
return emitOpError(); });
801 auto srcTy = getSourceType();
802 uint64_t chunkSize =
static_cast<int64_t
>(getChunkSize().value_or(1));
803 auto memTy = dyn_cast<MemRefType>(srcTy);
806 return emitError() <<
"Value should have the same element type as MemRef.";
808 auto offsetsTy = getOffsets().getType();
810 [&]() {
return emitOpError(); });
813 void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
814 Type valueType, Value source, Value mask,
815 xegpu::CachePolicyAttr l1_hint,
816 xegpu::CachePolicyAttr l2_hint,
817 xegpu::CachePolicyAttr l3_hint) {
818 build(builder, state, valueType, source,
Value(), mask, IntegerAttr(),
819 l1_hint, l2_hint, l3_hint);
822 void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
823 Type valueType, Value source,
824 ArrayRef<OpFoldResult> offsets, Value mask,
825 IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint,
826 xegpu::CachePolicyAttr l2_hint,
827 xegpu::CachePolicyAttr l3_hint) {
828 auto loc = source.getLoc();
829 int64_t size =
static_cast<int64_t
>(offsets.size());
832 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
834 build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
842 auto tdescTy = getTensorDescType();
843 auto maskTy = getMaskType();
846 if (!tdescTy && !getOffsets())
847 return emitOpError(
"Expects offsets.");
849 if (tdescTy && getOffsets())
850 return emitOpError(
"offsets not allowed.");
852 if (tdescTy && !tdescTy.isScattered())
853 return emitOpError(
"Expects a scattered TensorDesc.");
856 return emitOpError(
"invalid l1_hint: ") << getL1HintAttr();
859 return emitOpError(
"invalid l2_hint: ") << getL2HintAttr();
862 return emitOpError(
"invalid l3_hint: ") << getL3HintAttr();
866 [&]() {
return emitOpError(); });
868 auto destTy = getDestType();
869 uint64_t chunkSize =
static_cast<int64_t
>(getChunkSize().value_or(1));
870 auto memTy = dyn_cast<MemRefType>(destTy);
873 return emitError() <<
"Value should have the same element type as MemRef.";
875 auto offsetsTy = getOffsets().getType();
877 [&]() {
return emitOpError(); });
880 void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
881 Value value, Value dest, Value mask,
882 xegpu::CachePolicyAttr l1_hint,
883 xegpu::CachePolicyAttr l2_hint,
884 xegpu::CachePolicyAttr l3_hint) {
885 build(builder, state, value, dest,
Value(), mask, IntegerAttr(), l1_hint,
889 void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
890 Value value, Value dest,
891 ArrayRef<OpFoldResult> offsets, Value mask,
892 IntegerAttr chunk_size,
893 xegpu::CachePolicyAttr l1_hint,
894 xegpu::CachePolicyAttr l2_hint,
895 xegpu::CachePolicyAttr l3_hint) {
896 auto loc = dest.getLoc();
897 int64_t size =
static_cast<int64_t
>(offsets.size());
900 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
903 build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
910 void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state,
913 auto tdescTy = mlir::dyn_cast<TensorDescType>(tensorDesc.
getType());
914 assert(tdescTy &&
"Expecting the source is a TensorDescType value.");
915 auto loc = tensorDesc.
getLoc();
916 int64_t size =
static_cast<int64_t
>(offsets.size());
919 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
920 build(builder, state, tdescTy, tensorDesc, offset);
923 void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state,
926 build(builder, state, tensorDesc, ofrs);
930 auto tdescTy = getTensorDescType();
931 if (!tdescTy.isScattered())
932 return emitOpError(
"Expects a scattered TensorDesc.\n");
934 SmallVector<int64_t> expectedOffsetShape =
getShapeOf(tdescTy);
935 SmallVector<int64_t> offsetShape =
getShapeOf(getOffsetsType());
936 if (tdescTy.getChunkSizeAsInt() > 1)
937 expectedOffsetShape.pop_back();
939 if (expectedOffsetShape != offsetShape)
941 "Offsets should match TensorDesc except the chunk size dim.");
950 int64_t lhsRank = getLhsType().getRank();
951 int64_t rhsRank = getRhsType().getRank();
952 int64_t resRank = getResultType().getRank();
953 auto lhsShape = getLhsType().getShape();
954 auto rhsShape = getRhsType().getShape();
955 auto resShape = getResultType().getShape();
957 if (getAcc() && getAcc().
getType() != getResultType())
958 return emitOpError(
"Expecting the acc type to be the same as result.");
963 if (lhsRank == 1 && rhsRank == 1 && resRank == 1) {
964 auto numElems = getRhsType().getNumElements();
965 auto elemTy = getRhsType().getElementType();
966 auto factor = 32 / elemTy.getIntOrFloatBitWidth();
967 if (numElems % factor != 0)
968 return emitOpError(
"Expecting B operand to be a multiple of 32 bits.");
973 if (lhsRank != 2 || (rhsRank != 2 && rhsRank != 3) || resRank != 2)
975 "expecting lhs and result to be a 2D vector, and rhs to be either "
976 "2D or 3D (packed) vector.");
977 auto bK = rhsRank == 3 ? rhsShape[0] * rhsShape[2] : rhsShape[0];
978 if (bK != lhsShape[1])
979 return emitOpError(
"K-dimension mismatch.");
980 if (lhsShape[0] != resShape[0])
981 return emitOpError(
"M-dimension mismatch.");
982 if (rhsShape[1] != resShape[1])
983 return emitOpError(
"N-dimension mismatch.");
992 auto srcLayout = getInputLayout();
993 auto resLayout = getTargetLayout();
995 return emitOpError(
"expected input layout.");
997 return emitOpError(
"expected target layout.");
1001 if ((!srcLayout.isForWorkgroup() || !resLayout.isForWorkgroup()) &&
1002 (!srcLayout.isForSubgroup() || !resLayout.isForSubgroup()))
1003 return emitOpError(
"expected input layout and target layout be WgLayout or "
1004 "SgLayout at the same time.");
1006 auto shape = getSource().getType().getShape();
1007 if (!XeGPUDialect::isEvenlyDistributable(shape, srcLayout))
1009 "invalid input layout, data cannot be evenly distributed.");
1011 if (!XeGPUDialect::isEvenlyDistributable(shape, resLayout))
1013 "invalid target layout, data cannot be evenly distributed.");
1015 return mlir::success();
1018 OpFoldResult ConvertLayoutOp::fold(FoldAdaptor adaptor) {
1019 if (getInputLayout() == getTargetLayout())
1028 if (op.getInputLayout() == op.getTargetLayout()) {
1038 patterns.add<FoldConvertLayoutOp>(context);
1044 void LoadMatrixOp::build(OpBuilder &builder, OperationState &state, Type res,
1045 TypedValue<MemDescType> memDesc,
1047 DistributeLayoutAttr layout) {
1051 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
1052 build(builder, state, res, memDesc, dynamicOffsets, staticOffsetsAttr,
1057 VectorType resTy = getRes().getType();
1058 MemDescType mdescTy = getMemDesc().getType();
1060 if (mdescTy.getRank() != 2)
1061 return emitOpError(
"mem_desc must be 2D.");
1063 ArrayRef<int64_t> valueShape = resTy.getShape();
1064 ArrayRef<int64_t> mdescShape = mdescTy.getShape();
1065 if (llvm::any_of(llvm::zip_equal(valueShape, mdescShape),
1066 [](
auto p) {
return std::get<0>(p) > std::get<1>(p); }))
1067 return emitOpError(
"result shape must not exceed mem_desc shape.");
1074 void StoreMatrixOp::build(OpBuilder &builder, OperationState &state, Value data,
1075 TypedValue<MemDescType> memDesc,
1077 DistributeLayoutAttr layout) {
1081 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
1082 build(builder, state, data, memDesc, dynamicOffsets, staticOffsetsAttr,
1087 VectorType dataTy = getData().getType();
1088 MemDescType mdescTy = getMemDesc().getType();
1090 if (mdescTy.getRank() != 2)
1091 return emitOpError(
"mem_desc must be 2D.");
1093 ArrayRef<int64_t> dataShape = dataTy.getShape();
1094 ArrayRef<int64_t> mdescShape = mdescTy.getShape();
1095 if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
1096 [](
auto p) {
return std::get<0>(p) > std::get<1>(p); }))
1097 return emitOpError(
"data shape must not exceed mem_desc shape.");
1106 void MemDescSubviewOp::build(OpBuilder &builder, OperationState &state,
1107 Type resTy, Value src,
1112 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
1113 build(builder, state, resTy, src, dynamicOffsets, staticOffsetsAttr);
1117 MemDescType srcTy = getSrc().getType();
1118 MemDescType resTy = getRes().getType();
1119 ArrayRef<int64_t> srcShape = srcTy.getShape();
1120 ArrayRef<int64_t> resShape = resTy.getShape();
1122 if (srcTy.getRank() < resTy.getRank())
1123 return emitOpError(
"result rank must not exceed source rank.");
1126 llvm::zip_equal(resShape, srcShape.take_back(resShape.size())),
1127 [](
auto p) { return std::get<0>(p) > std::get<1>(p); }))
1128 return emitOpError(
"result shape must not exceed source shape.");
1130 if (srcTy.getStrides() != resTy.getStrides())
1131 return emitOpError(
"result must inherit the source strides.");
1140 #include <mlir/Dialect/XeGPU/IR/XeGPUAttrInterface.cpp.inc>
1142 #include <mlir/Dialect/XeGPU/IR/XeGPUEnums.cpp.inc>
1143 #define GET_OP_CLASSES
1144 #include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
static Type getElementType(Type type)
Determine the element type of type.
union mlir::linalg::@1244::ArityGroupAndKind::Kind kind
static Type getValueType(Attribute attr)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Delimiter
These are the supported delimiters around operand lists and region argument lists,...
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
Attributes are known-constant values of operations.
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class helps build Operations.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
@ Type
An inlay hint that for a type annotation.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
static std::string makeString(T array, bool breakline=false)
static LogicalResult isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy, VectorType valueTy, int64_t chunkSize, function_ref< InFlightDiagnostic()> emitError)
ParseResult parseOptionalDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
void printOptionalDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, DenseI64ArrayAttr integers)
bool isSharedMemory(const MemRefType &memrefTy)
static bool isWriteHintOrNone(const CachePolicyAttr &attr)
static bool isReadHintOrNone(const CachePolicyAttr &attr)
static LogicalResult isValidGatherScatterParams(Type maskTy, VectorType valueTy, TensorDescType tdescTy, function_ref< InFlightDiagnostic()> emitError)
static SmallVector< int64_t > getShapeOf(Type type)
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
LogicalResult matchAndRewrite(xegpu::ConvertLayoutOp op, PatternRewriter &rewriter) const override