30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/TypeSwitch.h"
32#include "llvm/Support/AMDGPUAddrSpace.h"
33#include "llvm/Support/Casting.h"
34#include "llvm/Support/ErrorHandling.h"
39#define GEN_PASS_DEF_CONVERTAMDGPUTOROCDLPASS
40#include "mlir/Conversion/Passes.h.inc"
57 return chipset >=
Chipset(9, 0, 6);
99 if (chipset ==
Chipset(9, 5, 0))
109 IntegerType i32 = rewriter.getI32Type();
111 auto valTy = cast<IntegerType>(val.
getType());
114 return valTy.getWidth() > 32
115 ?
Value(LLVM::TruncOp::create(rewriter, loc, i32, val))
116 :
Value(LLVM::ZExtOp::create(rewriter, loc, i32, val));
121 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), value);
127 IntegerType i64 = rewriter.getI64Type();
129 auto valTy = cast<IntegerType>(val.
getType());
132 return valTy.getWidth() > 64
133 ?
Value(LLVM::TruncOp::create(rewriter, loc, i64, val))
134 :
Value(LLVM::ZExtOp::create(rewriter, loc, i64, val));
139 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), value);
146 IntegerType i32 = rewriter.getI32Type();
148 for (
auto [i, increment, stride] : llvm::enumerate(
indices, strides)) {
151 ShapedType::isDynamic(stride)
153 memRefDescriptor.
stride(rewriter, loc, i))
154 : LLVM::ConstantOp::create(rewriter, loc, i32, stride);
155 increment = LLVM::MulOp::create(rewriter, loc, increment, strideValue);
167 MemRefType memrefType,
171 if (chipset >=
kGfx1250 && !boundsCheck) {
172 constexpr int64_t first45bits = (1ll << 45) - 1;
175 if (memrefType.hasStaticShape() &&
176 !llvm::any_of(strides, ShapedType::isDynamic)) {
177 int64_t size = memrefType.getRank() == 0 ? 1 : 0;
179 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
180 size = std::max(
shape[i] * strides[i], size);
181 size = size * elementByteWidth;
185 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
186 Value size = memrefDescriptor.
size(rewriter, loc, i);
187 Value stride = memrefDescriptor.
stride(rewriter, loc, i);
188 Value maxThisDim = LLVM::MulOp::create(rewriter, loc, size, stride);
190 ? LLVM::UMaxOp::create(rewriter, loc, maxIndex, maxThisDim)
195 return LLVM::MulOp::create(rewriter, loc, maxIndexI64, byteWidthConst);
201 Value cacheSwizzleStride =
nullptr,
202 unsigned addressSpace = 8) {
206 Type i16 = rewriter.getI16Type();
209 Value cacheStrideZext =
210 LLVM::ZExtOp::create(rewriter, loc, i16, cacheSwizzleStride);
211 Value swizzleBit = LLVM::ConstantOp::create(
212 rewriter, loc, i16, rewriter.getI16IntegerAttr(1 << 14));
213 stride = LLVM::OrOp::create(rewriter, loc, cacheStrideZext, swizzleBit,
216 stride = LLVM::ConstantOp::create(rewriter, loc, i16,
217 rewriter.getI16IntegerAttr(0));
246 flags |= (7 << 12) | (4 << 15);
249 uint32_t oob = boundsCheck ? 3 : 2;
250 flags |= (oob << 28);
255 LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
256 Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>(
257 loc, rsrcType, basePointer, stride, numRecords, flagsConst);
262struct FatRawBufferCastLowering
264 FatRawBufferCastLowering(
const LLVMTypeConverter &converter, Chipset chipset)
265 : ConvertOpToLLVMPattern<FatRawBufferCastOp>(converter),
271 matchAndRewrite(FatRawBufferCastOp op, FatRawBufferCastOpAdaptor adaptor,
272 ConversionPatternRewriter &rewriter)
const override {
273 Location loc = op.getLoc();
274 Value memRef = adaptor.getSource();
275 Value unconvertedMemref = op.getSource();
276 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.
getType());
277 MemRefDescriptor descriptor(memRef);
279 DataLayout dataLayout = DataLayout::closest(op);
280 int64_t elementByteWidth =
283 int64_t unusedOffset = 0;
284 SmallVector<int64_t, 5> strideVals;
285 if (
failed(memrefType.getStridesAndOffset(strideVals, unusedOffset)))
286 return op.emitOpError(
"Can't lower non-stride-offset memrefs");
288 Value numRecords = adaptor.getValidBytes();
291 getNumRecords(rewriter, loc, memrefType, descriptor, strideVals,
292 elementByteWidth, chipset, adaptor.getBoundsCheck());
295 adaptor.getResetOffset()
296 ? descriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
298 : descriptor.alignedPtr(rewriter, loc);
300 Value offset = adaptor.getResetOffset()
301 ? LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
302 rewriter.getIndexAttr(0))
303 : descriptor.offset(rewriter, loc);
305 bool hasSizes = memrefType.getRank() > 0;
308 Value sizes = hasSizes
309 ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
313 hasSizes ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
318 rewriter, loc, basePointer, numRecords, adaptor.getBoundsCheck(),
319 chipset, adaptor.getCacheSwizzleStride(), 7);
321 Value
result = MemRefDescriptor::poison(
323 getTypeConverter()->convertType(op.getResult().getType()));
325 result = LLVM::InsertValueOp::create(rewriter, loc,
result, fatPtr, pos);
326 result = LLVM::InsertValueOp::create(rewriter, loc,
result, fatPtr,
328 result = LLVM::InsertValueOp::create(rewriter, loc,
result, offset,
331 result = LLVM::InsertValueOp::create(rewriter, loc,
result, sizes,
333 result = LLVM::InsertValueOp::create(rewriter, loc,
result, strides,
336 rewriter.replaceOp(op,
result);
342template <
typename GpuOp,
typename Intrinsic>
344 RawBufferOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
345 : ConvertOpToLLVMPattern<GpuOp>(converter), chipset(chipset) {}
348 static constexpr uint32_t maxVectorOpWidth = 128;
351 matchAndRewrite(GpuOp gpuOp,
typename GpuOp::Adaptor adaptor,
352 ConversionPatternRewriter &rewriter)
const override {
353 Location loc = gpuOp.getLoc();
354 Value memref = adaptor.getMemref();
355 Value unconvertedMemref = gpuOp.getMemref();
356 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.
getType());
358 if (chipset.majorVersion < 9)
359 return gpuOp.emitOpError(
"raw buffer ops require GCN or higher");
361 Value storeData = adaptor.getODSOperands(0)[0];
362 if (storeData == memref)
366 wantedDataType = storeData.
getType();
368 wantedDataType = gpuOp.getODSResults(0)[0].getType();
370 Value atomicCmpData = Value();
373 Value maybeCmpData = adaptor.getODSOperands(1)[0];
374 if (maybeCmpData != memref)
375 atomicCmpData = maybeCmpData;
378 Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
380 Type i32 = rewriter.getI32Type();
383 DataLayout dataLayout = DataLayout::closest(gpuOp);
384 int64_t elementByteWidth =
393 Type llvmBufferValType = llvmWantedDataType;
395 if (
auto floatType = dyn_cast<FloatType>(wantedDataType))
396 llvmBufferValType = this->getTypeConverter()->convertType(
397 rewriter.getIntegerType(floatType.getWidth()));
399 if (
auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
400 uint32_t vecLen = dataVector.getNumElements();
403 uint32_t totalBits = elemBits * vecLen;
405 isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2;
406 if (totalBits > maxVectorOpWidth)
407 return gpuOp.emitOpError(
408 "Total width of loads or stores must be no more than " +
409 Twine(maxVectorOpWidth) +
" bits, but we call for " +
411 " bits. This should've been caught in validation");
412 if (!usePackedFp16 && elemBits < 32) {
413 if (totalBits > 32) {
414 if (totalBits % 32 != 0)
415 return gpuOp.emitOpError(
"Load or store of more than 32-bits that "
416 "doesn't fit into words. Can't happen\n");
417 llvmBufferValType = this->typeConverter->convertType(
418 VectorType::get(totalBits / 32, i32));
420 llvmBufferValType = this->typeConverter->convertType(
421 rewriter.getIntegerType(totalBits));
425 if (
auto vecType = dyn_cast<VectorType>(llvmBufferValType)) {
428 if (vecType.getNumElements() == 1)
429 llvmBufferValType = vecType.getElementType();
432 SmallVector<Value, 6> args;
434 if (llvmBufferValType != llvmWantedDataType) {
435 Value castForStore = LLVM::BitcastOp::create(
436 rewriter, loc, llvmBufferValType, storeData);
437 args.push_back(castForStore);
439 args.push_back(storeData);
444 if (llvmBufferValType != llvmWantedDataType) {
445 Value castForCmp = LLVM::BitcastOp::create(
446 rewriter, loc, llvmBufferValType, atomicCmpData);
447 args.push_back(castForCmp);
449 args.push_back(atomicCmpData);
455 SmallVector<int64_t, 5> strides;
456 if (
failed(memrefType.getStridesAndOffset(strides, offset)))
457 return gpuOp.emitOpError(
"Can't lower non-stride-offset memrefs");
459 MemRefDescriptor memrefDescriptor(memref);
461 Value ptr = memrefDescriptor.bufferPtr(
462 rewriter, loc, *this->getTypeConverter(), memrefType);
464 getNumRecords(rewriter, loc, memrefType, memrefDescriptor, strides,
465 elementByteWidth, chipset, adaptor.getBoundsCheck());
467 adaptor.getBoundsCheck(), chipset);
468 args.push_back(resource);
472 adaptor.getIndices(), strides);
473 if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset();
474 indexOffset && *indexOffset > 0) {
476 voffset = voffset ? LLVM::AddOp::create(rewriter, loc, voffset,
480 voffset = LLVM::MulOp::create(rewriter, loc, voffset, byteWidthConst);
481 args.push_back(voffset);
484 Value sgprOffset = adaptor.getSgprOffset();
487 sgprOffset = LLVM::MulOp::create(rewriter, loc, sgprOffset, byteWidthConst);
488 args.push_back(sgprOffset);
495 llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(),
497 Operation *lowered = Intrinsic::create(rewriter, loc, resultTypes, args,
498 ArrayRef<NamedAttribute>());
501 if (llvmBufferValType != llvmWantedDataType) {
502 replacement = LLVM::BitcastOp::create(rewriter, loc, llvmWantedDataType,
507 rewriter.eraseOp(gpuOp);
524static FailureOr<unsigned> encodeWaitcnt(
Chipset chipset,
unsigned vmcnt,
525 unsigned expcnt,
unsigned lgkmcnt) {
527 vmcnt = std::min(15u, vmcnt);
528 expcnt = std::min(7u, expcnt);
529 lgkmcnt = std::min(15u, lgkmcnt);
530 return vmcnt | (expcnt << 4) | (lgkmcnt << 8);
533 vmcnt = std::min(63u, vmcnt);
534 expcnt = std::min(7u, expcnt);
535 lgkmcnt = std::min(15u, lgkmcnt);
536 unsigned lowBits = vmcnt & 0xF;
537 unsigned highBits = (vmcnt >> 4) << 14;
538 unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
539 return lowBits | highBits | otherCnts;
542 vmcnt = std::min(63u, vmcnt);
543 expcnt = std::min(7u, expcnt);
544 lgkmcnt = std::min(63u, lgkmcnt);
545 unsigned lowBits = vmcnt & 0xF;
546 unsigned highBits = (vmcnt >> 4) << 14;
547 unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
548 return lowBits | highBits | otherCnts;
551 vmcnt = std::min(63u, vmcnt);
552 expcnt = std::min(7u, expcnt);
553 lgkmcnt = std::min(63u, lgkmcnt);
554 return (vmcnt << 10) | expcnt | (lgkmcnt << 4);
559struct MemoryCounterWaitOpLowering
561 MemoryCounterWaitOpLowering(
const LLVMTypeConverter &converter,
563 : ConvertOpToLLVMPattern<MemoryCounterWaitOp>(converter),
569 matchAndRewrite(MemoryCounterWaitOp op, OpAdaptor adaptor,
570 ConversionPatternRewriter &rewriter)
const override {
571 if (chipset.majorVersion >= 12) {
572 Location loc = op.getLoc();
573 if (std::optional<int> ds = adaptor.getDs())
574 ROCDL::WaitDscntOp::create(rewriter, loc, *ds);
576 if (std::optional<int>
load = adaptor.getLoad())
577 ROCDL::WaitLoadcntOp::create(rewriter, loc, *
load);
579 if (std::optional<int> store = adaptor.getStore())
580 ROCDL::WaitStorecntOp::create(rewriter, loc, *store);
582 if (std::optional<int> exp = adaptor.getExp())
583 ROCDL::WaitExpcntOp::create(rewriter, loc, *exp);
585 if (std::optional<int> tensor = adaptor.getTensor())
586 ROCDL::WaitTensorcntOp::create(rewriter, loc, *tensor);
588 rewriter.eraseOp(op);
592 if (adaptor.getTensor())
593 return op.emitOpError(
"unsupported chipset");
595 auto getVal = [](Attribute attr) ->
unsigned {
597 return cast<IntegerAttr>(attr).getInt();
602 unsigned ds = getVal(adaptor.getDsAttr());
603 unsigned exp = getVal(adaptor.getExpAttr());
605 unsigned vmcnt = 1024;
606 Attribute
load = adaptor.getLoadAttr();
607 Attribute store = adaptor.getStoreAttr();
609 vmcnt = getVal(
load) + getVal(store);
611 vmcnt = getVal(
load);
613 vmcnt = getVal(store);
616 FailureOr<unsigned> waitcnt = encodeWaitcnt(chipset, vmcnt, exp, ds);
618 return op.emitOpError(
"unsupported chipset");
620 rewriter.replaceOpWithNewOp<ROCDL::SWaitcntOp>(op, *waitcnt);
626 LDSBarrierOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
627 : ConvertOpToLLVMPattern<LDSBarrierOp>(converter), chipset(chipset) {}
632 matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
633 ConversionPatternRewriter &rewriter)
const override {
634 Location loc = op.getLoc();
637 bool requiresInlineAsm = chipset <
kGfx90a;
640 rewriter.getAttr<LLVM::MMRATagAttr>(
"amdgpu-synchronize-as",
"local");
649 StringRef scope =
"workgroup";
651 auto relFence = LLVM::FenceOp::create(rewriter, loc,
652 LLVM::AtomicOrdering::release, scope);
653 relFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(), mmra);
654 if (requiresInlineAsm) {
655 auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
656 LLVM::AsmDialect::AD_ATT);
657 const char *asmStr =
";;;WARNING: BREAKS DEBUG WATCHES\ns_barrier";
658 const char *constraints =
"";
659 LLVM::InlineAsmOp::create(
662 asmStr, constraints,
true,
663 false, LLVM::TailCallKind::None,
666 }
else if (chipset.majorVersion < 12) {
667 ROCDL::SBarrierOp::create(rewriter, loc);
669 ROCDL::BarrierSignalOp::create(rewriter, loc, -1);
670 ROCDL::BarrierWaitOp::create(rewriter, loc, -1);
673 auto acqFence = LLVM::FenceOp::create(rewriter, loc,
674 LLVM::AtomicOrdering::acquire, scope);
675 acqFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(), mmra);
676 rewriter.replaceOp(op, acqFence);
682 SchedBarrierOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
683 : ConvertOpToLLVMPattern<SchedBarrierOp>(converter), chipset(chipset) {}
688 matchAndRewrite(SchedBarrierOp op, SchedBarrierOp::Adaptor adaptor,
689 ConversionPatternRewriter &rewriter)
const override {
690 rewriter.replaceOpWithNewOp<ROCDL::SchedBarrier>(op,
691 (uint32_t)op.getOpts());
715 bool allowBf16 =
true) {
717 if (
auto vectorType = dyn_cast<VectorType>(inputType)) {
718 if (vectorType.getElementType().isBF16() && !allowBf16)
719 return LLVM::BitcastOp::create(
720 rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
721 if (vectorType.getElementType().isInteger(8) &&
722 vectorType.getNumElements() <= 8)
723 return LLVM::BitcastOp::create(
725 rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
726 if (isa<IntegerType>(vectorType.getElementType()) &&
727 vectorType.getElementTypeBitWidth() <= 8) {
728 int64_t numWords = llvm::divideCeil(
729 vectorType.getNumElements() * vectorType.getElementTypeBitWidth(),
731 return LLVM::BitcastOp::create(
732 rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()),
742 bool allowBf16 =
true) {
744 auto vectorType = cast<VectorType>(inputType);
746 if (vectorType.getElementType().isBF16() && !allowBf16)
747 return LLVM::BitcastOp::create(
748 rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
750 if (isa<IntegerType>(vectorType.getElementType()) &&
751 vectorType.getElementTypeBitWidth() <= 8) {
752 int64_t numWords = llvm::divideCeil(
753 vectorType.getNumElements() * vectorType.getElementTypeBitWidth(), 32);
754 Type castType = (numWords > 1)
755 ?
Type{VectorType::get(numWords, rewriter.getI32Type())}
756 : rewriter.getI32Type();
757 return LLVM::BitcastOp::create(rewriter, loc, castType, input);
775 .Case([&](IntegerType) {
777 return LLVM::ZExtOp::create(rewriter, loc, rewriter.getI32Type(),
780 .Case([&](VectorType vectorType) {
782 int64_t numElements = vectorType.getNumElements();
783 assert((numElements == 4 || numElements == 8) &&
784 "scale operand must be a vector of length 4 or 8");
785 IntegerType outputType =
786 (numElements == 4) ? rewriter.getI32Type() : rewriter.getI64Type();
787 return LLVM::BitcastOp::create(rewriter, loc, outputType, input);
789 .DefaultUnreachable(
"unexpected input type for scale operand");
795 .Case([](Float8E8M0FNUType) {
return 0; })
796 .Case([](Float8E4M3FNType) {
return 2; })
797 .Default(std::nullopt);
802static std::optional<StringRef>
804 if (m == 16 && n == 16 && k == 128)
806 ? ROCDL::wmma_scale16_f32_16x16x128_f8f6f4::getOperationName()
807 : ROCDL::wmma_scale_f32_16x16x128_f8f6f4::getOperationName();
809 if (m == 32 && n == 16 && k == 128)
810 return isScale16 ? ROCDL::wmma_scale16_f32_32x16x128_f4::getOperationName()
811 : ROCDL::wmma_scale_f32_32x16x128_f4::getOperationName();
825 ConversionPatternRewriter &rewriter,
Location loc,
830 auto vectorType = dyn_cast<VectorType>(inputType);
832 operands.push_back(llvmInput);
835 Type elemType = vectorType.getElementType();
837 operands.push_back(llvmInput);
844 auto mlirInputType = cast<VectorType>(mlirInput.
getType());
845 bool isInputInteger = mlirInputType.getElementType().isInteger();
846 if (isInputInteger) {
848 bool localIsUnsigned = isUnsigned;
850 localIsUnsigned =
true;
852 localIsUnsigned =
false;
855 NamedAttribute(attrName, rewriter.getBoolAttr(!localIsUnsigned)));
860 Type i32 = rewriter.getI32Type();
861 Type intrinsicInType = numBits <= 32
862 ? (
Type)rewriter.getIntegerType(numBits)
863 : (
Type)VectorType::get(numBits / 32, i32);
864 auto llvmIntrinsicInType = typeConverter->convertType(intrinsicInType);
865 Value castInput = rewriter.createOrFold<LLVM::BitcastOp>(
866 loc, llvmIntrinsicInType, llvmInput);
871 castInput = LLVM::ZExtOp::create(rewriter, loc, i32, castInput);
872 operands.push_back(castInput);
885 Value output, int32_t subwordOffset,
889 auto vectorType = dyn_cast<VectorType>(inputType);
890 Type elemType = vectorType.getElementType();
891 operands.push_back(output);
903 return (chipset ==
kGfx942 && isa<Float8E5M2FNUZType>(type)) ||
904 (
hasOcpFp8(chipset) && isa<Float8E5M2Type>(type));
910 return (chipset ==
kGfx942 && isa<Float8E4M3FNUZType>(type)) ||
911 (
hasOcpFp8(chipset) && isa<Float8E4M3FNType>(type));
919 uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
920 b = mfma.getBlocks();
925 if (mfma.getReducePrecision() && chipset >=
kGfx942) {
926 if (m == 32 && n == 32 && k == 4 &&
b == 1)
927 return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
928 if (m == 16 && n == 16 && k == 8 &&
b == 1)
929 return ROCDL::mfma_f32_16x16x8_xf32::getOperationName();
931 if (m == 32 && n == 32 && k == 1 &&
b == 2)
932 return ROCDL::mfma_f32_32x32x1f32::getOperationName();
933 if (m == 16 && n == 16 && k == 1 &&
b == 4)
934 return ROCDL::mfma_f32_16x16x1f32::getOperationName();
935 if (m == 4 && n == 4 && k == 1 &&
b == 16)
936 return ROCDL::mfma_f32_4x4x1f32::getOperationName();
937 if (m == 32 && n == 32 && k == 2 &&
b == 1)
938 return ROCDL::mfma_f32_32x32x2f32::getOperationName();
939 if (m == 16 && n == 16 && k == 4 &&
b == 1)
940 return ROCDL::mfma_f32_16x16x4f32::getOperationName();
945 if (m == 32 && n == 32 && k == 16 &&
b == 1)
946 return ROCDL::mfma_f32_32x32x16_f16::getOperationName();
947 if (m == 16 && n == 16 && k == 32 &&
b == 1)
948 return ROCDL::mfma_f32_16x16x32_f16::getOperationName();
950 if (m == 32 && n == 32 && k == 4 &&
b == 2)
951 return ROCDL::mfma_f32_32x32x4f16::getOperationName();
952 if (m == 16 && n == 16 && k == 4 &&
b == 4)
953 return ROCDL::mfma_f32_16x16x4f16::getOperationName();
954 if (m == 4 && n == 4 && k == 4 &&
b == 16)
955 return ROCDL::mfma_f32_4x4x4f16::getOperationName();
956 if (m == 32 && n == 32 && k == 8 &&
b == 1)
957 return ROCDL::mfma_f32_32x32x8f16::getOperationName();
958 if (m == 16 && n == 16 && k == 16 &&
b == 1)
959 return ROCDL::mfma_f32_16x16x16f16::getOperationName();
964 if (m == 32 && n == 32 && k == 16 &&
b == 1)
965 return ROCDL::mfma_f32_32x32x16_bf16::getOperationName();
966 if (m == 16 && n == 16 && k == 32 &&
b == 1)
967 return ROCDL::mfma_f32_16x16x32_bf16::getOperationName();
970 if (m == 32 && n == 32 && k == 4 &&
b == 2)
971 return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
972 if (m == 16 && n == 16 && k == 4 &&
b == 4)
973 return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
974 if (m == 4 && n == 4 && k == 4 &&
b == 16)
975 return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
976 if (m == 32 && n == 32 && k == 8 &&
b == 1)
977 return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
978 if (m == 16 && n == 16 && k == 16 &&
b == 1)
979 return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
981 if (m == 32 && n == 32 && k == 2 &&
b == 2)
982 return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
983 if (m == 16 && n == 16 && k == 2 &&
b == 4)
984 return ROCDL::mfma_f32_16x16x2bf16::getOperationName();
985 if (m == 4 && n == 4 && k == 2 &&
b == 16)
986 return ROCDL::mfma_f32_4x4x2bf16::getOperationName();
987 if (m == 32 && n == 32 && k == 4 &&
b == 1)
988 return ROCDL::mfma_f32_32x32x4bf16::getOperationName();
989 if (m == 16 && n == 16 && k == 8 &&
b == 1)
990 return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
995 if (m == 32 && n == 32 && k == 32 &&
b == 1)
996 return ROCDL::mfma_i32_32x32x32_i8::getOperationName();
997 if (m == 16 && n == 16 && k == 64 &&
b == 1)
998 return ROCDL::mfma_i32_16x16x64_i8::getOperationName();
1000 if (m == 32 && n == 32 && k == 4 &&
b == 2)
1001 return ROCDL::mfma_i32_32x32x4i8::getOperationName();
1002 if (m == 16 && n == 16 && k == 4 &&
b == 4)
1003 return ROCDL::mfma_i32_16x16x4i8::getOperationName();
1004 if (m == 4 && n == 4 && k == 4 &&
b == 16)
1005 return ROCDL::mfma_i32_4x4x4i8::getOperationName();
1006 if (m == 32 && n == 32 && k == 8 &&
b == 1)
1007 return ROCDL::mfma_i32_32x32x8i8::getOperationName();
1008 if (m == 16 && n == 16 && k == 16 &&
b == 1)
1009 return ROCDL::mfma_i32_16x16x16i8::getOperationName();
1010 if (m == 32 && n == 32 && k == 16 &&
b == 1 && chipset >=
kGfx942)
1011 return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
1012 if (m == 16 && n == 16 && k == 32 &&
b == 1 && chipset >=
kGfx942)
1013 return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
1017 if (m == 16 && n == 16 && k == 4 &&
b == 1)
1018 return ROCDL::mfma_f64_16x16x4f64::getOperationName();
1019 if (m == 4 && n == 4 && k == 4 &&
b == 4)
1020 return ROCDL::mfma_f64_4x4x4f64::getOperationName();
1027 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
1028 if (m == 16 && n == 16 && k == 32 &&
b == 1) {
1030 return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
1032 return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
1034 if (m == 32 && n == 32 && k == 16 &&
b == 1) {
1036 return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
1038 return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
1044 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
1045 if (m == 16 && n == 16 && k == 32 &&
b == 1) {
1047 return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
1049 return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
1051 if (m == 32 && n == 32 && k == 16 &&
b == 1) {
1053 return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
1055 return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
1059 return std::nullopt;
1064 .Case([](Float8E4M3FNType) {
return 0u; })
1065 .Case([](Float8E5M2Type) {
return 1u; })
1066 .Case([](Float6E2M3FNType) {
return 2u; })
1067 .Case([](Float6E3M2FNType) {
return 3u; })
1068 .Case([](Float4E2M1FNType) {
return 4u; })
1069 .Default(std::nullopt);
1079static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1081 uint32_t n, uint32_t k, uint32_t
b,
Chipset chipset) {
1087 return std::nullopt;
1088 if (!isa<Float32Type>(destType))
1089 return std::nullopt;
1093 if (!aTypeCode || !bTypeCode)
1094 return std::nullopt;
1096 if (m == 32 && n == 32 && k == 64 &&
b == 1)
1097 return std::tuple{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(),
1098 *aTypeCode, *bTypeCode};
1099 if (m == 16 && n == 16 && k == 128 &&
b == 1)
1101 ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), *aTypeCode,
1104 return std::nullopt;
1107static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1110 mfma.getSourceA().getType(), mfma.getSourceB().getType(),
1111 mfma.getDestC().getType(), mfma.getM(), mfma.getN(), mfma.getK(),
1112 mfma.getBlocks(), chipset);
1115static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1118 smfma.getSourceB().getType(),
1119 smfma.getDestC().getType(), smfma.getM(),
1120 smfma.getN(), smfma.getK(), 1u, chipset);
1125static std::optional<StringRef>
1127 Type elemDestType, uint32_t k,
bool isRDNA3) {
1128 using fp8 = Float8E4M3FNType;
1129 using bf8 = Float8E5M2Type;
1134 if (elemSourceType.
isF16() && elemDestType.
isF32())
1135 return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
1136 if (elemSourceType.
isBF16() && elemDestType.
isF32())
1137 return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
1138 if (elemSourceType.
isF16() && elemDestType.
isF16())
1139 return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
1141 return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
1143 return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
1148 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
1149 return std::nullopt;
1153 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
1154 elemDestType.
isF32())
1155 return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
1156 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
1157 elemDestType.
isF32())
1158 return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName();
1159 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
1160 elemDestType.
isF32())
1161 return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName();
1162 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
1163 elemDestType.
isF32())
1164 return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
1166 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
1168 return std::nullopt;
1172 if (k == 32 && !isRDNA3) {
1174 return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
1177 return std::nullopt;
1183 Type elemBSourceType,
1186 using fp8 = Float8E4M3FNType;
1187 using bf8 = Float8E5M2Type;
1190 if (elemSourceType.
isF32() && elemDestType.
isF32())
1191 return ROCDL::wmma_f32_16x16x4_f32::getOperationName();
1193 return std::nullopt;
1197 if (elemSourceType.
isF16() && elemDestType.
isF32())
1198 return ROCDL::wmma_f32_16x16x32_f16::getOperationName();
1199 if (elemSourceType.
isBF16() && elemDestType.
isF32())
1200 return ROCDL::wmma_f32_16x16x32_bf16::getOperationName();
1201 if (elemSourceType.
isF16() && elemDestType.
isF16())
1202 return ROCDL::wmma_f16_16x16x32_f16::getOperationName();
1204 return ROCDL::wmma_bf16_16x16x32_bf16::getOperationName();
1206 return std::nullopt;
1210 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1211 if (elemDestType.
isF32())
1212 return ROCDL::wmma_f32_16x16x64_fp8_fp8::getOperationName();
1213 if (elemDestType.
isF16())
1214 return ROCDL::wmma_f16_16x16x64_fp8_fp8::getOperationName();
1216 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1217 if (elemDestType.
isF32())
1218 return ROCDL::wmma_f32_16x16x64_fp8_bf8::getOperationName();
1219 if (elemDestType.
isF16())
1220 return ROCDL::wmma_f16_16x16x64_fp8_bf8::getOperationName();
1222 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1223 if (elemDestType.
isF32())
1224 return ROCDL::wmma_f32_16x16x64_bf8_bf8::getOperationName();
1225 if (elemDestType.
isF16())
1226 return ROCDL::wmma_f16_16x16x64_bf8_bf8::getOperationName();
1228 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1229 if (elemDestType.
isF32())
1230 return ROCDL::wmma_f32_16x16x64_bf8_fp8::getOperationName();
1231 if (elemDestType.
isF16())
1232 return ROCDL::wmma_f16_16x16x64_bf8_fp8::getOperationName();
1235 return ROCDL::wmma_i32_16x16x64_iu8::getOperationName();
1237 return std::nullopt;
1241 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1242 if (elemDestType.
isF32())
1243 return ROCDL::wmma_f32_16x16x128_fp8_fp8::getOperationName();
1244 if (elemDestType.
isF16())
1245 return ROCDL::wmma_f16_16x16x128_fp8_fp8::getOperationName();
1247 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1248 if (elemDestType.
isF32())
1249 return ROCDL::wmma_f32_16x16x128_fp8_bf8::getOperationName();
1250 if (elemDestType.
isF16())
1251 return ROCDL::wmma_f16_16x16x128_fp8_bf8::getOperationName();
1253 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1254 if (elemDestType.
isF32())
1255 return ROCDL::wmma_f32_16x16x128_bf8_bf8::getOperationName();
1256 if (elemDestType.
isF16())
1257 return ROCDL::wmma_f16_16x16x128_bf8_bf8::getOperationName();
1259 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1260 if (elemDestType.
isF32())
1261 return ROCDL::wmma_f32_16x16x128_bf8_fp8::getOperationName();
1262 if (elemDestType.
isF16())
1263 return ROCDL::wmma_f16_16x16x128_bf8_fp8::getOperationName();
1266 return std::nullopt;
1269 return std::nullopt;
1277 bool isGfx950 = chipset >=
kGfx950;
1281 uint32_t m = op.getM(), n = op.getN(), k = op.getK();
1286 if (m == 16 && n == 16 && k == 32) {
1288 return ROCDL::smfmac_f32_16x16x32_f16::getOperationName();
1290 return ROCDL::smfmac_f32_16x16x32_bf16::getOperationName();
1293 if (m == 16 && n == 16 && k == 64) {
1296 return ROCDL::smfmac_f32_16x16x64_f16::getOperationName();
1298 return ROCDL::smfmac_f32_16x16x64_bf16::getOperationName();
1302 return ROCDL::smfmac_i32_16x16x64_i8::getOperationName();
1303 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1304 return ROCDL::smfmac_f32_16x16x64_fp8_fp8::getOperationName();
1305 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1306 return ROCDL::smfmac_f32_16x16x64_fp8_bf8::getOperationName();
1307 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1308 return ROCDL::smfmac_f32_16x16x64_bf8_fp8::getOperationName();
1309 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1310 return ROCDL::smfmac_f32_16x16x64_bf8_bf8::getOperationName();
1313 if (m == 16 && n == 16 && k == 128 && isGfx950) {
1316 return ROCDL::smfmac_i32_16x16x128_i8::getOperationName();
1317 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1318 return ROCDL::smfmac_f32_16x16x128_fp8_fp8::getOperationName();
1319 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1320 return ROCDL::smfmac_f32_16x16x128_fp8_bf8::getOperationName();
1321 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1322 return ROCDL::smfmac_f32_16x16x128_bf8_fp8::getOperationName();
1323 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1324 return ROCDL::smfmac_f32_16x16x128_bf8_bf8::getOperationName();
1327 if (m == 32 && n == 32 && k == 16) {
1329 return ROCDL::smfmac_f32_32x32x16_f16::getOperationName();
1331 return ROCDL::smfmac_f32_32x32x16_bf16::getOperationName();
1334 if (m == 32 && n == 32 && k == 32) {
1337 return ROCDL::smfmac_f32_32x32x32_f16::getOperationName();
1339 return ROCDL::smfmac_f32_32x32x32_bf16::getOperationName();
1343 return ROCDL::smfmac_i32_32x32x32_i8::getOperationName();
1344 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1345 return ROCDL::smfmac_f32_32x32x32_fp8_fp8::getOperationName();
1346 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1347 return ROCDL::smfmac_f32_32x32x32_fp8_bf8::getOperationName();
1348 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1349 return ROCDL::smfmac_f32_32x32x32_bf8_fp8::getOperationName();
1350 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1351 return ROCDL::smfmac_f32_32x32x32_bf8_bf8::getOperationName();
1354 if (m == 32 && n == 32 && k == 64 && isGfx950) {
1357 return ROCDL::smfmac_i32_32x32x64_i8::getOperationName();
1358 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1359 return ROCDL::smfmac_f32_32x32x64_fp8_fp8::getOperationName();
1360 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1361 return ROCDL::smfmac_f32_32x32x64_fp8_bf8::getOperationName();
1362 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1363 return ROCDL::smfmac_f32_32x32x64_bf8_fp8::getOperationName();
1364 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1365 return ROCDL::smfmac_f32_32x32x64_bf8_bf8::getOperationName();
1368 return std::nullopt;
1376 auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType());
1377 auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType());
1378 auto destVectorType = cast<VectorType>(wmma.getDestC().getType());
1379 Type elemSourceType = sourceVectorType.getElementType();
1380 Type elemBSourceType = sourceBVectorType.getElementType();
1381 Type elemDestType = destVectorType.getElementType();
1383 const uint32_t k = wmma.getK();
1388 if (isRDNA3 || isRDNA4)
1397 return std::nullopt;
1410static std::optional<SparseWMMAOpInfo>
1416 uint32_t m = swmmac.getM(), n = swmmac.getN(), k = swmmac.getK();
1418 if ((m != 16) || (n != 16))
1419 return std::nullopt;
1426 ROCDL::swmmac_f32_16x16x32_f16::getOperationName(),
false,
false,
1430 ROCDL::swmmac_f32_16x16x32_bf16::getOperationName(),
false,
false,
1434 ROCDL::swmmac_f16_16x16x32_f16::getOperationName(),
false,
false,
1438 ROCDL::swmmac_bf16_16x16x32_bf16::getOperationName(),
false,
false,
1443 ROCDL::swmmac_i32_16x16x32_iu8::getOperationName(),
true,
false,
1448 ROCDL::swmmac_i32_16x16x32_iu4::getOperationName(),
true,
false,
1453 ROCDL::swmmac_f32_16x16x32_fp8_fp8::getOperationName(),
false,
1458 ROCDL::swmmac_f32_16x16x32_fp8_bf8::getOperationName(),
false,
1463 ROCDL::swmmac_f32_16x16x32_bf8_fp8::getOperationName(),
false,
1467 ROCDL::swmmac_f32_16x16x32_bf8_bf8::getOperationName(),
false,
1474 ROCDL::swmmac_i32_16x16x64_iu4::getOperationName(),
true,
false,
1479 const bool isGFX1250 = chipset ==
kGfx1250;
1480 const bool isWavesize64 = swmmac.getWave64();
1481 if (isGFX1250 && !isWavesize64) {
1485 ROCDL::swmmac_f32_16x16x64_f16::getOperationName(),
true,
true,
1489 ROCDL::swmmac_f32_16x16x64_bf16::getOperationName(),
true,
true,
1493 ROCDL::swmmac_f16_16x16x64_f16::getOperationName(),
true,
true,
1497 ROCDL::swmmac_bf16_16x16x64_bf16::getOperationName(),
true,
true,
1504 ROCDL::swmmac_f32_16x16x128_fp8_fp8::getOperationName(),
false,
1509 ROCDL::swmmac_f32_16x16x128_fp8_bf8::getOperationName(),
false,
1514 ROCDL::swmmac_f32_16x16x128_bf8_fp8::getOperationName(),
false,
1518 ROCDL::swmmac_f32_16x16x128_bf8_bf8::getOperationName(),
false,
1523 ROCDL::swmmac_f16_16x16x128_fp8_fp8::getOperationName(),
false,
1528 ROCDL::swmmac_f16_16x16x128_fp8_bf8::getOperationName(),
false,
1533 ROCDL::swmmac_f16_16x16x128_bf8_fp8::getOperationName(),
false,
1537 ROCDL::swmmac_f16_16x16x128_bf8_bf8::getOperationName(),
false,
1542 ROCDL::swmmac_f16_16x16x128_bf8_bf8::getOperationName(),
false,
1547 ROCDL::swmmac_i32_16x16x128_iu8::getOperationName(),
true,
true,
1552 return std::nullopt;
1557 MFMAOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1558 : ConvertOpToLLVMPattern<MFMAOp>(converter), chipset(chipset) {}
1563 matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor,
1564 ConversionPatternRewriter &rewriter)
const override {
1565 Location loc = op.getLoc();
1566 Type outType = typeConverter->convertType(op.getDestD().getType());
1567 Type intrinsicOutType = outType;
1568 if (
auto outVecType = dyn_cast<VectorType>(outType))
1569 if (outVecType.getElementType().isBF16())
1570 intrinsicOutType = outVecType.clone(rewriter.getI16Type());
1572 if (chipset.majorVersion != 9 || chipset <
kGfx908)
1573 return op->emitOpError(
"MFMA only supported on gfx908+");
1574 uint32_t getBlgpField =
static_cast<uint32_t
>(op.getBlgp());
1575 if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
1577 return op.emitOpError(
"negation unsupported on older than gfx942");
1579 op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
1582 std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1584 if (!maybeIntrinsic.has_value() && !maybeScaledIntrinsic.has_value())
1585 return op.emitOpError(
"no intrinsic matching MFMA size on given chipset");
1588 !maybeIntrinsic.has_value() && maybeScaledIntrinsic.has_value();
1590 (adaptor.getAbid() > 0 || getBlgpField > 0 || op.getCbsz() > 0)) {
1591 return op.emitOpError(
1592 "non-default abid, blgp, and cbsz aren't supported on MFMAs that can "
1593 "be scaled as those fields are used for type information");
1596 StringRef intrinsicName =
1597 isScaled ? std::get<0>(*maybeScaledIntrinsic) : *maybeIntrinsic;
1600 bool allowBf16 = [&]() {
1605 return intrinsicName.contains(
"16x16x32.bf16") ||
1606 intrinsicName.contains(
"32x32x16.bf16");
1608 OperationState loweredOp(loc, intrinsicName);
1609 loweredOp.addTypes(intrinsicOutType);
1611 rewriter, loc, adaptor.getSourceA(), allowBf16),
1613 rewriter, loc, adaptor.getSourceB(), allowBf16),
1614 adaptor.getDestC()});
1617 auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1618 loweredOp.addOperands({zero, zero});
1619 loweredOp.addAttributes({{
"cbsz", rewriter.getI32IntegerAttr(aTypeCode)},
1620 {
"blgp", rewriter.getI32IntegerAttr(bTypeCode)},
1621 {
"opselA", rewriter.getI32IntegerAttr(0)},
1622 {
"opselB", rewriter.getI32IntegerAttr(0)}});
1624 loweredOp.addAttributes(
1625 {{
"cbsz", rewriter.getI32IntegerAttr(op.getCbsz())},
1626 {
"abid", rewriter.getI32IntegerAttr(op.getAbid())},
1627 {
"blgp", rewriter.getI32IntegerAttr(getBlgpField)}});
1629 Value lowered = rewriter.create(loweredOp)->getResult(0);
1630 if (outType != intrinsicOutType)
1631 lowered = LLVM::BitcastOp::create(rewriter, loc, outType, lowered);
1632 rewriter.replaceOp(op, lowered);
1638 ScaledMFMAOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1639 : ConvertOpToLLVMPattern(converter), chipset(chipset) {}
1644 matchAndRewrite(ScaledMFMAOp op, ScaledMFMAOpAdaptor adaptor,
1645 ConversionPatternRewriter &rewriter)
const override {
1646 Location loc = op.getLoc();
1647 Type intrinsicOutType = typeConverter->convertType(op.getDestD().getType());
1649 if (chipset.majorVersion != 9 || chipset <
kGfx950)
1650 return op->emitOpError(
"scaled MFMA only supported on gfx908+");
1651 std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1653 if (!maybeScaledIntrinsic.has_value())
1654 return op.emitOpError(
1655 "no intrinsic matching scaled MFMA size on given chipset");
1657 auto [intrinsicName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1658 OperationState loweredOp(loc, intrinsicName);
1659 loweredOp.addTypes(intrinsicOutType);
1660 loweredOp.addOperands(
1663 adaptor.getDestC()});
1664 loweredOp.addOperands(
1669 loweredOp.addAttributes(
1670 {{
"cbsz", rewriter.getI32IntegerAttr(aTypeCode)},
1671 {
"blgp", rewriter.getI32IntegerAttr(bTypeCode)},
1672 {
"opselA", rewriter.getI32IntegerAttr(adaptor.getScalesIdxA())},
1673 {
"opselB", rewriter.getI32IntegerAttr(adaptor.getScalesIdxB())}});
1675 Value lowered = rewriter.create(loweredOp)->getResult(0);
1676 rewriter.replaceOp(op, lowered);
1682 SparseMFMAOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1683 : ConvertOpToLLVMPattern<SparseMFMAOp>(converter), chipset(chipset) {}
1688 matchAndRewrite(SparseMFMAOp op, SparseMFMAOpAdaptor adaptor,
1689 ConversionPatternRewriter &rewriter)
const override {
1690 Location loc = op.getLoc();
1692 typeConverter->convertType<VectorType>(op.getDestC().
getType());
1694 return rewriter.notifyMatchFailure(op,
"type conversion failed");
1697 if (chipset.majorVersion != 9 || chipset <
kGfx942)
1698 return op->emitOpError(
"sparse MFMA (smfmac) only supported on gfx942+");
1699 bool isGfx950 = chipset >=
kGfx950;
1705 Value c = adaptor.getDestC();
1708 if (!maybeIntrinsic.has_value())
1709 return op.emitOpError(
1710 "no intrinsic matching sparse MFMA on the given chipset");
1714 Value sparseIdx = adaptor.getSparseIdx();
1715 Type i32Type = rewriter.getI32Type();
1716 if (sparseIdx.
getType() != i32Type)
1717 sparseIdx = LLVM::BitcastOp::create(rewriter, loc, i32Type, sparseIdx);
1719 OperationState loweredOp(loc, maybeIntrinsic.value());
1720 loweredOp.addTypes(outType);
1721 loweredOp.addOperands({a,
b, c, sparseIdx});
1722 loweredOp.addAttributes(
1723 {{
"cbsz", rewriter.getI32IntegerAttr(op.getCbsz())},
1724 {
"abid", rewriter.getI32IntegerAttr(op.getAbid())}});
1725 Value lowered = rewriter.create(loweredOp)->getResult(0);
1726 rewriter.replaceOp(op, lowered);
1732 WMMAOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1733 : ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
1738 matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
1739 ConversionPatternRewriter &rewriter)
const override {
1740 Location loc = op.getLoc();
1742 typeConverter->convertType<VectorType>(op.getDestD().
getType());
1744 return rewriter.notifyMatchFailure(op,
"type conversion failed");
1746 if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
1747 return op->emitOpError(
"WMMA only supported on gfx11 and gfx12");
1749 bool isGFX1250 = chipset >=
kGfx1250;
1754 auto aType = cast<VectorType>(adaptor.getSourceA().getType());
1755 auto bType = cast<VectorType>(adaptor.getSourceB().getType());
1756 auto destCType = cast<VectorType>(adaptor.getDestC().getType());
1757 bool castAToI16 = aType.getElementType().isBF16() && !isGFX1250;
1758 bool castBToI16 = bType.getElementType().isBF16() && !isGFX1250;
1759 bool castDestCToI16 = destCType.getElementType().isBF16() && !isGFX1250;
1760 bool castOutToI16 = outType.getElementType().
isBF16() && !isGFX1250;
1761 VectorType rawOutType = outType;
1763 rawOutType = outType.clone(rewriter.getI16Type());
1764 Value a = adaptor.getSourceA();
1766 a = LLVM::BitcastOp::create(rewriter, loc,
1767 aType.clone(rewriter.getI16Type()), a);
1768 Value
b = adaptor.getSourceB();
1770 b = LLVM::BitcastOp::create(rewriter, loc,
1771 bType.clone(rewriter.getI16Type()),
b);
1772 Value destC = adaptor.getDestC();
1774 destC = LLVM::BitcastOp::create(
1775 rewriter, loc, destCType.clone(rewriter.getI16Type()), destC);
1779 if (!maybeIntrinsic.has_value())
1780 return op.emitOpError(
"no intrinsic matching WMMA on the given chipset");
1782 if (chipset.majorVersion >= 12 && op.getSubwordOffset() != 0)
1783 return op.emitOpError(
"subwordOffset not supported on gfx12+");
1785 SmallVector<Value, 4> operands;
1786 SmallVector<NamedAttribute, 4> attrs;
1788 op.getSourceA(), operands, attrs,
"signA");
1790 op.getSourceB(), operands, attrs,
"signB");
1792 op.getSubwordOffset(), op.getClamp(), operands,
1795 OperationState loweredOp(loc, *maybeIntrinsic);
1796 loweredOp.addTypes(rawOutType);
1797 loweredOp.addOperands(operands);
1798 loweredOp.addAttributes(attrs);
1799 Operation *lowered = rewriter.create(loweredOp);
1801 Operation *maybeCastBack = lowered;
1802 if (rawOutType != outType)
1803 maybeCastBack = LLVM::BitcastOp::create(rewriter, loc, outType,
1805 rewriter.replaceOp(op, maybeCastBack->
getResults());
1811enum class DotFamily {
1820static std::optional<std::pair<StringRef, DotFamily>>
1821dotOpToIntrinsic(DotOp op,
Chipset chipset) {
1822 Type aElem = cast<VectorType>(op.getSourceA().getType()).getElementType();
1823 Type bElem = cast<VectorType>(op.getSourceB().getType()).getElementType();
1824 Type dest = op.getDestC().getType();
1825 bool uA = op.getUnsignedA();
1826 bool uB = op.getUnsignedB();
1831 return {{ROCDL::fdot2::getOperationName(), DotFamily::Clamp}};
1833 return {{ROCDL::fdot2_f16_f16::getOperationName(), DotFamily::NoClamp}};
1834 return std::nullopt;
1840 return {{ROCDL::fdot2_f32_bf16::getOperationName(), DotFamily::Clamp}};
1842 return {{ROCDL::fdot2_bf16_bf16::getOperationName(), DotFamily::NoClamp}};
1843 return std::nullopt;
1847 if (isa<IntegerType>(aElem) && isa<IntegerType>(bElem) &&
1849 bool mixedSign = (uA != uB);
1854 return std::nullopt;
1856 switch (elemWidth) {
1858 name = ROCDL::sudot4::getOperationName();
1861 name = ROCDL::sudot8::getOperationName();
1864 return std::nullopt;
1866 return {{name, DotFamily::Sudot}};
1870 bool supported =
false;
1871 switch (elemWidth) {
1874 name = uA ? ROCDL::udot2::getOperationName()
1875 :
ROCDL::sdot2::getOperationName();
1880 name = uA ? ROCDL::udot4::getOperationName()
1881 :
ROCDL::sdot4::getOperationName();
1886 name = uA ? ROCDL::udot8::getOperationName()
1887 :
ROCDL::sdot8::getOperationName();
1890 return std::nullopt;
1893 return std::nullopt;
1894 return {{name, DotFamily::Clamp}};
1898 bool aIsFp8 = isa<Float8E4M3FNType>(aElem);
1899 bool aIsBf8 = isa<Float8E5M2Type>(aElem);
1900 bool bIsFp8 = isa<Float8E4M3FNType>(bElem);
1901 bool bIsBf8 = isa<Float8E5M2Type>(bElem);
1902 if ((aIsFp8 || aIsBf8) && (bIsFp8 || bIsBf8) && dest.
isF32()) {
1904 return std::nullopt;
1906 if (aIsFp8 && bIsFp8)
1907 name = ROCDL::dot4_f32_fp8_fp8::getOperationName();
1908 else if (aIsFp8 && bIsBf8)
1909 name = ROCDL::dot4_f32_fp8_bf8::getOperationName();
1910 else if (aIsBf8 && bIsFp8)
1911 name = ROCDL::dot4_f32_bf8_fp8::getOperationName();
1913 name = ROCDL::dot4_f32_bf8_bf8::getOperationName();
1914 return {{name, DotFamily::NoClamp}};
1917 return std::nullopt;
1921 DotOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1922 : ConvertOpToLLVMPattern<DotOp>(converter), chipset(chipset) {}
1927 matchAndRewrite(DotOp op, DotOpAdaptor adaptor,
1928 ConversionPatternRewriter &rewriter)
const override {
1929 Location loc = op.getLoc();
1931 std::optional<std::pair<StringRef, DotFamily>> maybeIntrinsic =
1932 dotOpToIntrinsic(op, chipset);
1933 if (!maybeIntrinsic)
1934 return op.emitOpError(
"no intrinsic matching dot on the given chipset: ")
1935 << op.getSourceA().getType() <<
" * " << op.getSourceB().getType()
1936 <<
" + " << op.getDestC().getType();
1938 auto [intrinsicName, family] = maybeIntrinsic.value();
1942 Value c = adaptor.getDestC();
1944 SmallVector<NamedAttribute, 3> attrs;
1945 if (family == DotFamily::Sudot) {
1946 attrs.push_back(rewriter.getNamedAttr(
1947 "signA", rewriter.getBoolAttr(!op.getUnsignedA())));
1948 attrs.push_back(rewriter.getNamedAttr(
1949 "signB", rewriter.getBoolAttr(!op.getUnsignedB())));
1952 if (family != DotFamily::NoClamp && op.getClamp())
1954 rewriter.getNamedAttr(
"clamp", rewriter.getBoolAttr(
true)));
1956 Type resultType = typeConverter->convertType(op.getDestD().getType());
1958 OperationState loweredOp(loc, intrinsicName);
1959 loweredOp.addTypes(resultType);
1960 loweredOp.addOperands({a,
b, c});
1961 loweredOp.addAttributes(attrs);
1962 Operation *lowered = rewriter.create(loweredOp);
1963 rewriter.replaceOp(op, lowered->
getResults());
1969 SparseWMMAOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1970 : ConvertOpToLLVMPattern<SparseWMMAOp>(converter), chipset(chipset) {}
1975 matchAndRewrite(SparseWMMAOp op, SparseWMMAOpAdaptor adaptor,
1976 ConversionPatternRewriter &rewriter)
const override {
1977 Location loc = op.getLoc();
1979 typeConverter->convertType<VectorType>(op.getDestD().
getType());
1981 return rewriter.notifyMatchFailure(op,
"type conversion failed");
1983 std::optional<SparseWMMAOpInfo> maybeIntrinsic =
1986 if (!maybeIntrinsic.has_value())
1987 return op.emitOpError(
1988 "no intrinsic matching Sparse WMMA on the given chipset");
1989 SparseWMMAOpInfo intrinsic = maybeIntrinsic.value();
1991 SmallVector<NamedAttribute> attrs;
1993 if ((op.getUnsignedA() || op.getUnsignedB()) && !intrinsic.
useSign)
1994 return op->emitOpError(
"intrinsic doesn't support unsign");
1996 if (
auto attr = op.getUnsignedAAttr())
1997 attrs.push_back({
"signA", attr});
1998 if (
auto attr = op.getUnsignedBAttr())
1999 attrs.push_back({
"signB", attr});
2002 if ((op.getReuseA() || op.getReuseB()) && !intrinsic.
useReuse)
2003 return op->emitOpError(
"intrinsic doesn't support reuse");
2005 if (
auto attr = op.getReuseAAttr())
2006 attrs.push_back({
"reuseA", attr});
2007 if (
auto attr = op.getReuseBAttr())
2008 attrs.push_back({
"reuseB", attr});
2011 if (op.getClamp() && !intrinsic.
useClamp)
2012 return op->emitOpError(
"intrinsic doesn't support clamp");
2013 if (intrinsic.
useClamp && op.getClampAttr())
2014 attrs.push_back({
"clamp", op.getClampAttr()});
2016 const bool isGFX1250orHigher =
2017 chipset.majorVersion == 12 && chipset.minorVersion >= 5;
2022 Value c = adaptor.getDestC();
2023 VectorType rawOutType = outType;
2024 if (!isGFX1250orHigher) {
2026 rawOutType = cast<VectorType>(c.
getType());
2030 Value sparseIdx = LLVM::BitcastOp::create(
2031 rewriter, loc, rewriter.getI32Type(), adaptor.getSparseIdx());
2033 OperationState loweredOp(loc, intrinsic.
name);
2034 loweredOp.addTypes(rawOutType);
2035 loweredOp.addOperands({a,
b, c, sparseIdx});
2036 loweredOp.addAttributes(attrs);
2037 Operation *lowered = rewriter.create(loweredOp);
2039 Operation *maybeCastBack = lowered;
2040 if (rawOutType != outType)
2041 maybeCastBack = LLVM::BitcastOp::create(rewriter, loc, outType,
2043 rewriter.replaceOp(op, maybeCastBack->
getResults());
2050 ScaledWMMAOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
2051 : ConvertOpToLLVMPattern<ScaledWMMAOp>(converter), chipset(chipset) {}
2056 matchAndRewrite(ScaledWMMAOp op, ScaledWMMAOpAdaptor adaptor,
2057 ConversionPatternRewriter &rewriter)
const override {
2058 Location loc = op.getLoc();
2060 typeConverter->convertType<VectorType>(op.getDestD().
getType());
2062 return rewriter.notifyMatchFailure(op,
"type conversion failed");
2065 return op->emitOpError(
"WMMA scale only supported on gfx1250+");
2067 int64_t m = op.getM();
2068 int64_t n = op.getN();
2069 int64_t k = op.getK();
2077 if (!aFmtCode || !bFmtCode)
2078 return op.emitOpError(
"unsupported element types for scaled_wmma");
2081 auto scaleAVecType = cast<VectorType>(op.getScaleA().getType());
2082 auto scaleBVecType = cast<VectorType>(op.getScaleB().getType());
2084 if (scaleAVecType.getNumElements() != scaleBVecType.getNumElements())
2085 return op.emitOpError(
"scaleA and scaleB must have equal vector length");
2088 Type scaleAElemType = scaleAVecType.getElementType();
2089 Type scaleBElemType = scaleBVecType.getElementType();
2094 if (!scaleAFmt || !scaleBFmt)
2095 return op.emitOpError(
"unsupported scale element types");
2098 bool isScale16 = (scaleAVecType.getNumElements() == 8);
2099 std::optional<StringRef> intrinsicName =
2102 return op.emitOpError(
"unsupported scaled_wmma dimensions: ")
2103 << m <<
"x" << n <<
"x" << k;
2105 SmallVector<NamedAttribute, 8> attrs;
2108 bool is32x16 = (m == 32 && n == 16 && k == 128);
2110 attrs.emplace_back(
"fmtA", rewriter.getI32IntegerAttr(*aFmtCode));
2111 attrs.emplace_back(
"fmtB", rewriter.getI32IntegerAttr(*bFmtCode));
2115 attrs.emplace_back(
"modC", rewriter.getI16IntegerAttr(0));
2120 "scaleAType", rewriter.getI32IntegerAttr(op.getAFirstScaleLane() / 16));
2121 attrs.emplace_back(
"fmtScaleA", rewriter.getI32IntegerAttr(*scaleAFmt));
2123 "scaleBType", rewriter.getI32IntegerAttr(op.getBFirstScaleLane() / 16));
2124 attrs.emplace_back(
"fmtScaleB", rewriter.getI32IntegerAttr(*scaleBFmt));
2127 attrs.emplace_back(
"reuseA", rewriter.getBoolAttr(
false));
2128 attrs.emplace_back(
"reuseB", rewriter.getBoolAttr(
false));
2141 OperationState loweredOp(loc, *intrinsicName);
2142 loweredOp.addTypes(outType);
2143 loweredOp.addOperands(
2144 {sourceA, sourceB, adaptor.getDestC(), packedScaleA, packedScaleB});
2145 loweredOp.addAttributes(attrs);
2147 Operation *lowered = rewriter.create(loweredOp);
2148 rewriter.replaceOp(op, lowered->
getResults());
2154struct TransposeLoadOpLowering
2156 TransposeLoadOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
2157 : ConvertOpToLLVMPattern<TransposeLoadOp>(converter), chipset(chipset) {}
2162 matchAndRewrite(TransposeLoadOp op, TransposeLoadOpAdaptor adaptor,
2163 ConversionPatternRewriter &rewriter)
const override {
2165 return op.emitOpError(
"Non-gfx950 chipset not supported");
2167 Location loc = op.getLoc();
2168 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
2172 size_t srcElementSize =
2173 srcMemRefType.getElementType().getIntOrFloatBitWidth();
2174 if (srcElementSize < 8)
2175 return op.emitOpError(
"Expect source memref to have at least 8 bits "
2176 "element size, got ")
2179 auto resultType = cast<VectorType>(op.getResult().getType());
2182 (adaptor.getSrcIndices()));
2184 size_t numElements = resultType.getNumElements();
2185 size_t elementTypeSize =
2190 Type rocdlResultType = VectorType::get((numElements * elementTypeSize) / 32,
2191 rewriter.getIntegerType(32));
2192 Type llvmResultType = typeConverter->convertType(resultType);
2194 switch (elementTypeSize) {
2196 assert(numElements == 16);
2197 auto rocdlOp = ROCDL::ds_read_tr4_b64::create(rewriter, loc,
2198 rocdlResultType, srcPtr);
2199 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
2203 assert(numElements == 16);
2204 auto rocdlOp = ROCDL::ds_read_tr6_b96::create(rewriter, loc,
2205 rocdlResultType, srcPtr);
2206 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
2210 assert(numElements == 8);
2211 auto rocdlOp = ROCDL::ds_read_tr8_b64::create(rewriter, loc,
2212 rocdlResultType, srcPtr);
2213 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
2217 assert(numElements == 4);
2218 rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr16_b64>(op, llvmResultType,
2223 return op.emitOpError(
"Unsupported element size for transpose load");
2230 GatherToLDSOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
2231 : ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}
2236 matchAndRewrite(GatherToLDSOp op, GatherToLDSOpAdaptor adaptor,
2237 ConversionPatternRewriter &rewriter)
const override {
2238 if (chipset.majorVersion < 9 || chipset.majorVersion > 10)
2239 return op.emitOpError(
"pre-gfx9 and post-gfx10 not supported");
2241 Location loc = op.getLoc();
2243 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
2244 auto dstMemRefType = cast<MemRefType>(op.getDst().getType());
2249 Type transferType = op.getTransferType();
2250 int loadWidth = [&]() ->
int {
2251 if (
auto transferVectorType = dyn_cast<VectorType>(transferType)) {
2252 return (transferVectorType.getNumElements() *
2253 transferVectorType.getElementTypeBitWidth()) /
2260 if (!llvm::is_contained({1, 2, 4, 12, 16}, loadWidth))
2261 return op.emitOpError(
"chipset unsupported element size");
2263 if (chipset !=
kGfx950 && llvm::is_contained({12, 16}, loadWidth))
2264 return op.emitOpError(
"Gather to LDS instructions with 12-byte and "
2265 "16-byte load widths are only supported on gfx950");
2269 (adaptor.getSrcIndices()));
2272 (adaptor.getDstIndices()));
2274 if (op.getAsync()) {
2275 rewriter.replaceOpWithNewOp<ROCDL::LoadAsyncToLDSOp>(
2276 op, srcPtr, dstPtr, rewriter.getI32IntegerAttr(loadWidth),
2277 rewriter.getI32IntegerAttr(0),
2281 rewriter.replaceOpWithNewOp<ROCDL::LoadToLDSOp>(
2282 op, srcPtr, dstPtr, rewriter.getI32IntegerAttr(loadWidth),
2283 rewriter.getI32IntegerAttr(0),
2292struct GlobalLoadAsyncToLDSOpLowering
2294 GlobalLoadAsyncToLDSOpLowering(
const LLVMTypeConverter &converter,
2296 : ConvertOpToLLVMPattern<GlobalLoadAsyncToLDSOp>(converter),
2302 matchAndRewrite(GlobalLoadAsyncToLDSOp op,
2303 GlobalLoadAsyncToLDSOpAdaptor adaptor,
2304 ConversionPatternRewriter &rewriter)
const override {
2306 return op.emitOpError(
2307 "global_load_async_to_lds is only supported on gfx1250+");
2309 Location loc = op.getLoc();
2310 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
2311 auto dstMemRefType = cast<MemRefType>(op.getDst().getType());
2313 Type transferType = op.getTransferType();
2315 isa<VectorType>(transferType)
2316 ? cast<VectorType>(transferType).getNumElements() *
2317 cast<VectorType>(transferType).getElementTypeBitWidth()
2322 adaptor.getSrcIndices());
2325 adaptor.getDstIndices());
2328 Value mask = adaptor.getMask();
2329 int64_t nullptrVal =
2330 llvm::AMDGPU::getNullPointerValue(llvm::AMDGPUAS::LOCAL_ADDRESS);
2334 LLVM::IntToPtrOp::create(rewriter, loc, dstPtr.
getType(), nullInt);
2335 dstPtr = LLVM::SelectOp::create(rewriter, loc, mask, dstPtr, nullPtr);
2338 auto offset = rewriter.getI32IntegerAttr(0);
2339 auto aux = rewriter.getI32IntegerAttr(0);
2341 switch (transferBits) {
2343 rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadAsyncToLDSB8Op>(
2348 rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadAsyncToLDSB32Op>(
2353 rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadAsyncToLDSB64Op>(
2358 rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadAsyncToLDSB128Op>(
2363 return op.emitOpError(
"unsupported transfer width");
2370struct ExtPackedFp8OpLowering final
2372 ExtPackedFp8OpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
2373 : ConvertOpToLLVMPattern<amdgpu::ExtPackedFp8Op>(converter),
2378 matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
2379 ConversionPatternRewriter &rewriter)
const override;
2382struct ScaledExtPackedMatrixOpLowering final
2384 ScaledExtPackedMatrixOpLowering(
const LLVMTypeConverter &converter,
2386 : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedMatrixOp>(converter),
2391 matchAndRewrite(ScaledExtPackedMatrixOp op,
2392 ScaledExtPackedMatrixOpAdaptor adaptor,
2393 ConversionPatternRewriter &rewriter)
const override;
2396struct PackedTrunc2xFp8OpLowering final
2398 PackedTrunc2xFp8OpLowering(
const LLVMTypeConverter &converter,
2400 : ConvertOpToLLVMPattern<amdgpu::PackedTrunc2xFp8Op>(converter),
2405 matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
2406 ConversionPatternRewriter &rewriter)
const override;
2409struct PackedStochRoundFp8OpLowering final
2411 PackedStochRoundFp8OpLowering(
const LLVMTypeConverter &converter,
2413 : ConvertOpToLLVMPattern<amdgpu::PackedStochRoundFp8Op>(converter),
2418 matchAndRewrite(PackedStochRoundFp8Op op,
2419 PackedStochRoundFp8OpAdaptor adaptor,
2420 ConversionPatternRewriter &rewriter)
const override;
2423struct ScaledExtPackedOpLowering final
2425 ScaledExtPackedOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
2426 : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedOp>(converter),
2431 matchAndRewrite(ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
2432 ConversionPatternRewriter &rewriter)
const override;
2435struct PackedScaledTruncOpLowering final
2437 PackedScaledTruncOpLowering(
const LLVMTypeConverter &converter,
2439 : ConvertOpToLLVMPattern<amdgpu::PackedScaledTruncOp>(converter),
2444 matchAndRewrite(PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
2445 ConversionPatternRewriter &rewriter)
const override;
2450LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
2451 ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
2452 ConversionPatternRewriter &rewriter)
const {
2453 Location loc = op.getLoc();
2455 return rewriter.notifyMatchFailure(
2456 loc,
"Fp8 conversion instructions are not available on target "
2457 "architecture and their emulation is not implemented");
2459 getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type()));
2460 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2461 Type f32 = getTypeConverter()->convertType(op.getResult().getType());
2463 Value source = adaptor.getSource();
2464 auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
2465 auto resultVecType = dyn_cast<VectorType>(op.getResult().getType());
2468 if (!sourceVecType || sourceVecType.getNumElements() < 4) {
2469 Value longVec = LLVM::UndefOp::create(rewriter, loc, v4i8);
2470 if (!sourceVecType) {
2471 longVec = LLVM::InsertElementOp::create(
2474 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
2476 Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
2478 LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
2483 Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
2484 if (resultVecType) {
2486 rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Bf8Op>(op, f32, i32Source,
2489 rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Fp8Op>(op, f32, i32Source,
2494 rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
2497 rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
2504int32_t getScaleSel(int32_t blockSize,
unsigned bitWidth, int32_t scaleWaveHalf,
2505 int32_t firstScaleByte) {
2511 assert(llvm::is_contained({16, 32}, blockSize));
2512 assert(llvm::is_contained({4u, 6u, 8u}, bitWidth));
2514 const bool isFp8 = bitWidth == 8;
2515 const bool isBlock16 = blockSize == 16;
2518 int32_t bit0 = isBlock16;
2519 assert(llvm::is_contained({0, 1, 2}, firstScaleByte));
2520 int32_t bit1 = (firstScaleByte == 2) << 1;
2521 assert(llvm::is_contained({0, 1}, scaleWaveHalf));
2522 int32_t bit2 = scaleWaveHalf << 2;
2523 return bit2 | bit1 | bit0;
2526 int32_t bit0 = isBlock16;
2528 assert(llvm::is_contained({0, 1, 2, 3}, firstScaleByte));
2529 int32_t bits2and1 = firstScaleByte << 1;
2530 assert(llvm::is_contained({0, 1}, scaleWaveHalf));
2531 int32_t bit3 = scaleWaveHalf << 3;
2532 int32_t bits = bit3 | bits2and1 | bit0;
2534 assert(!llvm::is_contained(
2535 {0b0011, 0b0101, 0b0111, 0b1000, 0b1001, 0b1011, 0b1111}, bits));
2539static std::optional<StringRef>
2540scaledExtPacked816ToIntrinsic(Type srcElemType, Type destElemType) {
2541 using fp4 = Float4E2M1FNType;
2542 using fp8 = Float8E4M3FNType;
2543 using bf8 = Float8E5M2Type;
2544 using fp6 = Float6E2M3FNType;
2545 using bf6 = Float6E3M2FNType;
2546 if (isa<fp4>(srcElemType)) {
2547 if (destElemType.
isF16())
2548 return ROCDL::CvtPkScalePk8F16Fp4Op::getOperationName();
2549 if (destElemType.
isBF16())
2550 return ROCDL::CvtPkScalePk8Bf16Fp4Op::getOperationName();
2551 if (destElemType.
isF32())
2552 return ROCDL::CvtPkScalePk8F32Fp4Op::getOperationName();
2553 return std::nullopt;
2555 if (isa<fp8>(srcElemType)) {
2556 if (destElemType.
isF16())
2557 return ROCDL::CvtPkScalePk8F16Fp8Op::getOperationName();
2558 if (destElemType.
isBF16())
2559 return ROCDL::CvtPkScalePk8Bf16Fp8Op::getOperationName();
2560 if (destElemType.
isF32())
2561 return ROCDL::CvtPkScalePk8F32Fp8Op::getOperationName();
2562 return std::nullopt;
2564 if (isa<bf8>(srcElemType)) {
2565 if (destElemType.
isF16())
2566 return ROCDL::CvtPkScalePk8F16Bf8Op::getOperationName();
2567 if (destElemType.
isBF16())
2568 return ROCDL::CvtPkScalePk8Bf16Bf8Op::getOperationName();
2569 if (destElemType.
isF32())
2570 return ROCDL::CvtPkScalePk8F32Bf8Op::getOperationName();
2571 return std::nullopt;
2573 if (isa<fp6>(srcElemType)) {
2574 if (destElemType.
isF16())
2575 return ROCDL::CvtPkScalePk16F16Fp6Op::getOperationName();
2576 if (destElemType.
isBF16())
2577 return ROCDL::CvtPkScalePk16Bf16Fp6Op::getOperationName();
2578 if (destElemType.
isF32())
2579 return ROCDL::CvtPkScalePk16F32Fp6Op::getOperationName();
2580 return std::nullopt;
2582 if (isa<bf6>(srcElemType)) {
2583 if (destElemType.
isF16())
2584 return ROCDL::CvtPkScalePk16F16Bf6Op::getOperationName();
2585 if (destElemType.
isBF16())
2586 return ROCDL::CvtPkScalePk16Bf16Bf6Op::getOperationName();
2587 if (destElemType.
isF32())
2588 return ROCDL::CvtPkScalePk16F32Bf6Op::getOperationName();
2589 return std::nullopt;
2591 llvm_unreachable(
"invalid combination of element types for packed conversion "
2595LogicalResult ScaledExtPackedMatrixOpLowering::matchAndRewrite(
2596 ScaledExtPackedMatrixOp op, ScaledExtPackedMatrixOpAdaptor adaptor,
2597 ConversionPatternRewriter &rewriter)
const {
2598 using fp4 = Float4E2M1FNType;
2599 using fp8 = Float8E4M3FNType;
2600 using bf8 = Float8E5M2Type;
2601 using fp6 = Float6E2M3FNType;
2602 using bf6 = Float6E3M2FNType;
2603 Location loc = op.getLoc();
2605 return rewriter.notifyMatchFailure(
2607 "Scaled fp packed conversion instructions are not available on target "
2608 "architecture and their emulation is not implemented");
2612 int32_t scaleWaveHalf = op.getFirstScaleLane() / 16;
2613 int32_t firstScaleByte = op.getFirstScaleByte();
2614 int32_t blockSize = op.getBlockSize();
2615 auto sourceType = cast<VectorType>(op.getSource().getType());
2616 auto srcElemType = cast<FloatType>(sourceType.getElementType());
2617 unsigned bitWidth = srcElemType.getWidth();
2619 auto targetType = cast<VectorType>(op.getResult().getType());
2620 auto destElemType = cast<FloatType>(targetType.getElementType());
2622 IntegerType i32 = rewriter.getI32Type();
2623 Value source = adaptor.getSource();
2624 Type llvmResultType = typeConverter->convertType(op.getResult().getType());
2625 Type packedType =
nullptr;
2626 if (isa<fp4>(srcElemType)) {
2628 packedType = getTypeConverter()->convertType(packedType);
2629 }
else if (isa<fp8, bf8>(srcElemType)) {
2630 packedType = VectorType::get(2, i32);
2631 packedType = getTypeConverter()->convertType(packedType);
2632 }
else if (isa<fp6, bf6>(srcElemType)) {
2633 packedType = VectorType::get(3, i32);
2634 packedType = getTypeConverter()->convertType(packedType);
2636 llvm_unreachable(
"invalid element type for packed scaled ext");
2639 if (!packedType || !llvmResultType) {
2640 return rewriter.notifyMatchFailure(op,
"type conversion failed");
2643 std::optional<StringRef> maybeIntrinsic =
2644 scaledExtPacked816ToIntrinsic(srcElemType, destElemType);
2645 if (!maybeIntrinsic.has_value())
2646 return op.emitOpError(
2647 "no intrinsic matching packed scaled conversion on the given chipset");
2650 getScaleSel(blockSize, bitWidth, scaleWaveHalf, firstScaleByte);
2652 LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getScale());
2653 Value castedSource =
2654 LLVM::BitcastOp::create(rewriter, loc, packedType, source);
2656 OperationState loweredOp(loc, *maybeIntrinsic);
2657 loweredOp.addTypes({llvmResultType});
2658 loweredOp.addOperands({castedSource, castedScale});
2660 SmallVector<NamedAttribute, 1> attrs;
2662 NamedAttribute(
"scaleSel", rewriter.getI32IntegerAttr(scaleSel)));
2664 loweredOp.addAttributes(attrs);
2665 Operation *lowered = rewriter.create(loweredOp);
2666 rewriter.replaceOp(op, lowered);
2671LogicalResult ScaledExtPackedOpLowering::matchAndRewrite(
2672 ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
2673 ConversionPatternRewriter &rewriter)
const {
2674 Location loc = op.getLoc();
2676 return rewriter.notifyMatchFailure(
2677 loc,
"Scaled fp conversion instructions are not available on target "
2678 "architecture and their emulation is not implemented");
2679 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2681 Value source = adaptor.getSource();
2682 Value scale = adaptor.getScale();
2684 VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
2685 Type sourceElemType = sourceVecType.getElementType();
2686 VectorType destVecType = cast<VectorType>(op.getResult().getType());
2687 Type destElemType = destVecType.getElementType();
2689 VectorType packedVecType;
2690 if (isa<Float8E5M2Type, Float8E4M3FNType>(sourceElemType)) {
2691 VectorType v4i8 = VectorType::get(4, rewriter.getI8Type());
2692 packedVecType = cast<VectorType>(getTypeConverter()->convertType(v4i8));
2693 }
else if (isa<Float4E2M1FNType>(sourceElemType)) {
2694 VectorType v8i4 = VectorType::get(8, rewriter.getI4Type());
2695 packedVecType = cast<VectorType>(getTypeConverter()->convertType(v8i4));
2697 llvm_unreachable(
"invalid element type for scaled ext");
2701 if (sourceVecType.getNumElements() < packedVecType.getNumElements()) {
2702 Value longVec = LLVM::ZeroOp::create(rewriter, loc, packedVecType);
2703 if (!sourceVecType) {
2704 longVec = LLVM::InsertElementOp::create(
2707 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
2709 Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
2711 LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
2716 Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
2718 if (isa<Float8E5M2Type>(sourceElemType) && destElemType.
isF32())
2719 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Bf8Op>(
2720 op, destVecType, i32Source, scale, op.getIndex());
2721 else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.
isF16())
2722 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Bf8Op>(
2723 op, destVecType, i32Source, scale, op.getIndex());
2724 else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.
isBF16())
2725 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Bf8Op>(
2726 op, destVecType, i32Source, scale, op.getIndex());
2727 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.
isF32())
2728 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp8Op>(
2729 op, destVecType, i32Source, scale, op.getIndex());
2730 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.
isF16())
2731 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp8Op>(
2732 op, destVecType, i32Source, scale, op.getIndex());
2733 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.
isBF16())
2734 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp8Op>(
2735 op, destVecType, i32Source, scale, op.getIndex());
2736 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.
isF32())
2737 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp4Op>(
2738 op, destVecType, i32Source, scale, op.getIndex());
2739 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.
isF16())
2740 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp4Op>(
2741 op, destVecType, i32Source, scale, op.getIndex());
2742 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.
isBF16())
2743 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp4Op>(
2744 op, destVecType, i32Source, scale, op.getIndex());
2751LogicalResult PackedScaledTruncOpLowering::matchAndRewrite(
2752 PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
2753 ConversionPatternRewriter &rewriter)
const {
2754 Location loc = op.getLoc();
2756 return rewriter.notifyMatchFailure(
2757 loc,
"Scaled fp conversion instructions are not available on target "
2758 "architecture and their emulation is not implemented");
2759 Type v2i16 = getTypeConverter()->convertType(
2760 VectorType::get(2, rewriter.getI16Type()));
2761 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2763 Type resultType = op.getResult().getType();
2765 VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
2766 Type sourceElemType = sourceVecType.getElementType();
2768 Type intResultType = isa<Float4E2M1FNType>(resultElemType) ? i32 : v2i16;
2770 Value source = adaptor.getSource();
2771 Value scale = adaptor.getScale();
2772 Value existing = adaptor.getExisting();
2774 existing = LLVM::BitcastOp::create(rewriter, loc, intResultType, existing);
2776 existing = LLVM::ZeroOp::create(rewriter, loc, intResultType);
2778 if (sourceVecType.getNumElements() < 2) {
2780 Value elem0 = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
2781 VectorType v2 = VectorType::get(2, sourceElemType);
2782 source = LLVM::ZeroOp::create(rewriter, loc, v2);
2783 source = LLVM::InsertElementOp::create(rewriter, loc, source, elem0, c0);
2786 Value sourceA, sourceB;
2787 if (sourceElemType.
isF32()) {
2790 sourceA = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
2791 sourceB = LLVM::ExtractElementOp::create(rewriter, loc, source, c1);
2795 if (sourceElemType.
isF32() && isa<Float8E5M2Type>(resultElemType))
2796 result = ROCDL::CvtScaleF32PkBf8F32Op::create(rewriter, loc, intResultType,
2797 existing, sourceA, sourceB,
2798 scale, op.getIndex());
2799 else if (sourceElemType.
isF16() && isa<Float8E5M2Type>(resultElemType))
2800 result = ROCDL::CvtScaleF32PkBf8F16Op::create(
2801 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2802 else if (sourceElemType.
isBF16() && isa<Float8E5M2Type>(resultElemType))
2803 result = ROCDL::CvtScaleF32PkBf8Bf16Op::create(
2804 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2805 else if (sourceElemType.
isF32() && isa<Float8E4M3FNType>(resultElemType))
2806 result = ROCDL::CvtScaleF32PkFp8F32Op::create(rewriter, loc, intResultType,
2807 existing, sourceA, sourceB,
2808 scale, op.getIndex());
2809 else if (sourceElemType.
isF16() && isa<Float8E4M3FNType>(resultElemType))
2810 result = ROCDL::CvtScaleF32PkFp8F16Op::create(
2811 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2812 else if (sourceElemType.
isBF16() && isa<Float8E4M3FNType>(resultElemType))
2813 result = ROCDL::CvtScaleF32PkFp8Bf16Op::create(
2814 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2815 else if (sourceElemType.
isF32() && isa<Float4E2M1FNType>(resultElemType))
2816 result = ROCDL::CvtScaleF32PkFp4F32Op::create(rewriter, loc, intResultType,
2817 existing, sourceA, sourceB,
2818 scale, op.getIndex());
2819 else if (sourceElemType.
isF16() && isa<Float4E2M1FNType>(resultElemType))
2820 result = ROCDL::CvtScaleF32PkFp4F16Op::create(
2821 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2822 else if (sourceElemType.
isBF16() && isa<Float4E2M1FNType>(resultElemType))
2823 result = ROCDL::CvtScaleF32PkFp4Bf16Op::create(
2824 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2828 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
2829 op, getTypeConverter()->convertType(resultType),
result);
2833LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
2834 PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
2835 ConversionPatternRewriter &rewriter)
const {
2836 Location loc = op.getLoc();
2838 return rewriter.notifyMatchFailure(
2839 loc,
"Fp8 conversion instructions are not available on target "
2840 "architecture and their emulation is not implemented");
2841 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2843 Type resultType = op.getResult().getType();
2846 Value sourceA = adaptor.getSourceA();
2847 Value sourceB = adaptor.getSourceB();
2849 sourceB = LLVM::UndefOp::create(rewriter, loc, sourceA.
getType());
2850 Value existing = adaptor.getExisting();
2852 existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
2854 existing = LLVM::UndefOp::create(rewriter, loc, i32);
2858 result = ROCDL::CvtPkBf8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
2859 existing, op.getWordIndex());
2861 result = ROCDL::CvtPkFp8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
2862 existing, op.getWordIndex());
2864 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
2865 op, getTypeConverter()->convertType(resultType),
result);
2869LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
2870 PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
2871 ConversionPatternRewriter &rewriter)
const {
2872 Location loc = op.getLoc();
2874 return rewriter.notifyMatchFailure(
2875 loc,
"Fp8 conversion instructions are not available on target "
2876 "architecture and their emulation is not implemented");
2877 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2879 Type resultType = op.getResult().getType();
2882 Value source = adaptor.getSource();
2883 Value stoch = adaptor.getStochiasticParam();
2884 Value existing = adaptor.getExisting();
2886 existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
2888 existing = LLVM::UndefOp::create(rewriter, loc, i32);
2892 result = ROCDL::CvtSrBf8F32Op::create(rewriter, loc, i32, source, stoch,
2893 existing, op.getStoreIndex());
2895 result = ROCDL::CvtSrFp8F32Op::create(rewriter, loc, i32, source, stoch,
2896 existing, op.getStoreIndex());
2898 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
2899 op, getTypeConverter()->convertType(resultType),
result);
2905struct AMDGPUDPPLowering :
public ConvertOpToLLVMPattern<DPPOp> {
2906 AMDGPUDPPLowering(
const LLVMTypeConverter &converter, Chipset chipset)
2907 : ConvertOpToLLVMPattern<DPPOp>(converter), chipset(chipset) {}
2911 matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor,
2912 ConversionPatternRewriter &rewriter)
const override {
2915 Location loc = DppOp.getLoc();
2916 Value src = adaptor.getSrc();
2917 Value old = adaptor.getOld();
2920 Type llvmType =
nullptr;
2922 llvmType = rewriter.getI32Type();
2923 }
else if (isa<FloatType>(srcType)) {
2925 ? rewriter.getF32Type()
2926 : rewriter.getF64Type();
2927 }
else if (isa<IntegerType>(srcType)) {
2929 ? rewriter.getI32Type()
2930 : rewriter.getI64Type();
2932 auto llvmSrcIntType = typeConverter->convertType(
2936 auto convertOperand = [&](Value operand, Type operandType) {
2937 if (operandType.getIntOrFloatBitWidth() <= 16) {
2938 if (llvm::isa<FloatType>(operandType)) {
2940 LLVM::BitcastOp::create(rewriter, loc, llvmSrcIntType, operand);
2942 auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
2943 32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
2944 Value undefVec = LLVM::UndefOp::create(rewriter, loc, llvmVecType);
2946 LLVM::InsertElementOp::create(rewriter, loc, undefVec, operand,
2948 operand = LLVM::BitcastOp::create(rewriter, loc, llvmType, operand);
2953 src = convertOperand(src, srcType);
2954 old = convertOperand(old, oldType);
2957 enum DppCtrl :
unsigned {
2966 ROW_HALF_MIRROR = 0x141,
2971 auto kind = DppOp.getKind();
2972 auto permArgument = DppOp.getPermArgument();
2973 uint32_t DppCtrl = 0;
2977 case DPPPerm::quad_perm: {
2978 auto quadPermAttr = cast<ArrayAttr>(*permArgument);
2980 for (
auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
2981 uint32_t num = elem.getInt();
2982 DppCtrl |= num << (i * 2);
2987 case DPPPerm::row_shl: {
2988 auto intAttr = cast<IntegerAttr>(*permArgument);
2989 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0;
2992 case DPPPerm::row_shr: {
2993 auto intAttr = cast<IntegerAttr>(*permArgument);
2994 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0;
2997 case DPPPerm::row_ror: {
2998 auto intAttr = cast<IntegerAttr>(*permArgument);
2999 DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0;
3002 case DPPPerm::wave_shl:
3003 DppCtrl = DppCtrl::WAVE_SHL1;
3005 case DPPPerm::wave_shr:
3006 DppCtrl = DppCtrl::WAVE_SHR1;
3008 case DPPPerm::wave_rol:
3009 DppCtrl = DppCtrl::WAVE_ROL1;
3011 case DPPPerm::wave_ror:
3012 DppCtrl = DppCtrl::WAVE_ROR1;
3014 case DPPPerm::row_mirror:
3015 DppCtrl = DppCtrl::ROW_MIRROR;
3017 case DPPPerm::row_half_mirror:
3018 DppCtrl = DppCtrl::ROW_HALF_MIRROR;
3020 case DPPPerm::row_bcast_15:
3021 DppCtrl = DppCtrl::BCAST15;
3023 case DPPPerm::row_bcast_31:
3024 DppCtrl = DppCtrl::BCAST31;
3030 auto rowMask = DppOp->getAttrOfType<IntegerAttr>(
"row_mask").getInt();
3031 auto bankMask = DppOp->getAttrOfType<IntegerAttr>(
"bank_mask").getInt();
3032 bool boundCtrl = DppOp->getAttrOfType<BoolAttr>(
"bound_ctrl").getValue();
3036 ROCDL::DPPUpdateOp::create(rewriter, loc, llvmType, old, src, DppCtrl,
3037 rowMask, bankMask, boundCtrl);
3039 Value
result = dppMovOp.getRes();
3041 result = LLVM::TruncOp::create(rewriter, loc, llvmSrcIntType,
result);
3042 if (!llvm::isa<IntegerType>(srcType)) {
3043 result = LLVM::BitcastOp::create(rewriter, loc, srcType,
result);
3054struct AMDGPUSwizzleBitModeLowering
3055 :
public ConvertOpToLLVMPattern<SwizzleBitModeOp> {
3059 matchAndRewrite(SwizzleBitModeOp op, OpAdaptor adaptor,
3060 ConversionPatternRewriter &rewriter)
const override {
3061 Location loc = op.getLoc();
3062 Type i32 = rewriter.getI32Type();
3063 Value src = adaptor.getSrc();
3064 SmallVector<Value> decomposed;
3065 if (
failed(LLVM::decomposeValue(rewriter, loc, src, i32, decomposed)))
3066 return rewriter.notifyMatchFailure(op,
3067 "failed to decompose value to i32");
3068 unsigned andMask = op.getAndMask();
3069 unsigned orMask = op.getOrMask();
3070 unsigned xorMask = op.getXorMask();
3074 unsigned mask = andMask | (orMask << 5) | (xorMask << 10);
3076 SmallVector<Value> swizzled;
3077 for (Value v : decomposed) {
3079 ROCDL::DsSwizzleOp::create(rewriter, loc, v.getType(), v, maskValue);
3080 swizzled.emplace_back(res);
3083 Value
result = LLVM::composeValue(rewriter, loc, swizzled, src.
getType());
3084 rewriter.replaceOp(op,
result);
3089struct AMDGPUPermlaneLowering :
public ConvertOpToLLVMPattern<PermlaneSwapOp> {
3092 AMDGPUPermlaneLowering(
const LLVMTypeConverter &converter, Chipset chipset)
3093 : ConvertOpToLLVMPattern<PermlaneSwapOp>(converter), chipset(chipset) {}
3097 matchAndRewrite(PermlaneSwapOp op, OpAdaptor adaptor,
3098 ConversionPatternRewriter &rewriter)
const override {
3100 return op->emitOpError(
"permlane_swap is only supported on gfx950+");
3102 Location loc = op.getLoc();
3103 Type i32 = rewriter.getI32Type();
3104 Value src = adaptor.getSrc();
3105 unsigned rowLength = op.getRowLength();
3106 bool fi = op.getFetchInactive();
3107 bool boundctrl = op.getBoundCtrl();
3109 SmallVector<Value> decomposed;
3110 if (
failed(LLVM::decomposeValue(rewriter, loc, src, i32, decomposed)))
3111 return rewriter.notifyMatchFailure(op,
3112 "failed to decompose value to i32");
3114 SmallVector<Value> permuted;
3115 for (Value v : decomposed) {
3117 Type i32pair = LLVM::LLVMStructType::getLiteral(
3118 rewriter.getContext(), {v.getType(), v.getType()});
3120 if (rowLength == 16)
3121 res = ROCDL::Permlane16SwapOp::create(rewriter, loc, i32pair, v, v, fi,
3123 else if (rowLength == 32)
3124 res = ROCDL::Permlane32SwapOp::create(rewriter, loc, i32pair, v, v, fi,
3127 llvm_unreachable(
"unsupported row length");
3129 Value vdst0 = LLVM::ExtractValueOp::create(rewriter, loc, res, {0});
3130 Value vdst1 = LLVM::ExtractValueOp::create(rewriter, loc, res, {1});
3132 Value isEqual = LLVM::ICmpOp::create(rewriter, loc,
3133 LLVM::ICmpPredicate::eq, vdst0, v);
3138 LLVM::SelectOp::create(rewriter, loc, isEqual, vdst1, vdst0);
3139 permuted.emplace_back(vdstNew);
3142 Value
result = LLVM::composeValue(rewriter, loc, permuted, src.
getType());
3143 rewriter.replaceOp(op,
result);
3156constexpr int32_t kDsBarrierPendingCountBitWidth = 29;
3157constexpr int32_t kDsBarrierPhasePos = kDsBarrierPendingCountBitWidth;
3158constexpr int32_t kDsBarrierInitCountPos = 32;
3159constexpr int32_t kDsBarrierPendingCountMask =
3160 (1 << kDsBarrierPendingCountBitWidth) - 1;
3162struct DsBarrierInitOpLowering
3163 :
public ConvertOpToLLVMPattern<DsBarrierInitOp> {
3166 DsBarrierInitOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
3167 : ConvertOpToLLVMPattern<DsBarrierInitOp>(converter), chipset(chipset) {}
3170 matchAndRewrite(DsBarrierInitOp op, OpAdaptor adaptor,
3171 ConversionPatternRewriter &rewriter)
const override {
3173 return op->emitOpError(
"only supported on gfx1250+");
3175 Location loc = op.getLoc();
3176 Type i64 = rewriter.getI64Type();
3178 MemRefType memrefType = cast<MemRefType>(op.getBase().getType());
3180 adaptor.getBase(), adaptor.getIndices());
3187 LLVM::SubOp::create(rewriter, loc, adaptor.getParticipants(),
3194 Value maskedCount32 =
3195 LLVM::AndOp::create(rewriter, loc, initCount, countMask);
3196 Value maskedCount = LLVM::ZExtOp::create(rewriter, loc, i64, maskedCount32);
3198 Value initCountShifted = LLVM::ShlOp::create(
3199 rewriter, loc, maskedCount,
3201 Value barrierState =
3202 LLVM::OrOp::create(rewriter, loc, initCountShifted, maskedCount);
3204 LLVM::StoreOp::create(
3205 rewriter, loc, barrierState, ptr, 8,
false,
3207 false, LLVM::AtomicOrdering::release,
3210 rewriter.eraseOp(op);
3215struct DsBarrierPollStateOpLowering
3216 :
public ConvertOpToLLVMPattern<DsBarrierPollStateOp> {
3219 DsBarrierPollStateOpLowering(
const LLVMTypeConverter &converter,
3221 : ConvertOpToLLVMPattern<DsBarrierPollStateOp>(converter),
3225 matchAndRewrite(DsBarrierPollStateOp op, OpAdaptor adaptor,
3226 ConversionPatternRewriter &rewriter)
const override {
3228 return op->emitOpError(
"only supported on gfx1250+");
3230 Location loc = op.getLoc();
3231 Type i64 = rewriter.getI64Type();
3233 MemRefType memrefType = cast<MemRefType>(op.getBase().getType());
3235 adaptor.getBase(), adaptor.getIndices());
3239 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
3240 op, i64, ptr, 8,
false,
3242 false, LLVM::AtomicOrdering::acquire,
3248struct DsAsyncBarrierArriveOpLowering
3249 :
public ConvertOpToLLVMPattern<DsAsyncBarrierArriveOp> {
3252 DsAsyncBarrierArriveOpLowering(
const LLVMTypeConverter &converter,
3254 : ConvertOpToLLVMPattern<DsAsyncBarrierArriveOp>(converter),
3258 matchAndRewrite(DsAsyncBarrierArriveOp op, OpAdaptor adaptor,
3259 ConversionPatternRewriter &rewriter)
const override {
3261 return op->emitOpError(
"only supported on gfx1250+");
3263 Location loc = op.getLoc();
3265 MemRefType memrefType = cast<MemRefType>(op.getBase().getType());
3267 adaptor.getBase(), adaptor.getIndices());
3269 rewriter.replaceOpWithNewOp<ROCDL::DsAtomicAsyncBarrierArriveOp>(
3270 op, ptr,
nullptr,
nullptr,
3276struct DsBarrierArriveOpLowering
3277 :
public ConvertOpToLLVMPattern<DsBarrierArriveOp> {
3280 DsBarrierArriveOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
3281 : ConvertOpToLLVMPattern<DsBarrierArriveOp>(converter), chipset(chipset) {
3285 matchAndRewrite(DsBarrierArriveOp op, OpAdaptor adaptor,
3286 ConversionPatternRewriter &rewriter)
const override {
3288 return op->emitOpError(
"only supported on gfx1250+");
3290 Location loc = op.getLoc();
3291 Type i64 = rewriter.getI64Type();
3293 MemRefType memrefType = cast<MemRefType>(op.getBase().getType());
3295 adaptor.getBase(), adaptor.getIndices());
3297 rewriter.replaceOpWithNewOp<ROCDL::DsAtomicBarrierArriveRtnOp>(
3298 op, i64, ptr, adaptor.getCount(),
nullptr,
3304struct DsBarrierStatePhaseOpLowering
3305 :
public ConvertOpToLLVMPattern<DsBarrierStatePhaseOp> {
3309 matchAndRewrite(DsBarrierStatePhaseOp op, OpAdaptor adaptor,
3310 ConversionPatternRewriter &rewriter)
const override {
3311 Location loc = op.getLoc();
3312 Type i32 = rewriter.getI32Type();
3314 Value state = adaptor.getState();
3316 Value noInitCount = LLVM::TruncOp::create(rewriter, loc, i32, state);
3317 Value phase = LLVM::LShrOp::create(
3318 rewriter, loc, noInitCount,
3321 rewriter.replaceOp(op, phase);
3326struct DsBarrierStatePendingCountOpLowering
3327 :
public ConvertOpToLLVMPattern<DsBarrierStatePendingCountOp> {
3331 matchAndRewrite(DsBarrierStatePendingCountOp op, OpAdaptor adaptor,
3332 ConversionPatternRewriter &rewriter)
const override {
3333 Location loc = op.getLoc();
3334 Type i32 = rewriter.getI32Type();
3336 Value state = adaptor.getState();
3338 Value noInitCount = LLVM::TruncOp::create(rewriter, loc, i32, state);
3339 Value pendingCount = LLVM::AndOp::create(
3340 rewriter, loc, noInitCount,
3342 static_cast<uint32_t
>(kDsBarrierPendingCountMask)));
3344 rewriter.replaceOp(op, pendingCount);
3349struct DsBarrierStateInitCountOpLowering
3350 :
public ConvertOpToLLVMPattern<DsBarrierStateInitCountOp> {
3354 matchAndRewrite(DsBarrierStateInitCountOp op, OpAdaptor adaptor,
3355 ConversionPatternRewriter &rewriter)
const override {
3356 Location loc = op.getLoc();
3357 Type i32 = rewriter.getI32Type();
3359 Value state = adaptor.getState();
3361 Value initCountI64 = LLVM::LShrOp::create(
3362 rewriter, loc, state,
3364 Value initCount = LLVM::TruncOp::create(rewriter, loc, i32, initCountI64);
3366 rewriter.replaceOp(op, initCount);
3371struct DsBarrierStatePhaseParityLowering
3372 :
public ConvertOpToLLVMPattern<DsBarrierStatePhaseParity> {
3376 matchAndRewrite(DsBarrierStatePhaseParity op, OpAdaptor adaptor,
3377 ConversionPatternRewriter &rewriter)
const override {
3378 Location loc = op.getLoc();
3379 Type i1 = rewriter.getI1Type();
3381 Value state = adaptor.getState();
3384 LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(), state);
3385 Value phase = LLVM::LShrOp::create(
3386 rewriter, loc, noInitCount,
3388 Value parity = LLVM::TruncOp::create(rewriter, loc, i1, phase);
3390 rewriter.replaceOp(op, parity);
3399static Value setValueAtOffset(ConversionPatternRewriter &rewriter, Location loc,
3400 Value accumulator, Value value, int64_t shift) {
3405 value = LLVM::ShlOp::create(rewriter, loc, value, shiftAmount);
3411 constexpr bool isDisjoint =
true;
3412 return LLVM::OrOp::create(rewriter, loc, accumulator, value, isDisjoint);
3415template <
typename BaseOp>
3416struct AMDGPUMakeDmaBaseLowering :
public ConvertOpToLLVMPattern<BaseOp> {
3417 using ConvertOpToLLVMPattern<BaseOp>::ConvertOpToLLVMPattern;
3420 AMDGPUMakeDmaBaseLowering(
const LLVMTypeConverter &converter, Chipset chipset)
3421 : ConvertOpToLLVMPattern<BaseOp>(converter), chipset(chipset) {}
3425 matchAndRewrite(BaseOp op, Adaptor adaptor,
3426 ConversionPatternRewriter &rewriter)
const override {
3428 return op->emitOpError(
"make_dma_base is only supported on gfx1250");
3430 Location loc = op.getLoc();
3432 constexpr int32_t constlen = 4;
3433 Value consts[constlen];
3434 for (int64_t i = 0; i < constlen; ++i)
3437 constexpr int32_t sgprslen = constlen;
3438 Value sgprs[sgprslen];
3439 for (int64_t i = 0; i < sgprslen; ++i) {
3440 sgprs[i] = consts[0];
3443 sgprs[0] = consts[1];
3445 if constexpr (BaseOp::isGather()) {
3446 sgprs[0] = setValueAtOffset(rewriter, loc, sgprs[0], consts[1], 30);
3448 auto type = cast<TDMGatherBaseType>(op.getResult().getType());
3449 Type indexType = type.getIndexType();
3451 assert(llvm::is_contained({16u, 32u}, indexSize) &&
3452 "expected index_size to be 16 or 32");
3453 unsigned idx = (indexSize / 16) - 1;
3456 sgprs[0] = setValueAtOffset(rewriter, loc, sgprs[0], consts[1], 31);
3459 ValueRange ldsIndices = adaptor.getLdsIndices();
3460 Value lds = adaptor.getLds();
3461 auto ldsMemRefType = cast<MemRefType>(op.getLds().getType());
3464 rewriter, loc, ldsMemRefType, lds, ldsIndices);
3466 ValueRange globalIndices = adaptor.getGlobalIndices();
3467 Value global = adaptor.getGlobal();
3468 auto globalMemRefType = cast<MemRefType>(op.getGlobal().getType());
3471 rewriter, loc, globalMemRefType, global, globalIndices);
3473 Type i32 = rewriter.getI32Type();
3474 Type i64 = rewriter.getI64Type();
3476 sgprs[1] = LLVM::PtrToIntOp::create(rewriter, loc, i32, ldsPtr);
3477 Value castForGlobalAddr =
3478 LLVM::PtrToIntOp::create(rewriter, loc, i64, globalPtr);
3480 sgprs[2] = LLVM::TruncOp::create(rewriter, loc, i32, castForGlobalAddr);
3482 Value shift = LLVM::LShrOp::create(rewriter, loc, castForGlobalAddr,
3485 Value highHalf = LLVM::TruncOp::create(rewriter, loc, i32, shift);
3488 highHalf = LLVM::AndOp::create(rewriter, loc, highHalf, mask);
3490 sgprs[3] = setValueAtOffset(rewriter, loc, highHalf, consts[2], 30);
3492 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
3493 assert(v4i32 &&
"expected type conversion to succeed");
3494 Value
result = LLVM::PoisonOp::create(rewriter, loc, v4i32);
3496 for (
auto [sgpr, constant] : llvm::zip_equal(sgprs, consts))
3498 LLVM::InsertElementOp::create(rewriter, loc,
result, sgpr, constant);
3500 rewriter.replaceOp(op,
result);
3505template <
typename DescriptorOp>
3506struct AMDGPULowerDescriptor :
public ConvertOpToLLVMPattern<DescriptorOp> {
3507 using ConvertOpToLLVMPattern<DescriptorOp>::ConvertOpToLLVMPattern;
3510 AMDGPULowerDescriptor(
const LLVMTypeConverter &converter, Chipset chipset)
3511 : ConvertOpToLLVMPattern<DescriptorOp>(converter), chipset(chipset) {}
3514 Value getDGroup0(OpAdaptor adaptor)
const {
return adaptor.getBase(); }
3516 Value setWorkgroupMask(DescriptorOp op, OpAdaptor adaptor,
3517 ConversionPatternRewriter &rewriter, Location loc,
3518 Value sgpr0)
const {
3519 Value mask = op.getWorkgroupMask();
3523 Type i16 = rewriter.getI16Type();
3524 mask = LLVM::BitcastOp::create(rewriter, loc, i16, mask);
3525 Type i32 = rewriter.getI32Type();
3526 Value extendedMask = LLVM::ZExtOp::create(rewriter, loc, i32, mask);
3527 return setValueAtOffset(rewriter, loc, sgpr0, extendedMask, 0);
3530 Value setDataSize(DescriptorOp op, OpAdaptor adaptor,
3531 ConversionPatternRewriter &rewriter, Location loc,
3532 Value sgpr0, ArrayRef<Value> consts)
const {
3533 unsigned elementTypeWidthInBits = op.getElementTypeWidth();
3534 assert(llvm::is_contained({8u, 16u, 32u, 64u}, elementTypeWidthInBits) &&
3535 "expected type width to be 8, 16, 32, or 64.");
3536 int64_t idx = llvm::Log2_32(elementTypeWidthInBits / 8);
3537 Value size = consts[idx];
3538 return setValueAtOffset(rewriter, loc, sgpr0, size, 16);
3541 Value setAtomicBarrier(DescriptorOp op, OpAdaptor adaptor,
3542 ConversionPatternRewriter &rewriter, Location loc,
3543 Value sgpr0, ArrayRef<Value> consts)
const {
3544 if (!adaptor.getAtomicBarrierAddress())
3547 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 18);
3550 Value setIterateEnable(DescriptorOp op, OpAdaptor adaptor,
3551 ConversionPatternRewriter &rewriter, Location loc,
3552 Value sgpr0, ArrayRef<Value> consts)
const {
3553 if (!adaptor.getGlobalIncrement())
3558 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 19);
3561 Value setPadEnable(DescriptorOp op, OpAdaptor adaptor,
3562 ConversionPatternRewriter &rewriter, Location loc,
3563 Value sgpr0, ArrayRef<Value> consts)
const {
3564 if (!op.getPadAmount())
3567 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 20);
3570 Value setEarlyTimeout(DescriptorOp op, OpAdaptor adaptor,
3571 ConversionPatternRewriter &rewriter, Location loc,
3572 Value sgpr0, ArrayRef<Value> consts)
const {
3573 if (!op.getWorkgroupMask())
3576 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 21);
3579 Value setPadInterval(DescriptorOp op, OpAdaptor adaptor,
3580 ConversionPatternRewriter &rewriter, Location loc,
3581 Value sgpr0, ArrayRef<Value> consts)
const {
3582 if (!op.getPadAmount())
3591 IntegerType i32 = rewriter.getI32Type();
3592 Value padInterval = adaptor.getPadInterval();
3593 padInterval = LLVM::CountTrailingZerosOp::create(rewriter, loc, i32,
3594 padInterval,
false);
3595 padInterval = LLVM::SubOp::create(rewriter, loc, padInterval, consts[1]);
3597 return setValueAtOffset(rewriter, loc, sgpr0, padInterval, 22);
3600 Value setPadAmount(DescriptorOp op, OpAdaptor adaptor,
3601 ConversionPatternRewriter &rewriter, Location loc,
3602 Value sgpr0, ArrayRef<Value> consts)
const {
3603 if (!op.getPadAmount())
3612 Value padAmount = adaptor.getPadAmount();
3613 padAmount = LLVM::SubOp::create(rewriter, loc, padAmount, consts[1]);
3615 return setValueAtOffset(rewriter, loc, sgpr0, padAmount, 25);
3618 Value setAtomicBarrierAddress(DescriptorOp op, OpAdaptor adaptor,
3619 ConversionPatternRewriter &rewriter,
3620 Location loc, Value sgpr1,
3621 ArrayRef<Value> consts)
const {
3622 if (!adaptor.getAtomicBarrierAddress())
3625 Value atomicBarrierAddress = adaptor.getAtomicBarrierAddress();
3626 auto barrierAddressTy =
3627 cast<MemRefType>(op.getAtomicBarrierAddress().getType());
3628 ValueRange atomicBarrierIndices = adaptor.getAtomicBarrierIndices();
3630 rewriter, loc, barrierAddressTy, atomicBarrierAddress,
3631 atomicBarrierIndices);
3632 IntegerType i32 = rewriter.getI32Type();
3638 atomicBarrierAddress =
3639 LLVM::PtrToIntOp::create(rewriter, loc, i32, atomicBarrierAddress);
3640 atomicBarrierAddress =
3641 LLVM::LShrOp::create(rewriter, loc, atomicBarrierAddress, consts[3]);
3643 atomicBarrierAddress =
3644 LLVM::AndOp::create(rewriter, loc, atomicBarrierAddress, mask);
3645 return setValueAtOffset(rewriter, loc, sgpr1, atomicBarrierAddress, 32);
3648 std::pair<Value, Value> setTensorDimX(DescriptorOp op, OpAdaptor adaptor,
3649 ConversionPatternRewriter &rewriter,
3650 Location loc, Value sgpr1, Value sgpr2,
3651 ArrayRef<Value> consts, uint64_t dimX,
3652 uint32_t offset)
const {
3653 ArrayRef<int64_t> globalStaticSizes = adaptor.getGlobalStaticSizes();
3654 ValueRange globalDynamicSizes = adaptor.getGlobalDynamicSizes();
3655 SmallVector<OpFoldResult> mixedGlobalSizes =
3657 if (mixedGlobalSizes.size() <= dimX)
3658 return {sgpr1, sgpr2};
3660 OpFoldResult tensorDimXOpFoldResult = *(mixedGlobalSizes.rbegin() + dimX);
3667 if (
auto attr = dyn_cast<Attribute>(tensorDimXOpFoldResult)) {
3671 IntegerType i32 = rewriter.getI32Type();
3672 tensorDimX = cast<Value>(tensorDimXOpFoldResult);
3673 tensorDimX = LLVM::TruncOp::create(rewriter, loc, i32, tensorDimX);
3676 sgpr1 = setValueAtOffset(rewriter, loc, sgpr1, tensorDimX, offset);
3679 Value tensorDimXHigh = LLVM::LShrOp::create(rewriter, loc, tensorDimX, c16);
3680 sgpr2 = setValueAtOffset(rewriter, loc, sgpr2, tensorDimXHigh, offset + 16);
3681 return {sgpr1, sgpr2};
3684 std::pair<Value, Value> setTensorDim0(DescriptorOp op, OpAdaptor adaptor,
3685 ConversionPatternRewriter &rewriter,
3686 Location loc, Value sgpr1, Value sgpr2,
3687 ArrayRef<Value> consts)
const {
3688 return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, sgpr2, consts, 0,
3692 std::pair<Value, Value> setTensorDim1(DescriptorOp op, OpAdaptor adaptor,
3693 ConversionPatternRewriter &rewriter,
3694 Location loc, Value sgpr2, Value sgpr3,
3695 ArrayRef<Value> consts)
const {
3696 return setTensorDimX(op, adaptor, rewriter, loc, sgpr2, sgpr3, consts, 1,
3700 Value setTileDimX(DescriptorOp op, OpAdaptor adaptor,
3701 ConversionPatternRewriter &rewriter, Location loc,
3702 Value sgpr, ArrayRef<Value> consts,
size_t dimX,
3703 int64_t offset)
const {
3704 ArrayRef<int64_t> sharedStaticSizes = adaptor.getSharedStaticSizes();
3705 ValueRange sharedDynamicSizes = adaptor.getSharedDynamicSizes();
3706 SmallVector<OpFoldResult> mixedSharedSizes =
3708 if (mixedSharedSizes.size() <= dimX)
3711 OpFoldResult tileDimXOpFoldResult = *(mixedSharedSizes.rbegin() + dimX);
3720 if (
auto attr = dyn_cast<Attribute>(tileDimXOpFoldResult)) {
3724 IntegerType i32 = rewriter.getI32Type();
3725 tileDimX = cast<Value>(tileDimXOpFoldResult);
3726 tileDimX = LLVM::TruncOp::create(rewriter, loc, i32, tileDimX);
3729 return setValueAtOffset(rewriter, loc, sgpr, tileDimX, offset);
3732 Value setTileDim0(DescriptorOp op, OpAdaptor adaptor,
3733 ConversionPatternRewriter &rewriter, Location loc,
3734 Value sgpr3, ArrayRef<Value> consts)
const {
3735 return setTileDimX(op, adaptor, rewriter, loc, sgpr3, consts, 0, 112);
3738 Value setTileDim1(DescriptorOp op, OpAdaptor adaptor,
3739 ConversionPatternRewriter &rewriter, Location loc,
3740 Value sgpr4, ArrayRef<Value> consts)
const {
3741 return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 1, 128);
3744 Value setValidIndices(DescriptorOp op, OpAdaptor adaptor,
3745 ConversionPatternRewriter &rewriter, Location loc,
3746 Value sgpr4, ArrayRef<Value> consts)
const {
3747 auto type = cast<VectorType>(op.getIndices().getType());
3748 ArrayRef<int64_t> shape = type.getShape();
3749 assert(shape.size() == 1 &&
"expected shape to be of rank 1.");
3750 unsigned length = shape.back();
3751 assert(0 < length && length <= 16 &&
"expected length to be at most 16.");
3753 return setValueAtOffset(rewriter, loc, sgpr4, value, 128);
3756 Value setTileDim1OrValidIndices(DescriptorOp op, OpAdaptor adaptor,
3757 ConversionPatternRewriter &rewriter,
3758 Location loc, Value sgpr4,
3759 ArrayRef<Value> consts)
const {
3760 if constexpr (DescriptorOp::isGather())
3761 return setValidIndices(op, adaptor, rewriter, loc, sgpr4, consts);
3762 return setTileDim1(op, adaptor, rewriter, loc, sgpr4, consts);
3765 Value setTileDim2(DescriptorOp op, OpAdaptor adaptor,
3766 ConversionPatternRewriter &rewriter, Location loc,
3767 Value sgpr4, ArrayRef<Value> consts)
const {
3769 if constexpr (DescriptorOp::isGather())
3771 return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 2, 144);
3774 std::pair<Value, Value>
3775 setTensorDimXStride(DescriptorOp op, OpAdaptor adaptor,
3776 ConversionPatternRewriter &rewriter, Location loc,
3777 Value sgprY, Value sgprZ, ArrayRef<Value> consts,
3778 size_t dimX, int64_t offset)
const {
3779 ArrayRef<int64_t> globalStaticStrides = adaptor.getGlobalStaticStrides();
3780 ValueRange globalDynamicStrides = adaptor.getGlobalDynamicStrides();
3781 SmallVector<OpFoldResult> mixedGlobalStrides =
3782 getMixedValues(globalStaticStrides, globalDynamicStrides, rewriter);
3784 if (mixedGlobalStrides.size() <= (dimX + 1))
3785 return {sgprY, sgprZ};
3787 OpFoldResult tensorDimXStrideOpFoldResult =
3788 *(mixedGlobalStrides.rbegin() + dimX + 1);
3793 Value tensorDimXStride;
3794 if (
auto attr = dyn_cast<Attribute>(tensorDimXStrideOpFoldResult))
3798 tensorDimXStride = cast<Value>(tensorDimXStrideOpFoldResult);
3800 constexpr int64_t first48bits = (1ll << 48) - 1;
3803 LLVM::AndOp::create(rewriter, loc, mask, tensorDimXStride);
3804 IntegerType i32 = rewriter.getI32Type();
3805 Value tensorDimXStrideLow =
3806 LLVM::TruncOp::create(rewriter, loc, i32, tensorDimXStride);
3807 sgprY = setValueAtOffset(rewriter, loc, sgprY, tensorDimXStrideLow, offset);
3809 int64_t shift = (offset % 32) == 0 ? 32 : offset % 32;
3811 Value tensorDimXStrideHigh =
3812 LLVM::LShrOp::create(rewriter, loc, tensorDimXStride, shiftVal);
3813 tensorDimXStrideHigh =
3814 LLVM::TruncOp::create(rewriter, loc, i32, tensorDimXStrideHigh);
3815 sgprZ = setValueAtOffset(rewriter, loc, sgprZ, tensorDimXStrideHigh,
3817 return {sgprY, sgprZ};
3820 std::pair<Value, Value>
3821 setTensorDim0Stride(DescriptorOp op, OpAdaptor adaptor,
3822 ConversionPatternRewriter &rewriter, Location loc,
3823 Value sgpr5, Value sgpr6, ArrayRef<Value> consts)
const {
3824 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts,
3828 std::pair<Value, Value>
3829 setTensorDim1Stride(DescriptorOp op, OpAdaptor adaptor,
3830 ConversionPatternRewriter &rewriter, Location loc,
3831 Value sgpr5, Value sgpr6, ArrayRef<Value> consts)
const {
3833 if constexpr (DescriptorOp::isGather())
3834 return {sgpr5, sgpr6};
3835 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts,
3839 Value getDGroup1(DescriptorOp op, OpAdaptor adaptor,
3840 ConversionPatternRewriter &rewriter, Location loc,
3841 ArrayRef<Value> consts)
const {
3843 for (int64_t i = 0; i < 8; ++i) {
3844 sgprs[i] = consts[0];
3847 sgprs[0] = setWorkgroupMask(op, adaptor, rewriter, loc, sgprs[0]);
3848 sgprs[0] = setDataSize(op, adaptor, rewriter, loc, sgprs[0], consts);
3849 sgprs[0] = setAtomicBarrier(op, adaptor, rewriter, loc, sgprs[0], consts);
3850 sgprs[0] = setIterateEnable(op, adaptor, rewriter, loc, sgprs[0], consts);
3851 sgprs[0] = setPadEnable(op, adaptor, rewriter, loc, sgprs[0], consts);
3852 sgprs[0] = setEarlyTimeout(op, adaptor, rewriter, loc, sgprs[0], consts);
3853 sgprs[0] = setPadInterval(op, adaptor, rewriter, loc, sgprs[0], consts);
3854 sgprs[0] = setPadAmount(op, adaptor, rewriter, loc, sgprs[0], consts);
3857 setAtomicBarrierAddress(op, adaptor, rewriter, loc, sgprs[1], consts);
3858 std::tie(sgprs[1], sgprs[2]) =
3859 setTensorDim0(op, adaptor, rewriter, loc, sgprs[1], sgprs[2], consts);
3860 std::tie(sgprs[2], sgprs[3]) =
3861 setTensorDim1(op, adaptor, rewriter, loc, sgprs[2], sgprs[3], consts);
3863 sgprs[3] = setTileDim0(op, adaptor, rewriter, loc, sgprs[3], consts);
3865 setTileDim1OrValidIndices(op, adaptor, rewriter, loc, sgprs[4], consts);
3866 sgprs[4] = setTileDim2(op, adaptor, rewriter, loc, sgprs[4], consts);
3867 std::tie(sgprs[5], sgprs[6]) = setTensorDim0Stride(
3868 op, adaptor, rewriter, loc, sgprs[5], sgprs[6], consts);
3869 std::tie(sgprs[6], sgprs[7]) = setTensorDim1Stride(
3870 op, adaptor, rewriter, loc, sgprs[6], sgprs[7], consts);
3872 IntegerType i32 = rewriter.getI32Type();
3873 Type v8i32 = this->typeConverter->convertType(VectorType::get(8, i32));
3874 assert(v8i32 &&
"expected type conversion to succeed");
3875 Value dgroup1 = LLVM::PoisonOp::create(rewriter, loc, v8i32);
3877 for (
auto [sgpr, constant] : llvm::zip_equal(sgprs, consts)) {
3879 LLVM::InsertElementOp::create(rewriter, loc, dgroup1, sgpr, constant);
3885 Value setTensorDimX(DescriptorOp op, OpAdaptor adaptor,
3886 ConversionPatternRewriter &rewriter, Location loc,
3887 Value sgpr0, ArrayRef<Value> consts, int64_t dimX,
3888 int64_t offset)
const {
3889 ArrayRef<int64_t> globalStaticSizes = adaptor.getGlobalStaticSizes();
3890 ValueRange globalDynamicSizes = adaptor.getGlobalDynamicSizes();
3891 SmallVector<OpFoldResult> mixedGlobalSizes =
3893 if (mixedGlobalSizes.size() <=
static_cast<unsigned long>(dimX))
3896 OpFoldResult tensorDimXOpFoldResult = *(mixedGlobalSizes.rbegin() + dimX);
3898 if (
auto attr = dyn_cast<Attribute>(tensorDimXOpFoldResult)) {
3902 IntegerType i32 = rewriter.getI32Type();
3903 tensorDimX = cast<Value>(tensorDimXOpFoldResult);
3904 tensorDimX = LLVM::TruncOp::create(rewriter, loc, i32, tensorDimX);
3907 return setValueAtOffset(rewriter, loc, sgpr0, tensorDimX, offset);
3910 Value setTensorDim2(DescriptorOp op, OpAdaptor adaptor,
3911 ConversionPatternRewriter &rewriter, Location loc,
3912 Value sgpr0, ArrayRef<Value> consts)
const {
3913 return setTensorDimX(op, adaptor, rewriter, loc, sgpr0, consts, 2, 0);
3916 Value truncateAndSetValueAtOffset(ConversionPatternRewriter &rewriter,
3917 Location loc, Value accumulator,
3918 Value value, int64_t shift)
const {
3920 IntegerType i32 = rewriter.getI32Type();
3921 value = LLVM::TruncOp::create(rewriter, loc, i32, value);
3922 return setValueAtOffset(rewriter, loc, accumulator, value, shift);
3925 Value setLDSAddrIncrement(DescriptorOp op, OpAdaptor adaptor,
3926 ConversionPatternRewriter &rewriter, Location loc,
3927 Value sgpr1, ArrayRef<Value> consts,
3928 int64_t offset)
const {
3929 Value ldsAddrIncrement = adaptor.getLdsIncrement();
3930 return setValueAtOffset(rewriter, loc, sgpr1, ldsAddrIncrement, offset);
3933 std::pair<Value, Value>
3934 setGlobalAddrIncrement(DescriptorOp op, OpAdaptor adaptor,
3935 ConversionPatternRewriter &rewriter, Location loc,
3936 Value sgpr2, Value sgpr3, ArrayRef<Value> consts,
3937 int64_t offset)
const {
3938 Value globalAddrIncrement = adaptor.getGlobalIncrement();
3939 sgpr2 = truncateAndSetValueAtOffset(rewriter, loc, sgpr2,
3940 globalAddrIncrement, offset);
3942 globalAddrIncrement =
3943 LLVM::LShrOp::create(rewriter, loc, globalAddrIncrement, shift);
3944 constexpr int64_t first16BitsHigh = (1ll << 16) - 1;
3945 sgpr3 = truncateAndSetValueAtOffset(rewriter, loc, sgpr3,
3946 globalAddrIncrement, offset + 32);
3948 sgpr3 = LLVM::AndOp::create(rewriter, loc, sgpr3, mask);
3949 return {sgpr2, sgpr3};
3952 Value setTensorDim3OrLDSAddrIncrement(DescriptorOp op, OpAdaptor adaptor,
3953 ConversionPatternRewriter &rewriter,
3954 Location loc, Value sgpr1,
3955 ArrayRef<Value> consts)
const {
3956 Value ldsIncrement = op.getLdsIncrement();
3957 constexpr int64_t dim = 3;
3958 constexpr int64_t offset = 32;
3960 return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, consts, dim,
3962 return setLDSAddrIncrement(op, adaptor, rewriter, loc, sgpr1, consts,
3966 std::pair<Value, Value> setTensorDim2StrideOrGlobalAddrIncrement(
3967 DescriptorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter,
3968 Location loc, Value sgpr2, Value sgpr3, ArrayRef<Value> consts)
const {
3969 Value globalIncrement = op.getGlobalIncrement();
3970 constexpr int32_t dim = 2;
3971 constexpr int32_t offset = 64;
3972 if (!globalIncrement)
3973 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr2, sgpr3,
3974 consts, dim, offset);
3975 return setGlobalAddrIncrement(op, adaptor, rewriter, loc, sgpr2, sgpr3,
3979 Value setIterateCount(DescriptorOp op, OpAdaptor adaptor,
3980 ConversionPatternRewriter &rewriter, Location loc,
3981 Value sgpr3, ArrayRef<Value> consts,
3982 int32_t offset)
const {
3983 Value iterationCount = adaptor.getIterationCount();
3984 IntegerType i32 = rewriter.getI32Type();
3991 iterationCount = LLVM::TruncOp::create(rewriter, loc, i32, iterationCount);
3993 LLVM::SubOp::create(rewriter, loc, iterationCount, consts[1]);
3994 return setValueAtOffset(rewriter, loc, sgpr3, iterationCount, offset);
3997 Value setTileDim3OrIterateCount(DescriptorOp op, OpAdaptor adaptor,
3998 ConversionPatternRewriter &rewriter,
3999 Location loc, Value sgpr3,
4000 ArrayRef<Value> consts)
const {
4001 Value iterateCount = op.getIterationCount();
4002 constexpr int32_t dim = 2;
4003 constexpr int32_t offset = 112;
4005 return setTileDimX(op, adaptor, rewriter, loc, sgpr3, consts, dim,
4008 return setIterateCount(op, adaptor, rewriter, loc, sgpr3, consts, offset);
4011 Value getDGroup2(DescriptorOp op, OpAdaptor adaptor,
4012 ConversionPatternRewriter &rewriter, Location loc,
4013 ArrayRef<Value> consts)
const {
4014 if constexpr (DescriptorOp::isGather())
4015 return getDGroup2Gather(op, adaptor, rewriter, loc, consts);
4016 return getDGroup2NonGather(op, adaptor, rewriter, loc, consts);
4019 Value getDGroup2NonGather(DescriptorOp op, OpAdaptor adaptor,
4020 ConversionPatternRewriter &rewriter, Location loc,
4021 ArrayRef<Value> consts)
const {
4022 IntegerType i32 = rewriter.getI32Type();
4023 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
4024 assert(v4i32 &&
"expected type conversion to succeed.");
4026 bool onlyNeedsTwoDescriptors = !op.getLdsIncrement() && op.getRank() <= 2;
4027 if (onlyNeedsTwoDescriptors)
4028 return LLVM::ZeroOp::create(rewriter, loc, v4i32);
4030 constexpr int64_t sgprlen = 4;
4031 Value sgprs[sgprlen];
4032 for (
int i = 0; i < sgprlen; ++i)
4033 sgprs[i] = consts[0];
4035 sgprs[0] = setTensorDim2(op, adaptor, rewriter, loc, sgprs[0], consts);
4036 sgprs[1] = setTensorDim3OrLDSAddrIncrement(op, adaptor, rewriter, loc,
4038 std::tie(sgprs[2], sgprs[3]) = setTensorDim2StrideOrGlobalAddrIncrement(
4039 op, adaptor, rewriter, loc, sgprs[2], sgprs[3], consts);
4041 setTileDim3OrIterateCount(op, adaptor, rewriter, loc, sgprs[3], consts);
4043 Value dgroup2 = LLVM::PoisonOp::create(rewriter, loc, v4i32);
4044 for (
auto [sgpr, constant] : llvm::zip(sgprs, consts))
4046 LLVM::InsertElementOp::create(rewriter, loc, dgroup2, sgpr, constant);
4051 Value getGatherIndices(DescriptorOp op, OpAdaptor adaptor,
4052 ConversionPatternRewriter &rewriter, Location loc,
4053 ArrayRef<Value> consts,
bool firstHalf)
const {
4054 IntegerType i32 = rewriter.getI32Type();
4055 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
4056 assert(v4i32 &&
"expected type conversion to succeed.");
4058 Value
indices = adaptor.getIndices();
4059 auto vectorType = cast<VectorType>(
indices.getType());
4060 unsigned length = vectorType.getShape().back();
4061 Type elementType = vectorType.getElementType();
4062 unsigned maxLength = elementType == i32 ? 4 : 8;
4063 int32_t offset = firstHalf ? 0 : maxLength;
4064 unsigned discountedLength =
4065 std::max(
static_cast<int32_t
>(length - offset), 0);
4067 unsigned targetSize = std::min(maxLength, discountedLength);
4069 SmallVector<Value> indicesVector;
4070 for (
unsigned i = offset; i < targetSize + offset; ++i) {
4072 if (i < consts.size())
4076 Value elem = LLVM::ExtractElementOp::create(rewriter, loc,
indices, idx);
4077 indicesVector.push_back(elem);
4080 SmallVector<Value> indicesI32Vector;
4081 if (elementType == i32) {
4082 indicesI32Vector = indicesVector;
4084 for (
unsigned i = 0; i < targetSize; ++i) {
4085 Value index = indicesVector[i];
4086 indicesI32Vector.push_back(
4087 LLVM::ZExtOp::create(rewriter, loc, i32, index));
4089 if ((targetSize % 2) != 0)
4091 indicesI32Vector.push_back(consts[0]);
4094 SmallVector<Value> indicesToInsert;
4095 if (elementType == i32) {
4096 indicesToInsert = indicesI32Vector;
4098 unsigned size = indicesI32Vector.size() / 2;
4099 for (
unsigned i = 0; i < size; ++i) {
4100 Value first = indicesI32Vector[2 * i];
4101 Value second = indicesI32Vector[2 * i + 1];
4102 Value joined = setValueAtOffset(rewriter, loc, first, second, 16);
4103 indicesToInsert.push_back(joined);
4107 Value dgroup = LLVM::PoisonOp::create(rewriter, loc, v4i32);
4108 for (
auto [sgpr, constant] : llvm::zip_first(indicesToInsert, consts))
4110 LLVM::InsertElementOp::create(rewriter, loc, dgroup, sgpr, constant);
4115 Value getDGroup2Gather(DescriptorOp op, OpAdaptor adaptor,
4116 ConversionPatternRewriter &rewriter, Location loc,
4117 ArrayRef<Value> consts)
const {
4118 return getGatherIndices(op, adaptor, rewriter, loc, consts,
true);
4121 std::pair<Value, Value>
4122 setTensorDim3Stride(DescriptorOp op, OpAdaptor adaptor,
4123 ConversionPatternRewriter &rewriter, Location loc,
4124 Value sgpr0, Value sgpr1, ArrayRef<Value> consts)
const {
4125 constexpr int32_t dim = 3;
4126 constexpr int32_t offset = 0;
4127 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr0, sgpr1, consts,
4131 std::pair<Value, Value> setTensorDim4(DescriptorOp op, OpAdaptor adaptor,
4132 ConversionPatternRewriter &rewriter,
4133 Location loc, Value sgpr1, Value sgpr2,
4134 ArrayRef<Value> consts)
const {
4135 constexpr int32_t dim = 4;
4136 constexpr int32_t offset = 48;
4137 return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, sgpr2, consts, dim,
4141 Value setTileDim4(DescriptorOp op, OpAdaptor adaptor,
4142 ConversionPatternRewriter &rewriter, Location loc,
4143 Value sgpr2, ArrayRef<Value> consts)
const {
4144 constexpr int32_t dim = 4;
4145 constexpr int32_t offset = 80;
4146 return setTileDimX(op, adaptor, rewriter, loc, sgpr2, consts, dim, offset);
4149 Value getDGroup3(DescriptorOp op, OpAdaptor adaptor,
4150 ConversionPatternRewriter &rewriter, Location loc,
4151 ArrayRef<Value> consts)
const {
4152 if constexpr (DescriptorOp::isGather())
4153 return getDGroup3Gather(op, adaptor, rewriter, loc, consts);
4154 return getDGroup3NonGather(op, adaptor, rewriter, loc, consts);
4157 Value getDGroup3NonGather(DescriptorOp op, OpAdaptor adaptor,
4158 ConversionPatternRewriter &rewriter, Location loc,
4159 ArrayRef<Value> consts)
const {
4160 IntegerType i32 = rewriter.getI32Type();
4161 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
4162 assert(v4i32 &&
"expected type conversion to succeed.");
4163 bool onlyNeedsTwoDescriptors = !op.getLdsIncrement() && op.getRank() <= 2;
4164 if (onlyNeedsTwoDescriptors)
4165 return LLVM::ZeroOp::create(rewriter, loc, v4i32);
4167 constexpr int32_t sgprlen = 4;
4168 Value sgprs[sgprlen];
4169 for (
int i = 0; i < sgprlen; ++i)
4170 sgprs[i] = consts[0];
4172 std::tie(sgprs[0], sgprs[1]) = setTensorDim3Stride(
4173 op, adaptor, rewriter, loc, sgprs[0], sgprs[1], consts);
4174 std::tie(sgprs[1], sgprs[2]) =
4175 setTensorDim4(op, adaptor, rewriter, loc, sgprs[1], sgprs[2], consts);
4176 sgprs[2] = setTileDim4(op, adaptor, rewriter, loc, sgprs[2], consts);
4178 Value dgroup3 = LLVM::PoisonOp::create(rewriter, loc, v4i32);
4179 for (
auto [sgpr, constant] : llvm::zip(sgprs, consts))
4181 LLVM::InsertElementOp::create(rewriter, loc, dgroup3, sgpr, constant);
4186 Value getDGroup3Gather(DescriptorOp op, OpAdaptor adaptor,
4187 ConversionPatternRewriter &rewriter, Location loc,
4188 ArrayRef<Value> consts)
const {
4189 return getGatherIndices(op, adaptor, rewriter, loc, consts,
false);
4193 matchAndRewrite(DescriptorOp op, OpAdaptor adaptor,
4194 ConversionPatternRewriter &rewriter)
const override {
4196 return op->emitOpError(
4197 "make_dma_descriptor is only supported on gfx1250");
4199 Location loc = op.getLoc();
4201 SmallVector<Value> consts;
4202 for (int64_t i = 0; i < 8; ++i)
4205 Value dgroup0 = this->getDGroup0(adaptor);
4206 Value dgroup1 = this->getDGroup1(op, adaptor, rewriter, loc, consts);
4207 Value dgroup2 = this->getDGroup2(op, adaptor, rewriter, loc, consts);
4208 Value dgroup3 = this->getDGroup3(op, adaptor, rewriter, loc, consts);
4209 SmallVector<Value> results = {dgroup0, dgroup1, dgroup2, dgroup3};
4210 rewriter.replaceOpWithMultiple(op, {results});
4215template <
typename SourceOp,
typename TargetOp>
4216struct AMDGPUTensorLoadStoreOpLowering
4217 :
public ConvertOpToLLVMPattern<SourceOp> {
4218 using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
4220 AMDGPUTensorLoadStoreOpLowering(
const LLVMTypeConverter &converter,
4222 : ConvertOpToLLVMPattern<SourceOp>(converter), chipset(chipset) {}
4226 matchAndRewrite(SourceOp op, Adaptor adaptor,
4227 ConversionPatternRewriter &rewriter)
const override {
4229 return op->emitOpError(
"is only supported on gfx1250");
4234 auto v8i32 = VectorType::get(8, rewriter.getI32Type());
4235 Value dgroup4 = LLVM::ZeroOp::create(rewriter, op.getLoc(), v8i32);
4236 rewriter.replaceOpWithNewOp<TargetOp>(op, desc[0], desc[1], desc[2],
4237 desc[3], dgroup4, 0,
4245struct GlobalPrefetchOpLowering
4246 :
public ConvertOpToLLVMPattern<GlobalPrefetchOp> {
4247 GlobalPrefetchOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
4248 : ConvertOpToLLVMPattern<GlobalPrefetchOp>(converter), chipset(chipset) {}
4251 matchAndRewrite(GlobalPrefetchOp op, GlobalPrefetchOpAdaptor adaptor,
4252 ConversionPatternRewriter &rewriter)
const override {
4254 return op->emitOpError(
"is only supported on gfx1250+");
4256 const bool isSpeculative = op.getSpeculative();
4258 op.getTemporalHint(), op.getCacheScope(), isSpeculative);
4259 IntegerAttr immArgAttr = rewriter.getI32IntegerAttr(immArgValue);
4262 Value memRef = adaptor.getSrc();
4263 MemRefDescriptor descriptor(memRef);
4264 MemRefType memRefType = op.getSrc().getType();
4265 Location loc = op->getLoc();
4266 auto inboundsFlags = isSpeculative ? LLVM::GEPNoWrapFlags::none
4267 : LLVM::GEPNoWrapFlags::inbounds |
4268 LLVM::GEPNoWrapFlags::nuw;
4270 rewriter, loc, memRefType, descriptor,
indices, inboundsFlags);
4272 rewriter.replaceOpWithNewOp<ROCDL::GlobalPrefetchOp>(
4273 op, prefetchPtr, immArgAttr, mlir::ArrayAttr{}, mlir::ArrayAttr{},
4282struct ConvertAMDGPUToROCDLPass
4283 :
public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
4286 void runOnOperation()
override {
4289 if (
failed(maybeChipset)) {
4290 emitError(UnknownLoc::get(ctx),
"Invalid chipset name: " + chipset);
4291 return signalPassFailure();
4294 RewritePatternSet patterns(ctx);
4295 LLVMTypeConverter converter(ctx);
4298 amdgpu::populateCommonGPUTypeAndAttributeConversions(converter);
4300 target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
4301 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
4302 target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
4303 if (
failed(applyPartialConversion(getOperation(),
target,
4304 std::move(patterns))))
4305 signalPassFailure();
4313 typeConverter, [](gpu::AddressSpace space) {
4315 case gpu::AddressSpace::Global:
4316 return ROCDL::ROCDLDialect::kGlobalMemoryAddressSpace;
4317 case gpu::AddressSpace::Workgroup:
4318 return ROCDL::ROCDLDialect::kSharedMemoryAddressSpace;
4319 case gpu::AddressSpace::Private:
4320 return ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace;
4321 case gpu::AddressSpace::Constant:
4322 return ROCDL::ROCDLDialect::kConstantMemoryAddressSpace;
4324 llvm_unreachable(
"unknown address space enum value");
4330 typeConverter.addTypeAttributeConversion(
4332 -> TypeConverter::AttributeConversionResult {
4334 Type i64 = IntegerType::get(ctx, 64);
4335 switch (as.getValue()) {
4336 case amdgpu::AddressSpace::FatRawBuffer:
4337 return IntegerAttr::get(i64, 7);
4338 case amdgpu::AddressSpace::BufferRsrc:
4339 return IntegerAttr::get(i64, 8);
4340 case amdgpu::AddressSpace::FatStructuredBuffer:
4341 return IntegerAttr::get(i64, 9);
4343 return TypeConverter::AttributeConversionResult::abort();
4345 typeConverter.addConversion([&](DsBarrierStateType type) ->
Type {
4346 return IntegerType::get(type.
getContext(), 64);
4348 typeConverter.addConversion([&](TDMBaseType type) ->
Type {
4350 return typeConverter.convertType(VectorType::get(4, i32));
4352 typeConverter.addConversion([&](TDMGatherBaseType type) ->
Type {
4354 return typeConverter.convertType(VectorType::get(4, i32));
4356 typeConverter.addConversion(
4357 [&](TDMDescriptorType type,
4360 Type v4i32 = typeConverter.convertType(VectorType::get(4, i32));
4361 Type v8i32 = typeConverter.convertType(VectorType::get(8, i32));
4362 llvm::append_values(
result, v4i32, v8i32, v4i32, v4i32);
4372 if (inputs.size() != 1)
4375 if (!isa<TDMDescriptorType>(inputs[0].
getType()))
4378 auto cast = UnrealizedConversionCastOp::create(builder, loc, types, inputs);
4379 return cast.getResults();
4382 typeConverter.addTargetMaterialization(addUnrealizedCast);
4390 .
add<FatRawBufferCastLowering,
4391 RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
4392 RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
4393 RawBufferOpLowering<RawBufferAtomicFaddOp,
4394 ROCDL::RawPtrBufferAtomicFaddOp>,
4395 RawBufferOpLowering<RawBufferAtomicFmaxOp,
4396 ROCDL::RawPtrBufferAtomicFmaxOp>,
4397 RawBufferOpLowering<RawBufferAtomicSmaxOp,
4398 ROCDL::RawPtrBufferAtomicSmaxOp>,
4399 RawBufferOpLowering<RawBufferAtomicUminOp,
4400 ROCDL::RawPtrBufferAtomicUminOp>,
4401 RawBufferOpLowering<RawBufferAtomicCmpswapOp,
4402 ROCDL::RawPtrBufferAtomicCmpSwap>,
4403 AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
4404 SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
4405 SparseMFMAOpLowering, WMMAOpLowering, ScaledWMMAOpLowering,
4406 SparseWMMAOpLowering, DotOpLowering, ExtPackedFp8OpLowering,
4407 ScaledExtPackedMatrixOpLowering, ScaledExtPackedOpLowering,
4408 PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
4409 PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
4410 GlobalLoadAsyncToLDSOpLowering, TransposeLoadOpLowering,
4411 AMDGPUPermlaneLowering, AMDGPUMakeDmaBaseLowering<MakeDmaBaseOp>,
4412 AMDGPUMakeDmaBaseLowering<MakeGatherDmaBaseOp>,
4413 AMDGPULowerDescriptor<MakeDmaDescriptorOp>,
4414 AMDGPULowerDescriptor<MakeGatherDmaDescriptorOp>,
4415 AMDGPUTensorLoadStoreOpLowering<TensorLoadToLDSOp,
4416 ROCDL::TensorLoadToLDSOp>,
4417 AMDGPUTensorLoadStoreOpLowering<TensorStoreFromLDSOp,
4418 ROCDL::TensorStoreFromLDSOp>,
4419 DsBarrierInitOpLowering, DsBarrierPollStateOpLowering,
4420 DsAsyncBarrierArriveOpLowering, DsBarrierArriveOpLowering,
4421 GlobalPrefetchOpLowering>(converter, chipset);
4422 patterns.
add<AMDGPUSwizzleBitModeLowering, DsBarrierStatePhaseOpLowering,
4423 DsBarrierStatePendingCountOpLowering,
4424 DsBarrierStateInitCountOpLowering,
4425 DsBarrierStatePhaseParityLowering>(converter);
static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type)
Return true if type is the E4M3FN variant of an 8-bit float that is supported by the _fp8 instruction...
constexpr Chipset kGfx942
static std::optional< StringRef > wmmaOpToIntrinsicRDNA(Type elemSourceType, Type elemBSourceType, Type elemDestType, uint32_t k, bool isRDNA3)
Returns the rocdl intrinsic corresponding to a WMMA operation wmma for RDNA3/4 architectures.
static bool hasDot10Insts(const Chipset &chipset)
static bool hasDot7Insts(const Chipset &chipset)
static std::optional< SparseWMMAOpInfo > sparseWMMAOpToIntrinsic(SparseWMMAOp swmmac, Chipset chipset)
static std::optional< std::tuple< StringRef, uint32_t, uint32_t > > mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m, uint32_t n, uint32_t k, uint32_t b, Chipset chipset)
If there is a scaled MFMA instruction for the input element types aType and bType,...
static std::optional< StringRef > mfmaOpToIntrinsic(MFMAOp mfma, Chipset chipset)
Return the rocdl intrinsic corresponding to a MFMA operation mfma if one exists.
constexpr Chipset kGfx908
static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput, Value mlirInput, SmallVectorImpl< Value > &operands, SmallVectorImpl< NamedAttribute > &attrs, StringRef attrName)
Push an input operand.
constexpr Chipset kGfx1250
static Value castScaleOperand(ConversionPatternRewriter &rewriter, Location loc, Value input)
Converts the scaled MFMA/WMMA operands, scalesA and scalesB, from MLIR AMDGPU dialect convention to R...
constexpr Chipset kGfx90a
static std::optional< StringRef > getScaledWmmaIntrinsicName(int64_t m, int64_t n, int64_t k, bool isScale16)
Determines the ROCDL intrinsic name for scaled WMMA based on dimensions and scale block size (16 or 3...
static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, Value output, int32_t subwordOffset, bool clamp, SmallVectorImpl< Value > &operands, SmallVectorImpl< NamedAttribute > &attrs)
Push the output operand.
static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type)
Return true if type is the E5M2 variant of an 8-bit float that is supported by the _bf8 instructions ...
static std::optional< StringRef > wmmaOpToIntrinsic(WMMAOp wmma, Chipset chipset)
Returns the rocdl intrinsic corresponding to a WMMA operation wmma if one exists.
static bool hasDot11Insts(const Chipset &chipset)
static std::optional< StringRef > smfmacOpToIntrinsic(SparseMFMAOp op, Chipset chipset)
Returns the rocdl intrinsic corresponding to a SparseMFMA (smfmac) operation if one exists.
static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc, Value basePointer, Value numRecords, bool boundsCheck, amdgpu::Chipset chipset, Value cacheSwizzleStride=nullptr, unsigned addressSpace=8)
static Value createI64Constant(ConversionPatternRewriter &rewriter, Location loc, int64_t value)
static bool hasDot9Insts(const Chipset &chipset)
static std::optional< StringRef > wmmaOpToIntrinsicGfx1250(Type elemSourceType, Type elemBSourceType, Type elemDestType, uint32_t k)
Return the rocdl intrinsic corresponding to a WMMA operation wmma for the gfx1250 architecture.
static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc, MemRefType memrefType, MemRefDescriptor &memrefDescriptor, ArrayRef< int64_t > strides, int64_t elementByteWidth, amdgpu::Chipset chipset, bool boundsCheck)
Compute the contents of the num_records field for a given memref descriptor - that is,...
static Value packSmallFloatVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input, bool allowBf16=true)
Pack small float vector operands (fp4/fp6/fp8/bf16) into the format expected by scaled matrix multipl...
static std::optional< uint32_t > getWmmaScaleFormat(Type elemType)
Maps f8 scale element types to WMMA scale format codes.
static Value convertPackedVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input, bool allowBf16=true)
Converts packed vector operands to the expected ROCDL types.
static Value getLinearIndexI32(ConversionPatternRewriter &rewriter, Location loc, MemRefDescriptor &memRefDescriptor, ValueRange indices, ArrayRef< int64_t > strides)
Returns the linear index used to access an element in the memref.
static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i32.
static bool hasDot8Insts(const Chipset &chipset)
static bool hasDot2Insts(const Chipset &chipset)
static Value createI32Constant(ConversionPatternRewriter &rewriter, Location loc, int32_t value)
static bool hasDot12Insts(const Chipset &chipset)
static std::optional< uint32_t > smallFloatTypeToFormatCode(Type mlirElemType)
static Value convertUnsignedToI64(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i64.
constexpr Chipset kGfx950
static bool hasDot1Insts(const Chipset &chipset)
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static constexpr unsigned kSizePosInMemRefDescriptor
static constexpr unsigned kStridePosInMemRefDescriptor
static constexpr unsigned kOffsetPosInMemRefDescriptor
static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor
static constexpr unsigned kAlignedPtrPosInMemRefDescriptor
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
This class provides a shared interface for ranked and unranked memref types.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
typename SourceOp::template GenericAdaptor< ArrayRef< ValueRange > > OneToNOpAdaptor
typename SourceOp::Adaptor OpAdaptor
Value getStridedElementPtr(ConversionPatternRewriter &rewriter, Location loc, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none) const
Convenience wrapper for the corresponding helper utility.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
NamedAttribute represents a combination of a name and an Attribute value.
This class helps build Operations.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
int32_t getGlobalPrefetchLLVMEncoding(amdgpu::LoadTemporalHint hint, amdgpu::Scope scope, bool isSpeculative)
bool hasOcpFp8(const Chipset &chipset)
void populateCommonGPUTypeAndAttributeConversions(TypeConverter &typeConverter)
Remap common GPU memory spaces (Workgroup, Private, etc) to LLVM address spaces.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, amdgpu::Chipset chipset)
Note: This function will also add conversions for the AMDGPU-specific address spaces and types,...
llvm::TypeSwitch< T, ResultT > TypeSwitch
void populateAMDGPUTypeAndAttributeConversions(TypeConverter &typeConverter)
Remap AMDGPU memory spaces to LLVM address spaces by mapping amdgpu::AddressSpace::fat_raw_buffer to ...
Returns the rocdl intrinsic corresponding to a SparseWMMA operation swmmac if one exists.
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.