31#include "llvm/ADT/DenseMap.h"
32#include "llvm/ADT/SmallVector.h"
33#include "llvm/ADT/TypeSwitch.h"
43#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.cpp.inc"
46struct AMDGPUInlinerInterface final : DialectInlinerInterface {
47 using DialectInlinerInterface::DialectInlinerInterface;
48 bool isLegalToInline(Operation *, Region *,
bool, IRMapping &)
const final {
54void AMDGPUDialect::initialize() {
57#include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
60#define GET_TYPEDEF_LIST
61#include "mlir/Dialect/AMDGPU/IR/AMDGPUTypes.cpp.inc"
64#define GET_ATTRDEF_LIST
65#include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
67 addInterfaces<AMDGPUInlinerInterface>();
73LogicalResult PackedTrunc2xFp8Op::verify() {
74 if (getExisting() && getExisting().
getType() != getResult().
getType())
75 return emitOpError(
"existing values must have same type as result");
79LogicalResult PackedStochRoundFp8Op::verify() {
80 if (getExisting() && getExisting().
getType() != getResult().
getType())
81 return emitOpError(
"existing values must have same type as result");
88LogicalResult PackedScaledTruncOp::verify() {
89 if (getExisting() && getExisting().
getType() != getResult().
getType())
90 return emitOpError(
"existing values must have same type as result");
107 amdgpu::AddressSpaceAttr::get(ctx, amdgpu::AddressSpace::FatRawBuffer));
108 MemRefLayoutAttrInterface layout = source.getLayout();
109 if (resetOffset && !layout.isIdentity()) {
110 auto stridedLayout = dyn_cast<StridedLayoutAttr>(layout);
113 MemRefLayoutAttrInterface newLayout =
114 StridedLayoutAttr::get(ctx, 0, stridedLayout.getStrides());
119 if (source.hasStaticShape()) {
121 }
else if (source.getRank() <= 1) {
124 if (stridesIfIdentity == stridedLayout.getStrides()) {
125 newLayout = AffineMapAttr::get(
130 return (MemRefType)(mb);
133LogicalResult FatRawBufferCastOp::inferReturnTypes(
137 Adaptor adaptor(operands, attributes, properties, regions);
139 dyn_cast_if_present<MemRefType>(adaptor.getSource().getType());
142 FailureOr<MemRefType> resultType =
150FailureOr<OpFoldResult> FatRawBufferCastOp::reifyDimOfResult(
OpBuilder &builder,
153 assert(resultIndex == 0 &&
"FatRawBufferCastOp has a single result");
157LogicalResult FatRawBufferCastOp::verify() {
158 FailureOr<MemRefType> expectedResultType =
160 if (
failed(expectedResultType))
162 << getSource().getType() <<
" can't have its offset reset";
163 if (getResult().
getType() != *expectedResultType)
165 << *expectedResultType <<
" but got " << getResult().getType();
172 if (
auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
173 return intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
174 if (
auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
175 return gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
182 if (
auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
183 return intMemorySpace.getInt() == 3;
184 if (
auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
185 return gpuMemorySpace.getValue() == gpu::AddressSpace::Workgroup;
192 if (
auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
193 return intMemorySpace.getInt() == 7;
194 if (
auto gpuMemorySpace = dyn_cast<amdgpu::AddressSpaceAttr>(memorySpace))
195 return gpuMemorySpace.getValue() == amdgpu::AddressSpace::FatRawBuffer;
204 MemRefType bufferType = llvm::cast<MemRefType>(op.getMemref().getType());
208 return op.emitOpError(
209 "Buffer ops must operate on a memref in global memory");
210 if (!bufferType.hasRank())
211 return op.emitOpError(
212 "Cannot meaningfully buffer_store to an unranked memref");
213 if (
static_cast<int64_t>(op.getIndices().size()) != bufferType.getRank())
214 return op.emitOpError(
"Expected " + Twine(bufferType.getRank()) +
215 " indices to memref");
223LogicalResult RawBufferAtomicFaddOp::verify() {
227LogicalResult RawBufferAtomicFmaxOp::verify() {
231LogicalResult RawBufferAtomicSmaxOp::verify() {
235LogicalResult RawBufferAtomicUminOp::verify() {
239LogicalResult RawBufferAtomicCmpswapOp::verify() {
248 return cst.getZExtValue();
252template <
typename OpType>
254 if (!op.getBoundsCheck())
256 MemRefType bufferType = op.getMemref().getType();
257 if (!bufferType.hasStaticShape())
261 if (failed(bufferType.getStridesAndOffset(strides, offset)))
264 if (op.getSgprOffset()) {
270 if (strides.size() != op.getIndices().size())
273 for (
auto pair : llvm::zip(strides, op.getIndices())) {
274 int64_t stride = std::get<0>(pair);
275 Value idx = std::get<1>(pair);
279 indexVal += stride * *idxVal;
282 if (
result > std::numeric_limits<uint32_t>::max())
285 return result >= bufferType.getNumElements();
289template <
typename OpType>
290struct RemoveStaticallyOobBufferLoads final :
public OpRewritePattern<OpType> {
291 using OpRewritePattern<OpType>::OpRewritePattern;
293 LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw)
const override {
296 Type loadType = op.getResult().getType();
303template <
typename OpType>
304struct RemoveStaticallyOobBufferWrites final :
public OpRewritePattern<OpType> {
305 using OpRewritePattern<OpType>::OpRewritePattern;
307 LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw)
const override {
319 results.
add<RemoveStaticallyOobBufferLoads<RawBufferLoadOp>>(context);
324 results.
add<RemoveStaticallyOobBufferWrites<RawBufferStoreOp>>(context);
327void RawBufferAtomicFaddOp::getCanonicalizationPatterns(
329 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFaddOp>>(context);
332void RawBufferAtomicFmaxOp::getCanonicalizationPatterns(
334 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFmaxOp>>(context);
337void RawBufferAtomicSmaxOp::getCanonicalizationPatterns(
339 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicSmaxOp>>(context);
342void RawBufferAtomicUminOp::getCanonicalizationPatterns(
344 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicUminOp>>(context);
347void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
349 results.
add<RemoveStaticallyOobBufferLoads<RawBufferAtomicCmpswapOp>>(
356LogicalResult ScaledExtPackedMatrixOp::verify() {
358 assert(llvm::is_contained({16, 32}, blockSize) &&
"invalid block size");
360 int firstScaleByte = getFirstScaleByte();
361 int firstScaleLane = getFirstScaleLane();
362 auto sourceType = cast<VectorType>(getSource().
getType());
363 Type elementType = sourceType.getElementType();
364 auto floatType = cast<FloatType>(elementType);
365 unsigned bitWidth = floatType.getWidth();
369 const bool is_fp8 = bitWidth == 8;
370 const bool is_block_16 = blockSize == 16;
374 if (!llvm::is_contained({0, 1}, firstScaleByte)) {
375 return emitOpError(
"blockSize of 16 can only have firstScaleByte be 0 "
376 "or 1 for f4 and f6.");
379 if (!llvm::is_contained({0, 2}, firstScaleByte)) {
380 return emitOpError(
"blockSize of 32 can only have firstScaleByte be 0 "
381 "or 2 for f4 and f6.");
386 bool is_valid = ((firstScaleLane == 0) && (firstScaleByte == 0)) ||
387 ((firstScaleLane == 16) && (firstScaleByte == 2));
389 return emitOpError(
"blockSize of 16 can only have (firstScaleLane, "
390 "firstScaleByte) be (0, 0) or (16, 2) for f8.");
403 IntegerAttr &m, IntegerAttr &n,
408 if (dimensions.size() != 3)
410 <<
"expected 3 dimensions in MNK dimension list";
418LogicalResult WMMAOp::verify() {
419 auto sourceAType = cast<VectorType>(getSourceA().
getType());
420 auto sourceBType = cast<VectorType>(getSourceB().
getType());
421 auto destType = cast<VectorType>(getDestC().
getType());
423 Type sourceAElemType = sourceAType.getElementType();
424 Type sourceBElemType = sourceBType.getElementType();
425 if (sourceAType.getNumElements() != sourceBType.getNumElements()) {
426 return emitOpError(
"source vectors have different lengths: ")
427 << sourceAType <<
" vs. " << sourceBType;
430 bool isDestFloat = destType.getElementType().
isFloat();
431 bool isSrcFloat = sourceAElemType.
isFloat();
433 if (isDestFloat && !isSrcFloat)
434 return emitOpError(
"expected float sources with float destination");
435 if (!isDestFloat && isSrcFloat)
436 return emitOpError(
"expected int sources with int destination");
438 if (!sourceAElemType.
isFloat(8) && sourceAElemType != sourceBElemType) {
440 "source element types must match (except for fp8/bf8) but have ")
441 << sourceAType <<
" and " << sourceBType;
446 return emitOpError(
"clamp flag is not supported for float types");
447 if (getUnsignedA() || getUnsignedB())
448 return emitOpError(
"unsigned flags are not supported for float types");
457LogicalResult ScaledWMMAOp::verify() {
459 auto isF8 = llvm::IsaPred<Float8E4M3FNType, Float8E5M2Type>;
460 auto isF6 = llvm::IsaPred<Float6E2M3FNType, Float6E3M2FNType>;
461 auto isF4 = llvm::IsaPred<Float4E2M1FNType>;
462 auto isScaleF8 = llvm::IsaPred<Float8E8M0FNUType, Float8E4M3FNType>;
463 auto isE8M0 = llvm::IsaPred<Float8E8M0FNUType>;
464 auto isE4M3 = llvm::IsaPred<Float8E4M3FNType>;
466 auto sourceAType = cast<VectorType>(getSourceA().
getType());
467 auto sourceBType = cast<VectorType>(getSourceB().
getType());
468 auto destType = cast<VectorType>(getDestC().
getType());
471 Type aElemType = sourceAType.getElementType();
472 Type bElemType = sourceBType.getElementType();
476 int64_t aLen = sourceAType.getNumElements();
477 int64_t bLen = sourceBType.getNumElements();
478 int64_t expectedOutLen = (m == 16) ? 8 : 16;
480 if (destType.getNumElements() != expectedOutLen)
481 return emitOpError(
"expected output vector of length ")
482 << expectedOutLen <<
" but got " << destType.getNumElements();
488 "for 16x16x128, sourceA must have 64 elements but got ")
492 "for 16x16x128, sourceB must have 64 elements but got ")
496 if (!isF4(aElemType) && !isF4(bElemType))
497 return emitOpError(
"32x16x128 only supports fp4 element types");
501 "for 32x16x128, sourceA must have 128 elements but got ")
505 "for 32x16x128, sourceB must have 64 elements but got ")
510 if (getAFirstScaleLane() != 0)
511 return emitOpError(
"for 32x16x128, a_first_scale_lane must be 0");
515 auto scaleAType = cast<VectorType>(getScaleA().
getType());
516 auto scaleBType = cast<VectorType>(getScaleB().
getType());
517 Type scaleAElemType = scaleAType.getElementType();
518 Type scaleBElemType = scaleBType.getElementType();
521 if (!isScaleF8(scaleAElemType) || !isScaleF8(scaleBElemType))
523 "scale operands must have f8 element types (E8M0FNU or E4M3FN)");
526 if (isE8M0(scaleAElemType) && isE8M0(scaleBElemType))
530 if ((isF8(aElemType) || isF6(aElemType)) && isE8M0(scaleAElemType) &&
531 isF4(bElemType) && isE4M3(scaleBElemType))
535 if (isF4(aElemType) && isE4M3(scaleAElemType) &&
536 (isF8(bElemType) || isF6(bElemType)) && isE8M0(scaleBElemType))
540 if (isF4(aElemType) && isF4(bElemType) && isE4M3(scaleAElemType) &&
541 isE4M3(scaleBElemType))
545 return emitOpError(
"invalid combination of matrix and scale types: ")
546 <<
"sourceA=" << aElemType <<
", scaleA=" << scaleAElemType
547 <<
", sourceB=" << bElemType <<
", scaleB=" << scaleBElemType;
553LogicalResult MFMAOp::verify() {
554 constexpr uint32_t waveSize = 64;
557 Type sourceType = getSourceA().getType();
558 Type destType = getDestC().getType();
560 Type sourceElem = sourceType, destElem = destType;
561 uint32_t sourceLen = 1, destLen = 1;
562 if (
auto sourceVector = dyn_cast<VectorType>(sourceType)) {
563 sourceLen = sourceVector.getNumElements();
564 sourceElem = sourceVector.getElementType();
566 if (
auto destVector = dyn_cast<VectorType>(destType)) {
567 destLen = destVector.getNumElements();
568 destElem = destVector.getElementType();
571 Type sourceBType = getSourceB().getType();
574 Type sourceBElem = sourceBType;
575 if (
auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
576 sourceBLen = sourceBVector.getNumElements();
577 sourceBElem = sourceBVector.getElementType();
581 return emitOpError(
"expected both source operands to have small-float "
582 "elements if one does");
583 if (sourceLen != sourceBLen)
585 "expected both small-float source vectors to have the same length");
587 if (sourceType != sourceBType)
588 return emitOpError(
"expected both non-small-float source operand types "
594 sourceElem =
b.getI8Type();
598 sourceElem =
b.getI8Type();
601 int64_t numSourceElems = (
getM() * getK() * getBlocks()) / waveSize;
602 if (sourceLen != numSourceElems)
603 return emitOpError(
"expected " + Twine(numSourceElems) +
604 " source values for this operation but got " +
608 if (destLen != numDestElems)
609 return emitOpError(
"expected " + Twine(numDestElems) +
610 " result values for this operation but got " +
613 if (destElem.isF64() && getBlgp() != MFMAPermB::none)
615 "double-precision ops do not support permuting lanes of B");
616 if (destElem.isF64() && getCbsz() != 0)
618 "double-precision ops do not support permuting lanes of A");
619 if (getAbid() >= (1u << getCbsz()))
621 "block ID for permuting A (abid) must be below 2 ** cbsz");
623 if ((getNegateA() || getNegateB() || getNegateC()) && !destElem.isF64())
625 "negation flags only available for double-precision operations");
634LogicalResult SparseMFMAOp::verify() {
635 constexpr uint32_t waveSize = 64;
637 auto sparseType = cast<VectorType>(getSourceA().
getType());
638 auto denseType = cast<VectorType>(getSourceB().
getType());
639 auto destType = cast<VectorType>(getDestC().
getType());
641 Type sparseElem = sparseType.getElementType();
642 Type denseElem = denseType.getElementType();
643 int64_t sparseLen = sparseType.getNumElements();
644 int64_t denseLen = denseType.getNumElements();
645 int64_t destLen = destType.getNumElements();
647 if (denseLen != 2 * sparseLen)
648 return emitOpError(
"expected dense source operand to have exactly double "
649 "the number of elements of the sparse source operand");
655 if (!bothFloat8 && sparseElem != denseElem)
657 "expected source operands to have the same element type");
663 if (getCbsz() == 0 && is8BitSource && getAbid() > 1)
664 return emitOpError(
"ABID must be 0 or 1 for 8-bit source data");
666 if (getCbsz() == 0 && !is8BitSource && getAbid() > 3)
667 return emitOpError(
"ABID must be between 0 and 3 for 16-bit source data");
670 auto sparseIdxType = cast<VectorType>(getSparseIdx().
getType());
673 if (sparseIdxType.getNumElements() != 2 ||
674 !sparseIdxType.getElementType().isInteger(16))
675 return emitOpError(
"expected vector<2xi16> sparse indices for 8-bit "
676 "source data, but got ")
677 << getSparseIdx().getType();
680 if (sparseIdxType.getNumElements() != 4 ||
681 !sparseIdxType.getElementType().isInteger(8))
682 return emitOpError(
"expected vector<4xi8> sparse indices for 16-bit "
683 "source data, but got ")
684 << getSparseIdx().getType();
687 int64_t expectedSourceElems = (
getM() * getK()) / waveSize;
688 if (denseLen != expectedSourceElems)
689 return emitOpError(
"expected " + Twine(expectedSourceElems) +
690 " source values for this operation but got " +
694 if (destLen != expectedDestElems)
695 return emitOpError(
"expected " + Twine(expectedDestElems) +
696 " result values for this operation but got " +
705LogicalResult DPPOp::verify() {
706 Type srcType = getSrc().getType();
708 return emitOpError(
"integer and floating point types larger than 64 bits "
709 "are not supported");
712 DPPPerm kind = getKind();
717 case DPPPerm::quad_perm: {
718 auto quadPermAttr = dyn_cast_or_null<ArrayAttr>(permArgument);
719 if (!quadPermAttr || quadPermAttr.size() != 4) {
720 return emitOpError(
"quad_perm attribute must have exactly 4 elements");
722 for (
auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
723 int32_t num = elem.getInt();
724 if (num < 0 || num > 3) {
726 "Each element of quad_perm must be in the range [0, 3]");
731 case DPPPerm::row_shl:
732 case DPPPerm::row_shr:
733 case DPPPerm::row_ror: {
735 return emitOpError(
"Attribute '" + Twine(stringifyDPPPerm(kind)) +
736 "' value not specified");
738 if (
auto intAttr = dyn_cast<IntegerAttr>(permArgument)) {
739 uint32_t attrValue = intAttr.getInt();
740 if (attrValue < 1 || attrValue > 15) {
741 return emitOpError(
"Attribute value must be between 1 and 15");
746 case DPPPerm::wave_shl:
747 case DPPPerm::wave_shr:
748 case DPPPerm::wave_rol:
749 case DPPPerm::wave_ror:
750 case DPPPerm::row_mirror:
751 case DPPPerm::row_half_mirror:
752 case DPPPerm::row_bcast_15:
753 case DPPPerm::row_bcast_31: {
754 if (permArgument && !isa<UnitAttr>(permArgument)) {
755 return emitOpError(
"Expected unit attribute for permArgument, but found "
756 "non-trivial argument");
767LogicalResult PermlaneSwapOp::verify() {
768 unsigned rowLength = getRowLength();
770 if (rowLength != 16 && rowLength != 32)
771 return emitOpError(
"row_length attribute must either be 16 or 32.");
779 if (isa_and_nonnull<LDSBarrierOp>(op->getNextNode())) {
798struct FuseMemoryCounterWaitOp final :
OpRewritePattern<MemoryCounterWaitOp> {
801 LogicalResult matchAndRewrite(MemoryCounterWaitOp op,
802 PatternRewriter &rewriter)
const override {
803 auto next = dyn_cast<MemoryCounterWaitOp>(op->getNextNode());
807 auto setters = {&MemoryCounterWaitOp::setLoad,
808 &MemoryCounterWaitOp::setStore, &MemoryCounterWaitOp::setDs,
809 &MemoryCounterWaitOp::setExp,
810 &MemoryCounterWaitOp::setTensor};
811 auto lhsVals = {op.getLoad(), op.getStore(), op.getDs(), op.getExp(),
813 auto rhsVals = {next.getLoad(), next.getStore(), next.getDs(),
814 next.getExp(), next.getTensor()};
816 for (
auto [setter,
lhs,
rhs] :
817 llvm::zip_equal(setters, lhsVals, rhsVals)) {
819 (op.*setter)(std::min(*
lhs, *
rhs));
833void MemoryCounterWaitOp::getCanonicalizationPatterns(
835 results.
add<FuseMemoryCounterWaitOp>(context);
842LogicalResult GatherToLDSOp::verify() {
843 MemRefType srcType = cast<MemRefType>(getSrc().
getType());
844 MemRefType dstType = cast<MemRefType>(getDst().
getType());
846 if (dstType.getRank() > 0 && !dstType.areTrailingDimsContiguous(1))
847 return emitOpError(
"destination type inner most dim must be contiguous");
849 auto elemType = srcType.getElementType();
851 if (elemType != dstType.getElementType())
852 return emitOpError(
"source and destination element types must match");
855 auto transferType = getTransferType();
857 if (
auto vectorTransfer = dyn_cast<VectorType>(transferType)) {
858 transferSize = vectorTransfer.getNumElements() *
859 vectorTransfer.getElementTypeBitWidth();
861 transferSize = transferType.getIntOrFloatBitWidth();
863 if (!llvm::is_contained({8, 16, 32, 96, 128}, transferSize))
865 "Transfering type size must be 8, 16, 32, 96 or 128 bits");
870 "source memory address space must be global or fat raw buffer");
873 return emitOpError(
"destination memory address space must be Workgroup");
884 LogicalResult matchAndRewrite(GatherToLDSOp gatherOp,
885 PatternRewriter &rewriter)
const override {
886 bool modified =
false;
887 auto foldCast = [&](OpOperand &operand) {
888 if (
auto castOp = operand.get().getDefiningOp<memref::CastOp>()) {
889 if (memref::CastOp::canFoldIntoConsumerOp(castOp)) {
891 [&] { operand.assign(castOp.getSource()); });
897 foldCast(gatherOp.getSrcMutable());
898 foldCast(gatherOp.getDstMutable());
907 results.
add<FoldGatherToLDSOfCast>(context);
914LogicalResult TransposeLoadOp::verify() {
915 MemRefType srcType = cast<MemRefType>(getSrc().
getType());
918 return emitOpError(
"source memory address space must be Workgroup");
920 auto transferType = cast<VectorType>(
getType());
921 size_t numElements = transferType.getNumElements();
922 size_t elementTypeSize =
923 transferType.getElementType().getIntOrFloatBitWidth();
926 const llvm::SmallDenseMap<size_t, size_t> kValidLoadSizeMap = {
933 auto validNumElems = kValidLoadSizeMap.find(elementTypeSize);
934 if (validNumElems == kValidLoadSizeMap.end())
935 return emitOpError(
"Unsupported element type size for transpose load: ")
936 << elementTypeSize <<
" bits";
938 if (numElements != validNumElems->second)
940 "Transferring type size mismatch: expected num of elements: ")
941 << validNumElems->second;
950template <
typename BaseOp>
952 auto ldsType = cast<MemRefType>(op.getLds().getType());
953 auto globalType = cast<MemRefType>(op.getGlobal().getType());
955 return op.emitOpError(
956 "lds memref must have workgroup address space attribute.");
958 return op.emitOpError(
959 "global memref must have global address space attribute.");
961 Type elementType = ldsType.getElementType();
964 if (!llvm::is_contained({8u, 16u, 32u, 64u}, width))
965 return op.emitOpError(
966 "element type must be 1, 2, 4, or 8 bytes long but type was ")
967 << width <<
" bits long.";
971LogicalResult MakeDmaBaseOp::verify() {
return verifyBase(*
this); }
981 if (!llvm::is_contained({8u, 16u, 32u, 64u}, width))
983 <<
"element type must be 1, 2, 4, or 8 bytes wide but type "
984 << elementType <<
" is " << width / 8 <<
" bytes wide.";
986 Type i16 = IntegerType::get(ctx, 32);
987 Type i32 = IntegerType::get(ctx, 16);
988 if (!llvm::is_contained({i16, i32}, indexType))
989 return emitError() <<
"index type must be i16 or i32 but index type is "
994LogicalResult MakeGatherDmaBaseOp::verify() {
return verifyBase(*
this); }
1000template <
typename DescriptorOp>
1004 if (globalStaticStrides.empty())
1005 return op.emitOpError(
"strides must not be empty.");
1006 if (globalStaticStrides.back() != 1)
1007 return op.emitOpError(
"strides for the innermost dimension must be 1.");
1010 size_t rank = globalStaticSizes.size();
1012 return op.emitOpError(
"tensor and tile must be at most of rank 5.");
1013 if (rank != globalStaticStrides.size())
1014 return op.emitOpError(
"strides and sizes must have same rank.");
1017 if (rank != sharedStaticSizes.size())
1018 return op.emitOpError(
"tensor must have same rank as tile.");
1020 unsigned elementTypeWidth = op.getElementTypeWidth();
1021 if (!llvm::is_contained({8u, 16u, 32u, 64u}, elementTypeWidth))
1022 return op.emitOpError(
1023 "element type width must be 1, 2, 4 or 8 bytes, but was ")
1024 << elementTypeWidth <<
" bits long";
1026 if (
Value atomicBarrierAddress = op.getAtomicBarrierAddress()) {
1027 auto atomicBarrierAddressType =
1028 cast<MemRefType>(atomicBarrierAddress.getType());
1032 return op.emitOpError(
"atomic barrier address must be in LDS.");
1035 if (op.getEarlyTimeout() && !op.getWorkgroupMask())
1036 return op.emitOpError(
1037 "early timeout does not apply when workgroup_mask is not set.");
1041template <
typename DescriptorOp,
typename FoldAdaptor>
1062 op.setGlobalStaticSizes(staticGlobalSizes);
1063 op.getGlobalDynamicSizesMutable().assign(dynamicGlobalSizes);
1066 staticGlobalStrides);
1067 op.setGlobalStaticStrides(staticGlobalStrides);
1068 op.getGlobalDynamicStridesMutable().assign(dynamicGlobalStrides);
1072 op.setSharedStaticSizes(staticSharedSizes);
1073 op.getSharedDynamicSizesMutable().assign(dynamicSharedSizes);
1074 return op.getResult();
1077LogicalResult MakeDmaDescriptorOp::verify() {
1081OpFoldResult MakeDmaDescriptorOp::fold(FoldAdaptor adaptor) {
1089LogicalResult MakeGatherDmaDescriptorOp::verify() {
1091 size_t rank = globalStaticSizes.size();
1094 "tensor and tile must be at most of rank two in gather mode.");
1096 Type elementType = cast<VectorType>(
indices.getType()).getElementType();
1098 return emitOpError(
"indices' element type must match base's element type.");
1103OpFoldResult MakeGatherDmaDescriptorOp::fold(FoldAdaptor adaptor) {
1117 LogicalResult matchAndRewrite(ScaledMFMAOp op,
1118 PatternRewriter &rewriter)
const override {
1119 Location loc = op.getLoc();
1120 auto setOpsel = [&op](
unsigned idx, int64_t val) {
1123 op.setScalesIdxA(val);
1126 op.setScalesIdxB(val);
1150 for (
auto opIdx : std::array<int64_t, 2>({3, 4})) {
1151 auto insertOp = op.getOperand(opIdx).getDefiningOp<vector::InsertOp>();
1154 "defining op not a vector.insert");
1157 if (isa<VectorType>(insertOp.getValueToStore().getType())) {
1159 op,
"scaled mfma operand already packed");
1163 insertOp.getValueToStore().getDefiningOp<vector::ExtractOp>();
1166 "defining op not a vector.extract");
1169 Value scaleSrc = extractOp.getOperand(0);
1170 auto scaleSrcType = dyn_cast<VectorType>(scaleSrc.
getType());
1171 if (!scaleSrcType) {
1177 if (!scaleSrcType.hasStaticShape()) {
1179 "dynamic dims not yet supported");
1182 int64_t numElements = scaleSrcType.getNumElements();
1183 if (numElements <= 4) {
1185 op,
"no packing if # of scales less than four");
1189 auto extractedPos = llvm::to_vector_of<int64_t>(
1190 llvm::reverse(extractOp.getStaticPosition()));
1191 ArrayRef<int64_t> scaleSrcShape = scaleSrcType.getShape();
1192 int64_t scaleSrcRank = scaleSrcType.getRank();
1193 SmallVector<int64_t> extractSizes(scaleSrcRank, 1);
1194 for (int64_t i = 1; i < scaleSrcRank; ++i) {
1195 extractSizes[i] = extractSizes[i - 1] * scaleSrcShape[scaleSrcRank - i];
1197 int64_t idx =
linearize(extractedPos, extractSizes);
1209 int64_t offset = idx - (idx % 4);
1210 int64_t opsel = idx - offset;
1213 if (numElements - offset < size) {
1214 opsel = size - (numElements - idx);
1215 offset = numElements - 4l;
1217 Type scaleSrcElemType = scaleSrcType.getElementType();
1219 VectorType::get(ArrayRef{numElements}, scaleSrcElemType);
1221 vector::ShapeCastOp::create(rewriter, loc, newSrcType, scaleSrc);
1222 auto extract = vector::ExtractStridedSliceOp::create(
1223 rewriter, loc, newScaleSrc, ArrayRef{offset}, ArrayRef{size},
1224 ArrayRef{int64_t(1)});
1226 op->setOperand(opIdx, extract);
1227 setOpsel(opIdx, opsel);
1237 results.
add<PackScales>(context);
1240#include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
1242#define GET_ATTRDEF_CLASSES
1243#include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
1245#define GET_TYPEDEF_CLASSES
1246#include "mlir/Dialect/AMDGPU/IR/AMDGPUTypes.cpp.inc"
1248#define GET_OP_CLASSES
1249#include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
static LogicalResult verifyDescriptorOp(DescriptorOp op)
static LogicalResult verifyRawBufferOp(T &op)
static OpFoldResult foldDescriptorOp(DescriptorOp op, FoldAdaptor adaptor)
static bool hasGlobalMemorySpace(Attribute memorySpace)
static bool hasWorkgroupMemorySpace(Attribute memorySpace)
static FailureOr< MemRefType > getFatRawBufferTypeLike(MemRefType source, bool resetOffset)
Convert the type source to one with the same sizes and strides - and offset, unless stripOffset is tr...
static LogicalResult eraseRedundantLDSBarrierOps(LDSBarrierOp op, PatternRewriter &rewriter)
Remove amdgpu.lds_barrier after amdgpu.lds_barrier.
static bool hasFatRawBufferMemorySpace(Attribute memorySpace)
static LogicalResult verifyBase(BaseOp op)
static bool staticallyOutOfBounds(OpType op)
static std::optional< uint32_t > getConstantUint32(Value v)
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
Attributes are known-constant values of operations.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI32IntegerAttr(int32_t value)
TypedAttr getZeroAttr(Type type)
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
This is a builder type that keeps local references to arguments.
Builder & setMemorySpace(Attribute newMemorySpace)
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This class helps build Operations.
This class represents a single result from folding an operation.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class provides an abstraction over the different types of ranges over Regions.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isFloat() const
Return true if this is an float type (with the specified width).
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
ParseResult parseMNKDimensionList(OpAsmParser &parser, IntegerAttr &m, IntegerAttr &n, IntegerAttr &k)
Parser for the custom<MNKDimensionList> custom assembly format used by WMMAOp.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given memref value.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
uint64_t getN(LevelType lt)
uint64_t getM(LevelType lt)
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< int64_t > computeSuffixProduct(ArrayRef< int64_t > sizes)
Given a set of sizes, return the suffix product.
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
llvm::function_ref< Fn > function_ref
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...