31#include "llvm/ADT/DenseMap.h"
32#include "llvm/ADT/SmallVector.h"
44template <
typename OpTy>
46 MemRefType memrefType,
48 int64_t rank = memrefType.getRank();
49 if (rank != numIndices)
50 return op.emitOpError(
"expected ")
51 << rank <<
" " << indexName <<
" indices, got " << numIndices;
58LogicalResult PackedTrunc2xFp8Op::verify() {
59 if (getExisting() && getExisting().
getType() != getResult().
getType())
60 return emitOpError(
"existing values must have same type as result");
64LogicalResult PackedStochRoundFp8Op::verify() {
65 if (getExisting() && getExisting().
getType() != getResult().
getType())
66 return emitOpError(
"existing values must have same type as result");
73LogicalResult PackedScaledTruncOp::verify() {
74 if (getExisting() && getExisting().
getType() != getResult().
getType())
75 return emitOpError(
"existing values must have same type as result");
92 amdgpu::AddressSpaceAttr::get(ctx, amdgpu::AddressSpace::FatRawBuffer));
93 MemRefLayoutAttrInterface layout = source.getLayout();
94 if (resetOffset && !layout.isIdentity()) {
95 auto stridedLayout = dyn_cast<StridedLayoutAttr>(layout);
98 MemRefLayoutAttrInterface newLayout =
99 StridedLayoutAttr::get(ctx, 0, stridedLayout.getStrides());
104 if (source.hasStaticShape()) {
106 }
else if (source.getRank() <= 1) {
109 if (stridesIfIdentity == stridedLayout.getStrides()) {
110 newLayout = AffineMapAttr::get(
115 return (MemRefType)(mb);
118LogicalResult FatRawBufferCastOp::inferReturnTypes(
122 Adaptor adaptor(operands, attributes, properties, regions);
124 dyn_cast_if_present<MemRefType>(adaptor.getSource().getType());
127 FailureOr<MemRefType> resultType =
135FailureOr<OpFoldResult> FatRawBufferCastOp::reifyDimOfResult(
OpBuilder &builder,
138 assert(resultIndex == 0 &&
"FatRawBufferCastOp has a single result");
142LogicalResult FatRawBufferCastOp::verify() {
143 FailureOr<MemRefType> expectedResultType =
145 if (
failed(expectedResultType))
147 << getSource().getType() <<
" can't have its offset reset";
148 if (getResult().
getType() != *expectedResultType)
150 << *expectedResultType <<
" but got " << getResult().getType();
159 MemRefType bufferType = llvm::cast<MemRefType>(op.getMemref().getType());
164 return op.emitOpError(
165 "buffer ops must operate on a memref in global memory");
166 if (!bufferType.hasRank())
167 return op.emitOpError(
168 "cannot meaningfully buffer_store to an unranked memref");
176LogicalResult RawBufferAtomicFaddOp::verify() {
180LogicalResult RawBufferAtomicFmaxOp::verify() {
184LogicalResult RawBufferAtomicSmaxOp::verify() {
188LogicalResult RawBufferAtomicUminOp::verify() {
192LogicalResult RawBufferAtomicCmpswapOp::verify() {
201 return cst.getZExtValue();
205template <
typename OpType>
207 if (!op.getBoundsCheck())
209 MemRefType bufferType = op.getMemref().getType();
210 if (!bufferType.hasStaticShape())
214 if (failed(bufferType.getStridesAndOffset(strides, offset)))
217 if (op.getSgprOffset()) {
223 if (strides.size() != op.getIndices().size())
226 for (
auto pair : llvm::zip(strides, op.getIndices())) {
227 int64_t stride = std::get<0>(pair);
228 Value idx = std::get<1>(pair);
232 indexVal += stride * *idxVal;
235 if (
result > std::numeric_limits<uint32_t>::max())
238 return result >= bufferType.getNumElements();
242template <
typename OpType>
243struct RemoveStaticallyOobBufferLoads final :
public OpRewritePattern<OpType> {
244 using OpRewritePattern<OpType>::OpRewritePattern;
246 LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw)
const override {
249 Type loadType = op.getResult().getType();
256template <
typename OpType>
257struct RemoveStaticallyOobBufferWrites final :
public OpRewritePattern<OpType> {
258 using OpRewritePattern<OpType>::OpRewritePattern;
260 LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw)
const override {
272 results.
add<RemoveStaticallyOobBufferLoads<RawBufferLoadOp>>(context);
277 results.
add<RemoveStaticallyOobBufferWrites<RawBufferStoreOp>>(context);
280void RawBufferAtomicFaddOp::getCanonicalizationPatterns(
282 results.
add<RemoveStaticallyOobBufferLoads<RawBufferAtomicFaddOp>>(context);
285void RawBufferAtomicFmaxOp::getCanonicalizationPatterns(
287 results.
add<RemoveStaticallyOobBufferLoads<RawBufferAtomicFmaxOp>>(context);
290void RawBufferAtomicSmaxOp::getCanonicalizationPatterns(
292 results.
add<RemoveStaticallyOobBufferLoads<RawBufferAtomicSmaxOp>>(context);
295void RawBufferAtomicUminOp::getCanonicalizationPatterns(
297 results.
add<RemoveStaticallyOobBufferLoads<RawBufferAtomicUminOp>>(context);
300void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
302 results.
add<RemoveStaticallyOobBufferLoads<RawBufferAtomicCmpswapOp>>(
309LogicalResult ScaledExtPackedMatrixOp::verify() {
311 assert(llvm::is_contained({16, 32}, blockSize) &&
"invalid block size");
313 int firstScaleByte = getFirstScaleByte();
314 int firstScaleLane = getFirstScaleLane();
315 auto sourceType = cast<VectorType>(getSource().
getType());
316 Type elementType = sourceType.getElementType();
317 auto floatType = cast<FloatType>(elementType);
318 unsigned bitWidth = floatType.getWidth();
322 const bool is_fp8 = bitWidth == 8;
323 const bool is_block_16 = blockSize == 16;
327 if (!llvm::is_contained({0, 1}, firstScaleByte)) {
328 return emitOpError(
"blockSize of 16 can only have firstScaleByte be 0 "
329 "or 1 for f4 and f6.");
332 if (!llvm::is_contained({0, 2}, firstScaleByte)) {
333 return emitOpError(
"blockSize of 32 can only have firstScaleByte be 0 "
334 "or 2 for f4 and f6.");
339 bool is_valid = ((firstScaleLane == 0) && (firstScaleByte == 0)) ||
340 ((firstScaleLane == 16) && (firstScaleByte == 2));
342 return emitOpError(
"blockSize of 16 can only have (firstScaleLane, "
343 "firstScaleByte) be (0, 0) or (16, 2) for f8.");
356 IntegerAttr &m, IntegerAttr &n,
361 if (dimensions.size() != 3)
363 <<
"expected 3 dimensions in MNK dimension list";
371LogicalResult WMMAOp::verify() {
372 auto sourceAType = cast<VectorType>(getSourceA().
getType());
373 auto sourceBType = cast<VectorType>(getSourceB().
getType());
374 auto destType = cast<VectorType>(getDestC().
getType());
376 Type sourceAElemType = sourceAType.getElementType();
377 Type sourceBElemType = sourceBType.getElementType();
378 if (sourceAType.getNumElements() != sourceBType.getNumElements()) {
379 return emitOpError(
"source vectors have different lengths: ")
380 << sourceAType <<
" vs. " << sourceBType;
383 bool isDestFloat = destType.getElementType().
isFloat();
384 bool isSrcFloat = sourceAElemType.
isFloat();
386 if (isDestFloat && !isSrcFloat)
387 return emitOpError(
"expected float sources with float destination");
388 if (!isDestFloat && isSrcFloat)
389 return emitOpError(
"expected int sources with int destination");
391 if (!sourceAElemType.
isFloat(8) && sourceAElemType != sourceBElemType) {
393 "source element types must match (except for fp8/bf8) but have ")
394 << sourceAType <<
" and " << sourceBType;
399 return emitOpError(
"clamp flag is not supported for float types");
400 if (getUnsignedA() || getUnsignedB())
401 return emitOpError(
"unsigned flags are not supported for float types");
410LogicalResult ScaledWMMAOp::verify() {
412 auto isF8 = llvm::IsaPred<Float8E4M3FNType, Float8E5M2Type>;
413 auto isF6 = llvm::IsaPred<Float6E2M3FNType, Float6E3M2FNType>;
414 auto isF4 = llvm::IsaPred<Float4E2M1FNType>;
415 auto isScaleF8 = llvm::IsaPred<Float8E8M0FNUType, Float8E4M3FNType>;
416 auto isE8M0 = llvm::IsaPred<Float8E8M0FNUType>;
417 auto isE4M3 = llvm::IsaPred<Float8E4M3FNType>;
419 auto sourceAType = cast<VectorType>(getSourceA().
getType());
420 auto sourceBType = cast<VectorType>(getSourceB().
getType());
421 auto destType = cast<VectorType>(getDestC().
getType());
424 Type aElemType = sourceAType.getElementType();
425 Type bElemType = sourceBType.getElementType();
429 int64_t aLen = sourceAType.getNumElements();
430 int64_t bLen = sourceBType.getNumElements();
431 int64_t expectedOutLen = (m == 16) ? 8 : 16;
433 if (destType.getNumElements() != expectedOutLen)
434 return emitOpError(
"expected output vector of length ")
435 << expectedOutLen <<
" but got " << destType.getNumElements();
441 "for 16x16x128, sourceA must have 64 elements but got ")
445 "for 16x16x128, sourceB must have 64 elements but got ")
449 if (!isF4(aElemType) && !isF4(bElemType))
450 return emitOpError(
"32x16x128 only supports fp4 element types");
454 "for 32x16x128, sourceA must have 128 elements but got ")
458 "for 32x16x128, sourceB must have 64 elements but got ")
463 if (getAFirstScaleLane() != 0)
464 return emitOpError(
"for 32x16x128, a_first_scale_lane must be 0");
468 auto scaleAType = cast<VectorType>(getScaleA().
getType());
469 auto scaleBType = cast<VectorType>(getScaleB().
getType());
470 Type scaleAElemType = scaleAType.getElementType();
471 Type scaleBElemType = scaleBType.getElementType();
474 if (!isScaleF8(scaleAElemType) || !isScaleF8(scaleBElemType))
476 "scale operands must have f8 element types (E8M0FNU or E4M3FN)");
479 if (isE8M0(scaleAElemType) && isE8M0(scaleBElemType))
483 if ((isF8(aElemType) || isF6(aElemType)) && isE8M0(scaleAElemType) &&
484 isF4(bElemType) && isE4M3(scaleBElemType))
488 if (isF4(aElemType) && isE4M3(scaleAElemType) &&
489 (isF8(bElemType) || isF6(bElemType)) && isE8M0(scaleBElemType))
493 if (isF4(aElemType) && isF4(bElemType) && isE4M3(scaleAElemType) &&
494 isE4M3(scaleBElemType))
498 return emitOpError(
"invalid combination of matrix and scale types: ")
499 <<
"sourceA=" << aElemType <<
", scaleA=" << scaleAElemType
500 <<
", sourceB=" << bElemType <<
", scaleB=" << scaleBElemType;
506LogicalResult MFMAOp::verify() {
507 constexpr uint32_t waveSize = 64;
510 Type sourceType = getSourceA().getType();
511 Type destType = getDestC().getType();
513 Type sourceElem = sourceType, destElem = destType;
514 uint32_t sourceLen = 1, destLen = 1;
515 if (
auto sourceVector = dyn_cast<VectorType>(sourceType)) {
516 sourceLen = sourceVector.getNumElements();
517 sourceElem = sourceVector.getElementType();
519 if (
auto destVector = dyn_cast<VectorType>(destType)) {
520 destLen = destVector.getNumElements();
521 destElem = destVector.getElementType();
524 Type sourceBType = getSourceB().getType();
527 Type sourceBElem = sourceBType;
528 if (
auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
529 sourceBLen = sourceBVector.getNumElements();
530 sourceBElem = sourceBVector.getElementType();
534 return emitOpError(
"expected both source operands to have small-float "
535 "elements if one does");
536 if (sourceLen != sourceBLen)
538 "expected both small-float source vectors to have the same length");
540 if (sourceType != sourceBType)
541 return emitOpError(
"expected both non-small-float source operand types "
547 sourceElem =
b.getI8Type();
551 sourceElem =
b.getI8Type();
554 int64_t numSourceElems = (
getM() * getK() * getBlocks()) / waveSize;
555 if (sourceLen != numSourceElems)
556 return emitOpError(
"expected " + Twine(numSourceElems) +
557 " source values for this operation but got " +
561 if (destLen != numDestElems)
562 return emitOpError(
"expected " + Twine(numDestElems) +
563 " result values for this operation but got " +
566 if (destElem.isF64() && getBlgp() != MFMAPermB::none)
568 "double-precision ops do not support permuting lanes of B");
569 if (destElem.isF64() && getCbsz() != 0)
571 "double-precision ops do not support permuting lanes of A");
572 if (getAbid() >= (1u << getCbsz()))
574 "block ID for permuting A (abid) must be below 2 ** cbsz");
576 if ((getNegateA() || getNegateB() || getNegateC()) && !destElem.isF64())
578 "negation flags only available for double-precision operations");
587LogicalResult SparseMFMAOp::verify() {
588 constexpr uint32_t waveSize = 64;
590 auto sparseType = cast<VectorType>(getSourceA().
getType());
591 auto denseType = cast<VectorType>(getSourceB().
getType());
592 auto destType = cast<VectorType>(getDestC().
getType());
594 Type sparseElem = sparseType.getElementType();
595 Type denseElem = denseType.getElementType();
596 int64_t sparseLen = sparseType.getNumElements();
597 int64_t denseLen = denseType.getNumElements();
598 int64_t destLen = destType.getNumElements();
600 if (denseLen != 2 * sparseLen)
601 return emitOpError(
"expected dense source operand to have exactly double "
602 "the number of elements of the sparse source operand");
608 if (!bothFloat8 && sparseElem != denseElem)
610 "expected source operands to have the same element type");
617 uint32_t m =
getM(), k = getK();
620 !is8BitSource && ((m == 16 && k == 32) || (m == 32 && k == 16));
622 is8BitSource && ((m == 16 && k == 128) || (m == 32 && k == 64));
631 "CBSZ must be 0 for this variant (field is ignored by hardware)");
634 "ABID must be 0 for this variant (field is ignored by hardware)");
635 }
else if (getCbsz() == 0) {
636 unsigned maxAbid = is16BitGfx942 ? 3u : 1u;
637 if (getAbid() > maxAbid)
639 << maxAbid <<
"] for this variant";
642 Type sparseIdxType = getSparseIdx().getType();
645 return emitOpError(
"expected i32 sparse indices for this variant "
646 "(no internal set structure), but got ")
649 unsigned expectedIdxElems = is16BitGfx942 ? 4 : 2;
650 unsigned expectedIdxBits = is16BitGfx942 ? 8 : 16;
651 auto vecType = dyn_cast<VectorType>(sparseIdxType);
652 if (!vecType || vecType.getNumElements() != expectedIdxElems ||
653 !vecType.getElementType().isInteger(expectedIdxBits))
655 << expectedIdxElems <<
"xi" << expectedIdxBits
656 <<
"> sparse indices for this variant, but got " << sparseIdxType;
659 int64_t expectedSourceElems = (
getM() * getK()) / waveSize;
660 if (denseLen != expectedSourceElems)
661 return emitOpError(
"expected " + Twine(expectedSourceElems) +
662 " source values for this operation but got " +
666 if (destLen != expectedDestElems)
667 return emitOpError(
"expected " + Twine(expectedDestElems) +
668 " result values for this operation but got " +
678LogicalResult SparseWMMAOp::verify() {
679 auto sparseType = cast<VectorType>(getSourceA().
getType());
680 auto denseType = cast<VectorType>(getSourceB().
getType());
681 auto destType = cast<VectorType>(getDestC().
getType());
683 Type sparseElem = sparseType.getElementType();
684 Type denseElem = denseType.getElementType();
685 Type destElem = destType.getElementType();
686 int64_t sparseLen = sparseType.getNumElements();
687 int64_t denseLen = denseType.getNumElements();
688 int64_t destLen = destType.getNumElements();
690 uint32_t m =
getM(), n =
getN(), k = getK();
691 if ((m != 16) || (n != 16))
692 return emitOpError(
"expected MxN to be exactly 16x16");
694 const bool isWavesize64 = getWave64();
696 const bool isEqualLengthAllowed = isWavesize64 && isInt4Input && k == 32;
698 if ((denseLen != 2 * sparseLen) && !isEqualLengthAllowed)
699 return emitOpError(
"expected dense source operand to have exactly double "
700 "the number of elements of the sparse source operand");
702 if (isEqualLengthAllowed && (denseLen != sparseLen))
703 return emitOpError(
"expected dense source operand to have exactly the "
704 "same the number of elements");
708 return emitOpError(
"source operand and destination operands must all be "
709 "either integer or float types");
715 return emitOpError(
"source operand and destination operands must all be "
716 "either integer or float types");
724 if (!bothFloat8 && sparseElem != denseElem)
726 "expected source operands to have the same element type");
728 const int64_t waveSize = isWavesize64 ? 64 : 32;
730 int64_t expectedSourceElems = (
getM() * getK()) / waveSize;
731 if (denseLen != expectedSourceElems)
732 return emitOpError(
"expected " + Twine(expectedSourceElems) +
733 " source values for this operation but got " +
737 if (destLen != expectedDestElems)
738 return emitOpError(
"expected " + Twine(expectedDestElems) +
739 " result values for this operation but got " +
748LogicalResult DotOp::verify() {
749 Type aElem = cast<VectorType>(getSourceA().
getType()).getElementType();
750 Type bElem = cast<VectorType>(getSourceB().
getType()).getElementType();
751 Type dest = getDestC().getType();
753 bool aIsFloat8 = aElem.
isFloat(8);
754 bool bIsFloat8 = bElem.
isFloat(8);
755 bool aIsInteger = isa<IntegerType>(aElem);
757 bool bothFloat8 = aIsFloat8 && bIsFloat8;
758 if (!bothFloat8 && aElem != bElem)
760 "expected source operands to have the same element type");
764 return emitOpError(
"expected f32 or f16 accumulator for f16 sources");
765 }
else if (aElem.
isBF16()) {
767 return emitOpError(
"expected f32 or bf16 accumulator for bf16 sources");
768 }
else if (aIsInteger) {
770 return emitOpError(
"expected i32 accumulator for integer sources");
771 }
else if (aIsFloat8) {
773 return emitOpError(
"expected f32 accumulator for fp8 sources");
776 if ((getUnsignedA() || getUnsignedB()) && !aIsInteger)
778 "unsignedA/unsignedB are only valid for integer source types");
780 if (aElem.
isInteger(16) && getUnsignedA() != getUnsignedB())
782 "mixed-sign dot is not supported for 16-bit integer sources");
785 bool noClamp = (aElem.
isF16() && dest.
isF16()) ||
789 "clamp is not supported for this (source, accumulator) combination");
798LogicalResult DPPOp::verify() {
799 DPPPerm kind = getKind();
804 case DPPPerm::quad_perm: {
805 auto quadPermAttr = dyn_cast_or_null<ArrayAttr>(permArgument);
806 if (!quadPermAttr || quadPermAttr.size() != 4) {
807 return emitOpError(
"quad_perm attribute must have exactly 4 elements");
809 for (
auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
810 int32_t num = elem.getInt();
811 if (num < 0 || num > 3) {
813 "Each element of quad_perm must be in the range [0, 3]");
818 case DPPPerm::row_shl:
819 case DPPPerm::row_shr:
820 case DPPPerm::row_ror: {
822 return emitOpError(
"Attribute '" + Twine(stringifyDPPPerm(kind)) +
823 "' value not specified");
825 if (
auto intAttr = dyn_cast<IntegerAttr>(permArgument)) {
826 uint32_t attrValue = intAttr.getInt();
827 if (attrValue < 1 || attrValue > 15) {
828 return emitOpError(
"Attribute value must be between 1 and 15");
833 case DPPPerm::wave_shl:
834 case DPPPerm::wave_shr:
835 case DPPPerm::wave_rol:
836 case DPPPerm::wave_ror:
837 case DPPPerm::row_mirror:
838 case DPPPerm::row_half_mirror:
839 case DPPPerm::row_bcast_15:
840 case DPPPerm::row_bcast_31: {
841 if (permArgument && !isa<UnitAttr>(permArgument)) {
842 return emitOpError(
"Expected unit attribute for permArgument, but found "
843 "non-trivial argument");
854LogicalResult PermlaneSwapOp::verify() {
855 unsigned rowLength = getRowLength();
857 if (rowLength != 16 && rowLength != 32)
858 return emitOpError(
"row_length attribute must either be 16 or 32.");
866 if (isa_and_nonnull<LDSBarrierOp>(op->getNextNode())) {
885struct FuseMemoryCounterWaitOp final :
OpRewritePattern<MemoryCounterWaitOp> {
888 LogicalResult matchAndRewrite(MemoryCounterWaitOp op,
889 PatternRewriter &rewriter)
const override {
890 auto next = dyn_cast<MemoryCounterWaitOp>(op->getNextNode());
894 auto setters = {&MemoryCounterWaitOp::setLoad,
895 &MemoryCounterWaitOp::setStore, &MemoryCounterWaitOp::setDs,
896 &MemoryCounterWaitOp::setExp,
897 &MemoryCounterWaitOp::setTensor};
898 auto lhsVals = {op.getLoad(), op.getStore(), op.getDs(), op.getExp(),
900 auto rhsVals = {next.getLoad(), next.getStore(), next.getDs(),
901 next.getExp(), next.getTensor()};
903 for (
auto [setter,
lhs,
rhs] :
904 llvm::zip_equal(setters, lhsVals, rhsVals)) {
906 (op.*setter)(std::min(*
lhs, *
rhs));
920void MemoryCounterWaitOp::getCanonicalizationPatterns(
922 results.
add<FuseMemoryCounterWaitOp>(context);
929LogicalResult GatherToLDSOp::verify() {
930 MemRefType srcType = cast<MemRefType>(getSrc().
getType());
931 MemRefType dstType = cast<MemRefType>(getDst().
getType());
936 getDstIndices().size())))
939 if (dstType.getRank() > 0 && !dstType.areTrailingDimsContiguous(1))
940 return emitOpError(
"destination type inner most dim must be contiguous");
942 auto elemType = srcType.getElementType();
944 if (elemType != dstType.getElementType())
945 return emitOpError(
"source and destination element types must match");
948 auto transferType = getTransferType();
950 if (
auto vectorTransfer = dyn_cast<VectorType>(transferType)) {
951 transferSize = vectorTransfer.getNumElements() *
952 vectorTransfer.getElementTypeBitWidth();
954 transferSize = transferType.getIntOrFloatBitWidth();
956 if (!llvm::is_contained({8, 16, 32, 96, 128}, transferSize))
958 "Transfering type size must be 8, 16, 32, 96 or 128 bits");
963 "source memory address space must be global or fat raw buffer");
966 return emitOpError(
"destination memory address space must be Workgroup");
977 LogicalResult matchAndRewrite(GatherToLDSOp gatherOp,
978 PatternRewriter &rewriter)
const override {
979 bool modified =
false;
980 auto foldCast = [&](OpOperand &operand) {
981 if (
auto castOp = operand.get().getDefiningOp<memref::CastOp>()) {
982 if (memref::CastOp::canFoldIntoConsumerOp(castOp)) {
984 [&] { operand.assign(castOp.getSource()); });
990 foldCast(gatherOp.getSrcMutable());
991 foldCast(gatherOp.getDstMutable());
1000 results.
add<FoldGatherToLDSOfCast>(context);
1007LogicalResult GlobalLoadAsyncToLDSOp::verify() {
1008 MemRefType srcType = cast<MemRefType>(getSrc().
getType());
1009 MemRefType dstType = cast<MemRefType>(getDst().
getType());
1014 getDstIndices().size())))
1017 if (srcType.getElementType() != dstType.getElementType())
1018 return emitOpError(
"source and destination element types must match");
1020 Type transferType = getTransferType();
1022 if (
auto vectorTransfer = dyn_cast<VectorType>(transferType)) {
1023 transferSize = vectorTransfer.getNumElements() *
1024 vectorTransfer.getElementTypeBitWidth();
1028 if (!llvm::is_contained({8, 32, 64, 128}, transferSize))
1029 return emitOpError(
"transfer type size must be 8, 32, 64, or 128 bits");
1032 return emitOpError(
"source memory address space must be global");
1035 return emitOpError(
"destination memory address space must be Workgroup");
1043 Value mask = op.getMask();
1051 if (maskValue.isZero()) {
1060void GlobalLoadAsyncToLDSOp::getCanonicalizationPatterns(
1069LogicalResult TransposeLoadOp::verify() {
1070 MemRefType srcType = cast<MemRefType>(getSrc().
getType());
1077 return emitOpError(
"source memory address space must be Workgroup");
1079 auto transferType = cast<VectorType>(
getType());
1080 size_t numElements = transferType.getNumElements();
1081 size_t elementTypeSize =
1084 auto emitNumElementsError = [&](StringRef expected) {
1086 "Transferring type size mismatch: expected num of elements: ")
1090 switch (elementTypeSize) {
1093 if (numElements != 16)
1094 return emitNumElementsError(
"16");
1097 if (numElements != 8)
1098 return emitNumElementsError(
"8");
1101 if (numElements != 4 && numElements != 8)
1102 return emitNumElementsError(
"4 or 8");
1105 return emitOpError(
"Unsupported element type size for transpose load: ")
1106 << elementTypeSize <<
" bits";
1116LogicalResult GlobalTransposeLoadOp::verify() {
1117 MemRefType srcType = cast<MemRefType>(getSrc().
getType());
1124 return emitOpError(
"source memory address space must be Global");
1126 auto resultType = cast<VectorType>(
getType());
1127 size_t numElements = resultType.getNumElements();
1128 size_t elementTypeSize = resultType.getElementType().getIntOrFloatBitWidth();
1132 static const llvm::SmallDenseMap<size_t, size_t> kValidLoadSizeMap = {
1139 auto validNumElems = kValidLoadSizeMap.find(elementTypeSize);
1140 if (validNumElems == kValidLoadSizeMap.end())
1142 "unsupported element type size for global transpose load: ")
1143 << elementTypeSize <<
" bits";
1145 if (numElements != validNumElems->second)
1147 "transferring type size mismatch: expected num of elements: ")
1148 << validNumElems->second;
1157template <
typename BaseOp>
1159 auto ldsType = cast<MemRefType>(op.getLds().getType());
1160 auto globalType = cast<MemRefType>(op.getGlobal().getType());
1162 op.getGlobalIndices().size())) ||
1167 return op.emitOpError(
1168 "lds memref must have workgroup address space attribute.");
1170 return op.emitOpError(
1171 "global memref must have global address space attribute.");
1173 Type elementType = ldsType.getElementType();
1176 if (!llvm::is_contained({8u, 16u, 32u, 64u}, width))
1177 return op.emitOpError(
1178 "element type must be 1, 2, 4, or 8 bytes long but type was ")
1179 << width <<
" bits long.";
1183LogicalResult MakeDmaBaseOp::verify() {
return verifyBase(*
this); }
1191 Type elementType,
Type indexType) {
1193 if (!llvm::is_contained({8u, 16u, 32u, 64u}, width))
1195 <<
"element type must be 1, 2, 4, or 8 bytes wide but type "
1196 << elementType <<
" is " << width / 8 <<
" bytes wide.";
1198 Type i16 = IntegerType::get(ctx, 32);
1199 Type i32 = IntegerType::get(ctx, 16);
1200 if (!llvm::is_contained({i16, i32}, indexType))
1201 return emitError() <<
"index type must be i16 or i32 but index type is "
1202 << indexType <<
".";
1206LogicalResult MakeGatherDmaBaseOp::verify() {
return verifyBase(*
this); }
1212template <
typename DescriptorOp>
1216 if (globalStaticStrides.empty())
1217 return op.emitOpError(
"strides must not be empty.");
1218 if (globalStaticStrides.back() != 1)
1219 return op.emitOpError(
"strides for the innermost dimension must be 1.");
1222 size_t rank = globalStaticSizes.size();
1224 return op.emitOpError(
"tensor and tile must be at most of rank 5.");
1225 if (rank != globalStaticStrides.size())
1226 return op.emitOpError(
"strides and sizes must have same rank.");
1229 if (rank != sharedStaticSizes.size())
1230 return op.emitOpError(
"tensor must have same rank as tile.");
1232 unsigned elementTypeWidth = op.getElementTypeWidth();
1233 if (!llvm::is_contained({8u, 16u, 32u, 64u}, elementTypeWidth))
1234 return op.emitOpError(
1235 "element type width must be 1, 2, 4 or 8 bytes, but was ")
1236 << elementTypeWidth <<
" bits long";
1238 if (!op.getAtomicBarrierAddress() && !op.getAtomicBarrierIndices().empty())
1239 return op.emitOpError(
1240 "atomic barrier indices require an atomic barrier address");
1242 if (
Value atomicBarrierAddress = op.getAtomicBarrierAddress()) {
1243 auto atomicBarrierAddressType =
1244 cast<MemRefType>(atomicBarrierAddress.getType());
1245 if (failed(
verifyIndexCount(op,
"atomic barrier", atomicBarrierAddressType,
1246 op.getAtomicBarrierIndices().size())))
1252 return op.emitOpError(
"atomic barrier address must be in LDS.");
1255 if (op.getEarlyTimeout() && !op.getWorkgroupMask())
1256 return op.emitOpError(
1257 "early timeout does not apply when workgroup_mask is not set.");
1261template <
typename DescriptorOp,
typename FoldAdaptor>
1282 op.setGlobalStaticSizes(staticGlobalSizes);
1283 op.getGlobalDynamicSizesMutable().assign(dynamicGlobalSizes);
1286 staticGlobalStrides);
1287 op.setGlobalStaticStrides(staticGlobalStrides);
1288 op.getGlobalDynamicStridesMutable().assign(dynamicGlobalStrides);
1292 op.setSharedStaticSizes(staticSharedSizes);
1293 op.getSharedDynamicSizesMutable().assign(dynamicSharedSizes);
1294 return op.getResult();
1297LogicalResult MakeDmaDescriptorOp::verify() {
1301OpFoldResult MakeDmaDescriptorOp::fold(FoldAdaptor adaptor) {
1309LogicalResult MakeGatherDmaDescriptorOp::verify() {
1311 size_t rank = globalStaticSizes.size();
1314 "tensor and tile must be at most of rank two in gather mode.");
1316 Type elementType = cast<VectorType>(
indices.getType()).getElementType();
1318 return emitOpError(
"indices' element type must match base's element type.");
1323OpFoldResult MakeGatherDmaDescriptorOp::fold(FoldAdaptor adaptor) {
1337 LogicalResult matchAndRewrite(ScaledMFMAOp op,
1338 PatternRewriter &rewriter)
const override {
1339 Location loc = op.getLoc();
1340 auto setOpsel = [&op](
unsigned idx, int64_t val) {
1343 op.setScalesIdxA(val);
1346 op.setScalesIdxB(val);
1370 for (
auto opIdx : std::array<int64_t, 2>({3, 4})) {
1371 auto insertOp = op.getOperand(opIdx).getDefiningOp<vector::InsertOp>();
1374 "defining op not a vector.insert");
1377 if (isa<VectorType>(insertOp.getValueToStore().getType())) {
1379 op,
"scaled mfma operand already packed");
1383 insertOp.getValueToStore().getDefiningOp<vector::ExtractOp>();
1386 "defining op not a vector.extract");
1389 Value scaleSrc = extractOp.getOperand(0);
1390 auto scaleSrcType = dyn_cast<VectorType>(scaleSrc.
getType());
1391 if (!scaleSrcType) {
1397 if (!scaleSrcType.hasStaticShape()) {
1399 "dynamic dims not yet supported");
1402 int64_t numElements = scaleSrcType.getNumElements();
1403 if (numElements < 4) {
1405 op,
"do not pack if # of scales less than four");
1409 auto extractedPos = llvm::to_vector_of<int64_t>(
1410 llvm::reverse(extractOp.getStaticPosition()));
1411 ArrayRef<int64_t> scaleSrcShape = scaleSrcType.getShape();
1412 int64_t scaleSrcRank = scaleSrcType.getRank();
1413 SmallVector<int64_t> extractSizes(scaleSrcRank, 1);
1414 for (int64_t i = 1; i < scaleSrcRank; ++i) {
1415 extractSizes[i] = extractSizes[i - 1] * scaleSrcShape[scaleSrcRank - i];
1417 int64_t idx =
linearize(extractedPos, extractSizes);
1429 int64_t offset = idx - (idx % 4);
1430 int64_t opsel = idx - offset;
1433 if (numElements - offset < size) {
1434 opsel = size - (numElements - idx);
1435 offset = numElements - 4l;
1437 Type scaleSrcElemType = scaleSrcType.getElementType();
1439 VectorType::get(ArrayRef{numElements}, scaleSrcElemType);
1441 vector::ShapeCastOp::create(rewriter, loc, newSrcType, scaleSrc);
1442 auto extract = vector::ExtractStridedSliceOp::create(
1443 rewriter, loc, newScaleSrc, ArrayRef{offset}, ArrayRef{size},
1444 ArrayRef{int64_t(1)});
1446 op->setOperand(opIdx, extract);
1447 setOpsel(opIdx, opsel);
1457 results.
add<PackScales>(context);
1464template <
typename T>
1466 MemRefType memrefType = llvm::cast<MemRefType>(op.getBase().getType());
1472 return op.emitOpError(
"barrier must be in workgroup (LDS) memory");
1477LogicalResult DsBarrierInitOp::verify() {
1481LogicalResult DsBarrierPollStateOp::verify() {
1485LogicalResult DsAsyncBarrierArriveOp::verify() {
1489LogicalResult DsBarrierArriveOp::verify() {
1497LogicalResult GlobalPrefetchOp::verify() {
1498 auto src = cast<MemRefType>(getSrc().
getType());
1503 Attribute memSpace = src.getMemorySpace();
1505 return this->
emitOpError(
"the source must have address space attribute");
1507 return this->
emitOpError(
"the source must reside in global address space");
1509 const LoadTemporalHint temporalHint = getTemporalHint();
1510 const Scope scope = getCacheScope();
1511 const bool isSpeculative = getSpeculative();
1514 if (isSpeculative && scope == Scope::WGP)
1516 "does not support speculative prefetch in WGP scope");
1524 if (llvm::is_contained({LoadTemporalHint::NT, LoadTemporalHint::LU},
1526 return this->
emitOpError(
"does not support NT and LU modes");
1528 if (llvm::is_contained({LoadTemporalHint::NT_RT, LoadTemporalHint::RT_NT,
1529 LoadTemporalHint::NT_HT},
1532 return this->
emitOpError(
"operates only in the speculative mode");
1537#define GET_OP_CLASSES
1538#include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
static LogicalResult verifyDescriptorOp(DescriptorOp op)
static LogicalResult verifyRawBufferOp(T &op)
static LogicalResult verifyDsBarrierOpCommon(T &op)
static OpFoldResult foldDescriptorOp(DescriptorOp op, FoldAdaptor adaptor)
static FailureOr< MemRefType > getFatRawBufferTypeLike(MemRefType source, bool resetOffset)
Convert the type source to one with the same sizes and strides - and offset, unless stripOffset is tr...
static LogicalResult verifyIndexCount(OpTy op, StringRef indexName, MemRefType memrefType, int64_t numIndices)
Verifies that the number of indices matches the rank of the indexed memref, emitting an op error ment...
static LogicalResult eraseRedundantLDSBarrierOps(LDSBarrierOp op, PatternRewriter &rewriter)
Remove amdgpu.lds_barrier after amdgpu.lds_barrier.
static LogicalResult foldGlobalLoadAsyncToLDSConstantMask(GlobalLoadAsyncToLDSOp op, PatternRewriter &rewriter)
static LogicalResult verifyBase(BaseOp op)
static bool staticallyOutOfBounds(OpType op)
static std::optional< uint32_t > getConstantUint32(Value v)
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
Attributes are known-constant values of operations.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI32IntegerAttr(int32_t value)
TypedAttr getZeroAttr(Type type)
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
This is a builder type that keeps local references to arguments.
Builder & setMemorySpace(Attribute newMemorySpace)
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This class helps build Operations.
This class represents a single result from folding an operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Type-safe wrapper around a void* for passing properties, including the properties structs of operatio...
This class provides an abstraction over the different types of ranges over Regions.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isFloat() const
Return true if this is an float type (with the specified width).
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
bool isGlobalMemorySpace(Attribute memorySpace, bool allowFlat)
ParseResult parseMNKDimensionList(OpAsmParser &parser, IntegerAttr &m, IntegerAttr &n, IntegerAttr &k)
Parser for the custom<MNKDimensionList> custom assembly format used by WMMAOp.
bool isWorkgroupMemorySpace(Attribute memorySpace)
bool isFatRawBufferMemorySpace(Attribute memorySpace)
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given memref value.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
uint64_t getN(LevelType lt)
uint64_t getM(LevelType lt)
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< int64_t > computeSuffixProduct(ArrayRef< int64_t > sizes)
Given a set of sizes, return the suffix product.
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
llvm::function_ref< Fn > function_ref
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...