22 #include "../LLVMCommon/MemRefDescriptor.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/TypeSwitch.h"
29 #define GEN_PASS_DEF_CONVERTAMDGPUTOROCDLPASS
30 #include "mlir/Conversion/Passes.h.inc"
47 auto valTy = cast<IntegerType>(val.
getType());
50 return valTy.getWidth() > 32
51 ?
Value(rewriter.
create<LLVM::TruncOp>(loc, i32, val))
52 :
Value(rewriter.
create<LLVM::ZExtOp>(loc, i32, val));
58 return rewriter.
create<LLVM::ConstantOp>(loc, i32, value);
64 return rewriter.
create<LLVM::ConstantOp>(loc, llvmI1, value);
76 ShapedType::isDynamic(stride)
78 memRefDescriptor.
stride(rewriter, loc, i))
79 : rewriter.
create<LLVM::ConstantOp>(loc, i32, stride);
80 increment = rewriter.
create<LLVM::MulOp>(loc, increment, strideValue);
83 index ? rewriter.
create<LLVM::AddOp>(loc, index, increment) : increment;
92 MemRefType memrefType,
95 uint32_t elementByteWidth) {
96 if (memrefType.hasStaticShape() &&
97 !llvm::any_of(strides, ShapedType::isDynamic)) {
98 int64_t size = memrefType.getRank() == 0 ? 1 : 0;
100 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
101 size =
std::max(shape[i] * strides[i], size);
102 size = size * elementByteWidth;
104 "the memref buffer is too large");
108 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
109 Value size = memrefDescriptor.
size(rewriter, loc, i);
110 Value stride = memrefDescriptor.
stride(rewriter, loc, i);
111 Value maxThisDim = rewriter.
create<LLVM::MulOp>(loc, size, stride);
113 ? rewriter.
create<LLVM::UMaxOp>(loc, maxIndex, maxThisDim)
118 return rewriter.
create<LLVM::MulOp>(loc, maxIndexI32, byteWidthConst);
123 bool boundsCheck, amdgpu::Chipset chipset,
124 Value cacheSwizzleStride =
nullptr,
125 unsigned addressSpace = 8) {
131 if (chipset.majorVersion == 9 && chipset >=
kGfx942 && cacheSwizzleStride) {
132 Value cacheStrideZext =
133 rewriter.
create<LLVM::ZExtOp>(loc, i16, cacheSwizzleStride);
134 Value swizzleBit = rewriter.
create<LLVM::ConstantOp>(
136 stride = rewriter.
create<LLVM::OrOp>(loc, cacheStrideZext, swizzleBit,
139 stride = rewriter.
create<LLVM::ConstantOp>(loc, i16,
157 uint32_t flags = (7 << 12) | (4 << 15);
158 if (chipset.majorVersion >= 10) {
160 uint32_t oob = boundsCheck ? 3 : 2;
161 flags |= (oob << 28);
167 loc, rsrcType, basePointer, stride, numRecords, flagsConst);
172 struct FatRawBufferCastLowering
181 matchAndRewrite(FatRawBufferCastOp op, FatRawBufferCastOpAdaptor adaptor,
184 Value memRef = adaptor.getSource();
185 Value unconvertedMemref = op.getSource();
186 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.
getType());
190 int64_t elementByteWidth =
193 int64_t unusedOffset = 0;
195 if (failed(memrefType.getStridesAndOffset(strideVals, unusedOffset)))
196 return op.emitOpError(
"Can't lower non-stride-offset memrefs");
198 Value numRecords = adaptor.getValidBytes();
200 numRecords =
getNumRecords(rewriter, loc, memrefType, descriptor,
201 strideVals, elementByteWidth);
204 adaptor.getResetOffset()
205 ? descriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
207 : descriptor.alignedPtr(rewriter, loc);
209 Value offset = adaptor.getResetOffset()
210 ? rewriter.
create<LLVM::ConstantOp>(
212 : descriptor.offset(rewriter, loc);
214 bool hasSizes = memrefType.getRank() > 0;
217 Value sizes = hasSizes ? rewriter.
create<LLVM::ExtractValueOp>(
220 Value strides = hasSizes
221 ? rewriter.
create<LLVM::ExtractValueOp>(
226 rewriter, loc, basePointer, numRecords, adaptor.getBoundsCheck(),
227 chipset, adaptor.getCacheSwizzleStride(), 7);
231 getTypeConverter()->convertType(op.getResult().getType()));
232 result = rewriter.
create<LLVM::InsertValueOp>(
234 result = rewriter.
create<LLVM::InsertValueOp>(
236 result = rewriter.
create<LLVM::InsertValueOp>(loc, result, offset,
239 result = rewriter.
create<LLVM::InsertValueOp>(loc, result, sizes,
241 result = rewriter.
create<LLVM::InsertValueOp>(
250 template <
typename GpuOp,
typename Intrinsic>
256 static constexpr uint32_t maxVectorOpWidth = 128;
259 matchAndRewrite(GpuOp gpuOp,
typename GpuOp::Adaptor adaptor,
262 Value memref = adaptor.getMemref();
263 Value unconvertedMemref = gpuOp.getMemref();
264 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.
getType());
267 return gpuOp.emitOpError(
"raw buffer ops require GCN or higher");
269 Value storeData = adaptor.getODSOperands(0)[0];
270 if (storeData == memref)
274 wantedDataType = storeData.
getType();
276 wantedDataType = gpuOp.getODSResults(0)[0].getType();
281 Value maybeCmpData = adaptor.getODSOperands(1)[0];
282 if (maybeCmpData != memref)
283 atomicCmpData = maybeCmpData;
286 Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
292 int64_t elementByteWidth =
301 Type llvmBufferValType = llvmWantedDataType;
303 if (
auto floatType = dyn_cast<FloatType>(wantedDataType))
304 llvmBufferValType = this->getTypeConverter()->convertType(
307 if (
auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
308 uint32_t vecLen = dataVector.getNumElements();
311 uint32_t totalBits = elemBits * vecLen;
313 isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2;
314 if (totalBits > maxVectorOpWidth)
315 return gpuOp.emitOpError(
316 "Total width of loads or stores must be no more than " +
317 Twine(maxVectorOpWidth) +
" bits, but we call for " +
319 " bits. This should've been caught in validation");
320 if (!usePackedFp16 && elemBits < 32) {
321 if (totalBits > 32) {
322 if (totalBits % 32 != 0)
323 return gpuOp.emitOpError(
"Load or store of more than 32-bits that "
324 "doesn't fit into words. Can't happen\n");
325 llvmBufferValType = this->typeConverter->convertType(
328 llvmBufferValType = this->typeConverter->convertType(
333 if (
auto vecType = dyn_cast<VectorType>(llvmBufferValType)) {
336 if (vecType.getNumElements() == 1)
337 llvmBufferValType = vecType.getElementType();
342 if (llvmBufferValType != llvmWantedDataType) {
344 rewriter.
create<LLVM::BitcastOp>(loc, llvmBufferValType, storeData);
345 args.push_back(castForStore);
347 args.push_back(storeData);
352 if (llvmBufferValType != llvmWantedDataType) {
353 Value castForCmp = rewriter.
create<LLVM::BitcastOp>(
354 loc, llvmBufferValType, atomicCmpData);
355 args.push_back(castForCmp);
357 args.push_back(atomicCmpData);
364 if (failed(memrefType.getStridesAndOffset(strides, offset)))
365 return gpuOp.emitOpError(
"Can't lower non-stride-offset memrefs");
369 Value ptr = memrefDescriptor.bufferPtr(
370 rewriter, loc, *this->getTypeConverter(), memrefType);
372 rewriter, loc, memrefType, memrefDescriptor, strides, elementByteWidth);
374 adaptor.getBoundsCheck(), chipset);
375 args.push_back(resource);
379 adaptor.getIndices(), strides);
380 if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset();
381 indexOffset && *indexOffset > 0) {
384 voffset ? rewriter.
create<LLVM::AddOp>(loc, voffset, extraOffsetConst)
387 voffset = rewriter.
create<LLVM::MulOp>(loc, voffset, byteWidthConst);
388 args.push_back(voffset);
391 Value sgprOffset = adaptor.getSgprOffset();
394 sgprOffset = rewriter.
create<LLVM::MulOp>(loc, sgprOffset, byteWidthConst);
395 args.push_back(sgprOffset);
404 Operation *lowered = rewriter.
create<Intrinsic>(loc, resultTypes, args,
408 if (llvmBufferValType != llvmWantedDataType) {
409 replacement = rewriter.
create<LLVM::BitcastOp>(loc, llvmWantedDataType,
427 matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
431 if (requiresInlineAsm) {
433 LLVM::AsmDialect::AD_ATT);
435 ";;;WARNING: BREAKS DEBUG WATCHES\ns_waitcnt lgkmcnt(0)\ns_barrier";
436 const char *constraints =
"";
440 asmStr, constraints,
true,
441 false, asmDialectAttr,
446 constexpr int32_t ldsOnlyBitsGfx6789 = ~(0x1f << 8);
447 constexpr int32_t ldsOnlyBitsGfx10 = ~(0x3f << 8);
450 constexpr int32_t ldsOnlyBitsGfx11 = ~(0x3f << 4);
454 ldsOnlyBits = ldsOnlyBitsGfx11;
456 ldsOnlyBits = ldsOnlyBitsGfx10;
458 ldsOnlyBits = ldsOnlyBitsGfx6789;
460 return op.emitOpError(
461 "don't know how to lower this for chipset major version")
465 rewriter.
create<ROCDL::SWaitcntOp>(loc, ldsOnlyBits);
469 rewriter.
create<ROCDL::WaitDscntOp>(loc, 0);
470 rewriter.
create<ROCDL::BarrierSignalOp>(loc, -1);
485 matchAndRewrite(SchedBarrierOp op, SchedBarrierOp::Adaptor adaptor,
488 (uint32_t)op.getOpts());
511 if (
auto vectorType = dyn_cast<VectorType>(inputType)) {
512 if (vectorType.getElementType().isBF16())
513 return rewriter.
create<LLVM::BitcastOp>(
514 loc, vectorType.clone(rewriter.
getI16Type()), input);
515 if (vectorType.getElementType().isInteger(8) &&
516 vectorType.getNumElements() <= 8)
517 return rewriter.
create<LLVM::BitcastOp>(
518 loc, rewriter.
getIntegerType(vectorType.getNumElements() * 8), input);
519 if (isa<IntegerType>(vectorType.getElementType()) &&
520 vectorType.getElementTypeBitWidth() <= 8) {
522 vectorType.getNumElements() * vectorType.getElementTypeBitWidth(),
524 return rewriter.
create<LLVM::BitcastOp>(
542 bool isUnsigned,
Value llvmInput,
546 auto vectorType = dyn_cast<VectorType>(inputType);
548 operands.push_back(llvmInput);
551 Type elemType = vectorType.getElementType();
554 llvmInput = rewriter.
create<LLVM::BitcastOp>(
555 loc, vectorType.clone(rewriter.
getI16Type()), llvmInput);
557 operands.push_back(llvmInput);
564 auto mlirInputType = cast<VectorType>(mlirInput.
getType());
565 bool isInputInteger = mlirInputType.getElementType().isInteger();
566 if (isInputInteger) {
568 bool localIsUnsigned = isUnsigned;
570 localIsUnsigned =
true;
572 localIsUnsigned =
false;
575 operands.push_back(sign);
581 Type intrinsicInType = numBits <= 32
584 auto llvmIntrinsicInType = typeConverter->
convertType(intrinsicInType);
586 loc, llvmIntrinsicInType, llvmInput);
591 castInput = rewriter.
create<LLVM::ZExtOp>(loc, i32, castInput);
592 operands.push_back(castInput);
605 Value output, int32_t subwordOffset,
608 auto vectorType = dyn_cast<VectorType>(inputType);
609 Type elemType = vectorType.getElementType();
611 output = rewriter.
create<LLVM::BitcastOp>(
612 loc, vectorType.clone(rewriter.
getI16Type()), output);
613 operands.push_back(output);
624 return (chipset ==
kGfx942 && isa<Float8E5M2FNUZType>(type)) ||
625 (
hasOcpFp8(chipset) && isa<Float8E5M2Type>(type));
631 return (chipset ==
kGfx942 && isa<Float8E4M3FNUZType>(type)) ||
632 (
hasOcpFp8(chipset) && isa<Float8E4M3FNType>(type));
640 uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
641 b = mfma.getBlocks();
646 if (mfma.getReducePrecision() && chipset >=
kGfx942) {
647 if (m == 32 && n == 32 && k == 4 && b == 1)
648 return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
649 if (m == 16 && n == 16 && k == 8 && b == 1)
650 return ROCDL::mfma_f32_16x16x8_xf32::getOperationName();
652 if (m == 32 && n == 32 && k == 1 && b == 2)
653 return ROCDL::mfma_f32_32x32x1f32::getOperationName();
654 if (m == 16 && n == 16 && k == 1 && b == 4)
655 return ROCDL::mfma_f32_16x16x1f32::getOperationName();
656 if (m == 4 && n == 4 && k == 1 && b == 16)
657 return ROCDL::mfma_f32_4x4x1f32::getOperationName();
658 if (m == 32 && n == 32 && k == 2 && b == 1)
659 return ROCDL::mfma_f32_32x32x2f32::getOperationName();
660 if (m == 16 && n == 16 && k == 4 && b == 1)
661 return ROCDL::mfma_f32_16x16x4f32::getOperationName();
666 if (m == 32 && n == 32 && k == 16 && b == 1)
667 return ROCDL::mfma_f32_32x32x16_f16::getOperationName();
668 if (m == 16 && n == 16 && k == 32 && b == 1)
669 return ROCDL::mfma_f32_16x16x32_f16::getOperationName();
671 if (m == 32 && n == 32 && k == 4 && b == 2)
672 return ROCDL::mfma_f32_32x32x4f16::getOperationName();
673 if (m == 16 && n == 16 && k == 4 && b == 4)
674 return ROCDL::mfma_f32_16x16x4f16::getOperationName();
675 if (m == 4 && n == 4 && k == 4 && b == 16)
676 return ROCDL::mfma_f32_4x4x4f16::getOperationName();
677 if (m == 32 && n == 32 && k == 8 && b == 1)
678 return ROCDL::mfma_f32_32x32x8f16::getOperationName();
679 if (m == 16 && n == 16 && k == 16 && b == 1)
680 return ROCDL::mfma_f32_16x16x16f16::getOperationName();
685 if (m == 32 && n == 32 && k == 16 && b == 1)
686 return ROCDL::mfma_f32_32x32x16_bf16::getOperationName();
687 if (m == 16 && n == 16 && k == 32 && b == 1)
688 return ROCDL::mfma_f32_16x16x32_bf16::getOperationName();
691 if (m == 32 && n == 32 && k == 4 && b == 2)
692 return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
693 if (m == 16 && n == 16 && k == 4 && b == 4)
694 return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
695 if (m == 4 && n == 4 && k == 4 && b == 16)
696 return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
697 if (m == 32 && n == 32 && k == 8 && b == 1)
698 return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
699 if (m == 16 && n == 16 && k == 16 && b == 1)
700 return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
702 if (m == 32 && n == 32 && k == 2 && b == 2)
703 return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
704 if (m == 16 && n == 16 && k == 2 && b == 4)
705 return ROCDL::mfma_f32_16x16x2bf16::getOperationName();
706 if (m == 4 && n == 4 && k == 2 && b == 16)
707 return ROCDL::mfma_f32_4x4x2bf16::getOperationName();
708 if (m == 32 && n == 32 && k == 4 && b == 1)
709 return ROCDL::mfma_f32_32x32x4bf16::getOperationName();
710 if (m == 16 && n == 16 && k == 8 && b == 1)
711 return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
716 if (m == 32 && n == 32 && k == 32 && b == 1)
717 return ROCDL::mfma_i32_32x32x32_i8::getOperationName();
718 if (m == 16 && n == 16 && k == 64 && b == 1)
719 return ROCDL::mfma_i32_16x16x64_i8::getOperationName();
721 if (m == 32 && n == 32 && k == 4 && b == 2)
722 return ROCDL::mfma_i32_32x32x4i8::getOperationName();
723 if (m == 16 && n == 16 && k == 4 && b == 4)
724 return ROCDL::mfma_i32_16x16x4i8::getOperationName();
725 if (m == 4 && n == 4 && k == 4 && b == 16)
726 return ROCDL::mfma_i32_4x4x4i8::getOperationName();
727 if (m == 32 && n == 32 && k == 8 && b == 1)
728 return ROCDL::mfma_i32_32x32x8i8::getOperationName();
729 if (m == 16 && n == 16 && k == 16 && b == 1)
730 return ROCDL::mfma_i32_16x16x16i8::getOperationName();
731 if (m == 32 && n == 32 && k == 16 && b == 1 && chipset >=
kGfx942)
732 return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
733 if (m == 16 && n == 16 && k == 32 && b == 1 && chipset >=
kGfx942)
734 return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
738 if (m == 16 && n == 16 && k == 4 && b == 1)
739 return ROCDL::mfma_f64_16x16x4f64::getOperationName();
740 if (m == 4 && n == 4 && k == 4 && b == 4)
741 return ROCDL::mfma_f64_4x4x4f64::getOperationName();
748 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
749 if (m == 16 && n == 16 && k == 32 && b == 1) {
751 return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
753 return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
755 if (m == 32 && n == 32 && k == 16 && b == 1) {
757 return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
759 return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
765 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
766 if (m == 16 && n == 16 && k == 32 && b == 1) {
768 return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
770 return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
772 if (m == 32 && n == 32 && k == 16 && b == 1) {
774 return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
776 return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
785 .Case([](Float8E4M3FNType) {
return 0u; })
786 .Case([](Float8E5M2Type) {
return 1u; })
787 .Case([](Float6E2M3FNType) {
return 2u; })
788 .Case([](Float6E3M2FNType) {
return 3u; })
789 .Case([](Float4E2M1FNType) {
return 4u; })
790 .Default([](
Type) {
return std::nullopt; });
800 static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
802 uint32_t n, uint32_t k, uint32_t b,
Chipset chipset) {
809 if (!isa<Float32Type>(destType))
814 if (!aTypeCode || !bTypeCode)
817 if (m == 32 && n == 32 && k == 64 && b == 1)
818 return std::tuple{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(),
819 *aTypeCode, *bTypeCode};
820 if (m == 16 && n == 16 && k == 128 && b == 1)
822 ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), *aTypeCode,
828 static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
831 mfma.getSourceA().getType(), mfma.getSourceB().getType(),
832 mfma.getDestC().getType(), mfma.getM(), mfma.getN(), mfma.getK(),
833 mfma.getBlocks(), chipset);
841 auto sourceVectorType = dyn_cast<VectorType>(wmma.getSourceA().getType());
842 auto sourceBVectorType = dyn_cast<VectorType>(wmma.getSourceB().getType());
843 auto destVectorType = dyn_cast<VectorType>(wmma.getDestC().getType());
844 auto elemSourceType = sourceVectorType.getElementType();
845 auto elemBSourceType = sourceBVectorType.getElementType();
846 auto elemDestType = destVectorType.getElementType();
848 if (elemSourceType.isF16() && elemDestType.isF32())
849 return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
850 if (elemSourceType.isBF16() && elemDestType.isF32())
851 return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
852 if (elemSourceType.isF16() && elemDestType.isF16())
853 return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
854 if (elemSourceType.isBF16() && elemDestType.isBF16())
855 return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
856 if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
857 return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
859 if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
860 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
863 if (isa<Float8E4M3FNType>(elemSourceType) &&
864 isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
865 return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
866 if (isa<Float8E4M3FNType>(elemSourceType) &&
867 isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
868 return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName();
869 if (isa<Float8E5M2Type>(elemSourceType) &&
870 isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
871 return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName();
872 if (isa<Float8E5M2Type>(elemSourceType) &&
873 isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
874 return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
875 if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) {
876 bool isWave64 = destVectorType.getNumElements() == 4;
879 bool has8Inputs = sourceVectorType.getNumElements() == 8;
880 if ((isWave64 && has8Inputs) || (!isWave64 && !has8Inputs))
881 return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
882 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
896 matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor,
899 Type outType = typeConverter->convertType(op.getDestD().getType());
900 Type intrinsicOutType = outType;
901 if (
auto outVecType = dyn_cast<VectorType>(outType))
902 if (outVecType.getElementType().isBF16())
903 intrinsicOutType = outVecType.clone(rewriter.
getI16Type());
906 return op->emitOpError(
"MFMA only supported on gfx908+");
907 uint32_t getBlgpField =
static_cast<uint32_t
>(op.getBlgp());
908 if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
910 return op.emitOpError(
"negation unsupported on older than gfx942");
912 op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
915 std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
917 if (!maybeIntrinsic.has_value() && !maybeScaledIntrinsic.has_value())
918 return op.emitOpError(
"no intrinsic matching MFMA size on given chipset");
921 !maybeIntrinsic.has_value() && maybeScaledIntrinsic.has_value();
923 (adaptor.getAbid() > 0 || getBlgpField > 0 || op.getCbsz() > 0)) {
924 return op.emitOpError(
925 "non-default abid, blgp, and cbsz aren't supported on MFMAs that can "
926 "be scaled as those fields are used for type information");
929 StringRef intrinsicName =
930 isScaled ? std::get<0>(*maybeScaledIntrinsic) : *maybeIntrinsic;
932 loweredOp.addTypes(intrinsicOutType);
933 loweredOp.addOperands(
936 adaptor.getDestC()});
939 auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
950 if (outType != intrinsicOutType)
951 lowered = rewriter.
create<LLVM::BitcastOp>(loc, outType, lowered);
964 matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
968 typeConverter->convertType<VectorType>(op.getDestD().getType());
973 return op->emitOpError(
"WMMA only supported on gfx11 and gfx12");
977 VectorType rawOutType = outType;
978 if (outType.getElementType().
isBF16())
979 rawOutType = outType.clone(rewriter.
getI16Type());
983 if (!maybeIntrinsic.has_value())
984 return op.emitOpError(
"no intrinsic matching WMMA on the given chipset");
986 if (chipset.
majorVersion >= 12 && op.getSubwordOffset() != 0)
987 return op.emitOpError(
"subwordOffset not supported on gfx12+");
990 loweredOp.addTypes(rawOutType);
994 adaptor.getSourceA(), op.getSourceA(), operands);
996 adaptor.getSourceB(), op.getSourceB(), operands);
998 op.getSubwordOffset(), op.getClamp(), operands);
1000 loweredOp.addOperands(operands);
1004 if (rawOutType != outType)
1006 rewriter.
create<LLVM::BitcastOp>(loc, outType, lowered->
getResult(0));
1020 matchAndRewrite(GatherToLDSOp op, GatherToLDSOpAdaptor adaptor,
1023 return op.emitOpError(
"chipset not supported");
1027 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
1028 auto dstMemRefType = cast<MemRefType>(op.getSrc().getType());
1033 Type transferType = op.getTransferType();
1034 size_t loadWidth = [&]() ->
size_t {
1035 if (
auto transferVectorType = dyn_cast<VectorType>(transferType)) {
1036 return transferVectorType.getNumElements() *
1037 (transferVectorType.getElementTypeBitWidth() / 8);
1044 if (loadWidth != 1 && loadWidth != 2 && loadWidth != 4)
1045 return op.emitOpError(
"chipset unsupported element size");
1047 Value srcPtr = getStridedElementPtr(loc, srcMemRefType, adaptor.getSrc(),
1048 (adaptor.getSrcIndices()), rewriter);
1049 Value dstPtr = getStridedElementPtr(loc, dstMemRefType, adaptor.getDst(),
1050 (adaptor.getDstIndices()), rewriter);
1063 struct ExtPackedFp8OpLowering final
1071 matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
1075 struct PackedTrunc2xFp8OpLowering final
1084 matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
1088 struct PackedStochRoundFp8OpLowering final
1097 matchAndRewrite(PackedStochRoundFp8Op op,
1098 PackedStochRoundFp8OpAdaptor adaptor,
1103 LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
1104 ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
1109 loc,
"Fp8 conversion instructions are not available on target "
1110 "architecture and their emulation is not implemented");
1113 Type i32 = getTypeConverter()->convertType(rewriter.
getI32Type());
1114 Type f32 = getTypeConverter()->convertType(op.getResult().getType());
1116 Value source = adaptor.getSource();
1117 auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
1118 auto resultVecType = dyn_cast<VectorType>(op.getResult().getType());
1121 if (!sourceVecType || sourceVecType.getNumElements() < 4) {
1122 Value longVec = rewriter.
create<LLVM::UndefOp>(loc, v4i8);
1123 if (!sourceVecType) {
1124 longVec = rewriter.
create<LLVM::InsertElementOp>(
1127 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
1129 Value elem = rewriter.
create<LLVM::ExtractElementOp>(loc, source, idx);
1131 rewriter.
create<LLVM::InsertElementOp>(loc, longVec, elem, idx);
1136 Value i32Source = rewriter.
create<LLVM::BitcastOp>(loc, i32, source);
1137 if (resultVecType) {
1159 LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
1160 PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
1165 loc,
"Fp8 conversion instructions are not available on target "
1166 "architecture and their emulation is not implemented");
1167 Type i32 = getTypeConverter()->convertType(rewriter.
getI32Type());
1169 Type resultType = op.getResult().getType();
1172 Value sourceA = adaptor.getSourceA();
1173 Value sourceB = adaptor.getSourceB();
1175 sourceB = rewriter.
create<LLVM::UndefOp>(loc, sourceA.
getType());
1176 Value existing = adaptor.getExisting();
1178 existing = rewriter.
create<LLVM::BitcastOp>(loc, i32, existing);
1180 existing = rewriter.
create<LLVM::UndefOp>(loc, i32);
1185 result = rewriter.
create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
1188 result = rewriter.
create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
1192 op, getTypeConverter()->convertType(resultType), result);
1196 LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
1197 PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
1202 loc,
"Fp8 conversion instructions are not available on target "
1203 "architecture and their emulation is not implemented");
1204 Type i32 = getTypeConverter()->convertType(rewriter.
getI32Type());
1206 Type resultType = op.getResult().getType();
1209 Value source = adaptor.getSource();
1210 Value stoch = adaptor.getStochiasticParam();
1211 Value existing = adaptor.getExisting();
1213 existing = rewriter.
create<LLVM::BitcastOp>(loc, i32, existing);
1215 existing = rewriter.
create<LLVM::UndefOp>(loc, i32);
1220 result = rewriter.
create<ROCDL::CvtSrBf8F32Op>(loc, i32, source, stoch,
1223 result = rewriter.
create<ROCDL::CvtSrFp8F32Op>(loc, i32, source, stoch,
1227 op, getTypeConverter()->convertType(resultType), result);
1239 matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor,
1244 Value src = adaptor.getSrc();
1245 Value old = adaptor.getOld();
1248 Type llvmType =
nullptr;
1251 }
else if (isa<FloatType>(srcType)) {
1255 }
else if (isa<IntegerType>(srcType)) {
1260 auto llvmSrcIntType = typeConverter->convertType(
1264 auto convertOperand = [&](
Value operand,
Type operandType) {
1265 if (operandType.getIntOrFloatBitWidth() <= 16) {
1266 if (llvm::isa<FloatType>(operandType)) {
1268 rewriter.
create<LLVM::BitcastOp>(loc, llvmSrcIntType, operand);
1271 32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
1272 Value undefVec = rewriter.
create<LLVM::UndefOp>(loc, llvmVecType);
1273 operand = rewriter.
create<LLVM::InsertElementOp>(
1275 operand = rewriter.
create<LLVM::BitcastOp>(loc, llvmType, operand);
1280 src = convertOperand(src, srcType);
1281 old = convertOperand(old, oldType);
1284 enum DppCtrl :
unsigned {
1293 ROW_HALF_MIRROR = 0x141,
1298 auto kind = DppOp.getKind();
1299 auto permArgument = DppOp.getPermArgument();
1300 uint32_t DppCtrl = 0;
1304 case DPPPerm::quad_perm:
1305 if (
auto quadPermAttr = cast<ArrayAttr>(*permArgument)) {
1307 for (
auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
1308 uint32_t num = elem.getInt();
1309 DppCtrl |= num << (i * 2);
1314 case DPPPerm::row_shl:
1315 if (
auto intAttr = cast<IntegerAttr>(*permArgument)) {
1316 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0;
1319 case DPPPerm::row_shr:
1320 if (
auto intAttr = cast<IntegerAttr>(*permArgument)) {
1321 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0;
1324 case DPPPerm::row_ror:
1325 if (
auto intAttr = cast<IntegerAttr>(*permArgument)) {
1326 DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0;
1329 case DPPPerm::wave_shl:
1330 DppCtrl = DppCtrl::WAVE_SHL1;
1332 case DPPPerm::wave_shr:
1333 DppCtrl = DppCtrl::WAVE_SHR1;
1335 case DPPPerm::wave_rol:
1336 DppCtrl = DppCtrl::WAVE_ROL1;
1338 case DPPPerm::wave_ror:
1339 DppCtrl = DppCtrl::WAVE_ROR1;
1341 case DPPPerm::row_mirror:
1342 DppCtrl = DppCtrl::ROW_MIRROR;
1344 case DPPPerm::row_half_mirror:
1345 DppCtrl = DppCtrl::ROW_HALF_MIRROR;
1347 case DPPPerm::row_bcast_15:
1348 DppCtrl = DppCtrl::BCAST15;
1350 case DPPPerm::row_bcast_31:
1351 DppCtrl = DppCtrl::BCAST31;
1357 auto rowMask = DppOp->getAttrOfType<IntegerAttr>(
"row_mask").getInt();
1358 auto bankMask = DppOp->getAttrOfType<IntegerAttr>(
"bank_mask").getInt();
1359 bool boundCtrl = DppOp->getAttrOfType<
BoolAttr>(
"bound_ctrl").getValue();
1362 auto dppMovOp = rewriter.
create<ROCDL::DPPUpdateOp>(
1363 loc, llvmType, old, src, DppCtrl, rowMask, bankMask, boundCtrl);
1365 Value result = dppMovOp.getRes();
1367 result = rewriter.
create<LLVM::TruncOp>(loc, llvmSrcIntType, result);
1368 if (!llvm::isa<IntegerType>(srcType)) {
1369 result = rewriter.
create<LLVM::BitcastOp>(loc, srcType, result);
1380 struct ConvertAMDGPUToROCDLPass
1381 :
public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
1384 void runOnOperation()
override {
1387 if (failed(maybeChipset)) {
1389 return signalPassFailure();
1396 target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
1397 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
1398 target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
1401 signalPassFailure();
1413 switch (as.getValue()) {
1414 case amdgpu::AddressSpace::FatRawBuffer:
1416 case amdgpu::AddressSpace::BufferRsrc:
1418 case amdgpu::AddressSpace::FatStructuredBuffer:
1430 .add<FatRawBufferCastLowering,
1431 RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
1432 RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
1433 RawBufferOpLowering<RawBufferAtomicFaddOp,
1434 ROCDL::RawPtrBufferAtomicFaddOp>,
1435 RawBufferOpLowering<RawBufferAtomicFmaxOp,
1436 ROCDL::RawPtrBufferAtomicFmaxOp>,
1437 RawBufferOpLowering<RawBufferAtomicSmaxOp,
1438 ROCDL::RawPtrBufferAtomicSmaxOp>,
1439 RawBufferOpLowering<RawBufferAtomicUminOp,
1440 ROCDL::RawPtrBufferAtomicUminOp>,
1441 RawBufferOpLowering<RawBufferAtomicCmpswapOp,
1442 ROCDL::RawPtrBufferAtomicCmpSwap>,
1443 AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
1444 MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering,
1445 PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
1446 GatherToLDSOpLowering>(converter, chipset);
static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type)
Return true if type is the E4M3FN variant of an 8-bit float that is supported by the _fp8 instruction...
constexpr Chipset kGfx942
static std::optional< StringRef > wmmaOpToIntrinsic(WMMAOp wmma, Chipset chipset)
Return the rocdl intrinsic corresponding to a WMMA operation wmma if one exists.
static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc, bool value)
constexpr Chipset kGfx908
constexpr Chipset kGfx90a
static std::optional< StringRef > mfmaOpToIntrinsic(MFMAOp mfma, Chipset chipset)
Return the rocdl intrinsic corresponding to a MFMA operation mfma if one exists.
static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, Value output, int32_t subwordOffset, bool clamp, SmallVector< Value, 4 > &operands)
Push the output operand.
static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type)
Return true if type is the E5M2 variant of an 8-bit float that is supported by the _bf8 instructions ...
static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc, Value basePointer, Value numRecords, bool boundsCheck, amdgpu::Chipset chipset, Value cacheSwizzleStride=nullptr, unsigned addressSpace=8)
static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput, Value mlirInput, SmallVector< Value, 4 > &operands)
Push an input operand.
static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input)
Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL and LLVM AMDGPU intrinsic...
static Value getLinearIndexI32(ConversionPatternRewriter &rewriter, Location loc, MemRefDescriptor &memRefDescriptor, ValueRange indices, ArrayRef< int64_t > strides)
Returns the linear index used to access an element in the memref.
static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i32.
static Value createI32Constant(ConversionPatternRewriter &rewriter, Location loc, int32_t value)
static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc, MemRefType memrefType, MemRefDescriptor &memrefDescriptor, ArrayRef< int64_t > strides, uint32_t elementByteWidth)
Compute the contents of the num_records field for a given memref descriptor - that is,...
static std::optional< uint32_t > mfmaTypeSelectCode(Type mlirElemType)
static std::optional< std::tuple< StringRef, uint32_t, uint32_t > > mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m, uint32_t n, uint32_t k, uint32_t b, Chipset chipset)
If there is a scaled MFMA instruction for the input element types aType and bType,...
constexpr Chipset kGfx950
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1195::ArityGroupAndKind::Kind kind
static constexpr unsigned kSizePosInMemRefDescriptor
static constexpr unsigned kStridePosInMemRefDescriptor
static constexpr unsigned kOffsetPosInMemRefDescriptor
static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor
static constexpr unsigned kAlignedPtrPosInMemRefDescriptor
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
This class provides a shared interface for ranked and unranked memref types.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getI16IntegerAttr(int16_t value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
The main mechanism for performing data layout queries.
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
static MemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating a poison value of the descriptor type.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
The general result of a type attribute conversion callback, allowing for early termination.
static AttributeConversionResult abort()
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
void addTypeAttributeConversion(FnT &&callback)
Register a conversion function for attributes within types.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
bool hasOcpFp8(const Chipset &chipset)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
Include the generated interface declarations.
void populateAMDGPUMemorySpaceAttributeConversions(TypeConverter &typeConverter)
Remap AMDGPU memory spaces to LLVM address spaces by mapping amdgpu::AddressSpace::fat_raw_buffer to ...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, amdgpu::Chipset chipset)
Note: This function will also add conversions for the AMDGPU-specific address spaces,...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.