31#include "llvm/ADT/DenseMap.h"
32#include "llvm/ADT/SmallVector.h"
33#include "llvm/ADT/TypeSwitch.h"
43#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.cpp.inc"
46struct AMDGPUInlinerInterface final : DialectInlinerInterface {
47 using DialectInlinerInterface::DialectInlinerInterface;
48 bool isLegalToInline(Operation *, Region *,
bool, IRMapping &)
const final {
54void AMDGPUDialect::initialize() {
57#include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
60#define GET_TYPEDEF_LIST
61#include "mlir/Dialect/AMDGPU/IR/AMDGPUTypes.cpp.inc"
64#define GET_ATTRDEF_LIST
65#include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
67 addInterfaces<AMDGPUInlinerInterface>();
73LogicalResult PackedTrunc2xFp8Op::verify() {
74 if (getExisting() && getExisting().
getType() != getResult().
getType())
75 return emitOpError(
"existing values must have same type as result");
79LogicalResult PackedStochRoundFp8Op::verify() {
80 if (getExisting() && getExisting().
getType() != getResult().
getType())
81 return emitOpError(
"existing values must have same type as result");
88LogicalResult PackedScaledTruncOp::verify() {
89 if (getExisting() && getExisting().
getType() != getResult().
getType())
90 return emitOpError(
"existing values must have same type as result");
107 amdgpu::AddressSpaceAttr::get(ctx, amdgpu::AddressSpace::FatRawBuffer));
108 MemRefLayoutAttrInterface layout = source.getLayout();
109 if (resetOffset && !layout.isIdentity()) {
110 auto stridedLayout = dyn_cast<StridedLayoutAttr>(layout);
113 MemRefLayoutAttrInterface newLayout =
114 StridedLayoutAttr::get(ctx, 0, stridedLayout.getStrides());
119 if (source.hasStaticShape()) {
121 }
else if (source.getRank() <= 1) {
124 if (stridesIfIdentity == stridedLayout.getStrides()) {
125 newLayout = AffineMapAttr::get(
130 return (MemRefType)(mb);
133LogicalResult FatRawBufferCastOp::inferReturnTypes(
137 Adaptor adaptor(operands, attributes, properties, regions);
139 dyn_cast_if_present<MemRefType>(adaptor.getSource().getType());
142 FailureOr<MemRefType> resultType =
150FailureOr<OpFoldResult> FatRawBufferCastOp::reifyDimOfResult(
OpBuilder &builder,
153 assert(resultIndex == 0 &&
"FatRawBufferCastOp has a single result");
154 Value source = getSource();
155 auto sourceType = cast<MemRefType>(source.
getType());
156 if (sourceType.isDynamicDim(dim))
158 builder.
createOrFold<memref::DimOp>(getLoc(), source, dim));
162LogicalResult FatRawBufferCastOp::verify() {
163 FailureOr<MemRefType> expectedResultType =
165 if (
failed(expectedResultType))
167 << getSource().getType() <<
" can't have its offset reset";
168 if (getResult().
getType() != *expectedResultType)
170 << *expectedResultType <<
" but got " << getResult().getType();
177 if (
auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
178 return intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
179 if (
auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
180 return gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
187 if (
auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
188 return intMemorySpace.getInt() == 3;
189 if (
auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
190 return gpuMemorySpace.getValue() == gpu::AddressSpace::Workgroup;
197 if (
auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
198 return intMemorySpace.getInt() == 7;
199 if (
auto gpuMemorySpace = dyn_cast<amdgpu::AddressSpaceAttr>(memorySpace))
200 return gpuMemorySpace.getValue() == amdgpu::AddressSpace::FatRawBuffer;
209 MemRefType bufferType = llvm::cast<MemRefType>(op.getMemref().getType());
213 return op.emitOpError(
214 "Buffer ops must operate on a memref in global memory");
215 if (!bufferType.hasRank())
216 return op.emitOpError(
217 "Cannot meaningfully buffer_store to an unranked memref");
218 if (
static_cast<int64_t>(op.getIndices().size()) != bufferType.getRank())
219 return op.emitOpError(
"Expected " + Twine(bufferType.getRank()) +
220 " indices to memref");
228LogicalResult RawBufferAtomicFaddOp::verify() {
232LogicalResult RawBufferAtomicFmaxOp::verify() {
236LogicalResult RawBufferAtomicSmaxOp::verify() {
240LogicalResult RawBufferAtomicUminOp::verify() {
244LogicalResult RawBufferAtomicCmpswapOp::verify() {
253 return cst.getZExtValue();
257template <
typename OpType>
259 if (!op.getBoundsCheck())
261 MemRefType bufferType = op.getMemref().getType();
262 if (!bufferType.hasStaticShape())
266 if (failed(bufferType.getStridesAndOffset(strides, offset)))
269 if (op.getSgprOffset()) {
275 if (strides.size() != op.getIndices().size())
278 for (
auto pair : llvm::zip(strides, op.getIndices())) {
279 int64_t stride = std::get<0>(pair);
280 Value idx = std::get<1>(pair);
284 indexVal += stride * *idxVal;
287 if (
result > std::numeric_limits<uint32_t>::max())
290 return result >= bufferType.getNumElements();
294template <
typename OpType>
295struct RemoveStaticallyOobBufferLoads final :
public OpRewritePattern<OpType> {
296 using OpRewritePattern<OpType>::OpRewritePattern;
298 LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw)
const override {
301 Type loadType = op.getResult().getType();
308template <
typename OpType>
309struct RemoveStaticallyOobBufferWrites final :
public OpRewritePattern<OpType> {
310 using OpRewritePattern<OpType>::OpRewritePattern;
312 LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw)
const override {
324 results.
add<RemoveStaticallyOobBufferLoads<RawBufferLoadOp>>(context);
329 results.
add<RemoveStaticallyOobBufferWrites<RawBufferStoreOp>>(context);
332void RawBufferAtomicFaddOp::getCanonicalizationPatterns(
334 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFaddOp>>(context);
337void RawBufferAtomicFmaxOp::getCanonicalizationPatterns(
339 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFmaxOp>>(context);
342void RawBufferAtomicSmaxOp::getCanonicalizationPatterns(
344 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicSmaxOp>>(context);
347void RawBufferAtomicUminOp::getCanonicalizationPatterns(
349 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicUminOp>>(context);
352void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
354 results.
add<RemoveStaticallyOobBufferLoads<RawBufferAtomicCmpswapOp>>(
361LogicalResult ScaledExtPackedMatrixOp::verify() {
363 assert(llvm::is_contained({16, 32}, blockSize) &&
"invalid block size");
365 int firstScaleByte = getFirstScaleByte();
366 int firstScaleLane = getFirstScaleLane();
367 auto sourceType = cast<VectorType>(getSource().
getType());
368 Type elementType = sourceType.getElementType();
369 auto floatType = cast<FloatType>(elementType);
370 unsigned bitWidth = floatType.getWidth();
374 const bool is_fp8 = bitWidth == 8;
375 const bool is_block_16 = blockSize == 16;
379 if (!llvm::is_contained({0, 1}, firstScaleByte)) {
380 return emitOpError(
"blockSize of 16 can only have firstScaleByte be 0 "
381 "or 1 for f4 and f6.");
384 if (!llvm::is_contained({0, 2}, firstScaleByte)) {
385 return emitOpError(
"blockSize of 32 can only have firstScaleByte be 0 "
386 "or 2 for f4 and f6.");
391 bool is_valid = ((firstScaleLane == 0) && (firstScaleByte == 0)) ||
392 ((firstScaleLane == 16) && (firstScaleByte == 2));
394 return emitOpError(
"blockSize of 16 can only have (firstScaleLane, "
395 "firstScaleByte) be (0, 0) or (16, 2) for f8.");
408 IntegerAttr &m, IntegerAttr &n,
413 if (dimensions.size() != 3)
415 <<
"expected 3 dimensions in MNK dimension list";
423LogicalResult WMMAOp::verify() {
424 auto sourceAType = cast<VectorType>(getSourceA().
getType());
425 auto sourceBType = cast<VectorType>(getSourceB().
getType());
426 auto destType = cast<VectorType>(getDestC().
getType());
428 Type sourceAElemType = sourceAType.getElementType();
429 Type sourceBElemType = sourceBType.getElementType();
430 if (sourceAType.getNumElements() != sourceBType.getNumElements()) {
431 return emitOpError(
"source vectors have different lengths: ")
432 << sourceAType <<
" vs. " << sourceBType;
435 bool isDestFloat = destType.getElementType().
isFloat();
436 bool isSrcFloat = sourceAElemType.
isFloat();
438 if (isDestFloat && !isSrcFloat)
439 return emitOpError(
"expected float sources with float destination");
440 if (!isDestFloat && isSrcFloat)
441 return emitOpError(
"expected int sources with int destination");
443 if (!sourceAElemType.
isFloat(8) && sourceAElemType != sourceBElemType) {
445 "source element types must match (except for fp8/bf8) but have ")
446 << sourceAType <<
" and " << sourceBType;
451 return emitOpError(
"clamp flag is not supported for float types");
452 if (getUnsignedA() || getUnsignedB())
453 return emitOpError(
"unsigned flags are not supported for float types");
462LogicalResult ScaledWMMAOp::verify() {
464 auto isF8 = llvm::IsaPred<Float8E4M3FNType, Float8E5M2Type>;
465 auto isF6 = llvm::IsaPred<Float6E2M3FNType, Float6E3M2FNType>;
466 auto isF4 = llvm::IsaPred<Float4E2M1FNType>;
467 auto isScaleF8 = llvm::IsaPred<Float8E8M0FNUType, Float8E4M3FNType>;
468 auto isE8M0 = llvm::IsaPred<Float8E8M0FNUType>;
469 auto isE4M3 = llvm::IsaPred<Float8E4M3FNType>;
471 auto sourceAType = cast<VectorType>(getSourceA().
getType());
472 auto sourceBType = cast<VectorType>(getSourceB().
getType());
473 auto destType = cast<VectorType>(getDestC().
getType());
476 Type aElemType = sourceAType.getElementType();
477 Type bElemType = sourceBType.getElementType();
481 int64_t aLen = sourceAType.getNumElements();
482 int64_t bLen = sourceBType.getNumElements();
483 int64_t expectedOutLen = (m == 16) ? 8 : 16;
485 if (destType.getNumElements() != expectedOutLen)
486 return emitOpError(
"expected output vector of length ")
487 << expectedOutLen <<
" but got " << destType.getNumElements();
493 "for 16x16x128, sourceA must have 64 elements but got ")
497 "for 16x16x128, sourceB must have 64 elements but got ")
501 if (!isF4(aElemType) && !isF4(bElemType))
502 return emitOpError(
"32x16x128 only supports fp4 element types");
506 "for 32x16x128, sourceA must have 128 elements but got ")
510 "for 32x16x128, sourceB must have 64 elements but got ")
515 if (getAFirstScaleLane() != 0)
516 return emitOpError(
"for 32x16x128, a_first_scale_lane must be 0");
520 auto scaleAType = cast<VectorType>(getScaleA().
getType());
521 auto scaleBType = cast<VectorType>(getScaleB().
getType());
522 Type scaleAElemType = scaleAType.getElementType();
523 Type scaleBElemType = scaleBType.getElementType();
526 if (!isScaleF8(scaleAElemType) || !isScaleF8(scaleBElemType))
528 "scale operands must have f8 element types (E8M0FNU or E4M3FN)");
531 if (isE8M0(scaleAElemType) && isE8M0(scaleBElemType))
535 if ((isF8(aElemType) || isF6(aElemType)) && isE8M0(scaleAElemType) &&
536 isF4(bElemType) && isE4M3(scaleBElemType))
540 if (isF4(aElemType) && isE4M3(scaleAElemType) &&
541 (isF8(bElemType) || isF6(bElemType)) && isE8M0(scaleBElemType))
545 if (isF4(aElemType) && isF4(bElemType) && isE4M3(scaleAElemType) &&
546 isE4M3(scaleBElemType))
550 return emitOpError(
"invalid combination of matrix and scale types: ")
551 <<
"sourceA=" << aElemType <<
", scaleA=" << scaleAElemType
552 <<
", sourceB=" << bElemType <<
", scaleB=" << scaleBElemType;
558LogicalResult MFMAOp::verify() {
559 constexpr uint32_t waveSize = 64;
562 Type sourceType = getSourceA().getType();
563 Type destType = getDestC().getType();
565 Type sourceElem = sourceType, destElem = destType;
566 uint32_t sourceLen = 1, destLen = 1;
567 if (
auto sourceVector = dyn_cast<VectorType>(sourceType)) {
568 sourceLen = sourceVector.getNumElements();
569 sourceElem = sourceVector.getElementType();
571 if (
auto destVector = dyn_cast<VectorType>(destType)) {
572 destLen = destVector.getNumElements();
573 destElem = destVector.getElementType();
576 Type sourceBType = getSourceB().getType();
579 Type sourceBElem = sourceBType;
580 if (
auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
581 sourceBLen = sourceBVector.getNumElements();
582 sourceBElem = sourceBVector.getElementType();
586 return emitOpError(
"expected both source operands to have small-float "
587 "elements if one does");
588 if (sourceLen != sourceBLen)
590 "expected both small-float source vectors to have the same length");
592 if (sourceType != sourceBType)
593 return emitOpError(
"expected both non-small-float source operand types "
599 sourceElem =
b.getI8Type();
603 sourceElem =
b.getI8Type();
606 int64_t numSourceElems = (
getM() * getK() * getBlocks()) / waveSize;
607 if (sourceLen != numSourceElems)
608 return emitOpError(
"expected " + Twine(numSourceElems) +
609 " source values for this operation but got " +
613 if (destLen != numDestElems)
614 return emitOpError(
"expected " + Twine(numDestElems) +
615 " result values for this operation but got " +
618 if (destElem.isF64() && getBlgp() != MFMAPermB::none)
620 "double-precision ops do not support permuting lanes of B");
621 if (destElem.isF64() && getCbsz() != 0)
623 "double-precision ops do not support permuting lanes of A");
624 if (getAbid() >= (1u << getCbsz()))
626 "block ID for permuting A (abid) must be below 2 ** cbsz");
628 if ((getNegateA() || getNegateB() || getNegateC()) && !destElem.isF64())
630 "negation flags only available for double-precision operations");
639LogicalResult SparseMFMAOp::verify() {
640 constexpr uint32_t waveSize = 64;
642 auto sparseType = cast<VectorType>(getSourceA().
getType());
643 auto denseType = cast<VectorType>(getSourceB().
getType());
644 auto destType = cast<VectorType>(getDestC().
getType());
646 Type sparseElem = sparseType.getElementType();
647 Type denseElem = denseType.getElementType();
648 int64_t sparseLen = sparseType.getNumElements();
649 int64_t denseLen = denseType.getNumElements();
650 int64_t destLen = destType.getNumElements();
652 if (denseLen != 2 * sparseLen)
653 return emitOpError(
"expected dense source operand to have exactly double "
654 "the number of elements of the sparse source operand");
660 if (!bothFloat8 && sparseElem != denseElem)
662 "expected source operands to have the same element type");
668 if (getCbsz() == 0 && is8BitSource && getAbid() > 1)
669 return emitOpError(
"ABID must be 0 or 1 for 8-bit source data");
671 if (getCbsz() == 0 && !is8BitSource && getAbid() > 3)
672 return emitOpError(
"ABID must be between 0 and 3 for 16-bit source data");
675 auto sparseIdxType = cast<VectorType>(getSparseIdx().
getType());
678 if (sparseIdxType.getNumElements() != 2 ||
679 !sparseIdxType.getElementType().isInteger(16))
680 return emitOpError(
"expected vector<2xi16> sparse indices for 8-bit "
681 "source data, but got ")
682 << getSparseIdx().getType();
685 if (sparseIdxType.getNumElements() != 4 ||
686 !sparseIdxType.getElementType().isInteger(8))
687 return emitOpError(
"expected vector<4xi8> sparse indices for 16-bit "
688 "source data, but got ")
689 << getSparseIdx().getType();
692 int64_t expectedSourceElems = (
getM() * getK()) / waveSize;
693 if (denseLen != expectedSourceElems)
694 return emitOpError(
"expected " + Twine(expectedSourceElems) +
695 " source values for this operation but got " +
699 if (destLen != expectedDestElems)
700 return emitOpError(
"expected " + Twine(expectedDestElems) +
701 " result values for this operation but got " +
710LogicalResult DPPOp::verify() {
711 Type srcType = getSrc().getType();
713 return emitOpError(
"integer and floating point types larger than 64 bits "
714 "are not supported");
717 DPPPerm kind = getKind();
722 case DPPPerm::quad_perm: {
723 auto quadPermAttr = dyn_cast_or_null<ArrayAttr>(permArgument);
724 if (!quadPermAttr || quadPermAttr.size() != 4) {
725 return emitOpError(
"quad_perm attribute must have exactly 4 elements");
727 for (
auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
728 int32_t num = elem.getInt();
729 if (num < 0 || num > 3) {
731 "Each element of quad_perm must be in the range [0, 3]");
736 case DPPPerm::row_shl:
737 case DPPPerm::row_shr:
738 case DPPPerm::row_ror: {
740 return emitOpError(
"Attribute '" + Twine(stringifyDPPPerm(kind)) +
741 "' value not specified");
743 if (
auto intAttr = dyn_cast<IntegerAttr>(permArgument)) {
744 uint32_t attrValue = intAttr.getInt();
745 if (attrValue < 1 || attrValue > 15) {
746 return emitOpError(
"Attribute value must be between 1 and 15");
751 case DPPPerm::wave_shl:
752 case DPPPerm::wave_shr:
753 case DPPPerm::wave_rol:
754 case DPPPerm::wave_ror:
755 case DPPPerm::row_mirror:
756 case DPPPerm::row_half_mirror:
757 case DPPPerm::row_bcast_15:
758 case DPPPerm::row_bcast_31: {
759 if (permArgument && !isa<UnitAttr>(permArgument)) {
760 return emitOpError(
"Expected unit attribute for permArgument, but found "
761 "non-trivial argument");
772LogicalResult PermlaneSwapOp::verify() {
773 unsigned rowLength = getRowLength();
775 if (rowLength != 16 && rowLength != 32)
776 return emitOpError(
"row_length attribute must either be 16 or 32.");
788struct FuseMemoryCounterWaitOp final :
OpRewritePattern<MemoryCounterWaitOp> {
791 LogicalResult matchAndRewrite(MemoryCounterWaitOp op,
792 PatternRewriter &rewriter)
const override {
793 auto next = dyn_cast<MemoryCounterWaitOp>(op->getNextNode());
797 auto setters = {&MemoryCounterWaitOp::setLoad,
798 &MemoryCounterWaitOp::setStore, &MemoryCounterWaitOp::setDs,
799 &MemoryCounterWaitOp::setExp,
800 &MemoryCounterWaitOp::setTensor};
801 auto lhsVals = {op.getLoad(), op.getStore(), op.getDs(), op.getExp(),
803 auto rhsVals = {next.getLoad(), next.getStore(), next.getDs(),
804 next.getExp(), next.getTensor()};
806 for (
auto [setter,
lhs,
rhs] :
807 llvm::zip_equal(setters, lhsVals, rhsVals)) {
809 (op.*setter)(std::min(*
lhs, *
rhs));
823void MemoryCounterWaitOp::getCanonicalizationPatterns(
825 results.
add<FuseMemoryCounterWaitOp>(context);
832LogicalResult GatherToLDSOp::verify() {
833 MemRefType srcType = cast<MemRefType>(getSrc().
getType());
834 MemRefType dstType = cast<MemRefType>(getDst().
getType());
836 if (dstType.getRank() > 0 && !dstType.areTrailingDimsContiguous(1))
837 return emitOpError(
"destination type inner most dim must be contiguous");
839 auto elemType = srcType.getElementType();
841 if (elemType != dstType.getElementType())
842 return emitOpError(
"source and destination element types must match");
845 auto transferType = getTransferType();
847 if (
auto vectorTransfer = dyn_cast<VectorType>(transferType)) {
848 transferSize = vectorTransfer.getNumElements() *
849 vectorTransfer.getElementTypeBitWidth();
851 transferSize = transferType.getIntOrFloatBitWidth();
853 if (!llvm::is_contained({8, 16, 32, 96, 128}, transferSize))
855 "Transfering type size must be 8, 16, 32, 96 or 128 bits");
860 "source memory address space must be global or fat raw buffer");
863 return emitOpError(
"destination memory address space must be Workgroup");
874 LogicalResult matchAndRewrite(GatherToLDSOp gatherOp,
875 PatternRewriter &rewriter)
const override {
876 bool modified =
false;
877 auto foldCast = [&](OpOperand &operand) {
878 if (
auto castOp = operand.get().getDefiningOp<memref::CastOp>()) {
879 if (memref::CastOp::canFoldIntoConsumerOp(castOp)) {
881 [&] { operand.assign(castOp.getSource()); });
887 foldCast(gatherOp.getSrcMutable());
888 foldCast(gatherOp.getDstMutable());
897 results.
add<FoldGatherToLDSOfCast>(context);
904LogicalResult TransposeLoadOp::verify() {
905 MemRefType srcType = cast<MemRefType>(getSrc().
getType());
908 return emitOpError(
"source memory address space must be Workgroup");
910 auto transferType = cast<VectorType>(
getType());
911 size_t numElements = transferType.getNumElements();
912 size_t elementTypeSize =
913 transferType.getElementType().getIntOrFloatBitWidth();
916 const llvm::SmallDenseMap<size_t, size_t> kValidLoadSizeMap = {
923 auto validNumElems = kValidLoadSizeMap.find(elementTypeSize);
924 if (validNumElems == kValidLoadSizeMap.end())
925 return emitOpError(
"Unsupported element type size for transpose load: ")
926 << elementTypeSize <<
" bits";
928 if (numElements != validNumElems->second)
930 "Transferring type size mismatch: expected num of elements: ")
931 << validNumElems->second;
940template <
typename BaseOp>
942 auto ldsType = cast<MemRefType>(op.getLds().getType());
943 auto globalType = cast<MemRefType>(op.getGlobal().getType());
945 return op.emitOpError(
946 "lds memref must have workgroup address space attribute.");
948 return op.emitOpError(
949 "global memref must have global address space attribute.");
951 Type elementType = ldsType.getElementType();
954 if (!llvm::is_contained({8u, 16u, 32u, 64u}, width))
955 return op.emitOpError(
956 "element type must be 1, 2, 4, or 8 bytes long but type was ")
957 << width <<
" bits long.";
961LogicalResult MakeDmaBaseOp::verify() {
return verifyBase(*
this); }
971 if (!llvm::is_contained({8u, 16u, 32u, 64u}, width))
973 <<
"element type must be 1, 2, 4, or 8 bytes wide but type "
974 << elementType <<
" is " << width / 8 <<
" bytes wide.";
976 Type i16 = IntegerType::get(ctx, 32);
977 Type i32 = IntegerType::get(ctx, 16);
978 if (!llvm::is_contained({i16, i32}, indexType))
979 return emitError() <<
"index type must be i16 or i32 but index type is "
984LogicalResult MakeGatherDmaBaseOp::verify() {
return verifyBase(*
this); }
990template <
typename DescriptorOp>
994 if (globalStaticStrides.empty())
995 return op.emitOpError(
"strides must not be empty.");
996 if (globalStaticStrides.back() != 1)
997 return op.emitOpError(
"strides for the innermost dimension must be 1.");
1000 size_t rank = globalStaticSizes.size();
1002 return op.emitOpError(
"tensor and tile must be at most of rank 5.");
1003 if (rank != globalStaticStrides.size())
1004 return op.emitOpError(
"strides and sizes must have same rank.");
1007 if (rank != sharedStaticSizes.size())
1008 return op.emitOpError(
"tensor must have same rank as tile.");
1010 unsigned elementTypeWidth = op.getElementTypeWidth();
1011 if (!llvm::is_contained({8u, 16u, 32u, 64u}, elementTypeWidth))
1012 return op.emitOpError(
1013 "element type width must be 1, 2, 4 or 8 bytes, but was ")
1014 << elementTypeWidth <<
" bits long";
1016 if (
Value atomicBarrierAddress = op.getAtomicBarrierAddress()) {
1017 auto atomicBarrierAddressType =
1018 cast<MemRefType>(atomicBarrierAddress.getType());
1022 return op.emitOpError(
"atomic barrier address must be in LDS.");
1025 if (op.getEarlyTimeout() && !op.getWorkgroupMask())
1026 return op.emitOpError(
1027 "early timeout does not apply when workgroup_mask is not set.");
1031template <
typename DescriptorOp,
typename FoldAdaptor>
1052 op.setGlobalStaticSizes(staticGlobalSizes);
1053 op.getGlobalDynamicSizesMutable().assign(dynamicGlobalSizes);
1056 staticGlobalStrides);
1057 op.setGlobalStaticStrides(staticGlobalStrides);
1058 op.getGlobalDynamicStridesMutable().assign(dynamicGlobalStrides);
1062 op.setSharedStaticSizes(staticSharedSizes);
1063 op.getSharedDynamicSizesMutable().assign(dynamicSharedSizes);
1064 return op.getResult();
1067LogicalResult MakeDmaDescriptorOp::verify() {
1071OpFoldResult MakeDmaDescriptorOp::fold(FoldAdaptor adaptor) {
1079LogicalResult MakeGatherDmaDescriptorOp::verify() {
1081 size_t rank = globalStaticSizes.size();
1084 "tensor and tile must be at most of rank two in gather mode.");
1086 Type elementType = cast<VectorType>(
indices.getType()).getElementType();
1088 return emitOpError(
"indices' element type must match base's element type.");
1093OpFoldResult MakeGatherDmaDescriptorOp::fold(FoldAdaptor adaptor) {
1107 LogicalResult matchAndRewrite(ScaledMFMAOp op,
1108 PatternRewriter &rewriter)
const override {
1109 Location loc = op.getLoc();
1110 auto setOpsel = [&op](
unsigned idx, int64_t val) {
1113 op.setScalesIdxA(val);
1116 op.setScalesIdxB(val);
1140 for (
auto opIdx : std::array<int64_t, 2>({3, 4})) {
1141 auto insertOp = op.getOperand(opIdx).getDefiningOp<vector::InsertOp>();
1144 "defining op not a vector.insert");
1147 if (isa<VectorType>(insertOp.getValueToStore().getType())) {
1149 op,
"scaled mfma operand already packed");
1153 insertOp.getValueToStore().getDefiningOp<vector::ExtractOp>();
1156 "defining op not a vector.extract");
1159 Value scaleSrc = extractOp.getOperand(0);
1160 auto scaleSrcType = dyn_cast<VectorType>(scaleSrc.
getType());
1161 if (!scaleSrcType) {
1167 if (!scaleSrcType.hasStaticShape()) {
1169 "dynamic dims not yet supported");
1172 int64_t numElements = scaleSrcType.getNumElements();
1173 if (numElements <= 4) {
1175 op,
"no packing if # of scales less than four");
1179 auto extractedPos = llvm::to_vector_of<int64_t>(
1180 llvm::reverse(extractOp.getStaticPosition()));
1181 ArrayRef<int64_t> scaleSrcShape = scaleSrcType.getShape();
1182 int64_t scaleSrcRank = scaleSrcType.getRank();
1183 SmallVector<int64_t> extractSizes(scaleSrcRank, 1);
1184 for (int64_t i = 1; i < scaleSrcRank; ++i) {
1185 extractSizes[i] = extractSizes[i - 1] * scaleSrcShape[scaleSrcRank - i];
1187 int64_t idx =
linearize(extractedPos, extractSizes);
1199 int64_t offset = idx - (idx % 4);
1200 int64_t opsel = idx - offset;
1203 if (numElements - offset < size) {
1204 opsel = size - (numElements - idx);
1205 offset = numElements - 4l;
1207 Type scaleSrcElemType = scaleSrcType.getElementType();
1209 VectorType::get(ArrayRef{numElements}, scaleSrcElemType);
1211 vector::ShapeCastOp::create(rewriter, loc, newSrcType, scaleSrc);
1212 auto extract = vector::ExtractStridedSliceOp::create(
1213 rewriter, loc, newScaleSrc, ArrayRef{offset}, ArrayRef{size},
1214 ArrayRef{int64_t(1)});
1216 op->setOperand(opIdx, extract);
1217 setOpsel(opIdx, opsel);
1227 results.
add<PackScales>(context);
1230#include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
1232#define GET_ATTRDEF_CLASSES
1233#include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
1235#define GET_TYPEDEF_CLASSES
1236#include "mlir/Dialect/AMDGPU/IR/AMDGPUTypes.cpp.inc"
1238#define GET_OP_CLASSES
1239#include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
static LogicalResult verifyDescriptorOp(DescriptorOp op)
static LogicalResult verifyRawBufferOp(T &op)
static OpFoldResult foldDescriptorOp(DescriptorOp op, FoldAdaptor adaptor)
static bool hasGlobalMemorySpace(Attribute memorySpace)
static bool hasWorkgroupMemorySpace(Attribute memorySpace)
static FailureOr< MemRefType > getFatRawBufferTypeLike(MemRefType source, bool resetOffset)
Convert the type source to one with the same sizes and strides - and offset, unless stripOffset is tr...
static bool hasFatRawBufferMemorySpace(Attribute memorySpace)
static LogicalResult verifyBase(BaseOp op)
static bool staticallyOutOfBounds(OpType op)
static std::optional< uint32_t > getConstantUint32(Value v)
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
Attributes are known-constant values of operations.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getI32IntegerAttr(int32_t value)
TypedAttr getZeroAttr(Type type)
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
This is a builder type that keeps local references to arguments.
Builder & setMemorySpace(Attribute newMemorySpace)
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This class represents a single result from folding an operation.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class provides an abstraction over the different types of ranges over Regions.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isFloat() const
Return true if this is an float type (with the specified width).
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
ParseResult parseMNKDimensionList(OpAsmParser &parser, IntegerAttr &m, IntegerAttr &n, IntegerAttr &k)
Parser for the custom<MNKDimensionList> custom assembly format used by WMMAOp.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
uint64_t getN(LevelType lt)
uint64_t getM(LevelType lt)
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< int64_t > computeSuffixProduct(ArrayRef< int64_t > sizes)
Given a set of sizes, return the suffix product.
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
llvm::function_ref< Fn > function_ref
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...