19#include "llvm/Support/Debug.h"
21#define DEBUG_TYPE "xegpu"
27 Attribute attr = memrefTy.getMemorySpace();
28 if (
auto intAttr = llvm::dyn_cast<IntegerAttr>(attr))
29 return intAttr.getInt() == 3;
30 if (
auto memrefSpace = llvm::dyn_cast<MemorySpaceAttr>(attr))
31 return memrefSpace.getValue() == MemorySpace::SLM;
32 if (
auto xevmSpace = llvm::dyn_cast<xevm::AddrSpaceAttr>(attr))
33 return xevmSpace.getValue() == xevm::AddrSpace::SHARED;
34 return gpu::GPUDialect::isWorkgroupMemoryAddressSpace(attr);
38static std::string
makeString(T array,
bool breakline =
false) {
41 llvm::raw_string_ostream os(buf);
43 for (
size_t i = 1; i < array.size(); i++) {
44 os << array[i - 1] <<
", ";
48 os << array.back() <<
"]";
54 if (
auto ty = llvm::dyn_cast<ShapedType>(type))
64 auto kind = attr.getValue();
65 return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED ||
66 kind == CachePolicy::STREAMING || kind == CachePolicy::READ_INVALIDATE;
72 auto kind = attr.getValue();
73 return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED ||
74 kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH;
79 VectorType valueTy,
int64_t chunkSize,
82 auto maskVecTy = dyn_cast<VectorType>(maskTy);
83 auto offsetsVecTy = dyn_cast<VectorType>(offsetsTy);
86 return emitError() <<
"Expecting chunk size == 1 for scalar result";
87 if (maskVecTy || offsetsVecTy)
88 return emitError() <<
"Expecting scalar mask and offsets.";
89 else if (maskVecTy && offsetsVecTy)
90 return emitError() <<
"Expecting a vector type result.";
94 auto valueSize = valueTy.getNumElements();
96 if (!maskVecTy && !offsetsVecTy) {
97 if (valueSize != chunkSize)
98 return emitError() <<
"value elements must match chunk size "
106 return emitError() <<
"Expecting a vector type mask.";
107 int64_t maskSize = maskVecTy.getNumElements();
110 if ((valueTy.getRank() == 1) && (valueSize != chunkSize))
111 return emitError() <<
"value elements must match chunk size "
114 if (valueSize != maskSize)
116 <<
"Mask should match value except the chunk size dim.";
122 expectedMaskShape.pop_back();
123 if (expectedMaskShape != maskShape)
124 return emitError() <<
"Mask should match value except the chunk size dim.";
131 UnitAttr subgroup_block_io, DistributeLayoutAttr layout,
135 if (subgroup_block_io)
136 return emitError() <<
"subgroup_block_io "
137 "are only allowed when result is a VectorType.";
142 if (mdescTy.getRank() < 2)
143 return emitError() <<
"mem_desc must be 2D or greater.";
149 ArrayAttr strideAttr = mdescTy.getStrideAttr();
151 for (
Attribute attr : strideAttr.getValue()) {
152 strides.push_back(cast<IntegerAttr>(attr).getInt());
154 if (subgroup_block_io && layout) {
155 auto laneData = layout.getEffectiveLaneDataAsInt();
156 auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
157 if (!laneData.empty()) {
158 bool isLaneDataContiguous =
159 std::all_of(laneData.begin(), std::prev(laneData.end()),
160 [](
int x) { return x == 1; });
161 if (!isLaneDataContiguous)
162 return emitError() <<
"With subgroup_block_io, accessed data must be "
163 "contiguous and coalesced.";
164 for (
size_t i = 0; i < laneData.size(); ++i) {
165 if (laneLayout[i] != blockShape[i])
166 return emitError() <<
"With subgroup_block_io, the block shape must "
167 "match the lane layout.";
168 if (laneLayout[i] != 1 && strides[i] != 1)
169 return emitError() <<
"With subgroup_block_io, the distributed "
170 "dimensions must be contiguous.";
175 if (layout && !layout.isDistributable(
177 return emitError() <<
"Value shape is not distributable with the layout";
179 if (dataShape.size() == 2) {
180 if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
181 [](
auto p) {
return std::get<0>(p) > std::get<1>(p); }))
182 return emitError() <<
"data shape must not exceed mem_desc shape.";
186 if (subgroup_block_io && !blockShape.size())
187 return emitError() <<
"mem_desc must have block attribute when "
188 "subgroup_block_io is set.";
191 if (subgroup_block_io && mdescTy.isColMajor())
192 return emitError() <<
"mem_desc should be row major when "
193 "subgroup_block_io is set.";
205 [[maybe_unused]]
auto ty = source.getType();
206 assert(ty.hasStaticShape() &&
"expecting a memref with static shape");
208 build(builder, state, tdesc, source,
ValueRange({}) ,
221 assert((isa<IntegerType, MemRefType>(srcTy)) &&
222 "Source has to be either int or memref.");
236 if (
auto memrefTy = dyn_cast<MemRefType>(srcTy)) {
237 auto memrefShape = memrefTy.getShape();
238 auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
243 if (staticShape == memrefShape && staticStrides == memrefStrides &&
244 dynamicShape.empty() && dynamicStrides.empty()) {
250 build(builder, state, tdesc, source,
ValueRange({}), dynamicShape,
258 [[maybe_unused]]
auto ty = source.getType();
259 assert(ty.hasStaticShape() && offsets.size() == (
size_t)ty.getRank());
265 build(builder, state, tdesc, source, dynamicOffsets ,
277 assert(!
shape.empty() && !offsets.empty() && !strides.empty() &&
278 shape.size() == strides.size() &&
shape.size() == offsets.size());
281 assert((isa<IntegerType, MemRefType>(srcTy)) &&
282 "Source has to be either int or memref.");
300 if (
auto memrefTy = dyn_cast<MemRefType>(srcTy)) {
301 auto memrefShape = memrefTy.getShape();
302 auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
307 if (staticShape == memrefShape && staticStrides == memrefStrides &&
308 dynamicShape.empty() && dynamicStrides.empty()) {
314 build(builder, state, tdesc, source, dynamicOffsets, dynamicShape,
315 dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr);
318LogicalResult CreateNdDescOp::verify() {
320 bool invalidRank = rank != getMixedStrides().size();
321 bool invalidElemTy =
false;
327 auto srcMemorySpace = getSourceMemorySpace();
328 auto tdescMemorySpace =
static_cast<unsigned>(
getType().getMemorySpace());
329 if (srcMemorySpace != tdescMemorySpace)
331 <<
" Source: " << srcMemorySpace
332 <<
", TensorDesc: " << tdescMemorySpace;
334 if (
size_t offsetRank = getMixedOffsets().size())
335 invalidRank |= (offsetRank != rank);
339 if (
auto memrefTy = dyn_cast<MemRefType>(getSourceType()))
342 if (llvm::isa<IntegerType>(getSourceType())) {
345 return emitOpError(
"expecting strides and shape to be present for "
351 "Expecting the rank of shape, strides, offsets, and source (if source "
352 "is a memref) should match with each other.");
357 "Expecting the TensorDesc rank is not greater than the "
358 "ranks of shape, strides, offsets or the memref source.");
361 return emitOpError(
"TensorDesc should have the same element "
362 "type with the source if it is a memref.\n");
374 auto parseIntegerOrValue = [&]() {
378 if (res.has_value() && succeeded(res.value())) {
379 values.push_back(operand);
380 integerVals.push_back(ShapedType::kDynamic);
381 if (valueTypes && parser.
parseColonType(valueTypes->emplace_back()))
387 integerVals.push_back(integer);
397 <<
"expected a list of SSA values or integers";
408 if (!integers || integers.empty())
418 Value tensorDesc, xegpu::CachePolicyAttr l1_hint,
419 xegpu::CachePolicyAttr l2_hint,
420 xegpu::CachePolicyAttr l3_hint) {
423 l1_hint, l2_hint, l3_hint,
nullptr);
428 xegpu::CachePolicyAttr l1_hint,
429 xegpu::CachePolicyAttr l2_hint,
430 xegpu::CachePolicyAttr l3_hint,
431 xegpu::DistributeLayoutAttr layout) {
438 build(builder, state, tensorDesc, dynamicOffsets, staticOffsetsAttr, l1_hint,
439 l2_hint, l3_hint, layout);
442LogicalResult PrefetchNdOp::verify() {
443 auto tdescTy = getTensorDescType();
446 return emitOpError(
"invalid l1_hint: ") << getL1HintAttr();
449 return emitOpError(
"invalid l2_hint: ") << getL2HintAttr();
452 return emitOpError(
"invalid l3_hint: ") << getL3HintAttr();
454 int64_t tDescRank = tdescTy.getRank();
455 int64_t offsetSize = getMixedOffsets().size();
456 if (offsetSize != 0 && offsetSize != tDescRank)
458 "Mismatched ranks between offsets and tensor descriptor");
460 if (
auto layout = getAnchorLayout()) {
461 if (!layout.isDistributable(
getShapeOf(tdescTy)))
463 "TensorDesc shape is not distributable with the layout");
474 Value tensorDesc, UnitAttr packed,
476 xegpu::CachePolicyAttr l1_hint,
477 xegpu::CachePolicyAttr l2_hint,
478 xegpu::CachePolicyAttr l3_hint) {
480 return build(builder, state, retType, tensorDesc,
ValueRange(),
488 xegpu::CachePolicyAttr l1_hint,
489 xegpu::CachePolicyAttr l2_hint,
490 xegpu::CachePolicyAttr l3_hint,
491 xegpu::DistributeLayoutAttr layout) {
498 build(builder, state, retType, tensorDesc, dynamicOffsets, staticOffsetsAttr,
499 packed, transpose, l1_hint, l2_hint, l3_hint,
503LogicalResult LoadNdOp::verify() {
504 auto tdescTy = getTensorDescType();
507 if (tdescTy.getRank() > 2)
508 return emitOpError(
"Expects a 1D or 2D TensorDesc.\n");
511 return emitOpError(
"Invalid result, it should be a VectorType.\n");
514 return emitOpError(
"invalid l1_hint: ") << getL1HintAttr();
517 return emitOpError(
"invalid l2_hint: ") << getL2HintAttr();
520 return emitOpError(
"invalid l3_hint: ") << getL3HintAttr();
522 int tdescElems = tdescTy.getNumElements() * tdescTy.getArrayLength();
523 int valueElems = valueTy.getNumElements();
528 if (valueElems < tdescElems && valueTy.getRank() == 1) {
530 if (tdescTy.getLayoutAttr())
532 <<
"TensorDesc doesn't need LayoutAttr for SIMT code";
537 if (tdescElems % valueElems)
540 <<
" is not a valid distribution for tensor descriptor "
550 if (getTranspose()) {
551 auto trans = getTranspose().value();
553 if (llvm::all_of(trans, [&](
size_t s) {
return s < tdescShape.size(); }))
560 if (tdescTy.getRank() == 2) {
562 auto vnni_factor = valueShape.back();
563 tdescShape[axis] /= vnni_factor;
564 tdescShape.push_back(vnni_factor);
567 <<
"Invalid Packed Attr. It is ignored (available for 2D "
572 auto array_len = tdescTy.getArrayLength();
574 tdescShape.insert(tdescShape.begin(), array_len);
576 if (tdescShape != valueShape)
578 <<
" is not consistent with tensor descriptor "
581 int64_t tDescRank = tdescTy.getRank();
582 int64_t offsetSize = getMixedOffsets().size();
583 if (offsetSize != 0 && offsetSize != tDescRank)
585 "Mismatched ranks between offsets and tensor descriptor");
587 if (
auto layout = getAnchorLayout()) {
588 if (!layout.isDistributable(
getShapeOf(tdescTy)))
590 "TensorDesc shape is not distributable with the layout");
601 Value tensorDesc, xegpu::CachePolicyAttr l1_hint,
602 xegpu::CachePolicyAttr l2_hint,
603 xegpu::CachePolicyAttr l3_hint) {
605 return build(builder, state, value, tensorDesc,
ValueRange(),
612 xegpu::CachePolicyAttr l1_hint,
613 xegpu::CachePolicyAttr l2_hint,
614 xegpu::CachePolicyAttr l3_hint,
615 xegpu::DistributeLayoutAttr layout) {
622 build(builder, state, value, tensorDesc, dynamicOffsets, staticOffsetsAttr,
623 l1_hint, l2_hint, l3_hint, layout);
626LogicalResult StoreNdOp::verify() {
627 auto dstTy = getTensorDescType();
630 if (dstTy.getRank() > 2)
631 return emitOpError(
"Expects a 1D or 2D TensorDesc.\n");
634 return emitOpError(
"Expecting a VectorType result.\n");
637 return emitOpError(
"invalid l1_hint: ") << getL1HintAttr();
640 return emitOpError(
"invalid l2_hint: ") << getL2HintAttr();
643 return emitOpError(
"invalid l3_hint: ") << getL3HintAttr();
645 auto array_len = dstTy.getArrayLength();
647 return emitOpError(
"array length is not supported by store_nd.\n");
649 auto tdescElems = dstTy.getNumElements();
650 auto valueElems = valTy.getNumElements();
655 if (valTy.getRank() == 1 && valueElems < tdescElems) {
657 if (dstTy.getLayoutAttr())
659 <<
"TensorDesc doesn't need LayoutAttr for SIMT code";
661 if (tdescElems % valueElems)
664 <<
" is not a valid distribution for tensor descriptor " << dstTy;
672 if (tdescShape != valueShape)
674 <<
" is not consistent with tensor descriptor "
677 int64_t tDescRank = dstTy.getRank();
678 int64_t offsetSize = getMixedOffsets().size();
679 if (offsetSize != 0 && offsetSize != tDescRank)
681 "Mismatched ranks between offsets and tensor descriptor");
683 if (
auto layout = getAnchorLayout()) {
684 if (!layout.isDistributable(tdescShape))
686 "TensorDesc shape is not distributable with the layout");
695LogicalResult UpdateNdOffsetOp::verify() {
696 auto ty = getTensorDescType();
699 if (ty.getRank() != (
int64_t)getNumOffsets()) {
708LogicalResult PrefetchOp::verify() {
709 auto tdescTy = getTensorDescType();
711 if (!tdescTy && !getOffsets())
714 if (tdescTy && getOffsets())
718 return emitOpError(
"invalid l1_hint: ") << getL1HintAttr();
721 return emitOpError(
"invalid l2_hint: ") << getL2HintAttr();
724 return emitOpError(
"invalid l3_hint: ") << getL3HintAttr();
726 auto srcTy = getSourceType();
727 if (srcTy.
isInteger() && !getOffsetAlignByteAttr())
728 return emitOpError(
"offset_align_byte is required with integer source.");
730 if (getOffsetAlignByteAttr() && !srcTy.
isInteger())
731 return emitOpError(
"offset_align_byte only allowed with integer source.");
733 if (
auto layout = getAnchorLayout()) {
735 if (
auto offsets = getOffsets()) {
736 auto offsetsTy = offsets.getType();
737 if (llvm::isa<VectorType>(offsetsTy) &&
738 !layout.isDistributable(
getShapeOf(offsetsTy)))
739 return emitOpError(
"offset shape is not distributable with the layout");
747 xegpu::CachePolicyAttr l1_hint,
748 xegpu::CachePolicyAttr l2_hint,
749 xegpu::CachePolicyAttr l3_hint) {
750 build(builder, state, source,
Value(), l1_hint, l2_hint, l3_hint,
751 IntegerAttr{},
nullptr);
757LogicalResult LoadGatherOp::verify() {
758 auto tdescTy = getTensorDescType();
759 auto maskTy = getMaskType();
762 if (!tdescTy && !getOffsets())
765 if (tdescTy && getOffsets())
769 return emitOpError(
"invalid l1_hint: ") << getL1HintAttr();
772 return emitOpError(
"invalid l2_hint: ") << getL2HintAttr();
775 return emitOpError(
"invalid l3_hint: ") << getL3HintAttr();
777 auto srcTy = getSourceType();
778 uint64_t chunkSize =
static_cast<int64_t>(getChunkSize().value_or(1));
779 auto memTy = dyn_cast<MemRefType>(srcTy);
782 return emitError() <<
"Value should have the same element type as MemRef.";
784 if (
auto layout = getAnchorLayout()) {
785 if (!layout.isDistributable(
getShapeOf(valueTy)))
786 return emitOpError(
"Value shape is not distributable with the layout");
789 auto offsetsTy = getOffsets().getType();
796 xegpu::CachePolicyAttr l1_hint,
797 xegpu::CachePolicyAttr l2_hint,
798 xegpu::CachePolicyAttr l3_hint) {
799 build(builder, state, valueType, source,
Value(), mask, IntegerAttr(),
800 l1_hint, l2_hint, l3_hint,
nullptr);
806 IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint,
807 xegpu::CachePolicyAttr l2_hint,
808 xegpu::CachePolicyAttr l3_hint) {
809 auto loc = source.
getLoc();
811 auto type = VectorType::get(size, builder.
getIndexType());
813 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
815 build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
816 l2_hint, l3_hint,
nullptr);
822 IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint,
823 xegpu::CachePolicyAttr l2_hint,
824 xegpu::CachePolicyAttr l3_hint,
825 DistributeLayoutAttr layout) {
826 auto loc = source.
getLoc();
828 auto type = VectorType::get(size, builder.
getIndexType());
830 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
832 build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
833 l2_hint, l3_hint, layout);
839LogicalResult StoreScatterOp::verify() {
840 auto tdescTy = getTensorDescType();
841 auto maskTy = getMaskType();
844 if (!tdescTy && !getOffsets())
847 if (tdescTy && getOffsets())
851 return emitOpError(
"invalid l1_hint: ") << getL1HintAttr();
854 return emitOpError(
"invalid l2_hint: ") << getL2HintAttr();
857 return emitOpError(
"invalid l3_hint: ") << getL3HintAttr();
859 auto destTy = getDestType();
860 uint64_t chunkSize =
static_cast<int64_t>(getChunkSize().value_or(1));
861 auto memTy = dyn_cast<MemRefType>(destTy);
864 return emitError() <<
"Value should have the same element type as MemRef.";
866 if (
auto layout = getAnchorLayout()) {
867 if (!layout.isDistributable(
getShapeOf(valueTy)))
868 return emitOpError(
"Value shape is not distributable with the layout");
871 auto offsetsTy = getOffsets().getType();
878 xegpu::CachePolicyAttr l1_hint,
879 xegpu::CachePolicyAttr l2_hint,
880 xegpu::CachePolicyAttr l3_hint) {
881 build(builder, state, value, dest,
Value(), mask, IntegerAttr(), l1_hint,
882 l2_hint, l3_hint,
nullptr);
888 IntegerAttr chunk_size,
889 xegpu::CachePolicyAttr l1_hint,
890 xegpu::CachePolicyAttr l2_hint,
891 xegpu::CachePolicyAttr l3_hint) {
894 auto type = VectorType::get(size, builder.
getIndexType());
896 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
899 build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
903void StoreScatterOp::build(
906 xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint,
907 xegpu::CachePolicyAttr l3_hint, DistributeLayoutAttr layout) {
910 auto type = VectorType::get(size, builder.
getIndexType());
912 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
915 build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
922LogicalResult DpasOp::verify() {
923 int64_t lhsRank = getLhsType().getRank();
924 int64_t rhsRank = getRhsType().getRank();
925 int64_t resRank = getResultType().getRank();
926 auto lhsShape = getLhsType().getShape();
927 auto rhsShape = getRhsType().getShape();
928 auto resShape = getResultType().getShape();
930 if (
auto cdLayout = getLayoutCd())
931 if (!cdLayout->isDistributable(
933 return emitOpError(
"Value shape is not distributable with the layout");
935 if (
auto aLayout = getLayoutA())
936 if (!aLayout->isDistributable(
938 return emitOpError(
"Value shape is not distributable with the layout");
940 if (
auto bLayout = getLayoutB())
941 if (!bLayout->isDistributable(
943 return emitOpError(
"Value shape is not distributable with the layout");
945 if (getAcc() && getAcc().
getType() != getResultType())
946 return emitOpError(
"Expecting the acc type to be the same as result.");
951 if (lhsRank == 1 && rhsRank == 1 && resRank == 1) {
952 auto numElems = getRhsType().getNumElements();
953 auto elemTy = getRhsType().getElementType();
954 auto factor = 32 / elemTy.getIntOrFloatBitWidth();
955 if (numElems % factor != 0)
956 return emitOpError(
"Expecting B operand to be a multiple of 32 bits.");
961 if (lhsRank != 2 || (rhsRank != 2 && rhsRank != 3) || resRank != 2)
963 "expecting lhs and result to be a 2D vector, and rhs to be either "
964 "2D or 3D (packed) vector.");
965 auto bK = rhsRank == 3 ? rhsShape[0] * rhsShape[2] : rhsShape[0];
966 if (bK != lhsShape[1])
968 if (lhsShape[0] != resShape[0])
970 if (rhsShape[1] != resShape[1])
979LogicalResult ConvertLayoutOp::verify() {
980 auto srcLayout = getInputLayout();
981 auto resLayout = getTargetLayout();
989 if ((!srcLayout.isForWorkgroup() || !resLayout.isForWorkgroup()) &&
990 (!srcLayout.isForSubgroup() || !resLayout.isForSubgroup()))
991 return emitOpError(
"expected input layout and target layout be WgLayout or "
992 "SgLayout at the same time.");
994 Type srcType = getSource().getType();
995 if (llvm::isa<VectorType>(srcType)) {
997 if (!srcLayout.isDistributable(
shape))
999 "invalid input layout, data cannot be evenly distributed.");
1001 if (!resLayout.isDistributable(
shape))
1003 "invalid target layout, data cannot be evenly distributed.");
1005 return mlir::success();
1014 DistributeLayoutAttr layout) {
1021 build(builder, state, res, memDesc, dynamicOffsets, staticOffsetsAttr,
1025LogicalResult LoadMatrixOp::verify() {
1027 auto resTy = dyn_cast<VectorType>(getRes().
getType());
1028 UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
1029 MemDescType mdescTy = getMemDesc().getType();
1032 getLayoutAttr(), [&]() {
return emitError(); });
1041 DistributeLayoutAttr layout) {
1046 build(builder, state, data, memDesc, dynamicOffsets, staticOffsetsAttr,
1050LogicalResult StoreMatrixOp::verify() {
1052 auto dataTy = dyn_cast<VectorType>(getData().
getType());
1053 UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
1054 MemDescType mdescTy = getMemDesc().getType();
1056 getLayoutAttr(), [&]() {
return emitError(); });
1063LogicalResult TruncfOp::verify() {
1064 auto sourceVecType = dyn_cast<VectorType>(getSource().
getType());
1065 auto resultVecType = dyn_cast<VectorType>(getResult().
getType());
1067 if (sourceVecType.getElementTypeBitWidth() <=
1068 resultVecType.getElementTypeBitWidth())
1069 return emitOpError(
"input type must be wider than result type.");
1078LogicalResult DpasMxOp::verify() {
1079 if (getAcc() && getAcc().
getType() != getResultType())
1080 return emitOpError(
"Expecting the acc type to be the same as result.");
1086#include <mlir/Dialect/XeGPU/IR/XeGPUAttrInterface.cpp.inc>
1088#include <mlir/Dialect/XeGPU/IR/XeGPUEnums.cpp.inc>
1089#define GET_OP_CLASSES
1090#include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Type getElementType(Type type)
Determine the element type of type.
static Type getValueType(Attribute attr)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
static SmallVector< int64_t > getShapeOf(Type type)
LogicalResult IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy, UnitAttr subgroup_block_io, DistributeLayoutAttr layout, function_ref< InFlightDiagnostic()> emitError)
static std::string makeString(T array, bool breakline=false)
static bool isWriteHintOrNone(const CachePolicyAttr &attr)
static bool isReadHintOrNone(const CachePolicyAttr &attr)
static LogicalResult isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy, VectorType valueTy, int64_t chunkSize, function_ref< InFlightDiagnostic()> emitError)
static void printOptionalDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, DenseI64ArrayAttr integers)
static bool isSharedMemory(const MemRefType &memrefTy)
static ParseResult parseOptionalDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Delimiter
These are the supported delimiters around operand lists and region argument lists,...
@ Square
Square brackets surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
Attributes are known-constant values of operations.
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
This class represents a diagnostic that is inflight and set to be reported.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class helps build Operations.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
llvm::function_ref< Fn > function_ref
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.