30#include "llvm/ADT/DenseMap.h"
31#include "llvm/ADT/SmallVector.h"
44LogicalResult PackedTrunc2xFp8Op::verify() {
45 if (getExisting() && getExisting().
getType() != getResult().
getType())
46 return emitOpError(
"existing values must have same type as result");
50LogicalResult PackedStochRoundFp8Op::verify() {
51 if (getExisting() && getExisting().
getType() != getResult().
getType())
52 return emitOpError(
"existing values must have same type as result");
59LogicalResult PackedScaledTruncOp::verify() {
60 if (getExisting() && getExisting().
getType() != getResult().
getType())
61 return emitOpError(
"existing values must have same type as result");
78 amdgpu::AddressSpaceAttr::get(ctx, amdgpu::AddressSpace::FatRawBuffer));
79 MemRefLayoutAttrInterface layout = source.getLayout();
80 if (resetOffset && !layout.isIdentity()) {
81 auto stridedLayout = dyn_cast<StridedLayoutAttr>(layout);
84 MemRefLayoutAttrInterface newLayout =
85 StridedLayoutAttr::get(ctx, 0, stridedLayout.getStrides());
90 if (source.hasStaticShape()) {
92 }
else if (source.getRank() <= 1) {
95 if (stridesIfIdentity == stridedLayout.getStrides()) {
96 newLayout = AffineMapAttr::get(
101 return (MemRefType)(mb);
104LogicalResult FatRawBufferCastOp::inferReturnTypes(
108 Adaptor adaptor(operands, attributes, properties, regions);
110 dyn_cast_if_present<MemRefType>(adaptor.getSource().getType());
113 FailureOr<MemRefType> resultType =
121FailureOr<OpFoldResult> FatRawBufferCastOp::reifyDimOfResult(
OpBuilder &builder,
124 assert(resultIndex == 0 &&
"FatRawBufferCastOp has a single result");
128LogicalResult FatRawBufferCastOp::verify() {
129 FailureOr<MemRefType> expectedResultType =
131 if (
failed(expectedResultType))
133 << getSource().getType() <<
" can't have its offset reset";
134 if (getResult().
getType() != *expectedResultType)
136 << *expectedResultType <<
" but got " << getResult().getType();
143 if (
auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
144 return intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
145 if (
auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
146 return gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
153 if (
auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
154 return intMemorySpace.getInt() == 3;
155 if (
auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
156 return gpuMemorySpace.getValue() == gpu::AddressSpace::Workgroup;
163 if (
auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
164 return intMemorySpace.getInt() == 7;
165 if (
auto gpuMemorySpace = dyn_cast<amdgpu::AddressSpaceAttr>(memorySpace))
166 return gpuMemorySpace.getValue() == amdgpu::AddressSpace::FatRawBuffer;
175 MemRefType bufferType = llvm::cast<MemRefType>(op.getMemref().getType());
179 return op.emitOpError(
180 "Buffer ops must operate on a memref in global memory");
181 if (!bufferType.hasRank())
182 return op.emitOpError(
183 "Cannot meaningfully buffer_store to an unranked memref");
184 if (
static_cast<int64_t>(op.getIndices().size()) != bufferType.getRank())
185 return op.emitOpError(
"Expected " + Twine(bufferType.getRank()) +
186 " indices to memref");
194LogicalResult RawBufferAtomicFaddOp::verify() {
198LogicalResult RawBufferAtomicFmaxOp::verify() {
202LogicalResult RawBufferAtomicSmaxOp::verify() {
206LogicalResult RawBufferAtomicUminOp::verify() {
210LogicalResult RawBufferAtomicCmpswapOp::verify() {
219 return cst.getZExtValue();
223template <
typename OpType>
225 if (!op.getBoundsCheck())
227 MemRefType bufferType = op.getMemref().getType();
228 if (!bufferType.hasStaticShape())
232 if (failed(bufferType.getStridesAndOffset(strides, offset)))
235 if (op.getSgprOffset()) {
241 if (strides.size() != op.getIndices().size())
244 for (
auto pair : llvm::zip(strides, op.getIndices())) {
245 int64_t stride = std::get<0>(pair);
246 Value idx = std::get<1>(pair);
250 indexVal += stride * *idxVal;
253 if (
result > std::numeric_limits<uint32_t>::max())
256 return result >= bufferType.getNumElements();
260template <
typename OpType>
261struct RemoveStaticallyOobBufferLoads final :
public OpRewritePattern<OpType> {
262 using OpRewritePattern<OpType>::OpRewritePattern;
264 LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw)
const override {
267 Type loadType = op.getResult().getType();
274template <
typename OpType>
275struct RemoveStaticallyOobBufferWrites final :
public OpRewritePattern<OpType> {
276 using OpRewritePattern<OpType>::OpRewritePattern;
278 LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw)
const override {
290 results.
add<RemoveStaticallyOobBufferLoads<RawBufferLoadOp>>(context);
295 results.
add<RemoveStaticallyOobBufferWrites<RawBufferStoreOp>>(context);
298void RawBufferAtomicFaddOp::getCanonicalizationPatterns(
300 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFaddOp>>(context);
303void RawBufferAtomicFmaxOp::getCanonicalizationPatterns(
305 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFmaxOp>>(context);
308void RawBufferAtomicSmaxOp::getCanonicalizationPatterns(
310 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicSmaxOp>>(context);
313void RawBufferAtomicUminOp::getCanonicalizationPatterns(
315 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicUminOp>>(context);
318void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
320 results.
add<RemoveStaticallyOobBufferLoads<RawBufferAtomicCmpswapOp>>(
327LogicalResult ScaledExtPackedMatrixOp::verify() {
329 assert(llvm::is_contained({16, 32}, blockSize) &&
"invalid block size");
331 int firstScaleByte = getFirstScaleByte();
332 int firstScaleLane = getFirstScaleLane();
333 auto sourceType = cast<VectorType>(getSource().
getType());
334 Type elementType = sourceType.getElementType();
335 auto floatType = cast<FloatType>(elementType);
336 unsigned bitWidth = floatType.getWidth();
340 const bool is_fp8 = bitWidth == 8;
341 const bool is_block_16 = blockSize == 16;
345 if (!llvm::is_contained({0, 1}, firstScaleByte)) {
346 return emitOpError(
"blockSize of 16 can only have firstScaleByte be 0 "
347 "or 1 for f4 and f6.");
350 if (!llvm::is_contained({0, 2}, firstScaleByte)) {
351 return emitOpError(
"blockSize of 32 can only have firstScaleByte be 0 "
352 "or 2 for f4 and f6.");
357 bool is_valid = ((firstScaleLane == 0) && (firstScaleByte == 0)) ||
358 ((firstScaleLane == 16) && (firstScaleByte == 2));
360 return emitOpError(
"blockSize of 16 can only have (firstScaleLane, "
361 "firstScaleByte) be (0, 0) or (16, 2) for f8.");
374 IntegerAttr &m, IntegerAttr &n,
379 if (dimensions.size() != 3)
381 <<
"expected 3 dimensions in MNK dimension list";
389LogicalResult WMMAOp::verify() {
390 auto sourceAType = cast<VectorType>(getSourceA().
getType());
391 auto sourceBType = cast<VectorType>(getSourceB().
getType());
392 auto destType = cast<VectorType>(getDestC().
getType());
394 Type sourceAElemType = sourceAType.getElementType();
395 Type sourceBElemType = sourceBType.getElementType();
396 if (sourceAType.getNumElements() != sourceBType.getNumElements()) {
397 return emitOpError(
"source vectors have different lengths: ")
398 << sourceAType <<
" vs. " << sourceBType;
401 bool isDestFloat = destType.getElementType().
isFloat();
402 bool isSrcFloat = sourceAElemType.
isFloat();
404 if (isDestFloat && !isSrcFloat)
405 return emitOpError(
"expected float sources with float destination");
406 if (!isDestFloat && isSrcFloat)
407 return emitOpError(
"expected int sources with int destination");
409 if (!sourceAElemType.
isFloat(8) && sourceAElemType != sourceBElemType) {
411 "source element types must match (except for fp8/bf8) but have ")
412 << sourceAType <<
" and " << sourceBType;
417 return emitOpError(
"clamp flag is not supported for float types");
418 if (getUnsignedA() || getUnsignedB())
419 return emitOpError(
"unsigned flags are not supported for float types");
428LogicalResult ScaledWMMAOp::verify() {
430 auto isF8 = llvm::IsaPred<Float8E4M3FNType, Float8E5M2Type>;
431 auto isF6 = llvm::IsaPred<Float6E2M3FNType, Float6E3M2FNType>;
432 auto isF4 = llvm::IsaPred<Float4E2M1FNType>;
433 auto isScaleF8 = llvm::IsaPred<Float8E8M0FNUType, Float8E4M3FNType>;
434 auto isE8M0 = llvm::IsaPred<Float8E8M0FNUType>;
435 auto isE4M3 = llvm::IsaPred<Float8E4M3FNType>;
437 auto sourceAType = cast<VectorType>(getSourceA().
getType());
438 auto sourceBType = cast<VectorType>(getSourceB().
getType());
439 auto destType = cast<VectorType>(getDestC().
getType());
442 Type aElemType = sourceAType.getElementType();
443 Type bElemType = sourceBType.getElementType();
447 int64_t aLen = sourceAType.getNumElements();
448 int64_t bLen = sourceBType.getNumElements();
449 int64_t expectedOutLen = (m == 16) ? 8 : 16;
451 if (destType.getNumElements() != expectedOutLen)
452 return emitOpError(
"expected output vector of length ")
453 << expectedOutLen <<
" but got " << destType.getNumElements();
459 "for 16x16x128, sourceA must have 64 elements but got ")
463 "for 16x16x128, sourceB must have 64 elements but got ")
467 if (!isF4(aElemType) && !isF4(bElemType))
468 return emitOpError(
"32x16x128 only supports fp4 element types");
472 "for 32x16x128, sourceA must have 128 elements but got ")
476 "for 32x16x128, sourceB must have 64 elements but got ")
481 if (getAFirstScaleLane() != 0)
482 return emitOpError(
"for 32x16x128, a_first_scale_lane must be 0");
486 auto scaleAType = cast<VectorType>(getScaleA().
getType());
487 auto scaleBType = cast<VectorType>(getScaleB().
getType());
488 Type scaleAElemType = scaleAType.getElementType();
489 Type scaleBElemType = scaleBType.getElementType();
492 if (!isScaleF8(scaleAElemType) || !isScaleF8(scaleBElemType))
494 "scale operands must have f8 element types (E8M0FNU or E4M3FN)");
497 if (isE8M0(scaleAElemType) && isE8M0(scaleBElemType))
501 if ((isF8(aElemType) || isF6(aElemType)) && isE8M0(scaleAElemType) &&
502 isF4(bElemType) && isE4M3(scaleBElemType))
506 if (isF4(aElemType) && isE4M3(scaleAElemType) &&
507 (isF8(bElemType) || isF6(bElemType)) && isE8M0(scaleBElemType))
511 if (isF4(aElemType) && isF4(bElemType) && isE4M3(scaleAElemType) &&
512 isE4M3(scaleBElemType))
516 return emitOpError(
"invalid combination of matrix and scale types: ")
517 <<
"sourceA=" << aElemType <<
", scaleA=" << scaleAElemType
518 <<
", sourceB=" << bElemType <<
", scaleB=" << scaleBElemType;
524LogicalResult MFMAOp::verify() {
525 constexpr uint32_t waveSize = 64;
528 Type sourceType = getSourceA().getType();
529 Type destType = getDestC().getType();
531 Type sourceElem = sourceType, destElem = destType;
532 uint32_t sourceLen = 1, destLen = 1;
533 if (
auto sourceVector = dyn_cast<VectorType>(sourceType)) {
534 sourceLen = sourceVector.getNumElements();
535 sourceElem = sourceVector.getElementType();
537 if (
auto destVector = dyn_cast<VectorType>(destType)) {
538 destLen = destVector.getNumElements();
539 destElem = destVector.getElementType();
542 Type sourceBType = getSourceB().getType();
545 Type sourceBElem = sourceBType;
546 if (
auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
547 sourceBLen = sourceBVector.getNumElements();
548 sourceBElem = sourceBVector.getElementType();
552 return emitOpError(
"expected both source operands to have small-float "
553 "elements if one does");
554 if (sourceLen != sourceBLen)
556 "expected both small-float source vectors to have the same length");
558 if (sourceType != sourceBType)
559 return emitOpError(
"expected both non-small-float source operand types "
565 sourceElem =
b.getI8Type();
569 sourceElem =
b.getI8Type();
572 int64_t numSourceElems = (
getM() * getK() * getBlocks()) / waveSize;
573 if (sourceLen != numSourceElems)
574 return emitOpError(
"expected " + Twine(numSourceElems) +
575 " source values for this operation but got " +
579 if (destLen != numDestElems)
580 return emitOpError(
"expected " + Twine(numDestElems) +
581 " result values for this operation but got " +
584 if (destElem.isF64() && getBlgp() != MFMAPermB::none)
586 "double-precision ops do not support permuting lanes of B");
587 if (destElem.isF64() && getCbsz() != 0)
589 "double-precision ops do not support permuting lanes of A");
590 if (getAbid() >= (1u << getCbsz()))
592 "block ID for permuting A (abid) must be below 2 ** cbsz");
594 if ((getNegateA() || getNegateB() || getNegateC()) && !destElem.isF64())
596 "negation flags only available for double-precision operations");
605LogicalResult SparseMFMAOp::verify() {
606 constexpr uint32_t waveSize = 64;
608 auto sparseType = cast<VectorType>(getSourceA().
getType());
609 auto denseType = cast<VectorType>(getSourceB().
getType());
610 auto destType = cast<VectorType>(getDestC().
getType());
612 Type sparseElem = sparseType.getElementType();
613 Type denseElem = denseType.getElementType();
614 int64_t sparseLen = sparseType.getNumElements();
615 int64_t denseLen = denseType.getNumElements();
616 int64_t destLen = destType.getNumElements();
618 if (denseLen != 2 * sparseLen)
619 return emitOpError(
"expected dense source operand to have exactly double "
620 "the number of elements of the sparse source operand");
626 if (!bothFloat8 && sparseElem != denseElem)
628 "expected source operands to have the same element type");
634 if (getCbsz() == 0 && is8BitSource && getAbid() > 1)
635 return emitOpError(
"ABID must be 0 or 1 for 8-bit source data");
637 if (getCbsz() == 0 && !is8BitSource && getAbid() > 3)
638 return emitOpError(
"ABID must be between 0 and 3 for 16-bit source data");
641 auto sparseIdxType = cast<VectorType>(getSparseIdx().
getType());
644 if (sparseIdxType.getNumElements() != 2 ||
645 !sparseIdxType.getElementType().isInteger(16))
646 return emitOpError(
"expected vector<2xi16> sparse indices for 8-bit "
647 "source data, but got ")
648 << getSparseIdx().getType();
651 if (sparseIdxType.getNumElements() != 4 ||
652 !sparseIdxType.getElementType().isInteger(8))
653 return emitOpError(
"expected vector<4xi8> sparse indices for 16-bit "
654 "source data, but got ")
655 << getSparseIdx().getType();
658 int64_t expectedSourceElems = (
getM() * getK()) / waveSize;
659 if (denseLen != expectedSourceElems)
660 return emitOpError(
"expected " + Twine(expectedSourceElems) +
661 " source values for this operation but got " +
665 if (destLen != expectedDestElems)
666 return emitOpError(
"expected " + Twine(expectedDestElems) +
667 " result values for this operation but got " +
676LogicalResult DPPOp::verify() {
677 DPPPerm kind = getKind();
682 case DPPPerm::quad_perm: {
683 auto quadPermAttr = dyn_cast_or_null<ArrayAttr>(permArgument);
684 if (!quadPermAttr || quadPermAttr.size() != 4) {
685 return emitOpError(
"quad_perm attribute must have exactly 4 elements");
687 for (
auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
688 int32_t num = elem.getInt();
689 if (num < 0 || num > 3) {
691 "Each element of quad_perm must be in the range [0, 3]");
696 case DPPPerm::row_shl:
697 case DPPPerm::row_shr:
698 case DPPPerm::row_ror: {
700 return emitOpError(
"Attribute '" + Twine(stringifyDPPPerm(kind)) +
701 "' value not specified");
703 if (
auto intAttr = dyn_cast<IntegerAttr>(permArgument)) {
704 uint32_t attrValue = intAttr.getInt();
705 if (attrValue < 1 || attrValue > 15) {
706 return emitOpError(
"Attribute value must be between 1 and 15");
711 case DPPPerm::wave_shl:
712 case DPPPerm::wave_shr:
713 case DPPPerm::wave_rol:
714 case DPPPerm::wave_ror:
715 case DPPPerm::row_mirror:
716 case DPPPerm::row_half_mirror:
717 case DPPPerm::row_bcast_15:
718 case DPPPerm::row_bcast_31: {
719 if (permArgument && !isa<UnitAttr>(permArgument)) {
720 return emitOpError(
"Expected unit attribute for permArgument, but found "
721 "non-trivial argument");
732LogicalResult PermlaneSwapOp::verify() {
733 unsigned rowLength = getRowLength();
735 if (rowLength != 16 && rowLength != 32)
736 return emitOpError(
"row_length attribute must either be 16 or 32.");
744 if (isa_and_nonnull<LDSBarrierOp>(op->getNextNode())) {
763struct FuseMemoryCounterWaitOp final :
OpRewritePattern<MemoryCounterWaitOp> {
766 LogicalResult matchAndRewrite(MemoryCounterWaitOp op,
767 PatternRewriter &rewriter)
const override {
768 auto next = dyn_cast<MemoryCounterWaitOp>(op->getNextNode());
772 auto setters = {&MemoryCounterWaitOp::setLoad,
773 &MemoryCounterWaitOp::setStore, &MemoryCounterWaitOp::setDs,
774 &MemoryCounterWaitOp::setExp,
775 &MemoryCounterWaitOp::setTensor};
776 auto lhsVals = {op.getLoad(), op.getStore(), op.getDs(), op.getExp(),
778 auto rhsVals = {next.getLoad(), next.getStore(), next.getDs(),
779 next.getExp(), next.getTensor()};
781 for (
auto [setter,
lhs,
rhs] :
782 llvm::zip_equal(setters, lhsVals, rhsVals)) {
784 (op.*setter)(std::min(*
lhs, *
rhs));
798void MemoryCounterWaitOp::getCanonicalizationPatterns(
800 results.
add<FuseMemoryCounterWaitOp>(context);
807LogicalResult GatherToLDSOp::verify() {
808 MemRefType srcType = cast<MemRefType>(getSrc().
getType());
809 MemRefType dstType = cast<MemRefType>(getDst().
getType());
811 if (dstType.getRank() > 0 && !dstType.areTrailingDimsContiguous(1))
812 return emitOpError(
"destination type inner most dim must be contiguous");
814 auto elemType = srcType.getElementType();
816 if (elemType != dstType.getElementType())
817 return emitOpError(
"source and destination element types must match");
820 auto transferType = getTransferType();
822 if (
auto vectorTransfer = dyn_cast<VectorType>(transferType)) {
823 transferSize = vectorTransfer.getNumElements() *
824 vectorTransfer.getElementTypeBitWidth();
826 transferSize = transferType.getIntOrFloatBitWidth();
828 if (!llvm::is_contained({8, 16, 32, 96, 128}, transferSize))
830 "Transfering type size must be 8, 16, 32, 96 or 128 bits");
835 "source memory address space must be global or fat raw buffer");
838 return emitOpError(
"destination memory address space must be Workgroup");
849 LogicalResult matchAndRewrite(GatherToLDSOp gatherOp,
850 PatternRewriter &rewriter)
const override {
851 bool modified =
false;
852 auto foldCast = [&](OpOperand &operand) {
853 if (
auto castOp = operand.get().getDefiningOp<memref::CastOp>()) {
854 if (memref::CastOp::canFoldIntoConsumerOp(castOp)) {
856 [&] { operand.assign(castOp.getSource()); });
862 foldCast(gatherOp.getSrcMutable());
863 foldCast(gatherOp.getDstMutable());
872 results.
add<FoldGatherToLDSOfCast>(context);
879LogicalResult TransposeLoadOp::verify() {
880 MemRefType srcType = cast<MemRefType>(getSrc().
getType());
883 return emitOpError(
"source memory address space must be Workgroup");
885 auto transferType = cast<VectorType>(
getType());
886 size_t numElements = transferType.getNumElements();
887 size_t elementTypeSize =
888 transferType.getElementType().getIntOrFloatBitWidth();
891 const llvm::SmallDenseMap<size_t, size_t> kValidLoadSizeMap = {
898 auto validNumElems = kValidLoadSizeMap.find(elementTypeSize);
899 if (validNumElems == kValidLoadSizeMap.end())
900 return emitOpError(
"Unsupported element type size for transpose load: ")
901 << elementTypeSize <<
" bits";
903 if (numElements != validNumElems->second)
905 "Transferring type size mismatch: expected num of elements: ")
906 << validNumElems->second;
915template <
typename BaseOp>
917 auto ldsType = cast<MemRefType>(op.getLds().getType());
918 auto globalType = cast<MemRefType>(op.getGlobal().getType());
920 return op.emitOpError(
921 "lds memref must have workgroup address space attribute.");
923 return op.emitOpError(
924 "global memref must have global address space attribute.");
926 Type elementType = ldsType.getElementType();
929 if (!llvm::is_contained({8u, 16u, 32u, 64u}, width))
930 return op.emitOpError(
931 "element type must be 1, 2, 4, or 8 bytes long but type was ")
932 << width <<
" bits long.";
936LogicalResult MakeDmaBaseOp::verify() {
return verifyBase(*
this); }
946 if (!llvm::is_contained({8u, 16u, 32u, 64u}, width))
948 <<
"element type must be 1, 2, 4, or 8 bytes wide but type "
949 << elementType <<
" is " << width / 8 <<
" bytes wide.";
951 Type i16 = IntegerType::get(ctx, 32);
952 Type i32 = IntegerType::get(ctx, 16);
953 if (!llvm::is_contained({i16, i32}, indexType))
954 return emitError() <<
"index type must be i16 or i32 but index type is "
959LogicalResult MakeGatherDmaBaseOp::verify() {
return verifyBase(*
this); }
965template <
typename DescriptorOp>
969 if (globalStaticStrides.empty())
970 return op.emitOpError(
"strides must not be empty.");
971 if (globalStaticStrides.back() != 1)
972 return op.emitOpError(
"strides for the innermost dimension must be 1.");
975 size_t rank = globalStaticSizes.size();
977 return op.emitOpError(
"tensor and tile must be at most of rank 5.");
978 if (rank != globalStaticStrides.size())
979 return op.emitOpError(
"strides and sizes must have same rank.");
982 if (rank != sharedStaticSizes.size())
983 return op.emitOpError(
"tensor must have same rank as tile.");
985 unsigned elementTypeWidth = op.getElementTypeWidth();
986 if (!llvm::is_contained({8u, 16u, 32u, 64u}, elementTypeWidth))
987 return op.emitOpError(
988 "element type width must be 1, 2, 4 or 8 bytes, but was ")
989 << elementTypeWidth <<
" bits long";
991 if (
Value atomicBarrierAddress = op.getAtomicBarrierAddress()) {
992 auto atomicBarrierAddressType =
993 cast<MemRefType>(atomicBarrierAddress.getType());
997 return op.emitOpError(
"atomic barrier address must be in LDS.");
1000 if (op.getEarlyTimeout() && !op.getWorkgroupMask())
1001 return op.emitOpError(
1002 "early timeout does not apply when workgroup_mask is not set.");
1006template <
typename DescriptorOp,
typename FoldAdaptor>
1027 op.setGlobalStaticSizes(staticGlobalSizes);
1028 op.getGlobalDynamicSizesMutable().assign(dynamicGlobalSizes);
1031 staticGlobalStrides);
1032 op.setGlobalStaticStrides(staticGlobalStrides);
1033 op.getGlobalDynamicStridesMutable().assign(dynamicGlobalStrides);
1037 op.setSharedStaticSizes(staticSharedSizes);
1038 op.getSharedDynamicSizesMutable().assign(dynamicSharedSizes);
1039 return op.getResult();
1042LogicalResult MakeDmaDescriptorOp::verify() {
1046OpFoldResult MakeDmaDescriptorOp::fold(FoldAdaptor adaptor) {
1054LogicalResult MakeGatherDmaDescriptorOp::verify() {
1056 size_t rank = globalStaticSizes.size();
1059 "tensor and tile must be at most of rank two in gather mode.");
1061 Type elementType = cast<VectorType>(
indices.getType()).getElementType();
1063 return emitOpError(
"indices' element type must match base's element type.");
1068OpFoldResult MakeGatherDmaDescriptorOp::fold(FoldAdaptor adaptor) {
1082 LogicalResult matchAndRewrite(ScaledMFMAOp op,
1083 PatternRewriter &rewriter)
const override {
1084 Location loc = op.getLoc();
1085 auto setOpsel = [&op](
unsigned idx, int64_t val) {
1088 op.setScalesIdxA(val);
1091 op.setScalesIdxB(val);
1115 for (
auto opIdx : std::array<int64_t, 2>({3, 4})) {
1116 auto insertOp = op.getOperand(opIdx).getDefiningOp<vector::InsertOp>();
1119 "defining op not a vector.insert");
1122 if (isa<VectorType>(insertOp.getValueToStore().getType())) {
1124 op,
"scaled mfma operand already packed");
1128 insertOp.getValueToStore().getDefiningOp<vector::ExtractOp>();
1131 "defining op not a vector.extract");
1134 Value scaleSrc = extractOp.getOperand(0);
1135 auto scaleSrcType = dyn_cast<VectorType>(scaleSrc.
getType());
1136 if (!scaleSrcType) {
1142 if (!scaleSrcType.hasStaticShape()) {
1144 "dynamic dims not yet supported");
1147 int64_t numElements = scaleSrcType.getNumElements();
1148 if (numElements < 4) {
1150 op,
"do not pack if # of scales less than four");
1154 auto extractedPos = llvm::to_vector_of<int64_t>(
1155 llvm::reverse(extractOp.getStaticPosition()));
1156 ArrayRef<int64_t> scaleSrcShape = scaleSrcType.getShape();
1157 int64_t scaleSrcRank = scaleSrcType.getRank();
1158 SmallVector<int64_t> extractSizes(scaleSrcRank, 1);
1159 for (int64_t i = 1; i < scaleSrcRank; ++i) {
1160 extractSizes[i] = extractSizes[i - 1] * scaleSrcShape[scaleSrcRank - i];
1162 int64_t idx =
linearize(extractedPos, extractSizes);
1174 int64_t offset = idx - (idx % 4);
1175 int64_t opsel = idx - offset;
1178 if (numElements - offset < size) {
1179 opsel = size - (numElements - idx);
1180 offset = numElements - 4l;
1182 Type scaleSrcElemType = scaleSrcType.getElementType();
1184 VectorType::get(ArrayRef{numElements}, scaleSrcElemType);
1186 vector::ShapeCastOp::create(rewriter, loc, newSrcType, scaleSrc);
1187 auto extract = vector::ExtractStridedSliceOp::create(
1188 rewriter, loc, newScaleSrc, ArrayRef{offset}, ArrayRef{size},
1189 ArrayRef{int64_t(1)});
1191 op->setOperand(opIdx, extract);
1192 setOpsel(opIdx, opsel);
1202 results.
add<PackScales>(context);
1209template <
typename T>
1211 MemRefType memrefType = llvm::cast<MemRefType>(op.getBase().getType());
1213 return op.emitOpError(
"barrier must be in workgroup (LDS) memory");
1218LogicalResult DsBarrierInitOp::verify() {
1222LogicalResult DsBarrierPollStateOp::verify() {
1226LogicalResult DsAsyncBarrierArriveOp::verify() {
1230LogicalResult DsBarrierArriveOp::verify() {
1234#define GET_OP_CLASSES
1235#include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
static LogicalResult verifyDescriptorOp(DescriptorOp op)
static LogicalResult verifyRawBufferOp(T &op)
static LogicalResult verifyDsBarrierOpCommon(T &op)
static OpFoldResult foldDescriptorOp(DescriptorOp op, FoldAdaptor adaptor)
static bool hasGlobalMemorySpace(Attribute memorySpace)
static bool hasWorkgroupMemorySpace(Attribute memorySpace)
static FailureOr< MemRefType > getFatRawBufferTypeLike(MemRefType source, bool resetOffset)
Convert the type source to one with the same sizes and strides - and offset, unless stripOffset is tr...
static LogicalResult eraseRedundantLDSBarrierOps(LDSBarrierOp op, PatternRewriter &rewriter)
Remove amdgpu.lds_barrier after amdgpu.lds_barrier.
static bool hasFatRawBufferMemorySpace(Attribute memorySpace)
static LogicalResult verifyBase(BaseOp op)
static bool staticallyOutOfBounds(OpType op)
static std::optional< uint32_t > getConstantUint32(Value v)
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
Attributes are known-constant values of operations.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI32IntegerAttr(int32_t value)
TypedAttr getZeroAttr(Type type)
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
This is a builder type that keeps local references to arguments.
Builder & setMemorySpace(Attribute newMemorySpace)
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This class helps build Operations.
This class represents a single result from folding an operation.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class provides an abstraction over the different types of ranges over Regions.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isFloat() const
Return true if this is an float type (with the specified width).
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
ParseResult parseMNKDimensionList(OpAsmParser &parser, IntegerAttr &m, IntegerAttr &n, IntegerAttr &k)
Parser for the custom<MNKDimensionList> custom assembly format used by WMMAOp.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given memref value.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
uint64_t getN(LevelType lt)
uint64_t getM(LevelType lt)
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< int64_t > computeSuffixProduct(ArrayRef< int64_t > sizes)
Given a set of sizes, return the suffix product.
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
llvm::function_ref< Fn > function_ref
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...