30#include "llvm/ADT/DenseMap.h"
31#include "llvm/ADT/SmallVector.h"
43template <
typename OpTy>
45 MemRefType memrefType,
47 int64_t rank = memrefType.getRank();
48 if (rank != numIndices)
49 return op.emitOpError(
"expected ")
50 << rank <<
" " << indexName <<
" indices, got " << numIndices;
57LogicalResult PackedTrunc2xFp8Op::verify() {
58 if (getExisting() && getExisting().
getType() != getResult().
getType())
59 return emitOpError(
"existing values must have same type as result");
63LogicalResult PackedStochRoundFp8Op::verify() {
64 if (getExisting() && getExisting().
getType() != getResult().
getType())
65 return emitOpError(
"existing values must have same type as result");
72LogicalResult PackedScaledTruncOp::verify() {
73 if (getExisting() && getExisting().
getType() != getResult().
getType())
74 return emitOpError(
"existing values must have same type as result");
91 amdgpu::AddressSpaceAttr::get(ctx, amdgpu::AddressSpace::FatRawBuffer));
92 MemRefLayoutAttrInterface layout = source.getLayout();
93 if (resetOffset && !layout.isIdentity()) {
94 auto stridedLayout = dyn_cast<StridedLayoutAttr>(layout);
97 MemRefLayoutAttrInterface newLayout =
98 StridedLayoutAttr::get(ctx, 0, stridedLayout.getStrides());
103 if (source.hasStaticShape()) {
105 }
else if (source.getRank() <= 1) {
108 if (stridesIfIdentity == stridedLayout.getStrides()) {
109 newLayout = AffineMapAttr::get(
114 return (MemRefType)(mb);
117LogicalResult FatRawBufferCastOp::inferReturnTypes(
121 Adaptor adaptor(operands, attributes, properties, regions);
123 dyn_cast_if_present<MemRefType>(adaptor.getSource().getType());
126 FailureOr<MemRefType> resultType =
134FailureOr<OpFoldResult> FatRawBufferCastOp::reifyDimOfResult(
OpBuilder &builder,
137 assert(resultIndex == 0 &&
"FatRawBufferCastOp has a single result");
141LogicalResult FatRawBufferCastOp::verify() {
142 FailureOr<MemRefType> expectedResultType =
144 if (
failed(expectedResultType))
146 << getSource().getType() <<
" can't have its offset reset";
147 if (getResult().
getType() != *expectedResultType)
149 << *expectedResultType <<
" but got " << getResult().getType();
156 if (
auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
157 return intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
158 if (
auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
159 return gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
166 if (
auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
167 return intMemorySpace.getInt() == 3;
168 if (
auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
169 return gpuMemorySpace.getValue() == gpu::AddressSpace::Workgroup;
176 if (
auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
177 return intMemorySpace.getInt() == 7;
178 if (
auto gpuMemorySpace = dyn_cast<amdgpu::AddressSpaceAttr>(memorySpace))
179 return gpuMemorySpace.getValue() == amdgpu::AddressSpace::FatRawBuffer;
188 MemRefType bufferType = llvm::cast<MemRefType>(op.getMemref().getType());
192 return op.emitOpError(
193 "buffer ops must operate on a memref in global memory");
194 if (!bufferType.hasRank())
195 return op.emitOpError(
196 "cannot meaningfully buffer_store to an unranked memref");
204LogicalResult RawBufferAtomicFaddOp::verify() {
208LogicalResult RawBufferAtomicFmaxOp::verify() {
212LogicalResult RawBufferAtomicSmaxOp::verify() {
216LogicalResult RawBufferAtomicUminOp::verify() {
220LogicalResult RawBufferAtomicCmpswapOp::verify() {
229 return cst.getZExtValue();
233template <
typename OpType>
235 if (!op.getBoundsCheck())
237 MemRefType bufferType = op.getMemref().getType();
238 if (!bufferType.hasStaticShape())
242 if (failed(bufferType.getStridesAndOffset(strides, offset)))
245 if (op.getSgprOffset()) {
251 if (strides.size() != op.getIndices().size())
254 for (
auto pair : llvm::zip(strides, op.getIndices())) {
255 int64_t stride = std::get<0>(pair);
256 Value idx = std::get<1>(pair);
260 indexVal += stride * *idxVal;
263 if (
result > std::numeric_limits<uint32_t>::max())
266 return result >= bufferType.getNumElements();
270template <
typename OpType>
271struct RemoveStaticallyOobBufferLoads final :
public OpRewritePattern<OpType> {
272 using OpRewritePattern<OpType>::OpRewritePattern;
274 LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw)
const override {
277 Type loadType = op.getResult().getType();
284template <
typename OpType>
285struct RemoveStaticallyOobBufferWrites final :
public OpRewritePattern<OpType> {
286 using OpRewritePattern<OpType>::OpRewritePattern;
288 LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw)
const override {
300 results.
add<RemoveStaticallyOobBufferLoads<RawBufferLoadOp>>(context);
305 results.
add<RemoveStaticallyOobBufferWrites<RawBufferStoreOp>>(context);
308void RawBufferAtomicFaddOp::getCanonicalizationPatterns(
310 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFaddOp>>(context);
313void RawBufferAtomicFmaxOp::getCanonicalizationPatterns(
315 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFmaxOp>>(context);
318void RawBufferAtomicSmaxOp::getCanonicalizationPatterns(
320 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicSmaxOp>>(context);
323void RawBufferAtomicUminOp::getCanonicalizationPatterns(
325 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicUminOp>>(context);
328void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
330 results.
add<RemoveStaticallyOobBufferLoads<RawBufferAtomicCmpswapOp>>(
337LogicalResult ScaledExtPackedMatrixOp::verify() {
339 assert(llvm::is_contained({16, 32}, blockSize) &&
"invalid block size");
341 int firstScaleByte = getFirstScaleByte();
342 int firstScaleLane = getFirstScaleLane();
343 auto sourceType = cast<VectorType>(getSource().
getType());
344 Type elementType = sourceType.getElementType();
345 auto floatType = cast<FloatType>(elementType);
346 unsigned bitWidth = floatType.getWidth();
350 const bool is_fp8 = bitWidth == 8;
351 const bool is_block_16 = blockSize == 16;
355 if (!llvm::is_contained({0, 1}, firstScaleByte)) {
356 return emitOpError(
"blockSize of 16 can only have firstScaleByte be 0 "
357 "or 1 for f4 and f6.");
360 if (!llvm::is_contained({0, 2}, firstScaleByte)) {
361 return emitOpError(
"blockSize of 32 can only have firstScaleByte be 0 "
362 "or 2 for f4 and f6.");
367 bool is_valid = ((firstScaleLane == 0) && (firstScaleByte == 0)) ||
368 ((firstScaleLane == 16) && (firstScaleByte == 2));
370 return emitOpError(
"blockSize of 16 can only have (firstScaleLane, "
371 "firstScaleByte) be (0, 0) or (16, 2) for f8.");
384 IntegerAttr &m, IntegerAttr &n,
389 if (dimensions.size() != 3)
391 <<
"expected 3 dimensions in MNK dimension list";
399LogicalResult WMMAOp::verify() {
400 auto sourceAType = cast<VectorType>(getSourceA().
getType());
401 auto sourceBType = cast<VectorType>(getSourceB().
getType());
402 auto destType = cast<VectorType>(getDestC().
getType());
404 Type sourceAElemType = sourceAType.getElementType();
405 Type sourceBElemType = sourceBType.getElementType();
406 if (sourceAType.getNumElements() != sourceBType.getNumElements()) {
407 return emitOpError(
"source vectors have different lengths: ")
408 << sourceAType <<
" vs. " << sourceBType;
411 bool isDestFloat = destType.getElementType().
isFloat();
412 bool isSrcFloat = sourceAElemType.
isFloat();
414 if (isDestFloat && !isSrcFloat)
415 return emitOpError(
"expected float sources with float destination");
416 if (!isDestFloat && isSrcFloat)
417 return emitOpError(
"expected int sources with int destination");
419 if (!sourceAElemType.
isFloat(8) && sourceAElemType != sourceBElemType) {
421 "source element types must match (except for fp8/bf8) but have ")
422 << sourceAType <<
" and " << sourceBType;
427 return emitOpError(
"clamp flag is not supported for float types");
428 if (getUnsignedA() || getUnsignedB())
429 return emitOpError(
"unsigned flags are not supported for float types");
438LogicalResult ScaledWMMAOp::verify() {
440 auto isF8 = llvm::IsaPred<Float8E4M3FNType, Float8E5M2Type>;
441 auto isF6 = llvm::IsaPred<Float6E2M3FNType, Float6E3M2FNType>;
442 auto isF4 = llvm::IsaPred<Float4E2M1FNType>;
443 auto isScaleF8 = llvm::IsaPred<Float8E8M0FNUType, Float8E4M3FNType>;
444 auto isE8M0 = llvm::IsaPred<Float8E8M0FNUType>;
445 auto isE4M3 = llvm::IsaPred<Float8E4M3FNType>;
447 auto sourceAType = cast<VectorType>(getSourceA().
getType());
448 auto sourceBType = cast<VectorType>(getSourceB().
getType());
449 auto destType = cast<VectorType>(getDestC().
getType());
452 Type aElemType = sourceAType.getElementType();
453 Type bElemType = sourceBType.getElementType();
457 int64_t aLen = sourceAType.getNumElements();
458 int64_t bLen = sourceBType.getNumElements();
459 int64_t expectedOutLen = (m == 16) ? 8 : 16;
461 if (destType.getNumElements() != expectedOutLen)
462 return emitOpError(
"expected output vector of length ")
463 << expectedOutLen <<
" but got " << destType.getNumElements();
469 "for 16x16x128, sourceA must have 64 elements but got ")
473 "for 16x16x128, sourceB must have 64 elements but got ")
477 if (!isF4(aElemType) && !isF4(bElemType))
478 return emitOpError(
"32x16x128 only supports fp4 element types");
482 "for 32x16x128, sourceA must have 128 elements but got ")
486 "for 32x16x128, sourceB must have 64 elements but got ")
491 if (getAFirstScaleLane() != 0)
492 return emitOpError(
"for 32x16x128, a_first_scale_lane must be 0");
496 auto scaleAType = cast<VectorType>(getScaleA().
getType());
497 auto scaleBType = cast<VectorType>(getScaleB().
getType());
498 Type scaleAElemType = scaleAType.getElementType();
499 Type scaleBElemType = scaleBType.getElementType();
502 if (!isScaleF8(scaleAElemType) || !isScaleF8(scaleBElemType))
504 "scale operands must have f8 element types (E8M0FNU or E4M3FN)");
507 if (isE8M0(scaleAElemType) && isE8M0(scaleBElemType))
511 if ((isF8(aElemType) || isF6(aElemType)) && isE8M0(scaleAElemType) &&
512 isF4(bElemType) && isE4M3(scaleBElemType))
516 if (isF4(aElemType) && isE4M3(scaleAElemType) &&
517 (isF8(bElemType) || isF6(bElemType)) && isE8M0(scaleBElemType))
521 if (isF4(aElemType) && isF4(bElemType) && isE4M3(scaleAElemType) &&
522 isE4M3(scaleBElemType))
526 return emitOpError(
"invalid combination of matrix and scale types: ")
527 <<
"sourceA=" << aElemType <<
", scaleA=" << scaleAElemType
528 <<
", sourceB=" << bElemType <<
", scaleB=" << scaleBElemType;
534LogicalResult MFMAOp::verify() {
535 constexpr uint32_t waveSize = 64;
538 Type sourceType = getSourceA().getType();
539 Type destType = getDestC().getType();
541 Type sourceElem = sourceType, destElem = destType;
542 uint32_t sourceLen = 1, destLen = 1;
543 if (
auto sourceVector = dyn_cast<VectorType>(sourceType)) {
544 sourceLen = sourceVector.getNumElements();
545 sourceElem = sourceVector.getElementType();
547 if (
auto destVector = dyn_cast<VectorType>(destType)) {
548 destLen = destVector.getNumElements();
549 destElem = destVector.getElementType();
552 Type sourceBType = getSourceB().getType();
555 Type sourceBElem = sourceBType;
556 if (
auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
557 sourceBLen = sourceBVector.getNumElements();
558 sourceBElem = sourceBVector.getElementType();
562 return emitOpError(
"expected both source operands to have small-float "
563 "elements if one does");
564 if (sourceLen != sourceBLen)
566 "expected both small-float source vectors to have the same length");
568 if (sourceType != sourceBType)
569 return emitOpError(
"expected both non-small-float source operand types "
575 sourceElem =
b.getI8Type();
579 sourceElem =
b.getI8Type();
582 int64_t numSourceElems = (
getM() * getK() * getBlocks()) / waveSize;
583 if (sourceLen != numSourceElems)
584 return emitOpError(
"expected " + Twine(numSourceElems) +
585 " source values for this operation but got " +
589 if (destLen != numDestElems)
590 return emitOpError(
"expected " + Twine(numDestElems) +
591 " result values for this operation but got " +
594 if (destElem.isF64() && getBlgp() != MFMAPermB::none)
596 "double-precision ops do not support permuting lanes of B");
597 if (destElem.isF64() && getCbsz() != 0)
599 "double-precision ops do not support permuting lanes of A");
600 if (getAbid() >= (1u << getCbsz()))
602 "block ID for permuting A (abid) must be below 2 ** cbsz");
604 if ((getNegateA() || getNegateB() || getNegateC()) && !destElem.isF64())
606 "negation flags only available for double-precision operations");
615LogicalResult SparseMFMAOp::verify() {
616 constexpr uint32_t waveSize = 64;
618 auto sparseType = cast<VectorType>(getSourceA().
getType());
619 auto denseType = cast<VectorType>(getSourceB().
getType());
620 auto destType = cast<VectorType>(getDestC().
getType());
622 Type sparseElem = sparseType.getElementType();
623 Type denseElem = denseType.getElementType();
624 int64_t sparseLen = sparseType.getNumElements();
625 int64_t denseLen = denseType.getNumElements();
626 int64_t destLen = destType.getNumElements();
628 if (denseLen != 2 * sparseLen)
629 return emitOpError(
"expected dense source operand to have exactly double "
630 "the number of elements of the sparse source operand");
636 if (!bothFloat8 && sparseElem != denseElem)
638 "expected source operands to have the same element type");
645 uint32_t m =
getM(), k = getK();
648 !is8BitSource && ((m == 16 && k == 32) || (m == 32 && k == 16));
650 is8BitSource && ((m == 16 && k == 128) || (m == 32 && k == 64));
659 "CBSZ must be 0 for this variant (field is ignored by hardware)");
662 "ABID must be 0 for this variant (field is ignored by hardware)");
663 }
else if (getCbsz() == 0) {
664 unsigned maxAbid = is16BitGfx942 ? 3u : 1u;
665 if (getAbid() > maxAbid)
667 << maxAbid <<
"] for this variant";
670 Type sparseIdxType = getSparseIdx().getType();
673 return emitOpError(
"expected i32 sparse indices for this variant "
674 "(no internal set structure), but got ")
677 unsigned expectedIdxElems = is16BitGfx942 ? 4 : 2;
678 unsigned expectedIdxBits = is16BitGfx942 ? 8 : 16;
679 auto vecType = dyn_cast<VectorType>(sparseIdxType);
680 if (!vecType || vecType.getNumElements() != expectedIdxElems ||
681 !vecType.getElementType().isInteger(expectedIdxBits))
683 << expectedIdxElems <<
"xi" << expectedIdxBits
684 <<
"> sparse indices for this variant, but got " << sparseIdxType;
687 int64_t expectedSourceElems = (
getM() * getK()) / waveSize;
688 if (denseLen != expectedSourceElems)
689 return emitOpError(
"expected " + Twine(expectedSourceElems) +
690 " source values for this operation but got " +
694 if (destLen != expectedDestElems)
695 return emitOpError(
"expected " + Twine(expectedDestElems) +
696 " result values for this operation but got " +
706LogicalResult SparseWMMAOp::verify() {
707 auto sparseType = cast<VectorType>(getSourceA().
getType());
708 auto denseType = cast<VectorType>(getSourceB().
getType());
709 auto destType = cast<VectorType>(getDestC().
getType());
711 Type sparseElem = sparseType.getElementType();
712 Type denseElem = denseType.getElementType();
713 Type destElem = destType.getElementType();
714 int64_t sparseLen = sparseType.getNumElements();
715 int64_t denseLen = denseType.getNumElements();
716 int64_t destLen = destType.getNumElements();
718 uint32_t m =
getM(), n =
getN(), k = getK();
719 if ((m != 16) || (n != 16))
720 return emitOpError(
"expected MxN to be exactly 16x16");
722 const bool isWavesize64 = getWave64();
724 const bool isEqualLengthAllowed = isWavesize64 && isInt4Input && k == 32;
726 if ((denseLen != 2 * sparseLen) && !isEqualLengthAllowed)
727 return emitOpError(
"expected dense source operand to have exactly double "
728 "the number of elements of the sparse source operand");
730 if (isEqualLengthAllowed && (denseLen != sparseLen))
731 return emitOpError(
"expected dense source operand to have exactly the "
732 "same the number of elements");
736 return emitOpError(
"source operand and destination operands must all be "
737 "either integer or float types");
743 return emitOpError(
"source operand and destination operands must all be "
744 "either integer or float types");
752 if (!bothFloat8 && sparseElem != denseElem)
754 "expected source operands to have the same element type");
756 const int64_t waveSize = isWavesize64 ? 64 : 32;
758 int64_t expectedSourceElems = (
getM() * getK()) / waveSize;
759 if (denseLen != expectedSourceElems)
760 return emitOpError(
"expected " + Twine(expectedSourceElems) +
761 " source values for this operation but got " +
765 if (destLen != expectedDestElems)
766 return emitOpError(
"expected " + Twine(expectedDestElems) +
767 " result values for this operation but got " +
776LogicalResult DotOp::verify() {
777 Type aElem = cast<VectorType>(getSourceA().
getType()).getElementType();
778 Type bElem = cast<VectorType>(getSourceB().
getType()).getElementType();
779 Type dest = getDestC().getType();
781 bool aIsFloat8 = aElem.
isFloat(8);
782 bool bIsFloat8 = bElem.
isFloat(8);
783 bool aIsInteger = isa<IntegerType>(aElem);
785 bool bothFloat8 = aIsFloat8 && bIsFloat8;
786 if (!bothFloat8 && aElem != bElem)
788 "expected source operands to have the same element type");
792 return emitOpError(
"expected f32 or f16 accumulator for f16 sources");
793 }
else if (aElem.
isBF16()) {
795 return emitOpError(
"expected f32 or bf16 accumulator for bf16 sources");
796 }
else if (aIsInteger) {
798 return emitOpError(
"expected i32 accumulator for integer sources");
799 }
else if (aIsFloat8) {
801 return emitOpError(
"expected f32 accumulator for fp8 sources");
804 if ((getUnsignedA() || getUnsignedB()) && !aIsInteger)
806 "unsignedA/unsignedB are only valid for integer source types");
808 if (aElem.
isInteger(16) && getUnsignedA() != getUnsignedB())
810 "mixed-sign dot is not supported for 16-bit integer sources");
813 bool noClamp = (aElem.
isF16() && dest.
isF16()) ||
817 "clamp is not supported for this (source, accumulator) combination");
826LogicalResult DPPOp::verify() {
827 DPPPerm kind = getKind();
832 case DPPPerm::quad_perm: {
833 auto quadPermAttr = dyn_cast_or_null<ArrayAttr>(permArgument);
834 if (!quadPermAttr || quadPermAttr.size() != 4) {
835 return emitOpError(
"quad_perm attribute must have exactly 4 elements");
837 for (
auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
838 int32_t num = elem.getInt();
839 if (num < 0 || num > 3) {
841 "Each element of quad_perm must be in the range [0, 3]");
846 case DPPPerm::row_shl:
847 case DPPPerm::row_shr:
848 case DPPPerm::row_ror: {
850 return emitOpError(
"Attribute '" + Twine(stringifyDPPPerm(kind)) +
851 "' value not specified");
853 if (
auto intAttr = dyn_cast<IntegerAttr>(permArgument)) {
854 uint32_t attrValue = intAttr.getInt();
855 if (attrValue < 1 || attrValue > 15) {
856 return emitOpError(
"Attribute value must be between 1 and 15");
861 case DPPPerm::wave_shl:
862 case DPPPerm::wave_shr:
863 case DPPPerm::wave_rol:
864 case DPPPerm::wave_ror:
865 case DPPPerm::row_mirror:
866 case DPPPerm::row_half_mirror:
867 case DPPPerm::row_bcast_15:
868 case DPPPerm::row_bcast_31: {
869 if (permArgument && !isa<UnitAttr>(permArgument)) {
870 return emitOpError(
"Expected unit attribute for permArgument, but found "
871 "non-trivial argument");
882LogicalResult PermlaneSwapOp::verify() {
883 unsigned rowLength = getRowLength();
885 if (rowLength != 16 && rowLength != 32)
886 return emitOpError(
"row_length attribute must either be 16 or 32.");
894 if (isa_and_nonnull<LDSBarrierOp>(op->getNextNode())) {
913struct FuseMemoryCounterWaitOp final :
OpRewritePattern<MemoryCounterWaitOp> {
916 LogicalResult matchAndRewrite(MemoryCounterWaitOp op,
917 PatternRewriter &rewriter)
const override {
918 auto next = dyn_cast<MemoryCounterWaitOp>(op->getNextNode());
922 auto setters = {&MemoryCounterWaitOp::setLoad,
923 &MemoryCounterWaitOp::setStore, &MemoryCounterWaitOp::setDs,
924 &MemoryCounterWaitOp::setExp,
925 &MemoryCounterWaitOp::setTensor};
926 auto lhsVals = {op.getLoad(), op.getStore(), op.getDs(), op.getExp(),
928 auto rhsVals = {next.getLoad(), next.getStore(), next.getDs(),
929 next.getExp(), next.getTensor()};
931 for (
auto [setter,
lhs,
rhs] :
932 llvm::zip_equal(setters, lhsVals, rhsVals)) {
934 (op.*setter)(std::min(*
lhs, *
rhs));
948void MemoryCounterWaitOp::getCanonicalizationPatterns(
950 results.
add<FuseMemoryCounterWaitOp>(context);
957LogicalResult GatherToLDSOp::verify() {
958 MemRefType srcType = cast<MemRefType>(getSrc().
getType());
959 MemRefType dstType = cast<MemRefType>(getDst().
getType());
964 getDstIndices().size())))
967 if (dstType.getRank() > 0 && !dstType.areTrailingDimsContiguous(1))
968 return emitOpError(
"destination type inner most dim must be contiguous");
970 auto elemType = srcType.getElementType();
972 if (elemType != dstType.getElementType())
973 return emitOpError(
"source and destination element types must match");
976 auto transferType = getTransferType();
978 if (
auto vectorTransfer = dyn_cast<VectorType>(transferType)) {
979 transferSize = vectorTransfer.getNumElements() *
980 vectorTransfer.getElementTypeBitWidth();
982 transferSize = transferType.getIntOrFloatBitWidth();
984 if (!llvm::is_contained({8, 16, 32, 96, 128}, transferSize))
986 "Transfering type size must be 8, 16, 32, 96 or 128 bits");
991 "source memory address space must be global or fat raw buffer");
994 return emitOpError(
"destination memory address space must be Workgroup");
1005 LogicalResult matchAndRewrite(GatherToLDSOp gatherOp,
1006 PatternRewriter &rewriter)
const override {
1007 bool modified =
false;
1008 auto foldCast = [&](OpOperand &operand) {
1009 if (
auto castOp = operand.get().getDefiningOp<memref::CastOp>()) {
1010 if (memref::CastOp::canFoldIntoConsumerOp(castOp)) {
1012 [&] { operand.assign(castOp.getSource()); });
1018 foldCast(gatherOp.getSrcMutable());
1019 foldCast(gatherOp.getDstMutable());
1028 results.
add<FoldGatherToLDSOfCast>(context);
1035LogicalResult GlobalLoadAsyncToLDSOp::verify() {
1036 MemRefType srcType = cast<MemRefType>(getSrc().
getType());
1037 MemRefType dstType = cast<MemRefType>(getDst().
getType());
1042 getDstIndices().size())))
1045 if (srcType.getElementType() != dstType.getElementType())
1046 return emitOpError(
"source and destination element types must match");
1048 Type transferType = getTransferType();
1050 if (
auto vectorTransfer = dyn_cast<VectorType>(transferType)) {
1051 transferSize = vectorTransfer.getNumElements() *
1052 vectorTransfer.getElementTypeBitWidth();
1056 if (!llvm::is_contained({8, 32, 64, 128}, transferSize))
1057 return emitOpError(
"transfer type size must be 8, 32, 64, or 128 bits");
1060 return emitOpError(
"source memory address space must be global");
1063 return emitOpError(
"destination memory address space must be Workgroup");
1072LogicalResult TransposeLoadOp::verify() {
1073 MemRefType srcType = cast<MemRefType>(getSrc().
getType());
1080 return emitOpError(
"source memory address space must be Workgroup");
1082 auto transferType = cast<VectorType>(
getType());
1083 size_t numElements = transferType.getNumElements();
1084 size_t elementTypeSize =
1088 const llvm::SmallDenseMap<size_t, size_t> kValidLoadSizeMap = {
1095 auto validNumElems = kValidLoadSizeMap.find(elementTypeSize);
1096 if (validNumElems == kValidLoadSizeMap.end())
1097 return emitOpError(
"Unsupported element type size for transpose load: ")
1098 << elementTypeSize <<
" bits";
1100 if (numElements != validNumElems->second)
1102 "Transferring type size mismatch: expected num of elements: ")
1103 << validNumElems->second;
1112LogicalResult GlobalTransposeLoadOp::verify() {
1113 MemRefType srcType = cast<MemRefType>(getSrc().
getType());
1120 return emitOpError(
"source memory address space must be Global");
1122 auto resultType = cast<VectorType>(
getType());
1123 size_t numElements = resultType.getNumElements();
1124 size_t elementTypeSize = resultType.getElementType().getIntOrFloatBitWidth();
1128 static const llvm::SmallDenseMap<size_t, size_t> kValidLoadSizeMap = {
1135 auto validNumElems = kValidLoadSizeMap.find(elementTypeSize);
1136 if (validNumElems == kValidLoadSizeMap.end())
1138 "unsupported element type size for global transpose load: ")
1139 << elementTypeSize <<
" bits";
1141 if (numElements != validNumElems->second)
1143 "transferring type size mismatch: expected num of elements: ")
1144 << validNumElems->second;
1153template <
typename BaseOp>
1155 auto ldsType = cast<MemRefType>(op.getLds().getType());
1156 auto globalType = cast<MemRefType>(op.getGlobal().getType());
1158 op.getGlobalIndices().size())) ||
1163 return op.emitOpError(
1164 "lds memref must have workgroup address space attribute.");
1166 return op.emitOpError(
1167 "global memref must have global address space attribute.");
1169 Type elementType = ldsType.getElementType();
1172 if (!llvm::is_contained({8u, 16u, 32u, 64u}, width))
1173 return op.emitOpError(
1174 "element type must be 1, 2, 4, or 8 bytes long but type was ")
1175 << width <<
" bits long.";
1179LogicalResult MakeDmaBaseOp::verify() {
return verifyBase(*
this); }
1187 Type elementType,
Type indexType) {
1189 if (!llvm::is_contained({8u, 16u, 32u, 64u}, width))
1191 <<
"element type must be 1, 2, 4, or 8 bytes wide but type "
1192 << elementType <<
" is " << width / 8 <<
" bytes wide.";
1194 Type i16 = IntegerType::get(ctx, 32);
1195 Type i32 = IntegerType::get(ctx, 16);
1196 if (!llvm::is_contained({i16, i32}, indexType))
1197 return emitError() <<
"index type must be i16 or i32 but index type is "
1198 << indexType <<
".";
1202LogicalResult MakeGatherDmaBaseOp::verify() {
return verifyBase(*
this); }
1208template <
typename DescriptorOp>
1212 if (globalStaticStrides.empty())
1213 return op.emitOpError(
"strides must not be empty.");
1214 if (globalStaticStrides.back() != 1)
1215 return op.emitOpError(
"strides for the innermost dimension must be 1.");
1218 size_t rank = globalStaticSizes.size();
1220 return op.emitOpError(
"tensor and tile must be at most of rank 5.");
1221 if (rank != globalStaticStrides.size())
1222 return op.emitOpError(
"strides and sizes must have same rank.");
1225 if (rank != sharedStaticSizes.size())
1226 return op.emitOpError(
"tensor must have same rank as tile.");
1228 unsigned elementTypeWidth = op.getElementTypeWidth();
1229 if (!llvm::is_contained({8u, 16u, 32u, 64u}, elementTypeWidth))
1230 return op.emitOpError(
1231 "element type width must be 1, 2, 4 or 8 bytes, but was ")
1232 << elementTypeWidth <<
" bits long";
1234 if (!op.getAtomicBarrierAddress() && !op.getAtomicBarrierIndices().empty())
1235 return op.emitOpError(
1236 "atomic barrier indices require an atomic barrier address");
1238 if (
Value atomicBarrierAddress = op.getAtomicBarrierAddress()) {
1239 auto atomicBarrierAddressType =
1240 cast<MemRefType>(atomicBarrierAddress.getType());
1241 if (failed(
verifyIndexCount(op,
"atomic barrier", atomicBarrierAddressType,
1242 op.getAtomicBarrierIndices().size())))
1248 return op.emitOpError(
"atomic barrier address must be in LDS.");
1251 if (op.getEarlyTimeout() && !op.getWorkgroupMask())
1252 return op.emitOpError(
1253 "early timeout does not apply when workgroup_mask is not set.");
1257template <
typename DescriptorOp,
typename FoldAdaptor>
1278 op.setGlobalStaticSizes(staticGlobalSizes);
1279 op.getGlobalDynamicSizesMutable().assign(dynamicGlobalSizes);
1282 staticGlobalStrides);
1283 op.setGlobalStaticStrides(staticGlobalStrides);
1284 op.getGlobalDynamicStridesMutable().assign(dynamicGlobalStrides);
1288 op.setSharedStaticSizes(staticSharedSizes);
1289 op.getSharedDynamicSizesMutable().assign(dynamicSharedSizes);
1290 return op.getResult();
1293LogicalResult MakeDmaDescriptorOp::verify() {
1297OpFoldResult MakeDmaDescriptorOp::fold(FoldAdaptor adaptor) {
1305LogicalResult MakeGatherDmaDescriptorOp::verify() {
1307 size_t rank = globalStaticSizes.size();
1310 "tensor and tile must be at most of rank two in gather mode.");
1312 Type elementType = cast<VectorType>(
indices.getType()).getElementType();
1314 return emitOpError(
"indices' element type must match base's element type.");
1319OpFoldResult MakeGatherDmaDescriptorOp::fold(FoldAdaptor adaptor) {
1333 LogicalResult matchAndRewrite(ScaledMFMAOp op,
1334 PatternRewriter &rewriter)
const override {
1335 Location loc = op.getLoc();
1336 auto setOpsel = [&op](
unsigned idx, int64_t val) {
1339 op.setScalesIdxA(val);
1342 op.setScalesIdxB(val);
1366 for (
auto opIdx : std::array<int64_t, 2>({3, 4})) {
1367 auto insertOp = op.getOperand(opIdx).getDefiningOp<vector::InsertOp>();
1370 "defining op not a vector.insert");
1373 if (isa<VectorType>(insertOp.getValueToStore().getType())) {
1375 op,
"scaled mfma operand already packed");
1379 insertOp.getValueToStore().getDefiningOp<vector::ExtractOp>();
1382 "defining op not a vector.extract");
1385 Value scaleSrc = extractOp.getOperand(0);
1386 auto scaleSrcType = dyn_cast<VectorType>(scaleSrc.
getType());
1387 if (!scaleSrcType) {
1393 if (!scaleSrcType.hasStaticShape()) {
1395 "dynamic dims not yet supported");
1398 int64_t numElements = scaleSrcType.getNumElements();
1399 if (numElements < 4) {
1401 op,
"do not pack if # of scales less than four");
1405 auto extractedPos = llvm::to_vector_of<int64_t>(
1406 llvm::reverse(extractOp.getStaticPosition()));
1407 ArrayRef<int64_t> scaleSrcShape = scaleSrcType.getShape();
1408 int64_t scaleSrcRank = scaleSrcType.getRank();
1409 SmallVector<int64_t> extractSizes(scaleSrcRank, 1);
1410 for (int64_t i = 1; i < scaleSrcRank; ++i) {
1411 extractSizes[i] = extractSizes[i - 1] * scaleSrcShape[scaleSrcRank - i];
1413 int64_t idx =
linearize(extractedPos, extractSizes);
1425 int64_t offset = idx - (idx % 4);
1426 int64_t opsel = idx - offset;
1429 if (numElements - offset < size) {
1430 opsel = size - (numElements - idx);
1431 offset = numElements - 4l;
1433 Type scaleSrcElemType = scaleSrcType.getElementType();
1435 VectorType::get(ArrayRef{numElements}, scaleSrcElemType);
1437 vector::ShapeCastOp::create(rewriter, loc, newSrcType, scaleSrc);
1438 auto extract = vector::ExtractStridedSliceOp::create(
1439 rewriter, loc, newScaleSrc, ArrayRef{offset}, ArrayRef{size},
1440 ArrayRef{int64_t(1)});
1442 op->setOperand(opIdx, extract);
1443 setOpsel(opIdx, opsel);
1453 results.
add<PackScales>(context);
1460template <
typename T>
1462 MemRefType memrefType = llvm::cast<MemRefType>(op.getBase().getType());
1468 return op.emitOpError(
"barrier must be in workgroup (LDS) memory");
1473LogicalResult DsBarrierInitOp::verify() {
1477LogicalResult DsBarrierPollStateOp::verify() {
1481LogicalResult DsAsyncBarrierArriveOp::verify() {
1485LogicalResult DsBarrierArriveOp::verify() {
1493LogicalResult GlobalPrefetchOp::verify() {
1494 auto src = cast<MemRefType>(getSrc().
getType());
1499 Attribute memSpace = src.getMemorySpace();
1501 return this->
emitOpError(
"the source must have address space attribute");
1503 return this->
emitOpError(
"the source must reside in global address space");
1505 const LoadTemporalHint temporalHint = getTemporalHint();
1506 const Scope scope = getCacheScope();
1507 const bool isSpeculative = getSpeculative();
1510 if (isSpeculative && scope == Scope::WGP)
1512 "does not support speculative prefetch in WGP scope");
1520 if (llvm::is_contained({LoadTemporalHint::NT, LoadTemporalHint::LU},
1522 return this->
emitOpError(
"does not support NT and LU modes");
1524 if (llvm::is_contained({LoadTemporalHint::NT_RT, LoadTemporalHint::RT_NT,
1525 LoadTemporalHint::NT_HT},
1528 return this->
emitOpError(
"operates only in the speculative mode");
1533#define GET_OP_CLASSES
1534#include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
static LogicalResult verifyDescriptorOp(DescriptorOp op)
static LogicalResult verifyRawBufferOp(T &op)
static LogicalResult verifyDsBarrierOpCommon(T &op)
static OpFoldResult foldDescriptorOp(DescriptorOp op, FoldAdaptor adaptor)
static bool hasGlobalMemorySpace(Attribute memorySpace)
static bool hasWorkgroupMemorySpace(Attribute memorySpace)
static FailureOr< MemRefType > getFatRawBufferTypeLike(MemRefType source, bool resetOffset)
Convert the type source to one with the same sizes and strides - and offset, unless stripOffset is tr...
static LogicalResult verifyIndexCount(OpTy op, StringRef indexName, MemRefType memrefType, int64_t numIndices)
Verifies that the number of indices matches the rank of the indexed memref, emitting an op error ment...
static LogicalResult eraseRedundantLDSBarrierOps(LDSBarrierOp op, PatternRewriter &rewriter)
Remove amdgpu.lds_barrier after amdgpu.lds_barrier.
static bool hasFatRawBufferMemorySpace(Attribute memorySpace)
static LogicalResult verifyBase(BaseOp op)
static bool staticallyOutOfBounds(OpType op)
static std::optional< uint32_t > getConstantUint32(Value v)
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
Attributes are known-constant values of operations.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI32IntegerAttr(int32_t value)
TypedAttr getZeroAttr(Type type)
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
This is a builder type that keeps local references to arguments.
Builder & setMemorySpace(Attribute newMemorySpace)
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This class helps build Operations.
This class represents a single result from folding an operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Type-safe wrapper around a void* for passing properties, including the properties structs of operatio...
This class provides an abstraction over the different types of ranges over Regions.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isFloat() const
Return true if this is an float type (with the specified width).
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
ParseResult parseMNKDimensionList(OpAsmParser &parser, IntegerAttr &m, IntegerAttr &n, IntegerAttr &k)
Parser for the custom<MNKDimensionList> custom assembly format used by WMMAOp.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given memref value.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
uint64_t getN(LevelType lt)
uint64_t getM(LevelType lt)
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< int64_t > computeSuffixProduct(ArrayRef< int64_t > sizes)
Given a set of sizes, return the suffix product.
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
llvm::function_ref< Fn > function_ref
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...