23 #include "../LLVMCommon/MemRefDescriptor.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/TypeSwitch.h"
27 #include "llvm/Support/Casting.h"
28 #include "llvm/Support/ErrorHandling.h"
32 #define GEN_PASS_DEF_CONVERTAMDGPUTOROCDLPASS
33 #include "mlir/Conversion/Passes.h.inc"
50 auto valTy = cast<IntegerType>(val.
getType());
53 return valTy.getWidth() > 32
54 ?
Value(LLVM::TruncOp::create(rewriter, loc, i32, val))
55 :
Value(LLVM::ZExtOp::create(rewriter, loc, i32, val));
61 return LLVM::ConstantOp::create(rewriter, loc, i32, value);
67 return LLVM::ConstantOp::create(rewriter, loc, llvmI1, value);
79 ShapedType::isDynamic(stride)
81 memRefDescriptor.
stride(rewriter, loc, i))
82 : LLVM::ConstantOp::create(rewriter, loc, i32, stride);
83 increment = LLVM::MulOp::create(rewriter, loc, increment, strideValue);
85 index = index ? LLVM::AddOp::create(rewriter, loc, index, increment)
95 MemRefType memrefType,
98 uint32_t elementByteWidth) {
99 if (memrefType.hasStaticShape() &&
100 !llvm::any_of(strides, ShapedType::isDynamic)) {
101 int64_t size = memrefType.getRank() == 0 ? 1 : 0;
103 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
104 size =
std::max(shape[i] * strides[i], size);
105 size = size * elementByteWidth;
107 "the memref buffer is too large");
111 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
112 Value size = memrefDescriptor.
size(rewriter, loc, i);
113 Value stride = memrefDescriptor.
stride(rewriter, loc, i);
114 Value maxThisDim = LLVM::MulOp::create(rewriter, loc, size, stride);
116 ? LLVM::UMaxOp::create(rewriter, loc, maxIndex, maxThisDim)
121 return LLVM::MulOp::create(rewriter, loc, maxIndexI32, byteWidthConst);
126 bool boundsCheck, amdgpu::Chipset chipset,
127 Value cacheSwizzleStride =
nullptr,
128 unsigned addressSpace = 8) {
134 if (chipset.majorVersion == 9 && chipset >=
kGfx942 && cacheSwizzleStride) {
135 Value cacheStrideZext =
136 LLVM::ZExtOp::create(rewriter, loc, i16, cacheSwizzleStride);
137 Value swizzleBit = LLVM::ConstantOp::create(
139 stride = LLVM::OrOp::create(rewriter, loc, cacheStrideZext, swizzleBit,
142 stride = LLVM::ConstantOp::create(rewriter, loc, i16,
160 uint32_t flags = (7 << 12) | (4 << 15);
161 if (chipset.majorVersion >= 10) {
163 uint32_t oob = boundsCheck ? 3 : 2;
164 flags |= (oob << 28);
170 loc, rsrcType, basePointer, stride, numRecords, flagsConst);
175 struct FatRawBufferCastLowering
184 matchAndRewrite(FatRawBufferCastOp op, FatRawBufferCastOpAdaptor adaptor,
187 Value memRef = adaptor.getSource();
188 Value unconvertedMemref = op.getSource();
189 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.
getType());
193 int64_t elementByteWidth =
196 int64_t unusedOffset = 0;
198 if (
failed(memrefType.getStridesAndOffset(strideVals, unusedOffset)))
199 return op.emitOpError(
"Can't lower non-stride-offset memrefs");
201 Value numRecords = adaptor.getValidBytes();
203 numRecords =
getNumRecords(rewriter, loc, memrefType, descriptor,
204 strideVals, elementByteWidth);
207 adaptor.getResetOffset()
208 ? descriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
210 : descriptor.alignedPtr(rewriter, loc);
212 Value offset = adaptor.getResetOffset()
213 ? LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
215 : descriptor.offset(rewriter, loc);
217 bool hasSizes = memrefType.getRank() > 0;
220 Value sizes = hasSizes
221 ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
225 hasSizes ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
230 rewriter, loc, basePointer, numRecords, adaptor.getBoundsCheck(),
231 chipset, adaptor.getCacheSwizzleStride(), 7);
235 getTypeConverter()->convertType(op.getResult().getType()));
237 result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr, pos);
238 result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr,
240 result = LLVM::InsertValueOp::create(rewriter, loc, result, offset,
243 result = LLVM::InsertValueOp::create(rewriter, loc, result, sizes,
245 result = LLVM::InsertValueOp::create(rewriter, loc, result, strides,
254 template <
typename GpuOp,
typename Intrinsic>
260 static constexpr uint32_t maxVectorOpWidth = 128;
263 matchAndRewrite(GpuOp gpuOp,
typename GpuOp::Adaptor adaptor,
266 Value memref = adaptor.getMemref();
267 Value unconvertedMemref = gpuOp.getMemref();
268 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.
getType());
271 return gpuOp.emitOpError(
"raw buffer ops require GCN or higher");
273 Value storeData = adaptor.getODSOperands(0)[0];
274 if (storeData == memref)
278 wantedDataType = storeData.
getType();
280 wantedDataType = gpuOp.getODSResults(0)[0].getType();
285 Value maybeCmpData = adaptor.getODSOperands(1)[0];
286 if (maybeCmpData != memref)
287 atomicCmpData = maybeCmpData;
290 Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
296 int64_t elementByteWidth =
305 Type llvmBufferValType = llvmWantedDataType;
307 if (
auto floatType = dyn_cast<FloatType>(wantedDataType))
308 llvmBufferValType = this->getTypeConverter()->convertType(
311 if (
auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
312 uint32_t vecLen = dataVector.getNumElements();
315 uint32_t totalBits = elemBits * vecLen;
317 isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2;
318 if (totalBits > maxVectorOpWidth)
319 return gpuOp.emitOpError(
320 "Total width of loads or stores must be no more than " +
321 Twine(maxVectorOpWidth) +
" bits, but we call for " +
323 " bits. This should've been caught in validation");
324 if (!usePackedFp16 && elemBits < 32) {
325 if (totalBits > 32) {
326 if (totalBits % 32 != 0)
327 return gpuOp.emitOpError(
"Load or store of more than 32-bits that "
328 "doesn't fit into words. Can't happen\n");
329 llvmBufferValType = this->typeConverter->convertType(
332 llvmBufferValType = this->typeConverter->convertType(
337 if (
auto vecType = dyn_cast<VectorType>(llvmBufferValType)) {
340 if (vecType.getNumElements() == 1)
341 llvmBufferValType = vecType.getElementType();
346 if (llvmBufferValType != llvmWantedDataType) {
347 Value castForStore = LLVM::BitcastOp::create(
348 rewriter, loc, llvmBufferValType, storeData);
349 args.push_back(castForStore);
351 args.push_back(storeData);
356 if (llvmBufferValType != llvmWantedDataType) {
357 Value castForCmp = LLVM::BitcastOp::create(
358 rewriter, loc, llvmBufferValType, atomicCmpData);
359 args.push_back(castForCmp);
361 args.push_back(atomicCmpData);
368 if (
failed(memrefType.getStridesAndOffset(strides, offset)))
369 return gpuOp.emitOpError(
"Can't lower non-stride-offset memrefs");
373 Value ptr = memrefDescriptor.bufferPtr(
374 rewriter, loc, *this->getTypeConverter(), memrefType);
376 rewriter, loc, memrefType, memrefDescriptor, strides, elementByteWidth);
378 adaptor.getBoundsCheck(), chipset);
379 args.push_back(resource);
383 adaptor.getIndices(), strides);
384 if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset();
385 indexOffset && *indexOffset > 0) {
387 voffset = voffset ? LLVM::AddOp::create(rewriter, loc, voffset,
391 voffset = LLVM::MulOp::create(rewriter, loc, voffset, byteWidthConst);
392 args.push_back(voffset);
395 Value sgprOffset = adaptor.getSgprOffset();
398 sgprOffset = LLVM::MulOp::create(rewriter, loc, sgprOffset, byteWidthConst);
399 args.push_back(sgprOffset);
408 Operation *lowered = Intrinsic::create(rewriter, loc, resultTypes, args,
412 if (llvmBufferValType != llvmWantedDataType) {
413 replacement = LLVM::BitcastOp::create(rewriter, loc, llvmWantedDataType,
435 static FailureOr<unsigned> encodeWaitcnt(
Chipset chipset,
unsigned vmcnt,
436 unsigned expcnt,
unsigned lgkmcnt) {
441 return vmcnt | (expcnt << 4) | (lgkmcnt << 8);
447 unsigned lowBits = vmcnt & 0xF;
448 unsigned highBits = (vmcnt >> 4) << 14;
449 unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
450 return lowBits | highBits | otherCnts;
456 unsigned lowBits = vmcnt & 0xF;
457 unsigned highBits = (vmcnt >> 4) << 14;
458 unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
459 return lowBits | highBits | otherCnts;
465 return (vmcnt << 10) | expcnt | (lgkmcnt << 4);
470 struct MemoryCounterWaitOpLowering
480 matchAndRewrite(MemoryCounterWaitOp op, OpAdaptor adaptor,
484 if (std::optional<int> ds = adaptor.getDs())
485 ROCDL::WaitDscntOp::create(rewriter, loc, *ds);
487 if (std::optional<int> load = adaptor.getLoad())
488 ROCDL::WaitLoadcntOp::create(rewriter, loc, *load);
490 if (std::optional<int> store = adaptor.getStore())
491 ROCDL::WaitStorecntOp::create(rewriter, loc, *store);
493 if (std::optional<int> exp = adaptor.getExp())
494 ROCDL::WaitExpcntOp::create(rewriter, loc, *exp);
500 auto getVal = [](
Attribute attr) ->
unsigned {
502 return cast<IntegerAttr>(attr).getInt();
507 unsigned ds = getVal(adaptor.getDsAttr());
508 unsigned exp = getVal(adaptor.getExpAttr());
510 unsigned vmcnt = 1024;
512 Attribute store = adaptor.getStoreAttr();
514 vmcnt = getVal(load) + getVal(store);
516 vmcnt = getVal(load);
518 vmcnt = getVal(store);
521 FailureOr<unsigned> waitcnt = encodeWaitcnt(chipset, vmcnt, exp, ds);
523 return op.emitOpError(
"unsupported chipset");
537 matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
541 if (requiresInlineAsm) {
543 LLVM::AsmDialect::AD_ATT);
545 ";;;WARNING: BREAKS DEBUG WATCHES\ns_waitcnt lgkmcnt(0)\ns_barrier";
546 const char *constraints =
"";
550 asmStr, constraints,
true,
557 constexpr int32_t ldsOnlyBitsGfx6789 = ~(0x1f << 8);
558 constexpr int32_t ldsOnlyBitsGfx10 = ~(0x3f << 8);
561 constexpr int32_t ldsOnlyBitsGfx11 = ~(0x3f << 4);
565 ldsOnlyBits = ldsOnlyBitsGfx11;
567 ldsOnlyBits = ldsOnlyBitsGfx10;
569 ldsOnlyBits = ldsOnlyBitsGfx6789;
571 return op.emitOpError(
572 "don't know how to lower this for chipset major version")
576 ROCDL::SWaitcntOp::create(rewriter, loc, ldsOnlyBits);
580 ROCDL::WaitDscntOp::create(rewriter, loc, 0);
581 ROCDL::BarrierSignalOp::create(rewriter, loc, -1);
596 matchAndRewrite(SchedBarrierOp op, SchedBarrierOp::Adaptor adaptor,
599 (uint32_t)op.getOpts());
623 bool allowBf16 =
true) {
625 if (
auto vectorType = dyn_cast<VectorType>(inputType)) {
626 if (vectorType.getElementType().isBF16() && !allowBf16)
627 return LLVM::BitcastOp::create(
628 rewriter, loc, vectorType.clone(rewriter.
getI16Type()), input);
629 if (vectorType.getElementType().isInteger(8) &&
630 vectorType.getNumElements() <= 8)
631 return LLVM::BitcastOp::create(
634 if (isa<IntegerType>(vectorType.getElementType()) &&
635 vectorType.getElementTypeBitWidth() <= 8) {
637 vectorType.getNumElements() * vectorType.getElementTypeBitWidth(),
639 return LLVM::BitcastOp::create(
661 if (
auto intType = dyn_cast<IntegerType>(inputType))
662 return LLVM::ZExtOp::create(rewriter, loc, outputType, input);
663 return LLVM::BitcastOp::create(rewriter, loc, outputType, input);
677 bool isUnsigned,
Value llvmInput,
681 auto vectorType = dyn_cast<VectorType>(inputType);
683 operands.push_back(llvmInput);
686 Type elemType = vectorType.getElementType();
689 llvmInput = LLVM::BitcastOp::create(
690 rewriter, loc, vectorType.clone(rewriter.
getI16Type()), llvmInput);
692 operands.push_back(llvmInput);
699 auto mlirInputType = cast<VectorType>(mlirInput.
getType());
700 bool isInputInteger = mlirInputType.getElementType().isInteger();
701 if (isInputInteger) {
703 bool localIsUnsigned = isUnsigned;
705 localIsUnsigned =
true;
707 localIsUnsigned =
false;
710 operands.push_back(sign);
716 Type intrinsicInType = numBits <= 32
719 auto llvmIntrinsicInType = typeConverter->
convertType(intrinsicInType);
721 loc, llvmIntrinsicInType, llvmInput);
726 castInput = LLVM::ZExtOp::create(rewriter, loc, i32, castInput);
727 operands.push_back(castInput);
740 Value output, int32_t subwordOffset,
743 auto vectorType = dyn_cast<VectorType>(inputType);
744 Type elemType = vectorType.getElementType();
746 output = LLVM::BitcastOp::create(
747 rewriter, loc, vectorType.clone(rewriter.
getI16Type()), output);
748 operands.push_back(output);
759 return (chipset ==
kGfx942 && isa<Float8E5M2FNUZType>(type)) ||
760 (
hasOcpFp8(chipset) && isa<Float8E5M2Type>(type));
766 return (chipset ==
kGfx942 && isa<Float8E4M3FNUZType>(type)) ||
767 (
hasOcpFp8(chipset) && isa<Float8E4M3FNType>(type));
775 uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
776 b = mfma.getBlocks();
781 if (mfma.getReducePrecision() && chipset >=
kGfx942) {
782 if (m == 32 && n == 32 && k == 4 && b == 1)
783 return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
784 if (m == 16 && n == 16 && k == 8 && b == 1)
785 return ROCDL::mfma_f32_16x16x8_xf32::getOperationName();
787 if (m == 32 && n == 32 && k == 1 && b == 2)
788 return ROCDL::mfma_f32_32x32x1f32::getOperationName();
789 if (m == 16 && n == 16 && k == 1 && b == 4)
790 return ROCDL::mfma_f32_16x16x1f32::getOperationName();
791 if (m == 4 && n == 4 && k == 1 && b == 16)
792 return ROCDL::mfma_f32_4x4x1f32::getOperationName();
793 if (m == 32 && n == 32 && k == 2 && b == 1)
794 return ROCDL::mfma_f32_32x32x2f32::getOperationName();
795 if (m == 16 && n == 16 && k == 4 && b == 1)
796 return ROCDL::mfma_f32_16x16x4f32::getOperationName();
801 if (m == 32 && n == 32 && k == 16 && b == 1)
802 return ROCDL::mfma_f32_32x32x16_f16::getOperationName();
803 if (m == 16 && n == 16 && k == 32 && b == 1)
804 return ROCDL::mfma_f32_16x16x32_f16::getOperationName();
806 if (m == 32 && n == 32 && k == 4 && b == 2)
807 return ROCDL::mfma_f32_32x32x4f16::getOperationName();
808 if (m == 16 && n == 16 && k == 4 && b == 4)
809 return ROCDL::mfma_f32_16x16x4f16::getOperationName();
810 if (m == 4 && n == 4 && k == 4 && b == 16)
811 return ROCDL::mfma_f32_4x4x4f16::getOperationName();
812 if (m == 32 && n == 32 && k == 8 && b == 1)
813 return ROCDL::mfma_f32_32x32x8f16::getOperationName();
814 if (m == 16 && n == 16 && k == 16 && b == 1)
815 return ROCDL::mfma_f32_16x16x16f16::getOperationName();
820 if (m == 32 && n == 32 && k == 16 && b == 1)
821 return ROCDL::mfma_f32_32x32x16_bf16::getOperationName();
822 if (m == 16 && n == 16 && k == 32 && b == 1)
823 return ROCDL::mfma_f32_16x16x32_bf16::getOperationName();
826 if (m == 32 && n == 32 && k == 4 && b == 2)
827 return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
828 if (m == 16 && n == 16 && k == 4 && b == 4)
829 return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
830 if (m == 4 && n == 4 && k == 4 && b == 16)
831 return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
832 if (m == 32 && n == 32 && k == 8 && b == 1)
833 return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
834 if (m == 16 && n == 16 && k == 16 && b == 1)
835 return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
837 if (m == 32 && n == 32 && k == 2 && b == 2)
838 return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
839 if (m == 16 && n == 16 && k == 2 && b == 4)
840 return ROCDL::mfma_f32_16x16x2bf16::getOperationName();
841 if (m == 4 && n == 4 && k == 2 && b == 16)
842 return ROCDL::mfma_f32_4x4x2bf16::getOperationName();
843 if (m == 32 && n == 32 && k == 4 && b == 1)
844 return ROCDL::mfma_f32_32x32x4bf16::getOperationName();
845 if (m == 16 && n == 16 && k == 8 && b == 1)
846 return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
851 if (m == 32 && n == 32 && k == 32 && b == 1)
852 return ROCDL::mfma_i32_32x32x32_i8::getOperationName();
853 if (m == 16 && n == 16 && k == 64 && b == 1)
854 return ROCDL::mfma_i32_16x16x64_i8::getOperationName();
856 if (m == 32 && n == 32 && k == 4 && b == 2)
857 return ROCDL::mfma_i32_32x32x4i8::getOperationName();
858 if (m == 16 && n == 16 && k == 4 && b == 4)
859 return ROCDL::mfma_i32_16x16x4i8::getOperationName();
860 if (m == 4 && n == 4 && k == 4 && b == 16)
861 return ROCDL::mfma_i32_4x4x4i8::getOperationName();
862 if (m == 32 && n == 32 && k == 8 && b == 1)
863 return ROCDL::mfma_i32_32x32x8i8::getOperationName();
864 if (m == 16 && n == 16 && k == 16 && b == 1)
865 return ROCDL::mfma_i32_16x16x16i8::getOperationName();
866 if (m == 32 && n == 32 && k == 16 && b == 1 && chipset >=
kGfx942)
867 return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
868 if (m == 16 && n == 16 && k == 32 && b == 1 && chipset >=
kGfx942)
869 return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
873 if (m == 16 && n == 16 && k == 4 && b == 1)
874 return ROCDL::mfma_f64_16x16x4f64::getOperationName();
875 if (m == 4 && n == 4 && k == 4 && b == 4)
876 return ROCDL::mfma_f64_4x4x4f64::getOperationName();
883 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
884 if (m == 16 && n == 16 && k == 32 && b == 1) {
886 return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
888 return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
890 if (m == 32 && n == 32 && k == 16 && b == 1) {
892 return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
894 return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
900 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
901 if (m == 16 && n == 16 && k == 32 && b == 1) {
903 return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
905 return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
907 if (m == 32 && n == 32 && k == 16 && b == 1) {
909 return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
911 return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
920 .Case([](Float8E4M3FNType) {
return 0u; })
921 .Case([](Float8E5M2Type) {
return 1u; })
922 .Case([](Float6E2M3FNType) {
return 2u; })
923 .Case([](Float6E3M2FNType) {
return 3u; })
924 .Case([](Float4E2M1FNType) {
return 4u; })
925 .Default([](
Type) {
return std::nullopt; });
935 static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
937 uint32_t n, uint32_t k, uint32_t b,
Chipset chipset) {
944 if (!isa<Float32Type>(destType))
949 if (!aTypeCode || !bTypeCode)
952 if (m == 32 && n == 32 && k == 64 && b == 1)
953 return std::tuple{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(),
954 *aTypeCode, *bTypeCode};
955 if (m == 16 && n == 16 && k == 128 && b == 1)
957 ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), *aTypeCode,
963 static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
966 mfma.getSourceA().getType(), mfma.getSourceB().getType(),
967 mfma.getDestC().getType(), mfma.getM(), mfma.getN(), mfma.getK(),
968 mfma.getBlocks(), chipset);
971 static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
974 smfma.getSourceB().getType(),
975 smfma.getDestC().getType(), smfma.getM(),
976 smfma.getN(), smfma.getK(), 1u, chipset);
984 auto sourceVectorType = dyn_cast<VectorType>(wmma.getSourceA().getType());
985 auto sourceBVectorType = dyn_cast<VectorType>(wmma.getSourceB().getType());
986 auto destVectorType = dyn_cast<VectorType>(wmma.getDestC().getType());
987 auto elemSourceType = sourceVectorType.getElementType();
988 auto elemBSourceType = sourceBVectorType.getElementType();
989 auto elemDestType = destVectorType.getElementType();
991 if (elemSourceType.isF16() && elemDestType.isF32())
992 return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
993 if (elemSourceType.isBF16() && elemDestType.isF32())
994 return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
995 if (elemSourceType.isF16() && elemDestType.isF16())
996 return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
997 if (elemSourceType.isBF16() && elemDestType.isBF16())
998 return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
999 if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
1000 return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
1002 if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
1003 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
1006 if (isa<Float8E4M3FNType>(elemSourceType) &&
1007 isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
1008 return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
1009 if (isa<Float8E4M3FNType>(elemSourceType) &&
1010 isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
1011 return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName();
1012 if (isa<Float8E5M2Type>(elemSourceType) &&
1013 isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
1014 return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName();
1015 if (isa<Float8E5M2Type>(elemSourceType) &&
1016 isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
1017 return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
1018 if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) {
1019 bool isWave64 = destVectorType.getNumElements() == 4;
1022 bool has8Inputs = sourceVectorType.getNumElements() == 8;
1023 if ((isWave64 && has8Inputs) || (!isWave64 && !has8Inputs))
1024 return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
1025 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
1028 return std::nullopt;
1039 matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor,
1042 Type outType = typeConverter->convertType(op.getDestD().getType());
1043 Type intrinsicOutType = outType;
1044 if (
auto outVecType = dyn_cast<VectorType>(outType))
1045 if (outVecType.getElementType().isBF16())
1046 intrinsicOutType = outVecType.clone(rewriter.
getI16Type());
1049 return op->emitOpError(
"MFMA only supported on gfx908+");
1050 uint32_t getBlgpField =
static_cast<uint32_t
>(op.getBlgp());
1051 if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
1053 return op.emitOpError(
"negation unsupported on older than gfx942");
1055 op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
1058 std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1060 if (!maybeIntrinsic.has_value() && !maybeScaledIntrinsic.has_value())
1061 return op.emitOpError(
"no intrinsic matching MFMA size on given chipset");
1064 !maybeIntrinsic.has_value() && maybeScaledIntrinsic.has_value();
1066 (adaptor.getAbid() > 0 || getBlgpField > 0 || op.getCbsz() > 0)) {
1067 return op.emitOpError(
1068 "non-default abid, blgp, and cbsz aren't supported on MFMAs that can "
1069 "be scaled as those fields are used for type information");
1072 StringRef intrinsicName =
1073 isScaled ? std::get<0>(*maybeScaledIntrinsic) : *maybeIntrinsic;
1076 bool allowBf16 = [&]() {
1081 return intrinsicName.contains(
"16x16x32.bf16") ||
1082 intrinsicName.contains(
"32x32x16.bf16");
1085 loweredOp.addTypes(intrinsicOutType);
1087 rewriter, loc, adaptor.getSourceA(), allowBf16),
1089 rewriter, loc, adaptor.getSourceB(), allowBf16),
1090 adaptor.getDestC()});
1093 auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1104 if (outType != intrinsicOutType)
1105 lowered = LLVM::BitcastOp::create(rewriter, loc, outType, lowered);
1118 matchAndRewrite(ScaledMFMAOp op, ScaledMFMAOpAdaptor adaptor,
1121 Type intrinsicOutType = typeConverter->convertType(op.getDestD().getType());
1124 return op->emitOpError(
"scaled MFMA only supported on gfx908+");
1125 std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1127 if (!maybeScaledIntrinsic.has_value())
1128 return op.emitOpError(
1129 "no intrinsic matching scaled MFMA size on given chipset");
1131 auto [intrinsicName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1133 loweredOp.addTypes(intrinsicOutType);
1134 loweredOp.addOperands(
1137 adaptor.getDestC()});
1142 loweredOp.addOperands(
1164 matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
1168 typeConverter->convertType<VectorType>(op.getDestD().getType());
1173 return op->emitOpError(
"WMMA only supported on gfx11 and gfx12");
1177 VectorType rawOutType = outType;
1178 if (outType.getElementType().
isBF16())
1179 rawOutType = outType.clone(rewriter.
getI16Type());
1183 if (!maybeIntrinsic.has_value())
1184 return op.emitOpError(
"no intrinsic matching WMMA on the given chipset");
1186 if (chipset.
majorVersion >= 12 && op.getSubwordOffset() != 0)
1187 return op.emitOpError(
"subwordOffset not supported on gfx12+");
1190 loweredOp.addTypes(rawOutType);
1194 adaptor.getSourceA(), op.getSourceA(), operands);
1196 adaptor.getSourceB(), op.getSourceB(), operands);
1198 op.getSubwordOffset(), op.getClamp(), operands);
1200 loweredOp.addOperands(operands);
1204 if (rawOutType != outType)
1205 maybeCastBack = LLVM::BitcastOp::create(rewriter, loc, outType,
1213 struct TransposeLoadOpLowering
1221 matchAndRewrite(TransposeLoadOp op, TransposeLoadOpAdaptor adaptor,
1224 return op.emitOpError(
"Non-gfx950 chipset not supported");
1227 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
1231 size_t srcElementSize =
1232 srcMemRefType.getElementType().getIntOrFloatBitWidth();
1233 if (srcElementSize < 8)
1234 return op.emitOpError(
"Expect source memref to have at least 8 bits "
1235 "element size, got ")
1238 auto resultType = cast<VectorType>(op.getResult().getType());
1241 (adaptor.getSrcIndices()));
1243 size_t numElements = resultType.getNumElements();
1244 size_t elementTypeSize =
1245 resultType.getElementType().getIntOrFloatBitWidth();
1251 Type llvmResultType = typeConverter->convertType(resultType);
1253 switch (elementTypeSize) {
1255 assert(numElements == 16);
1256 auto rocdlOp = ROCDL::ds_read_tr4_b64::create(rewriter, loc,
1257 rocdlResultType, srcPtr);
1262 assert(numElements == 16);
1263 auto rocdlOp = ROCDL::ds_read_tr6_b96::create(rewriter, loc,
1264 rocdlResultType, srcPtr);
1269 assert(numElements == 8);
1270 auto rocdlOp = ROCDL::ds_read_tr8_b64::create(rewriter, loc,
1271 rocdlResultType, srcPtr);
1276 assert(numElements == 4);
1282 return op.emitOpError(
"Unsupported element size for transpose load");
1295 matchAndRewrite(GatherToLDSOp op, GatherToLDSOpAdaptor adaptor,
1298 return op.emitOpError(
"pre-gfx9 and post-gfx10 not supported");
1302 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
1303 auto dstMemRefType = cast<MemRefType>(op.getDst().getType());
1308 Type transferType = op.getTransferType();
1309 int loadWidth = [&]() ->
int {
1310 if (
auto transferVectorType = dyn_cast<VectorType>(transferType)) {
1311 return (transferVectorType.getNumElements() *
1312 transferVectorType.getElementTypeBitWidth()) /
1319 if (!llvm::is_contained({1, 2, 4, 12, 16}, loadWidth))
1320 return op.emitOpError(
"chipset unsupported element size");
1322 if (chipset !=
kGfx950 && llvm::is_contained({12, 16}, loadWidth))
1323 return op.emitOpError(
"Gather to LDS instructions with 12-byte and "
1324 "16-byte load widths are only supported on gfx950");
1328 (adaptor.getSrcIndices()));
1331 (adaptor.getDstIndices()));
1344 struct ExtPackedFp8OpLowering final
1352 matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
1356 struct PackedTrunc2xFp8OpLowering final
1365 matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
1369 struct PackedStochRoundFp8OpLowering final
1378 matchAndRewrite(PackedStochRoundFp8Op op,
1379 PackedStochRoundFp8OpAdaptor adaptor,
1383 struct ScaledExtPackedOpLowering final
1391 matchAndRewrite(ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
1395 struct PackedScaledTruncOpLowering final
1404 matchAndRewrite(PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
1410 LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
1411 ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
1416 loc,
"Fp8 conversion instructions are not available on target "
1417 "architecture and their emulation is not implemented");
1420 Type i32 = getTypeConverter()->convertType(rewriter.
getI32Type());
1421 Type f32 = getTypeConverter()->convertType(op.getResult().getType());
1423 Value source = adaptor.getSource();
1424 auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
1425 auto resultVecType = dyn_cast<VectorType>(op.getResult().getType());
1428 if (!sourceVecType || sourceVecType.getNumElements() < 4) {
1429 Value longVec = LLVM::UndefOp::create(rewriter, loc, v4i8);
1430 if (!sourceVecType) {
1431 longVec = LLVM::InsertElementOp::create(
1434 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
1436 Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
1438 LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
1443 Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
1444 if (resultVecType) {
1464 LogicalResult ScaledExtPackedOpLowering::matchAndRewrite(
1465 ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
1470 loc,
"Scaled fp conversion instructions are not available on target "
1471 "architecture and their emulation is not implemented");
1472 Type i32 = getTypeConverter()->convertType(rewriter.
getI32Type());
1474 Value source = adaptor.getSource();
1475 Value scale = adaptor.getScale();
1477 VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
1478 Type sourceElemType = sourceVecType.getElementType();
1479 VectorType destVecType = cast<VectorType>(op.getResult().getType());
1480 Type destElemType = destVecType.getElementType();
1482 VectorType packedVecType;
1483 if (isa<Float8E5M2Type, Float8E4M3FNType>(sourceElemType)) {
1485 packedVecType = cast<VectorType>(getTypeConverter()->convertType(v4i8));
1486 }
else if (isa<Float4E2M1FNType>(sourceElemType)) {
1488 packedVecType = cast<VectorType>(getTypeConverter()->convertType(v8i4));
1490 llvm_unreachable(
"invalid element type for scaled ext");
1494 if (sourceVecType.getNumElements() < packedVecType.getNumElements()) {
1495 Value longVec = LLVM::ZeroOp::create(rewriter, loc, packedVecType);
1496 if (!sourceVecType) {
1497 longVec = LLVM::InsertElementOp::create(
1500 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
1502 Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
1504 LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
1509 Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
1511 if (isa<Float8E5M2Type>(sourceElemType) && destElemType.
isF32())
1513 op, destVecType, i32Source, scale, op.getIndex());
1514 else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.
isF16())
1516 op, destVecType, i32Source, scale, op.getIndex());
1517 else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.
isBF16())
1519 op, destVecType, i32Source, scale, op.getIndex());
1520 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.
isF32())
1522 op, destVecType, i32Source, scale, op.getIndex());
1523 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.
isF16())
1525 op, destVecType, i32Source, scale, op.getIndex());
1526 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.
isBF16())
1528 op, destVecType, i32Source, scale, op.getIndex());
1529 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.
isF32())
1531 op, destVecType, i32Source, scale, op.getIndex());
1532 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.
isF16())
1534 op, destVecType, i32Source, scale, op.getIndex());
1535 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.
isBF16())
1537 op, destVecType, i32Source, scale, op.getIndex());
1544 LogicalResult PackedScaledTruncOpLowering::matchAndRewrite(
1545 PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
1550 loc,
"Scaled fp conversion instructions are not available on target "
1551 "architecture and their emulation is not implemented");
1552 Type v2i16 = getTypeConverter()->convertType(
1554 Type i32 = getTypeConverter()->convertType(rewriter.
getI32Type());
1556 Type resultType = op.getResult().getType();
1558 VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
1559 Type sourceElemType = sourceVecType.getElementType();
1561 Type intResultType = isa<Float4E2M1FNType>(resultElemType) ? i32 : v2i16;
1563 Value source = adaptor.getSource();
1564 Value scale = adaptor.getScale();
1565 Value existing = adaptor.getExisting();
1567 existing = LLVM::BitcastOp::create(rewriter, loc, intResultType, existing);
1569 existing = LLVM::ZeroOp::create(rewriter, loc, intResultType);
1571 if (sourceVecType.getNumElements() < 2) {
1573 Value elem0 = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
1575 source = LLVM::ZeroOp::create(rewriter, loc, v2);
1576 source = LLVM::InsertElementOp::create(rewriter, loc, source, elem0, c0);
1579 Value sourceA, sourceB;
1580 if (sourceElemType.
isF32()) {
1583 sourceA = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
1584 sourceB = LLVM::ExtractElementOp::create(rewriter, loc, source, c1);
1588 if (sourceElemType.
isF32() && isa<Float8E5M2Type>(resultElemType))
1589 result = ROCDL::CvtScaleF32PkBf8F32Op::create(rewriter, loc, intResultType,
1590 existing, sourceA, sourceB,
1591 scale, op.getIndex());
1592 else if (sourceElemType.
isF16() && isa<Float8E5M2Type>(resultElemType))
1593 result = ROCDL::CvtScaleF32PkBf8F16Op::create(
1594 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
1595 else if (sourceElemType.
isBF16() && isa<Float8E5M2Type>(resultElemType))
1596 result = ROCDL::CvtScaleF32PkBf8Bf16Op::create(
1597 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
1598 else if (sourceElemType.
isF32() && isa<Float8E4M3FNType>(resultElemType))
1599 result = ROCDL::CvtScaleF32PkFp8F32Op::create(rewriter, loc, intResultType,
1600 existing, sourceA, sourceB,
1601 scale, op.getIndex());
1602 else if (sourceElemType.
isF16() && isa<Float8E4M3FNType>(resultElemType))
1603 result = ROCDL::CvtScaleF32PkFp8F16Op::create(
1604 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
1605 else if (sourceElemType.
isBF16() && isa<Float8E4M3FNType>(resultElemType))
1606 result = ROCDL::CvtScaleF32PkFp8Bf16Op::create(
1607 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
1608 else if (sourceElemType.
isF32() && isa<Float4E2M1FNType>(resultElemType))
1609 result = ROCDL::CvtScaleF32PkFp4F32Op::create(rewriter, loc, intResultType,
1610 existing, sourceA, sourceB,
1611 scale, op.getIndex());
1612 else if (sourceElemType.
isF16() && isa<Float4E2M1FNType>(resultElemType))
1613 result = ROCDL::CvtScaleF32PkFp4F16Op::create(
1614 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
1615 else if (sourceElemType.
isBF16() && isa<Float4E2M1FNType>(resultElemType))
1616 result = ROCDL::CvtScaleF32PkFp4Bf16Op::create(
1617 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
1622 op, getTypeConverter()->convertType(resultType), result);
1626 LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
1627 PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
1632 loc,
"Fp8 conversion instructions are not available on target "
1633 "architecture and their emulation is not implemented");
1634 Type i32 = getTypeConverter()->convertType(rewriter.
getI32Type());
1636 Type resultType = op.getResult().getType();
1639 Value sourceA = adaptor.getSourceA();
1640 Value sourceB = adaptor.getSourceB();
1642 sourceB = LLVM::UndefOp::create(rewriter, loc, sourceA.
getType());
1643 Value existing = adaptor.getExisting();
1645 existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
1647 existing = LLVM::UndefOp::create(rewriter, loc, i32);
1651 result = ROCDL::CvtPkBf8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
1652 existing, op.getWordIndex());
1654 result = ROCDL::CvtPkFp8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
1655 existing, op.getWordIndex());
1658 op, getTypeConverter()->convertType(resultType), result);
1662 LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
1663 PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
1668 loc,
"Fp8 conversion instructions are not available on target "
1669 "architecture and their emulation is not implemented");
1670 Type i32 = getTypeConverter()->convertType(rewriter.
getI32Type());
1672 Type resultType = op.getResult().getType();
1675 Value source = adaptor.getSource();
1676 Value stoch = adaptor.getStochiasticParam();
1677 Value existing = adaptor.getExisting();
1679 existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
1681 existing = LLVM::UndefOp::create(rewriter, loc, i32);
1685 result = ROCDL::CvtSrBf8F32Op::create(rewriter, loc, i32, source, stoch,
1686 existing, op.getStoreIndex());
1688 result = ROCDL::CvtSrFp8F32Op::create(rewriter, loc, i32, source, stoch,
1689 existing, op.getStoreIndex());
1692 op, getTypeConverter()->convertType(resultType), result);
1704 matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor,
1709 Value src = adaptor.getSrc();
1710 Value old = adaptor.getOld();
1713 Type llvmType =
nullptr;
1716 }
else if (isa<FloatType>(srcType)) {
1720 }
else if (isa<IntegerType>(srcType)) {
1725 auto llvmSrcIntType = typeConverter->convertType(
1729 auto convertOperand = [&](
Value operand,
Type operandType) {
1730 if (operandType.getIntOrFloatBitWidth() <= 16) {
1731 if (llvm::isa<FloatType>(operandType)) {
1733 LLVM::BitcastOp::create(rewriter, loc, llvmSrcIntType, operand);
1736 32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
1737 Value undefVec = LLVM::UndefOp::create(rewriter, loc, llvmVecType);
1739 LLVM::InsertElementOp::create(rewriter, loc, undefVec, operand,
1741 operand = LLVM::BitcastOp::create(rewriter, loc, llvmType, operand);
1746 src = convertOperand(src, srcType);
1747 old = convertOperand(old, oldType);
1750 enum DppCtrl :
unsigned {
1759 ROW_HALF_MIRROR = 0x141,
1764 auto kind = DppOp.getKind();
1765 auto permArgument = DppOp.getPermArgument();
1766 uint32_t DppCtrl = 0;
1770 case DPPPerm::quad_perm:
1771 if (
auto quadPermAttr = cast<ArrayAttr>(*permArgument)) {
1773 for (
auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
1774 uint32_t num = elem.getInt();
1775 DppCtrl |= num << (i * 2);
1780 case DPPPerm::row_shl:
1781 if (
auto intAttr = cast<IntegerAttr>(*permArgument)) {
1782 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0;
1785 case DPPPerm::row_shr:
1786 if (
auto intAttr = cast<IntegerAttr>(*permArgument)) {
1787 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0;
1790 case DPPPerm::row_ror:
1791 if (
auto intAttr = cast<IntegerAttr>(*permArgument)) {
1792 DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0;
1795 case DPPPerm::wave_shl:
1796 DppCtrl = DppCtrl::WAVE_SHL1;
1798 case DPPPerm::wave_shr:
1799 DppCtrl = DppCtrl::WAVE_SHR1;
1801 case DPPPerm::wave_rol:
1802 DppCtrl = DppCtrl::WAVE_ROL1;
1804 case DPPPerm::wave_ror:
1805 DppCtrl = DppCtrl::WAVE_ROR1;
1807 case DPPPerm::row_mirror:
1808 DppCtrl = DppCtrl::ROW_MIRROR;
1810 case DPPPerm::row_half_mirror:
1811 DppCtrl = DppCtrl::ROW_HALF_MIRROR;
1813 case DPPPerm::row_bcast_15:
1814 DppCtrl = DppCtrl::BCAST15;
1816 case DPPPerm::row_bcast_31:
1817 DppCtrl = DppCtrl::BCAST31;
1823 auto rowMask = DppOp->getAttrOfType<IntegerAttr>(
"row_mask").getInt();
1824 auto bankMask = DppOp->getAttrOfType<IntegerAttr>(
"bank_mask").getInt();
1825 bool boundCtrl = DppOp->getAttrOfType<
BoolAttr>(
"bound_ctrl").getValue();
1829 ROCDL::DPPUpdateOp::create(rewriter, loc, llvmType, old, src, DppCtrl,
1830 rowMask, bankMask, boundCtrl);
1832 Value result = dppMovOp.getRes();
1834 result = LLVM::TruncOp::create(rewriter, loc, llvmSrcIntType, result);
1835 if (!llvm::isa<IntegerType>(srcType)) {
1836 result = LLVM::BitcastOp::create(rewriter, loc, srcType, result);
1847 struct AMDGPUSwizzleBitModeLowering
1852 matchAndRewrite(SwizzleBitModeOp op, OpAdaptor adaptor,
1856 Value src = adaptor.getSrc();
1859 unsigned andMask = op.getAndMask();
1860 unsigned orMask = op.getOrMask();
1861 unsigned xorMask = op.getXorMask();
1865 unsigned mask = andMask | (orMask << 5) | (xorMask << 10);
1868 for (
Value v : decomposed) {
1870 ROCDL::DsSwizzleOp::create(rewriter, loc, v.getType(), v, maskValue);
1871 swizzled.emplace_back(res);
1888 matchAndRewrite(PermlaneSwapOp op, OpAdaptor adaptor,
1891 return op->emitOpError(
"permlane_swap is only supported on gfx950+");
1895 Value src = adaptor.getSrc();
1896 unsigned rowLength = op.getRowLength();
1897 bool fi = op.getFetchInactive();
1898 bool boundctrl = op.getBoundCtrl();
1904 for (
Value v : decomposed) {
1906 Type i32pair = LLVM::LLVMStructType::getLiteral(
1907 rewriter.
getContext(), {v.getType(), v.getType()});
1909 if (rowLength == 16)
1910 res = ROCDL::Permlane16SwapOp::create(rewriter, loc, i32pair, v, v, fi,
1912 else if (rowLength == 32)
1913 res = ROCDL::Permlane32SwapOp::create(rewriter, loc, i32pair, v, v, fi,
1916 llvm_unreachable(
"unsupported row length");
1918 Value vdstNew = LLVM::ExtractValueOp::create(rewriter, loc, res, {0});
1919 permuted.emplace_back(vdstNew);
1928 struct ConvertAMDGPUToROCDLPass
1929 :
public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
1932 void runOnOperation()
override {
1935 if (
failed(maybeChipset)) {
1937 return signalPassFailure();
1944 target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
1945 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
1946 target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
1949 signalPassFailure();
1961 switch (as.getValue()) {
1962 case amdgpu::AddressSpace::FatRawBuffer:
1964 case amdgpu::AddressSpace::BufferRsrc:
1966 case amdgpu::AddressSpace::FatStructuredBuffer:
1978 .add<FatRawBufferCastLowering,
1979 RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
1980 RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
1981 RawBufferOpLowering<RawBufferAtomicFaddOp,
1982 ROCDL::RawPtrBufferAtomicFaddOp>,
1983 RawBufferOpLowering<RawBufferAtomicFmaxOp,
1984 ROCDL::RawPtrBufferAtomicFmaxOp>,
1985 RawBufferOpLowering<RawBufferAtomicSmaxOp,
1986 ROCDL::RawPtrBufferAtomicSmaxOp>,
1987 RawBufferOpLowering<RawBufferAtomicUminOp,
1988 ROCDL::RawPtrBufferAtomicUminOp>,
1989 RawBufferOpLowering<RawBufferAtomicCmpswapOp,
1990 ROCDL::RawPtrBufferAtomicCmpSwap>,
1991 AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
1992 SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
1993 WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
1994 PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
1995 PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
1996 TransposeLoadOpLowering, AMDGPUPermlaneLowering>(converter, chipset);
1997 patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type)
Return true if type is the E4M3FN variant of an 8-bit float that is supported by the _fp8 instruction...
constexpr Chipset kGfx942
static std::optional< StringRef > wmmaOpToIntrinsic(WMMAOp wmma, Chipset chipset)
Return the rocdl intrinsic corresponding to a WMMA operation wmma if one exists.
static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input, bool allowBf16=true)
Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL and LLVM AMDGPU intrinsic...
static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc, bool value)
constexpr Chipset kGfx908
constexpr Chipset kGfx90a
static std::optional< StringRef > mfmaOpToIntrinsic(MFMAOp mfma, Chipset chipset)
Return the rocdl intrinsic corresponding to a MFMA operation mfma if one exists.
static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, Value output, int32_t subwordOffset, bool clamp, SmallVector< Value, 4 > &operands)
Push the output operand.
static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type)
Return true if type is the E5M2 variant of an 8-bit float that is supported by the _bf8 instructions ...
static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc, Value basePointer, Value numRecords, bool boundsCheck, amdgpu::Chipset chipset, Value cacheSwizzleStride=nullptr, unsigned addressSpace=8)
static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput, Value mlirInput, SmallVector< Value, 4 > &operands)
Push an input operand.
static Value getLinearIndexI32(ConversionPatternRewriter &rewriter, Location loc, MemRefDescriptor &memRefDescriptor, ValueRange indices, ArrayRef< int64_t > strides)
Returns the linear index used to access an element in the memref.
static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i32.
static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter, Location loc, Value input)
Converts the scaled MFMA operands, scalesA and scalesB, from MLIR AMDGPU dialect convention to ROCDL ...
static Value createI32Constant(ConversionPatternRewriter &rewriter, Location loc, int32_t value)
static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc, MemRefType memrefType, MemRefDescriptor &memrefDescriptor, ArrayRef< int64_t > strides, uint32_t elementByteWidth)
Compute the contents of the num_records field for a given memref descriptor - that is,...
static std::optional< uint32_t > mfmaTypeSelectCode(Type mlirElemType)
static std::optional< std::tuple< StringRef, uint32_t, uint32_t > > mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m, uint32_t n, uint32_t k, uint32_t b, Chipset chipset)
If there is a scaled MFMA instruction for the input element types aType and bType,...
constexpr Chipset kGfx950
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1244::ArityGroupAndKind::Kind kind
static constexpr unsigned kSizePosInMemRefDescriptor
static constexpr unsigned kStridePosInMemRefDescriptor
static constexpr unsigned kOffsetPosInMemRefDescriptor
static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor
static constexpr unsigned kAlignedPtrPosInMemRefDescriptor
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
Attributes are known-constant values of operations.
This class provides a shared interface for ranked and unranked memref types.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getI16IntegerAttr(int16_t value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
The main mechanism for performing data layout queries.
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
static MemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating a poison value of the descriptor type.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
The general result of a type attribute conversion callback, allowing for early termination.
static AttributeConversionResult abort()
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
void addTypeAttributeConversion(FnT &&callback)
Register a conversion function for attributes within types.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Value composeValue(OpBuilder &builder, Location loc, ValueRange src, Type dstType)
Composes a set of src values into a single value of type dstType through series of bitcasts and vecto...
SmallVector< Value > decomposeValue(OpBuilder &builder, Location loc, Value src, Type dstType)
Decomposes a src value into a set of values of type dstType through series of bitcasts and vector ops...
bool hasOcpFp8(const Chipset &chipset)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
Include the generated interface declarations.
void populateAMDGPUMemorySpaceAttributeConversions(TypeConverter &typeConverter)
Remap AMDGPU memory spaces to LLVM address spaces by mapping amdgpu::AddressSpace::fat_raw_buffer to ...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, amdgpu::Chipset chipset)
Note: This function will also add conversions for the AMDGPU-specific address spaces,...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.