30#include "llvm/ADT/DenseMap.h"
31#include "llvm/ADT/SmallVector.h"
44LogicalResult PackedTrunc2xFp8Op::verify() {
45 if (getExisting() && getExisting().
getType() != getResult().
getType())
46 return emitOpError(
"existing values must have same type as result");
50LogicalResult PackedStochRoundFp8Op::verify() {
51 if (getExisting() && getExisting().
getType() != getResult().
getType())
52 return emitOpError(
"existing values must have same type as result");
59LogicalResult PackedScaledTruncOp::verify() {
60 if (getExisting() && getExisting().
getType() != getResult().
getType())
61 return emitOpError(
"existing values must have same type as result");
78 amdgpu::AddressSpaceAttr::get(ctx, amdgpu::AddressSpace::FatRawBuffer));
79 MemRefLayoutAttrInterface layout = source.getLayout();
80 if (resetOffset && !layout.isIdentity()) {
81 auto stridedLayout = dyn_cast<StridedLayoutAttr>(layout);
84 MemRefLayoutAttrInterface newLayout =
85 StridedLayoutAttr::get(ctx, 0, stridedLayout.getStrides());
90 if (source.hasStaticShape()) {
92 }
else if (source.getRank() <= 1) {
95 if (stridesIfIdentity == stridedLayout.getStrides()) {
96 newLayout = AffineMapAttr::get(
101 return (MemRefType)(mb);
104LogicalResult FatRawBufferCastOp::inferReturnTypes(
108 Adaptor adaptor(operands, attributes, properties, regions);
110 dyn_cast_if_present<MemRefType>(adaptor.getSource().getType());
113 FailureOr<MemRefType> resultType =
121FailureOr<OpFoldResult> FatRawBufferCastOp::reifyDimOfResult(
OpBuilder &builder,
124 assert(resultIndex == 0 &&
"FatRawBufferCastOp has a single result");
128LogicalResult FatRawBufferCastOp::verify() {
129 FailureOr<MemRefType> expectedResultType =
131 if (
failed(expectedResultType))
133 << getSource().getType() <<
" can't have its offset reset";
134 if (getResult().
getType() != *expectedResultType)
136 << *expectedResultType <<
" but got " << getResult().getType();
143 if (
auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
144 return intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
145 if (
auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
146 return gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
153 if (
auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
154 return intMemorySpace.getInt() == 3;
155 if (
auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
156 return gpuMemorySpace.getValue() == gpu::AddressSpace::Workgroup;
163 if (
auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
164 return intMemorySpace.getInt() == 7;
165 if (
auto gpuMemorySpace = dyn_cast<amdgpu::AddressSpaceAttr>(memorySpace))
166 return gpuMemorySpace.getValue() == amdgpu::AddressSpace::FatRawBuffer;
175 MemRefType bufferType = llvm::cast<MemRefType>(op.getMemref().getType());
179 return op.emitOpError(
180 "Buffer ops must operate on a memref in global memory");
181 if (!bufferType.hasRank())
182 return op.emitOpError(
183 "Cannot meaningfully buffer_store to an unranked memref");
184 if (
static_cast<int64_t>(op.getIndices().size()) != bufferType.getRank())
185 return op.emitOpError(
"Expected " + Twine(bufferType.getRank()) +
186 " indices to memref");
194LogicalResult RawBufferAtomicFaddOp::verify() {
198LogicalResult RawBufferAtomicFmaxOp::verify() {
202LogicalResult RawBufferAtomicSmaxOp::verify() {
206LogicalResult RawBufferAtomicUminOp::verify() {
210LogicalResult RawBufferAtomicCmpswapOp::verify() {
219 return cst.getZExtValue();
223template <
typename OpType>
225 if (!op.getBoundsCheck())
227 MemRefType bufferType = op.getMemref().getType();
228 if (!bufferType.hasStaticShape())
232 if (failed(bufferType.getStridesAndOffset(strides, offset)))
235 if (op.getSgprOffset()) {
241 if (strides.size() != op.getIndices().size())
244 for (
auto pair : llvm::zip(strides, op.getIndices())) {
245 int64_t stride = std::get<0>(pair);
246 Value idx = std::get<1>(pair);
250 indexVal += stride * *idxVal;
253 if (
result > std::numeric_limits<uint32_t>::max())
256 return result >= bufferType.getNumElements();
260template <
typename OpType>
261struct RemoveStaticallyOobBufferLoads final :
public OpRewritePattern<OpType> {
262 using OpRewritePattern<OpType>::OpRewritePattern;
264 LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw)
const override {
267 Type loadType = op.getResult().getType();
274template <
typename OpType>
275struct RemoveStaticallyOobBufferWrites final :
public OpRewritePattern<OpType> {
276 using OpRewritePattern<OpType>::OpRewritePattern;
278 LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw)
const override {
290 results.
add<RemoveStaticallyOobBufferLoads<RawBufferLoadOp>>(context);
295 results.
add<RemoveStaticallyOobBufferWrites<RawBufferStoreOp>>(context);
298void RawBufferAtomicFaddOp::getCanonicalizationPatterns(
300 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFaddOp>>(context);
303void RawBufferAtomicFmaxOp::getCanonicalizationPatterns(
305 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFmaxOp>>(context);
308void RawBufferAtomicSmaxOp::getCanonicalizationPatterns(
310 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicSmaxOp>>(context);
313void RawBufferAtomicUminOp::getCanonicalizationPatterns(
315 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicUminOp>>(context);
318void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
320 results.
add<RemoveStaticallyOobBufferLoads<RawBufferAtomicCmpswapOp>>(
327LogicalResult ScaledExtPackedMatrixOp::verify() {
329 assert(llvm::is_contained({16, 32}, blockSize) &&
"invalid block size");
331 int firstScaleByte = getFirstScaleByte();
332 int firstScaleLane = getFirstScaleLane();
333 auto sourceType = cast<VectorType>(getSource().
getType());
334 Type elementType = sourceType.getElementType();
335 auto floatType = cast<FloatType>(elementType);
336 unsigned bitWidth = floatType.getWidth();
340 const bool is_fp8 = bitWidth == 8;
341 const bool is_block_16 = blockSize == 16;
345 if (!llvm::is_contained({0, 1}, firstScaleByte)) {
346 return emitOpError(
"blockSize of 16 can only have firstScaleByte be 0 "
347 "or 1 for f4 and f6.");
350 if (!llvm::is_contained({0, 2}, firstScaleByte)) {
351 return emitOpError(
"blockSize of 32 can only have firstScaleByte be 0 "
352 "or 2 for f4 and f6.");
357 bool is_valid = ((firstScaleLane == 0) && (firstScaleByte == 0)) ||
358 ((firstScaleLane == 16) && (firstScaleByte == 2));
360 return emitOpError(
"blockSize of 16 can only have (firstScaleLane, "
361 "firstScaleByte) be (0, 0) or (16, 2) for f8.");
374 IntegerAttr &m, IntegerAttr &n,
379 if (dimensions.size() != 3)
381 <<
"expected 3 dimensions in MNK dimension list";
389LogicalResult WMMAOp::verify() {
390 auto sourceAType = cast<VectorType>(getSourceA().
getType());
391 auto sourceBType = cast<VectorType>(getSourceB().
getType());
392 auto destType = cast<VectorType>(getDestC().
getType());
394 Type sourceAElemType = sourceAType.getElementType();
395 Type sourceBElemType = sourceBType.getElementType();
396 if (sourceAType.getNumElements() != sourceBType.getNumElements()) {
397 return emitOpError(
"source vectors have different lengths: ")
398 << sourceAType <<
" vs. " << sourceBType;
401 bool isDestFloat = destType.getElementType().
isFloat();
402 bool isSrcFloat = sourceAElemType.
isFloat();
404 if (isDestFloat && !isSrcFloat)
405 return emitOpError(
"expected float sources with float destination");
406 if (!isDestFloat && isSrcFloat)
407 return emitOpError(
"expected int sources with int destination");
409 if (!sourceAElemType.
isFloat(8) && sourceAElemType != sourceBElemType) {
411 "source element types must match (except for fp8/bf8) but have ")
412 << sourceAType <<
" and " << sourceBType;
417 return emitOpError(
"clamp flag is not supported for float types");
418 if (getUnsignedA() || getUnsignedB())
419 return emitOpError(
"unsigned flags are not supported for float types");
428LogicalResult ScaledWMMAOp::verify() {
430 auto isF8 = llvm::IsaPred<Float8E4M3FNType, Float8E5M2Type>;
431 auto isF6 = llvm::IsaPred<Float6E2M3FNType, Float6E3M2FNType>;
432 auto isF4 = llvm::IsaPred<Float4E2M1FNType>;
433 auto isScaleF8 = llvm::IsaPred<Float8E8M0FNUType, Float8E4M3FNType>;
434 auto isE8M0 = llvm::IsaPred<Float8E8M0FNUType>;
435 auto isE4M3 = llvm::IsaPred<Float8E4M3FNType>;
437 auto sourceAType = cast<VectorType>(getSourceA().
getType());
438 auto sourceBType = cast<VectorType>(getSourceB().
getType());
439 auto destType = cast<VectorType>(getDestC().
getType());
442 Type aElemType = sourceAType.getElementType();
443 Type bElemType = sourceBType.getElementType();
447 int64_t aLen = sourceAType.getNumElements();
448 int64_t bLen = sourceBType.getNumElements();
449 int64_t expectedOutLen = (m == 16) ? 8 : 16;
451 if (destType.getNumElements() != expectedOutLen)
452 return emitOpError(
"expected output vector of length ")
453 << expectedOutLen <<
" but got " << destType.getNumElements();
459 "for 16x16x128, sourceA must have 64 elements but got ")
463 "for 16x16x128, sourceB must have 64 elements but got ")
467 if (!isF4(aElemType) && !isF4(bElemType))
468 return emitOpError(
"32x16x128 only supports fp4 element types");
472 "for 32x16x128, sourceA must have 128 elements but got ")
476 "for 32x16x128, sourceB must have 64 elements but got ")
481 if (getAFirstScaleLane() != 0)
482 return emitOpError(
"for 32x16x128, a_first_scale_lane must be 0");
486 auto scaleAType = cast<VectorType>(getScaleA().
getType());
487 auto scaleBType = cast<VectorType>(getScaleB().
getType());
488 Type scaleAElemType = scaleAType.getElementType();
489 Type scaleBElemType = scaleBType.getElementType();
492 if (!isScaleF8(scaleAElemType) || !isScaleF8(scaleBElemType))
494 "scale operands must have f8 element types (E8M0FNU or E4M3FN)");
497 if (isE8M0(scaleAElemType) && isE8M0(scaleBElemType))
501 if ((isF8(aElemType) || isF6(aElemType)) && isE8M0(scaleAElemType) &&
502 isF4(bElemType) && isE4M3(scaleBElemType))
506 if (isF4(aElemType) && isE4M3(scaleAElemType) &&
507 (isF8(bElemType) || isF6(bElemType)) && isE8M0(scaleBElemType))
511 if (isF4(aElemType) && isF4(bElemType) && isE4M3(scaleAElemType) &&
512 isE4M3(scaleBElemType))
516 return emitOpError(
"invalid combination of matrix and scale types: ")
517 <<
"sourceA=" << aElemType <<
", scaleA=" << scaleAElemType
518 <<
", sourceB=" << bElemType <<
", scaleB=" << scaleBElemType;
524LogicalResult MFMAOp::verify() {
525 constexpr uint32_t waveSize = 64;
528 Type sourceType = getSourceA().getType();
529 Type destType = getDestC().getType();
531 Type sourceElem = sourceType, destElem = destType;
532 uint32_t sourceLen = 1, destLen = 1;
533 if (
auto sourceVector = dyn_cast<VectorType>(sourceType)) {
534 sourceLen = sourceVector.getNumElements();
535 sourceElem = sourceVector.getElementType();
537 if (
auto destVector = dyn_cast<VectorType>(destType)) {
538 destLen = destVector.getNumElements();
539 destElem = destVector.getElementType();
542 Type sourceBType = getSourceB().getType();
545 Type sourceBElem = sourceBType;
546 if (
auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
547 sourceBLen = sourceBVector.getNumElements();
548 sourceBElem = sourceBVector.getElementType();
552 return emitOpError(
"expected both source operands to have small-float "
553 "elements if one does");
554 if (sourceLen != sourceBLen)
556 "expected both small-float source vectors to have the same length");
558 if (sourceType != sourceBType)
559 return emitOpError(
"expected both non-small-float source operand types "
565 sourceElem =
b.getI8Type();
569 sourceElem =
b.getI8Type();
572 int64_t numSourceElems = (
getM() * getK() * getBlocks()) / waveSize;
573 if (sourceLen != numSourceElems)
574 return emitOpError(
"expected " + Twine(numSourceElems) +
575 " source values for this operation but got " +
579 if (destLen != numDestElems)
580 return emitOpError(
"expected " + Twine(numDestElems) +
581 " result values for this operation but got " +
584 if (destElem.isF64() && getBlgp() != MFMAPermB::none)
586 "double-precision ops do not support permuting lanes of B");
587 if (destElem.isF64() && getCbsz() != 0)
589 "double-precision ops do not support permuting lanes of A");
590 if (getAbid() >= (1u << getCbsz()))
592 "block ID for permuting A (abid) must be below 2 ** cbsz");
594 if ((getNegateA() || getNegateB() || getNegateC()) && !destElem.isF64())
596 "negation flags only available for double-precision operations");
605LogicalResult SparseMFMAOp::verify() {
606 constexpr uint32_t waveSize = 64;
608 auto sparseType = cast<VectorType>(getSourceA().
getType());
609 auto denseType = cast<VectorType>(getSourceB().
getType());
610 auto destType = cast<VectorType>(getDestC().
getType());
612 Type sparseElem = sparseType.getElementType();
613 Type denseElem = denseType.getElementType();
614 int64_t sparseLen = sparseType.getNumElements();
615 int64_t denseLen = denseType.getNumElements();
616 int64_t destLen = destType.getNumElements();
618 if (denseLen != 2 * sparseLen)
619 return emitOpError(
"expected dense source operand to have exactly double "
620 "the number of elements of the sparse source operand");
626 if (!bothFloat8 && sparseElem != denseElem)
628 "expected source operands to have the same element type");
634 if (getCbsz() == 0 && is8BitSource && getAbid() > 1)
635 return emitOpError(
"ABID must be 0 or 1 for 8-bit source data");
637 if (getCbsz() == 0 && !is8BitSource && getAbid() > 3)
638 return emitOpError(
"ABID must be between 0 and 3 for 16-bit source data");
641 auto sparseIdxType = cast<VectorType>(getSparseIdx().
getType());
644 if (sparseIdxType.getNumElements() != 2 ||
645 !sparseIdxType.getElementType().isInteger(16))
646 return emitOpError(
"expected vector<2xi16> sparse indices for 8-bit "
647 "source data, but got ")
648 << getSparseIdx().getType();
651 if (sparseIdxType.getNumElements() != 4 ||
652 !sparseIdxType.getElementType().isInteger(8))
653 return emitOpError(
"expected vector<4xi8> sparse indices for 16-bit "
654 "source data, but got ")
655 << getSparseIdx().getType();
658 int64_t expectedSourceElems = (
getM() * getK()) / waveSize;
659 if (denseLen != expectedSourceElems)
660 return emitOpError(
"expected " + Twine(expectedSourceElems) +
661 " source values for this operation but got " +
665 if (destLen != expectedDestElems)
666 return emitOpError(
"expected " + Twine(expectedDestElems) +
667 " result values for this operation but got " +
676LogicalResult DPPOp::verify() {
677 Type srcType = getSrc().getType();
679 return emitOpError(
"integer and floating point types larger than 64 bits "
680 "are not supported");
683 DPPPerm kind = getKind();
688 case DPPPerm::quad_perm: {
689 auto quadPermAttr = dyn_cast_or_null<ArrayAttr>(permArgument);
690 if (!quadPermAttr || quadPermAttr.size() != 4) {
691 return emitOpError(
"quad_perm attribute must have exactly 4 elements");
693 for (
auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
694 int32_t num = elem.getInt();
695 if (num < 0 || num > 3) {
697 "Each element of quad_perm must be in the range [0, 3]");
702 case DPPPerm::row_shl:
703 case DPPPerm::row_shr:
704 case DPPPerm::row_ror: {
706 return emitOpError(
"Attribute '" + Twine(stringifyDPPPerm(kind)) +
707 "' value not specified");
709 if (
auto intAttr = dyn_cast<IntegerAttr>(permArgument)) {
710 uint32_t attrValue = intAttr.getInt();
711 if (attrValue < 1 || attrValue > 15) {
712 return emitOpError(
"Attribute value must be between 1 and 15");
717 case DPPPerm::wave_shl:
718 case DPPPerm::wave_shr:
719 case DPPPerm::wave_rol:
720 case DPPPerm::wave_ror:
721 case DPPPerm::row_mirror:
722 case DPPPerm::row_half_mirror:
723 case DPPPerm::row_bcast_15:
724 case DPPPerm::row_bcast_31: {
725 if (permArgument && !isa<UnitAttr>(permArgument)) {
726 return emitOpError(
"Expected unit attribute for permArgument, but found "
727 "non-trivial argument");
738LogicalResult PermlaneSwapOp::verify() {
739 unsigned rowLength = getRowLength();
741 if (rowLength != 16 && rowLength != 32)
742 return emitOpError(
"row_length attribute must either be 16 or 32.");
750 if (isa_and_nonnull<LDSBarrierOp>(op->getNextNode())) {
769struct FuseMemoryCounterWaitOp final :
OpRewritePattern<MemoryCounterWaitOp> {
772 LogicalResult matchAndRewrite(MemoryCounterWaitOp op,
773 PatternRewriter &rewriter)
const override {
774 auto next = dyn_cast<MemoryCounterWaitOp>(op->getNextNode());
778 auto setters = {&MemoryCounterWaitOp::setLoad,
779 &MemoryCounterWaitOp::setStore, &MemoryCounterWaitOp::setDs,
780 &MemoryCounterWaitOp::setExp,
781 &MemoryCounterWaitOp::setTensor};
782 auto lhsVals = {op.getLoad(), op.getStore(), op.getDs(), op.getExp(),
784 auto rhsVals = {next.getLoad(), next.getStore(), next.getDs(),
785 next.getExp(), next.getTensor()};
787 for (
auto [setter,
lhs,
rhs] :
788 llvm::zip_equal(setters, lhsVals, rhsVals)) {
790 (op.*setter)(std::min(*
lhs, *
rhs));
804void MemoryCounterWaitOp::getCanonicalizationPatterns(
806 results.
add<FuseMemoryCounterWaitOp>(context);
813LogicalResult GatherToLDSOp::verify() {
814 MemRefType srcType = cast<MemRefType>(getSrc().
getType());
815 MemRefType dstType = cast<MemRefType>(getDst().
getType());
817 if (dstType.getRank() > 0 && !dstType.areTrailingDimsContiguous(1))
818 return emitOpError(
"destination type inner most dim must be contiguous");
820 auto elemType = srcType.getElementType();
822 if (elemType != dstType.getElementType())
823 return emitOpError(
"source and destination element types must match");
826 auto transferType = getTransferType();
828 if (
auto vectorTransfer = dyn_cast<VectorType>(transferType)) {
829 transferSize = vectorTransfer.getNumElements() *
830 vectorTransfer.getElementTypeBitWidth();
832 transferSize = transferType.getIntOrFloatBitWidth();
834 if (!llvm::is_contained({8, 16, 32, 96, 128}, transferSize))
836 "Transfering type size must be 8, 16, 32, 96 or 128 bits");
841 "source memory address space must be global or fat raw buffer");
844 return emitOpError(
"destination memory address space must be Workgroup");
855 LogicalResult matchAndRewrite(GatherToLDSOp gatherOp,
856 PatternRewriter &rewriter)
const override {
857 bool modified =
false;
858 auto foldCast = [&](OpOperand &operand) {
859 if (
auto castOp = operand.get().getDefiningOp<memref::CastOp>()) {
860 if (memref::CastOp::canFoldIntoConsumerOp(castOp)) {
862 [&] { operand.assign(castOp.getSource()); });
868 foldCast(gatherOp.getSrcMutable());
869 foldCast(gatherOp.getDstMutable());
878 results.
add<FoldGatherToLDSOfCast>(context);
885LogicalResult TransposeLoadOp::verify() {
886 MemRefType srcType = cast<MemRefType>(getSrc().
getType());
889 return emitOpError(
"source memory address space must be Workgroup");
891 auto transferType = cast<VectorType>(
getType());
892 size_t numElements = transferType.getNumElements();
893 size_t elementTypeSize =
894 transferType.getElementType().getIntOrFloatBitWidth();
897 const llvm::SmallDenseMap<size_t, size_t> kValidLoadSizeMap = {
904 auto validNumElems = kValidLoadSizeMap.find(elementTypeSize);
905 if (validNumElems == kValidLoadSizeMap.end())
906 return emitOpError(
"Unsupported element type size for transpose load: ")
907 << elementTypeSize <<
" bits";
909 if (numElements != validNumElems->second)
911 "Transferring type size mismatch: expected num of elements: ")
912 << validNumElems->second;
921template <
typename BaseOp>
923 auto ldsType = cast<MemRefType>(op.getLds().getType());
924 auto globalType = cast<MemRefType>(op.getGlobal().getType());
926 return op.emitOpError(
927 "lds memref must have workgroup address space attribute.");
929 return op.emitOpError(
930 "global memref must have global address space attribute.");
932 Type elementType = ldsType.getElementType();
935 if (!llvm::is_contained({8u, 16u, 32u, 64u}, width))
936 return op.emitOpError(
937 "element type must be 1, 2, 4, or 8 bytes long but type was ")
938 << width <<
" bits long.";
942LogicalResult MakeDmaBaseOp::verify() {
return verifyBase(*
this); }
952 if (!llvm::is_contained({8u, 16u, 32u, 64u}, width))
954 <<
"element type must be 1, 2, 4, or 8 bytes wide but type "
955 << elementType <<
" is " << width / 8 <<
" bytes wide.";
957 Type i16 = IntegerType::get(ctx, 32);
958 Type i32 = IntegerType::get(ctx, 16);
959 if (!llvm::is_contained({i16, i32}, indexType))
960 return emitError() <<
"index type must be i16 or i32 but index type is "
965LogicalResult MakeGatherDmaBaseOp::verify() {
return verifyBase(*
this); }
971template <
typename DescriptorOp>
975 if (globalStaticStrides.empty())
976 return op.emitOpError(
"strides must not be empty.");
977 if (globalStaticStrides.back() != 1)
978 return op.emitOpError(
"strides for the innermost dimension must be 1.");
981 size_t rank = globalStaticSizes.size();
983 return op.emitOpError(
"tensor and tile must be at most of rank 5.");
984 if (rank != globalStaticStrides.size())
985 return op.emitOpError(
"strides and sizes must have same rank.");
988 if (rank != sharedStaticSizes.size())
989 return op.emitOpError(
"tensor must have same rank as tile.");
991 unsigned elementTypeWidth = op.getElementTypeWidth();
992 if (!llvm::is_contained({8u, 16u, 32u, 64u}, elementTypeWidth))
993 return op.emitOpError(
994 "element type width must be 1, 2, 4 or 8 bytes, but was ")
995 << elementTypeWidth <<
" bits long";
997 if (
Value atomicBarrierAddress = op.getAtomicBarrierAddress()) {
998 auto atomicBarrierAddressType =
999 cast<MemRefType>(atomicBarrierAddress.getType());
1003 return op.emitOpError(
"atomic barrier address must be in LDS.");
1006 if (op.getEarlyTimeout() && !op.getWorkgroupMask())
1007 return op.emitOpError(
1008 "early timeout does not apply when workgroup_mask is not set.");
1012template <
typename DescriptorOp,
typename FoldAdaptor>
1033 op.setGlobalStaticSizes(staticGlobalSizes);
1034 op.getGlobalDynamicSizesMutable().assign(dynamicGlobalSizes);
1037 staticGlobalStrides);
1038 op.setGlobalStaticStrides(staticGlobalStrides);
1039 op.getGlobalDynamicStridesMutable().assign(dynamicGlobalStrides);
1043 op.setSharedStaticSizes(staticSharedSizes);
1044 op.getSharedDynamicSizesMutable().assign(dynamicSharedSizes);
1045 return op.getResult();
1048LogicalResult MakeDmaDescriptorOp::verify() {
1052OpFoldResult MakeDmaDescriptorOp::fold(FoldAdaptor adaptor) {
1060LogicalResult MakeGatherDmaDescriptorOp::verify() {
1062 size_t rank = globalStaticSizes.size();
1065 "tensor and tile must be at most of rank two in gather mode.");
1067 Type elementType = cast<VectorType>(
indices.getType()).getElementType();
1069 return emitOpError(
"indices' element type must match base's element type.");
1074OpFoldResult MakeGatherDmaDescriptorOp::fold(FoldAdaptor adaptor) {
1088 LogicalResult matchAndRewrite(ScaledMFMAOp op,
1089 PatternRewriter &rewriter)
const override {
1090 Location loc = op.getLoc();
1091 auto setOpsel = [&op](
unsigned idx, int64_t val) {
1094 op.setScalesIdxA(val);
1097 op.setScalesIdxB(val);
1121 for (
auto opIdx : std::array<int64_t, 2>({3, 4})) {
1122 auto insertOp = op.getOperand(opIdx).getDefiningOp<vector::InsertOp>();
1125 "defining op not a vector.insert");
1128 if (isa<VectorType>(insertOp.getValueToStore().getType())) {
1130 op,
"scaled mfma operand already packed");
1134 insertOp.getValueToStore().getDefiningOp<vector::ExtractOp>();
1137 "defining op not a vector.extract");
1140 Value scaleSrc = extractOp.getOperand(0);
1141 auto scaleSrcType = dyn_cast<VectorType>(scaleSrc.
getType());
1142 if (!scaleSrcType) {
1148 if (!scaleSrcType.hasStaticShape()) {
1150 "dynamic dims not yet supported");
1153 int64_t numElements = scaleSrcType.getNumElements();
1154 if (numElements <= 4) {
1156 op,
"no packing if # of scales less than four");
1160 auto extractedPos = llvm::to_vector_of<int64_t>(
1161 llvm::reverse(extractOp.getStaticPosition()));
1162 ArrayRef<int64_t> scaleSrcShape = scaleSrcType.getShape();
1163 int64_t scaleSrcRank = scaleSrcType.getRank();
1164 SmallVector<int64_t> extractSizes(scaleSrcRank, 1);
1165 for (int64_t i = 1; i < scaleSrcRank; ++i) {
1166 extractSizes[i] = extractSizes[i - 1] * scaleSrcShape[scaleSrcRank - i];
1168 int64_t idx =
linearize(extractedPos, extractSizes);
1180 int64_t offset = idx - (idx % 4);
1181 int64_t opsel = idx - offset;
1184 if (numElements - offset < size) {
1185 opsel = size - (numElements - idx);
1186 offset = numElements - 4l;
1188 Type scaleSrcElemType = scaleSrcType.getElementType();
1190 VectorType::get(ArrayRef{numElements}, scaleSrcElemType);
1192 vector::ShapeCastOp::create(rewriter, loc, newSrcType, scaleSrc);
1193 auto extract = vector::ExtractStridedSliceOp::create(
1194 rewriter, loc, newScaleSrc, ArrayRef{offset}, ArrayRef{size},
1195 ArrayRef{int64_t(1)});
1197 op->setOperand(opIdx, extract);
1198 setOpsel(opIdx, opsel);
1208 results.
add<PackScales>(context);
1211#define GET_OP_CLASSES
1212#include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
static LogicalResult verifyDescriptorOp(DescriptorOp op)
static LogicalResult verifyRawBufferOp(T &op)
static OpFoldResult foldDescriptorOp(DescriptorOp op, FoldAdaptor adaptor)
static bool hasGlobalMemorySpace(Attribute memorySpace)
static bool hasWorkgroupMemorySpace(Attribute memorySpace)
static FailureOr< MemRefType > getFatRawBufferTypeLike(MemRefType source, bool resetOffset)
Convert the type source to one with the same sizes and strides - and offset, unless stripOffset is tr...
static LogicalResult eraseRedundantLDSBarrierOps(LDSBarrierOp op, PatternRewriter &rewriter)
Remove amdgpu.lds_barrier after amdgpu.lds_barrier.
static bool hasFatRawBufferMemorySpace(Attribute memorySpace)
static LogicalResult verifyBase(BaseOp op)
static bool staticallyOutOfBounds(OpType op)
static std::optional< uint32_t > getConstantUint32(Value v)
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
Attributes are known-constant values of operations.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI32IntegerAttr(int32_t value)
TypedAttr getZeroAttr(Type type)
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
This is a builder type that keeps local references to arguments.
Builder & setMemorySpace(Attribute newMemorySpace)
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This class helps build Operations.
This class represents a single result from folding an operation.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class provides an abstraction over the different types of ranges over Regions.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isFloat() const
Return true if this is an float type (with the specified width).
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
ParseResult parseMNKDimensionList(OpAsmParser &parser, IntegerAttr &m, IntegerAttr &n, IntegerAttr &k)
Parser for the custom<MNKDimensionList> custom assembly format used by WMMAOp.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given memref value.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
uint64_t getN(LevelType lt)
uint64_t getM(LevelType lt)
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< int64_t > computeSuffixProduct(ArrayRef< int64_t > sizes)
Given a set of sizes, return the suffix product.
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
llvm::function_ref< Fn > function_ref
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...