30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/TypeSwitch.h"
32#include "llvm/Support/AMDGPUAddrSpace.h"
33#include "llvm/Support/Casting.h"
34#include "llvm/Support/ErrorHandling.h"
39#define GEN_PASS_DEF_CONVERTAMDGPUTOROCDLPASS
40#include "mlir/Conversion/Passes.h.inc"
58 return chipset >=
Chipset(9, 0, 6);
100 if (chipset ==
Chipset(9, 5, 0))
110 IntegerType i32 = rewriter.getI32Type();
112 auto valTy = cast<IntegerType>(val.
getType());
115 return valTy.getWidth() > 32
116 ?
Value(LLVM::TruncOp::create(rewriter, loc, i32, val))
117 :
Value(LLVM::ZExtOp::create(rewriter, loc, i32, val));
122 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), value);
128 IntegerType i64 = rewriter.getI64Type();
130 auto valTy = cast<IntegerType>(val.
getType());
133 return valTy.getWidth() > 64
134 ?
Value(LLVM::TruncOp::create(rewriter, loc, i64, val))
135 :
Value(LLVM::ZExtOp::create(rewriter, loc, i64, val));
140 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), value);
147 IntegerType i32 = rewriter.getI32Type();
149 for (
auto [i, increment, stride] : llvm::enumerate(
indices, strides)) {
152 ShapedType::isDynamic(stride)
154 memRefDescriptor.
stride(rewriter, loc, i))
155 : LLVM::ConstantOp::create(rewriter, loc, i32, stride);
156 increment = LLVM::MulOp::create(rewriter, loc, increment, strideValue);
168 MemRefType memrefType,
172 if (chipset >=
kGfx1250 && !boundsCheck) {
173 constexpr int64_t first45bits = (1ll << 45) - 1;
176 if (memrefType.hasStaticShape() &&
177 !llvm::any_of(strides, ShapedType::isDynamic)) {
178 int64_t size = memrefType.getRank() == 0 ? 1 : 0;
180 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
181 size = std::max(
shape[i] * strides[i], size);
182 size = size * elementByteWidth;
186 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
187 Value size = memrefDescriptor.
size(rewriter, loc, i);
188 Value stride = memrefDescriptor.
stride(rewriter, loc, i);
189 Value maxThisDim = LLVM::MulOp::create(rewriter, loc, size, stride);
191 ? LLVM::UMaxOp::create(rewriter, loc, maxIndex, maxThisDim)
196 return LLVM::MulOp::create(rewriter, loc, maxIndexI64, byteWidthConst);
202 Value cacheSwizzleStride =
nullptr,
203 unsigned addressSpace = 8) {
207 Type i16 = rewriter.getI16Type();
210 Value cacheStrideZext =
211 LLVM::ZExtOp::create(rewriter, loc, i16, cacheSwizzleStride);
212 Value swizzleBit = LLVM::ConstantOp::create(
213 rewriter, loc, i16, rewriter.getI16IntegerAttr(1 << 14));
214 stride = LLVM::OrOp::create(rewriter, loc, cacheStrideZext, swizzleBit,
217 stride = LLVM::ConstantOp::create(rewriter, loc, i16,
218 rewriter.getI16IntegerAttr(0));
247 flags |= (7 << 12) | (4 << 15);
250 uint32_t oob = boundsCheck ? 3 : 2;
251 flags |= (oob << 28);
256 LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
257 Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>(
258 loc, rsrcType, basePointer, stride, numRecords, flagsConst);
263struct FatRawBufferCastLowering
265 FatRawBufferCastLowering(
const LLVMTypeConverter &converter, Chipset chipset)
266 : ConvertOpToLLVMPattern<FatRawBufferCastOp>(converter),
272 matchAndRewrite(FatRawBufferCastOp op, FatRawBufferCastOpAdaptor adaptor,
273 ConversionPatternRewriter &rewriter)
const override {
274 Location loc = op.getLoc();
275 Value memRef = adaptor.getSource();
276 Value unconvertedMemref = op.getSource();
277 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.
getType());
278 MemRefDescriptor descriptor(memRef);
280 DataLayout dataLayout = DataLayout::closest(op);
281 int64_t elementByteWidth =
284 int64_t unusedOffset = 0;
285 SmallVector<int64_t, 5> strideVals;
286 if (
failed(memrefType.getStridesAndOffset(strideVals, unusedOffset)))
287 return op.emitOpError(
"Can't lower non-stride-offset memrefs");
289 Value numRecords = adaptor.getValidBytes();
292 getNumRecords(rewriter, loc, memrefType, descriptor, strideVals,
293 elementByteWidth, chipset, adaptor.getBoundsCheck());
296 adaptor.getResetOffset()
297 ? descriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
299 : descriptor.alignedPtr(rewriter, loc);
301 Value offset = adaptor.getResetOffset()
302 ? LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
303 rewriter.getIndexAttr(0))
304 : descriptor.offset(rewriter, loc);
306 bool hasSizes = memrefType.getRank() > 0;
309 Value sizes = hasSizes
310 ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
314 hasSizes ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
319 rewriter, loc, basePointer, numRecords, adaptor.getBoundsCheck(),
320 chipset, adaptor.getCacheSwizzleStride(), 7);
322 Value
result = MemRefDescriptor::poison(
324 getTypeConverter()->convertType(op.getResult().getType()));
326 result = LLVM::InsertValueOp::create(rewriter, loc,
result, fatPtr, pos);
327 result = LLVM::InsertValueOp::create(rewriter, loc,
result, fatPtr,
329 result = LLVM::InsertValueOp::create(rewriter, loc,
result, offset,
332 result = LLVM::InsertValueOp::create(rewriter, loc,
result, sizes,
334 result = LLVM::InsertValueOp::create(rewriter, loc,
result, strides,
337 rewriter.replaceOp(op,
result);
343template <
typename GpuOp,
typename Intrinsic>
345 RawBufferOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
346 : ConvertOpToLLVMPattern<GpuOp>(converter), chipset(chipset) {}
349 static constexpr uint32_t maxVectorOpWidth = 128;
352 matchAndRewrite(GpuOp gpuOp,
typename GpuOp::Adaptor adaptor,
353 ConversionPatternRewriter &rewriter)
const override {
354 Location loc = gpuOp.getLoc();
355 Value memref = adaptor.getMemref();
356 Value unconvertedMemref = gpuOp.getMemref();
357 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.
getType());
359 if (chipset.majorVersion < 9)
360 return gpuOp.emitOpError(
"raw buffer ops require GCN or higher");
362 Value storeData = adaptor.getODSOperands(0)[0];
363 if (storeData == memref)
367 wantedDataType = storeData.
getType();
369 wantedDataType = gpuOp.getODSResults(0)[0].getType();
371 Value atomicCmpData = Value();
374 Value maybeCmpData = adaptor.getODSOperands(1)[0];
375 if (maybeCmpData != memref)
376 atomicCmpData = maybeCmpData;
379 Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
381 Type i32 = rewriter.getI32Type();
384 DataLayout dataLayout = DataLayout::closest(gpuOp);
385 int64_t elementByteWidth =
394 Type llvmBufferValType = llvmWantedDataType;
396 if (
auto floatType = dyn_cast<FloatType>(wantedDataType))
397 llvmBufferValType = this->getTypeConverter()->convertType(
398 rewriter.getIntegerType(floatType.getWidth()));
400 if (
auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
401 uint32_t vecLen = dataVector.getNumElements();
404 uint32_t totalBits = elemBits * vecLen;
406 isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2;
407 if (totalBits > maxVectorOpWidth)
408 return gpuOp.emitOpError(
409 "Total width of loads or stores must be no more than " +
410 Twine(maxVectorOpWidth) +
" bits, but we call for " +
412 " bits. This should've been caught in validation");
413 if (!usePackedFp16 && elemBits < 32) {
414 if (totalBits > 32) {
415 if (totalBits % 32 != 0)
416 return gpuOp.emitOpError(
"Load or store of more than 32-bits that "
417 "doesn't fit into words. Can't happen\n");
418 llvmBufferValType = this->typeConverter->convertType(
419 VectorType::get(totalBits / 32, i32));
421 llvmBufferValType = this->typeConverter->convertType(
422 rewriter.getIntegerType(totalBits));
426 if (
auto vecType = dyn_cast<VectorType>(llvmBufferValType)) {
429 if (vecType.getNumElements() == 1)
430 llvmBufferValType = vecType.getElementType();
433 SmallVector<Value, 6> args;
435 if (llvmBufferValType != llvmWantedDataType) {
436 Value castForStore = LLVM::BitcastOp::create(
437 rewriter, loc, llvmBufferValType, storeData);
438 args.push_back(castForStore);
440 args.push_back(storeData);
445 if (llvmBufferValType != llvmWantedDataType) {
446 Value castForCmp = LLVM::BitcastOp::create(
447 rewriter, loc, llvmBufferValType, atomicCmpData);
448 args.push_back(castForCmp);
450 args.push_back(atomicCmpData);
456 SmallVector<int64_t, 5> strides;
457 if (
failed(memrefType.getStridesAndOffset(strides, offset)))
458 return gpuOp.emitOpError(
"Can't lower non-stride-offset memrefs");
460 MemRefDescriptor memrefDescriptor(memref);
462 Value ptr = memrefDescriptor.bufferPtr(
463 rewriter, loc, *this->getTypeConverter(), memrefType);
465 getNumRecords(rewriter, loc, memrefType, memrefDescriptor, strides,
466 elementByteWidth, chipset, adaptor.getBoundsCheck());
468 adaptor.getBoundsCheck(), chipset);
469 args.push_back(resource);
473 adaptor.getIndices(), strides);
474 if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset();
475 indexOffset && *indexOffset > 0) {
477 voffset = voffset ? LLVM::AddOp::create(rewriter, loc, voffset,
481 voffset = LLVM::MulOp::create(rewriter, loc, voffset, byteWidthConst);
482 args.push_back(voffset);
485 Value sgprOffset = adaptor.getSgprOffset();
488 sgprOffset = LLVM::MulOp::create(rewriter, loc, sgprOffset, byteWidthConst);
489 args.push_back(sgprOffset);
496 llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(),
498 Operation *lowered = Intrinsic::create(rewriter, loc, resultTypes, args,
499 ArrayRef<NamedAttribute>());
502 if (llvmBufferValType != llvmWantedDataType) {
503 replacement = LLVM::BitcastOp::create(rewriter, loc, llvmWantedDataType,
508 rewriter.eraseOp(gpuOp);
525static FailureOr<unsigned> encodeWaitcnt(
Chipset chipset,
unsigned vmcnt,
526 unsigned expcnt,
unsigned lgkmcnt) {
528 vmcnt = std::min(15u, vmcnt);
529 expcnt = std::min(7u, expcnt);
530 lgkmcnt = std::min(15u, lgkmcnt);
531 return vmcnt | (expcnt << 4) | (lgkmcnt << 8);
534 vmcnt = std::min(63u, vmcnt);
535 expcnt = std::min(7u, expcnt);
536 lgkmcnt = std::min(15u, lgkmcnt);
537 unsigned lowBits = vmcnt & 0xF;
538 unsigned highBits = (vmcnt >> 4) << 14;
539 unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
540 return lowBits | highBits | otherCnts;
543 vmcnt = std::min(63u, vmcnt);
544 expcnt = std::min(7u, expcnt);
545 lgkmcnt = std::min(63u, lgkmcnt);
546 unsigned lowBits = vmcnt & 0xF;
547 unsigned highBits = (vmcnt >> 4) << 14;
548 unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
549 return lowBits | highBits | otherCnts;
552 vmcnt = std::min(63u, vmcnt);
553 expcnt = std::min(7u, expcnt);
554 lgkmcnt = std::min(63u, lgkmcnt);
555 return (vmcnt << 10) | expcnt | (lgkmcnt << 4);
560struct MemoryCounterWaitOpLowering
562 MemoryCounterWaitOpLowering(
const LLVMTypeConverter &converter,
564 : ConvertOpToLLVMPattern<MemoryCounterWaitOp>(converter),
570 matchAndRewrite(MemoryCounterWaitOp op, OpAdaptor adaptor,
571 ConversionPatternRewriter &rewriter)
const override {
572 if (chipset.majorVersion >= 12) {
573 Location loc = op.getLoc();
574 if (std::optional<int> ds = adaptor.getDs())
575 ROCDL::WaitDscntOp::create(rewriter, loc, *ds);
577 if (std::optional<int>
load = adaptor.getLoad())
578 ROCDL::WaitLoadcntOp::create(rewriter, loc, *
load);
580 if (std::optional<int> store = adaptor.getStore())
581 ROCDL::WaitStorecntOp::create(rewriter, loc, *store);
583 if (std::optional<int> exp = adaptor.getExp())
584 ROCDL::WaitExpcntOp::create(rewriter, loc, *exp);
586 if (std::optional<int> tensor = adaptor.getTensor())
587 ROCDL::WaitTensorcntOp::create(rewriter, loc, *tensor);
589 rewriter.eraseOp(op);
593 if (adaptor.getTensor())
594 return op.emitOpError(
"unsupported chipset");
596 auto getVal = [](Attribute attr) ->
unsigned {
598 return cast<IntegerAttr>(attr).getInt();
603 unsigned ds = getVal(adaptor.getDsAttr());
604 unsigned exp = getVal(adaptor.getExpAttr());
606 unsigned vmcnt = 1024;
607 Attribute
load = adaptor.getLoadAttr();
608 Attribute store = adaptor.getStoreAttr();
610 vmcnt = getVal(
load) + getVal(store);
612 vmcnt = getVal(
load);
614 vmcnt = getVal(store);
617 FailureOr<unsigned> waitcnt = encodeWaitcnt(chipset, vmcnt, exp, ds);
619 return op.emitOpError(
"unsupported chipset");
621 rewriter.replaceOpWithNewOp<ROCDL::SWaitcntOp>(op, *waitcnt);
627 LDSBarrierOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
628 : ConvertOpToLLVMPattern<LDSBarrierOp>(converter), chipset(chipset) {}
633 matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
634 ConversionPatternRewriter &rewriter)
const override {
635 Location loc = op.getLoc();
638 bool requiresInlineAsm = chipset <
kGfx90a;
641 rewriter.getAttr<LLVM::MMRATagAttr>(
"amdgpu-synchronize-as",
"local");
650 StringRef scope =
"workgroup";
652 auto relFence = LLVM::FenceOp::create(rewriter, loc,
653 LLVM::AtomicOrdering::release, scope);
654 relFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(), mmra);
655 if (requiresInlineAsm) {
656 auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
657 LLVM::AsmDialect::AD_ATT);
658 const char *asmStr =
";;;WARNING: BREAKS DEBUG WATCHES\ns_barrier";
659 const char *constraints =
"";
660 LLVM::InlineAsmOp::create(
663 asmStr, constraints,
true,
664 false, LLVM::TailCallKind::None,
667 }
else if (chipset.majorVersion < 12) {
668 ROCDL::SBarrierOp::create(rewriter, loc);
670 ROCDL::BarrierSignalOp::create(rewriter, loc, -1);
671 ROCDL::BarrierWaitOp::create(rewriter, loc, -1);
674 auto acqFence = LLVM::FenceOp::create(rewriter, loc,
675 LLVM::AtomicOrdering::acquire, scope);
676 acqFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(), mmra);
677 rewriter.replaceOp(op, acqFence);
683 SchedBarrierOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
684 : ConvertOpToLLVMPattern<SchedBarrierOp>(converter), chipset(chipset) {}
689 matchAndRewrite(SchedBarrierOp op, SchedBarrierOp::Adaptor adaptor,
690 ConversionPatternRewriter &rewriter)
const override {
691 rewriter.replaceOpWithNewOp<ROCDL::SchedBarrier>(op,
692 (uint32_t)op.getOpts());
716 bool allowBf16 =
true) {
718 if (
auto vectorType = dyn_cast<VectorType>(inputType)) {
719 if (vectorType.getElementType().isBF16() && !allowBf16)
720 return LLVM::BitcastOp::create(
721 rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
722 if (vectorType.getElementType().isInteger(8) &&
723 vectorType.getNumElements() <= 8)
724 return LLVM::BitcastOp::create(
726 rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
727 if (isa<IntegerType>(vectorType.getElementType()) &&
728 vectorType.getElementTypeBitWidth() <= 8) {
729 int64_t numWords = llvm::divideCeil(
730 vectorType.getNumElements() * vectorType.getElementTypeBitWidth(),
732 return LLVM::BitcastOp::create(
733 rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()),
743 bool allowBf16 =
true) {
745 auto vectorType = cast<VectorType>(inputType);
747 if (vectorType.getElementType().isBF16() && !allowBf16)
748 return LLVM::BitcastOp::create(
749 rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
751 if (isa<IntegerType>(vectorType.getElementType()) &&
752 vectorType.getElementTypeBitWidth() <= 8) {
753 int64_t numWords = llvm::divideCeil(
754 vectorType.getNumElements() * vectorType.getElementTypeBitWidth(), 32);
755 Type castType = (numWords > 1)
756 ?
Type{VectorType::get(numWords, rewriter.getI32Type())}
757 : rewriter.getI32Type();
758 return LLVM::BitcastOp::create(rewriter, loc, castType, input);
776 .Case([&](IntegerType) {
778 return LLVM::ZExtOp::create(rewriter, loc, rewriter.getI32Type(),
781 .Case([&](VectorType vectorType) {
783 int64_t numElements = vectorType.getNumElements();
784 assert((numElements == 4 || numElements == 8) &&
785 "scale operand must be a vector of length 4 or 8");
786 IntegerType outputType =
787 (numElements == 4) ? rewriter.getI32Type() : rewriter.getI64Type();
788 return LLVM::BitcastOp::create(rewriter, loc, outputType, input);
790 .DefaultUnreachable(
"unexpected input type for scale operand");
796 .Case([](Float8E8M0FNUType) {
return 0; })
797 .Case([](Float8E4M3FNType) {
return 2; })
798 .Default(std::nullopt);
803static std::optional<StringRef>
805 if (m == 16 && n == 16 && k == 128)
807 ? ROCDL::wmma_scale16_f32_16x16x128_f8f6f4::getOperationName()
808 : ROCDL::wmma_scale_f32_16x16x128_f8f6f4::getOperationName();
810 if (m == 32 && n == 16 && k == 128)
811 return isScale16 ? ROCDL::wmma_scale16_f32_32x16x128_f4::getOperationName()
812 : ROCDL::wmma_scale_f32_32x16x128_f4::getOperationName();
826 ConversionPatternRewriter &rewriter,
Location loc,
831 auto vectorType = dyn_cast<VectorType>(inputType);
833 operands.push_back(llvmInput);
836 Type elemType = vectorType.getElementType();
838 operands.push_back(llvmInput);
845 auto mlirInputType = cast<VectorType>(mlirInput.
getType());
846 bool isInputInteger = mlirInputType.getElementType().isInteger();
847 if (isInputInteger) {
849 bool localIsUnsigned = isUnsigned;
851 localIsUnsigned =
true;
853 localIsUnsigned =
false;
856 NamedAttribute(attrName, rewriter.getBoolAttr(!localIsUnsigned)));
861 Type i32 = rewriter.getI32Type();
862 Type intrinsicInType = numBits <= 32
863 ? (
Type)rewriter.getIntegerType(numBits)
864 : (
Type)VectorType::get(numBits / 32, i32);
865 auto llvmIntrinsicInType = typeConverter->convertType(intrinsicInType);
866 Value castInput = rewriter.createOrFold<LLVM::BitcastOp>(
867 loc, llvmIntrinsicInType, llvmInput);
872 castInput = LLVM::ZExtOp::create(rewriter, loc, i32, castInput);
873 operands.push_back(castInput);
886 Value output, int32_t subwordOffset,
890 auto vectorType = dyn_cast<VectorType>(inputType);
891 Type elemType = vectorType.getElementType();
892 operands.push_back(output);
904 return (chipset ==
kGfx942 && isa<Float8E5M2FNUZType>(type)) ||
905 (
hasOcpFp8(chipset) && isa<Float8E5M2Type>(type));
911 return (chipset ==
kGfx942 && isa<Float8E4M3FNUZType>(type)) ||
912 (
hasOcpFp8(chipset) && isa<Float8E4M3FNType>(type));
920 uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
921 b = mfma.getBlocks();
926 if (mfma.getReducePrecision() && chipset >=
kGfx942) {
927 if (m == 32 && n == 32 && k == 4 &&
b == 1)
928 return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
929 if (m == 16 && n == 16 && k == 8 &&
b == 1)
930 return ROCDL::mfma_f32_16x16x8_xf32::getOperationName();
932 if (m == 32 && n == 32 && k == 1 &&
b == 2)
933 return ROCDL::mfma_f32_32x32x1f32::getOperationName();
934 if (m == 16 && n == 16 && k == 1 &&
b == 4)
935 return ROCDL::mfma_f32_16x16x1f32::getOperationName();
936 if (m == 4 && n == 4 && k == 1 &&
b == 16)
937 return ROCDL::mfma_f32_4x4x1f32::getOperationName();
938 if (m == 32 && n == 32 && k == 2 &&
b == 1)
939 return ROCDL::mfma_f32_32x32x2f32::getOperationName();
940 if (m == 16 && n == 16 && k == 4 &&
b == 1)
941 return ROCDL::mfma_f32_16x16x4f32::getOperationName();
946 if (m == 32 && n == 32 && k == 16 &&
b == 1)
947 return ROCDL::mfma_f32_32x32x16_f16::getOperationName();
948 if (m == 16 && n == 16 && k == 32 &&
b == 1)
949 return ROCDL::mfma_f32_16x16x32_f16::getOperationName();
951 if (m == 32 && n == 32 && k == 4 &&
b == 2)
952 return ROCDL::mfma_f32_32x32x4f16::getOperationName();
953 if (m == 16 && n == 16 && k == 4 &&
b == 4)
954 return ROCDL::mfma_f32_16x16x4f16::getOperationName();
955 if (m == 4 && n == 4 && k == 4 &&
b == 16)
956 return ROCDL::mfma_f32_4x4x4f16::getOperationName();
957 if (m == 32 && n == 32 && k == 8 &&
b == 1)
958 return ROCDL::mfma_f32_32x32x8f16::getOperationName();
959 if (m == 16 && n == 16 && k == 16 &&
b == 1)
960 return ROCDL::mfma_f32_16x16x16f16::getOperationName();
965 if (m == 32 && n == 32 && k == 16 &&
b == 1)
966 return ROCDL::mfma_f32_32x32x16_bf16::getOperationName();
967 if (m == 16 && n == 16 && k == 32 &&
b == 1)
968 return ROCDL::mfma_f32_16x16x32_bf16::getOperationName();
971 if (m == 32 && n == 32 && k == 4 &&
b == 2)
972 return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
973 if (m == 16 && n == 16 && k == 4 &&
b == 4)
974 return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
975 if (m == 4 && n == 4 && k == 4 &&
b == 16)
976 return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
977 if (m == 32 && n == 32 && k == 8 &&
b == 1)
978 return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
979 if (m == 16 && n == 16 && k == 16 &&
b == 1)
980 return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
982 if (m == 32 && n == 32 && k == 2 &&
b == 2)
983 return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
984 if (m == 16 && n == 16 && k == 2 &&
b == 4)
985 return ROCDL::mfma_f32_16x16x2bf16::getOperationName();
986 if (m == 4 && n == 4 && k == 2 &&
b == 16)
987 return ROCDL::mfma_f32_4x4x2bf16::getOperationName();
988 if (m == 32 && n == 32 && k == 4 &&
b == 1)
989 return ROCDL::mfma_f32_32x32x4bf16::getOperationName();
990 if (m == 16 && n == 16 && k == 8 &&
b == 1)
991 return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
996 if (m == 32 && n == 32 && k == 32 &&
b == 1)
997 return ROCDL::mfma_i32_32x32x32_i8::getOperationName();
998 if (m == 16 && n == 16 && k == 64 &&
b == 1)
999 return ROCDL::mfma_i32_16x16x64_i8::getOperationName();
1001 if (m == 32 && n == 32 && k == 4 &&
b == 2)
1002 return ROCDL::mfma_i32_32x32x4i8::getOperationName();
1003 if (m == 16 && n == 16 && k == 4 &&
b == 4)
1004 return ROCDL::mfma_i32_16x16x4i8::getOperationName();
1005 if (m == 4 && n == 4 && k == 4 &&
b == 16)
1006 return ROCDL::mfma_i32_4x4x4i8::getOperationName();
1007 if (m == 32 && n == 32 && k == 8 &&
b == 1)
1008 return ROCDL::mfma_i32_32x32x8i8::getOperationName();
1009 if (m == 16 && n == 16 && k == 16 &&
b == 1)
1010 return ROCDL::mfma_i32_16x16x16i8::getOperationName();
1011 if (m == 32 && n == 32 && k == 16 &&
b == 1 && chipset >=
kGfx942)
1012 return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
1013 if (m == 16 && n == 16 && k == 32 &&
b == 1 && chipset >=
kGfx942)
1014 return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
1018 if (m == 16 && n == 16 && k == 4 &&
b == 1)
1019 return ROCDL::mfma_f64_16x16x4f64::getOperationName();
1020 if (m == 4 && n == 4 && k == 4 &&
b == 4)
1021 return ROCDL::mfma_f64_4x4x4f64::getOperationName();
1028 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
1029 if (m == 16 && n == 16 && k == 32 &&
b == 1) {
1031 return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
1033 return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
1035 if (m == 32 && n == 32 && k == 16 &&
b == 1) {
1037 return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
1039 return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
1045 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
1046 if (m == 16 && n == 16 && k == 32 &&
b == 1) {
1048 return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
1050 return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
1052 if (m == 32 && n == 32 && k == 16 &&
b == 1) {
1054 return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
1056 return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
1060 return std::nullopt;
1065 .Case([](Float8E4M3FNType) {
return 0u; })
1066 .Case([](Float8E5M2Type) {
return 1u; })
1067 .Case([](Float6E2M3FNType) {
return 2u; })
1068 .Case([](Float6E3M2FNType) {
return 3u; })
1069 .Case([](Float4E2M1FNType) {
return 4u; })
1070 .Default(std::nullopt);
1080static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1082 uint32_t n, uint32_t k, uint32_t
b,
Chipset chipset) {
1088 return std::nullopt;
1089 if (!isa<Float32Type>(destType))
1090 return std::nullopt;
1094 if (!aTypeCode || !bTypeCode)
1095 return std::nullopt;
1097 if (m == 32 && n == 32 && k == 64 &&
b == 1)
1098 return std::tuple{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(),
1099 *aTypeCode, *bTypeCode};
1100 if (m == 16 && n == 16 && k == 128 &&
b == 1)
1102 ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), *aTypeCode,
1105 return std::nullopt;
1108static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1111 mfma.getSourceA().getType(), mfma.getSourceB().getType(),
1112 mfma.getDestC().getType(), mfma.getM(), mfma.getN(), mfma.getK(),
1113 mfma.getBlocks(), chipset);
1116static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1119 smfma.getSourceB().getType(),
1120 smfma.getDestC().getType(), smfma.getM(),
1121 smfma.getN(), smfma.getK(), 1u, chipset);
1126static std::optional<StringRef>
1128 Type elemDestType, uint32_t k,
bool isRDNA3) {
1129 using fp8 = Float8E4M3FNType;
1130 using bf8 = Float8E5M2Type;
1135 if (elemSourceType.
isF16() && elemDestType.
isF32())
1136 return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
1137 if (elemSourceType.
isBF16() && elemDestType.
isF32())
1138 return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
1139 if (elemSourceType.
isF16() && elemDestType.
isF16())
1140 return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
1142 return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
1144 return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
1149 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
1150 return std::nullopt;
1154 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
1155 elemDestType.
isF32())
1156 return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
1157 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
1158 elemDestType.
isF32())
1159 return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName();
1160 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
1161 elemDestType.
isF32())
1162 return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName();
1163 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
1164 elemDestType.
isF32())
1165 return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
1167 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
1169 return std::nullopt;
1173 if (k == 32 && !isRDNA3) {
1175 return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
1178 return std::nullopt;
1184 Type elemBSourceType,
1187 using fp8 = Float8E4M3FNType;
1188 using bf8 = Float8E5M2Type;
1191 if (elemSourceType.
isF32() && elemDestType.
isF32())
1192 return ROCDL::wmma_f32_16x16x4_f32::getOperationName();
1194 return std::nullopt;
1198 if (elemSourceType.
isF16() && elemDestType.
isF32())
1199 return ROCDL::wmma_f32_16x16x32_f16::getOperationName();
1200 if (elemSourceType.
isBF16() && elemDestType.
isF32())
1201 return ROCDL::wmma_f32_16x16x32_bf16::getOperationName();
1202 if (elemSourceType.
isF16() && elemDestType.
isF16())
1203 return ROCDL::wmma_f16_16x16x32_f16::getOperationName();
1205 return ROCDL::wmma_bf16_16x16x32_bf16::getOperationName();
1207 return std::nullopt;
1211 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1212 if (elemDestType.
isF32())
1213 return ROCDL::wmma_f32_16x16x64_fp8_fp8::getOperationName();
1214 if (elemDestType.
isF16())
1215 return ROCDL::wmma_f16_16x16x64_fp8_fp8::getOperationName();
1217 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1218 if (elemDestType.
isF32())
1219 return ROCDL::wmma_f32_16x16x64_fp8_bf8::getOperationName();
1220 if (elemDestType.
isF16())
1221 return ROCDL::wmma_f16_16x16x64_fp8_bf8::getOperationName();
1223 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1224 if (elemDestType.
isF32())
1225 return ROCDL::wmma_f32_16x16x64_bf8_bf8::getOperationName();
1226 if (elemDestType.
isF16())
1227 return ROCDL::wmma_f16_16x16x64_bf8_bf8::getOperationName();
1229 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1230 if (elemDestType.
isF32())
1231 return ROCDL::wmma_f32_16x16x64_bf8_fp8::getOperationName();
1232 if (elemDestType.
isF16())
1233 return ROCDL::wmma_f16_16x16x64_bf8_fp8::getOperationName();
1236 return ROCDL::wmma_i32_16x16x64_iu8::getOperationName();
1238 return std::nullopt;
1242 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1243 if (elemDestType.
isF32())
1244 return ROCDL::wmma_f32_16x16x128_fp8_fp8::getOperationName();
1245 if (elemDestType.
isF16())
1246 return ROCDL::wmma_f16_16x16x128_fp8_fp8::getOperationName();
1248 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1249 if (elemDestType.
isF32())
1250 return ROCDL::wmma_f32_16x16x128_fp8_bf8::getOperationName();
1251 if (elemDestType.
isF16())
1252 return ROCDL::wmma_f16_16x16x128_fp8_bf8::getOperationName();
1254 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1255 if (elemDestType.
isF32())
1256 return ROCDL::wmma_f32_16x16x128_bf8_bf8::getOperationName();
1257 if (elemDestType.
isF16())
1258 return ROCDL::wmma_f16_16x16x128_bf8_bf8::getOperationName();
1260 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1261 if (elemDestType.
isF32())
1262 return ROCDL::wmma_f32_16x16x128_bf8_fp8::getOperationName();
1263 if (elemDestType.
isF16())
1264 return ROCDL::wmma_f16_16x16x128_bf8_fp8::getOperationName();
1267 return std::nullopt;
1270 return std::nullopt;
1278 bool isGfx950 = chipset >=
kGfx950;
1282 uint32_t m = op.getM(), n = op.getN(), k = op.getK();
1287 if (m == 16 && n == 16 && k == 32) {
1289 return ROCDL::smfmac_f32_16x16x32_f16::getOperationName();
1291 return ROCDL::smfmac_f32_16x16x32_bf16::getOperationName();
1294 if (m == 16 && n == 16 && k == 64) {
1297 return ROCDL::smfmac_f32_16x16x64_f16::getOperationName();
1299 return ROCDL::smfmac_f32_16x16x64_bf16::getOperationName();
1303 return ROCDL::smfmac_i32_16x16x64_i8::getOperationName();
1304 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1305 return ROCDL::smfmac_f32_16x16x64_fp8_fp8::getOperationName();
1306 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1307 return ROCDL::smfmac_f32_16x16x64_fp8_bf8::getOperationName();
1308 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1309 return ROCDL::smfmac_f32_16x16x64_bf8_fp8::getOperationName();
1310 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1311 return ROCDL::smfmac_f32_16x16x64_bf8_bf8::getOperationName();
1314 if (m == 16 && n == 16 && k == 128 && isGfx950) {
1317 return ROCDL::smfmac_i32_16x16x128_i8::getOperationName();
1318 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1319 return ROCDL::smfmac_f32_16x16x128_fp8_fp8::getOperationName();
1320 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1321 return ROCDL::smfmac_f32_16x16x128_fp8_bf8::getOperationName();
1322 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1323 return ROCDL::smfmac_f32_16x16x128_bf8_fp8::getOperationName();
1324 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1325 return ROCDL::smfmac_f32_16x16x128_bf8_bf8::getOperationName();
1328 if (m == 32 && n == 32 && k == 16) {
1330 return ROCDL::smfmac_f32_32x32x16_f16::getOperationName();
1332 return ROCDL::smfmac_f32_32x32x16_bf16::getOperationName();
1335 if (m == 32 && n == 32 && k == 32) {
1338 return ROCDL::smfmac_f32_32x32x32_f16::getOperationName();
1340 return ROCDL::smfmac_f32_32x32x32_bf16::getOperationName();
1344 return ROCDL::smfmac_i32_32x32x32_i8::getOperationName();
1345 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1346 return ROCDL::smfmac_f32_32x32x32_fp8_fp8::getOperationName();
1347 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1348 return ROCDL::smfmac_f32_32x32x32_fp8_bf8::getOperationName();
1349 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1350 return ROCDL::smfmac_f32_32x32x32_bf8_fp8::getOperationName();
1351 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1352 return ROCDL::smfmac_f32_32x32x32_bf8_bf8::getOperationName();
1355 if (m == 32 && n == 32 && k == 64 && isGfx950) {
1358 return ROCDL::smfmac_i32_32x32x64_i8::getOperationName();
1359 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1360 return ROCDL::smfmac_f32_32x32x64_fp8_fp8::getOperationName();
1361 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1362 return ROCDL::smfmac_f32_32x32x64_fp8_bf8::getOperationName();
1363 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1364 return ROCDL::smfmac_f32_32x32x64_bf8_fp8::getOperationName();
1365 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1366 return ROCDL::smfmac_f32_32x32x64_bf8_bf8::getOperationName();
1369 return std::nullopt;
1377 auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType());
1378 auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType());
1379 auto destVectorType = cast<VectorType>(wmma.getDestC().getType());
1380 Type elemSourceType = sourceVectorType.getElementType();
1381 Type elemBSourceType = sourceBVectorType.getElementType();
1382 Type elemDestType = destVectorType.getElementType();
1384 const uint32_t k = wmma.getK();
1389 if (isRDNA3 || isRDNA4)
1398 return std::nullopt;
1411static std::optional<SparseWMMAOpInfo>
1417 uint32_t m = swmmac.getM(), n = swmmac.getN(), k = swmmac.getK();
1419 if ((m != 16) || (n != 16))
1420 return std::nullopt;
1427 ROCDL::swmmac_f32_16x16x32_f16::getOperationName(),
false,
false,
1431 ROCDL::swmmac_f32_16x16x32_bf16::getOperationName(),
false,
false,
1435 ROCDL::swmmac_f16_16x16x32_f16::getOperationName(),
false,
false,
1439 ROCDL::swmmac_bf16_16x16x32_bf16::getOperationName(),
false,
false,
1444 ROCDL::swmmac_i32_16x16x32_iu8::getOperationName(),
true,
false,
1449 ROCDL::swmmac_i32_16x16x32_iu4::getOperationName(),
true,
false,
1454 ROCDL::swmmac_f32_16x16x32_fp8_fp8::getOperationName(),
false,
1459 ROCDL::swmmac_f32_16x16x32_fp8_bf8::getOperationName(),
false,
1464 ROCDL::swmmac_f32_16x16x32_bf8_fp8::getOperationName(),
false,
1468 ROCDL::swmmac_f32_16x16x32_bf8_bf8::getOperationName(),
false,
1475 ROCDL::swmmac_i32_16x16x64_iu4::getOperationName(),
true,
false,
1480 const bool isGFX1250 = chipset ==
kGfx1250;
1481 const bool isWavesize64 = swmmac.getWave64();
1482 if (isGFX1250 && !isWavesize64) {
1486 ROCDL::swmmac_f32_16x16x64_f16::getOperationName(),
true,
true,
1490 ROCDL::swmmac_f32_16x16x64_bf16::getOperationName(),
true,
true,
1494 ROCDL::swmmac_f16_16x16x64_f16::getOperationName(),
true,
true,
1498 ROCDL::swmmac_bf16_16x16x64_bf16::getOperationName(),
true,
true,
1505 ROCDL::swmmac_f32_16x16x128_fp8_fp8::getOperationName(),
false,
1510 ROCDL::swmmac_f32_16x16x128_fp8_bf8::getOperationName(),
false,
1515 ROCDL::swmmac_f32_16x16x128_bf8_fp8::getOperationName(),
false,
1519 ROCDL::swmmac_f32_16x16x128_bf8_bf8::getOperationName(),
false,
1524 ROCDL::swmmac_f16_16x16x128_fp8_fp8::getOperationName(),
false,
1529 ROCDL::swmmac_f16_16x16x128_fp8_bf8::getOperationName(),
false,
1534 ROCDL::swmmac_f16_16x16x128_bf8_fp8::getOperationName(),
false,
1538 ROCDL::swmmac_f16_16x16x128_bf8_bf8::getOperationName(),
false,
1543 ROCDL::swmmac_f16_16x16x128_bf8_bf8::getOperationName(),
false,
1548 ROCDL::swmmac_i32_16x16x128_iu8::getOperationName(),
true,
true,
1553 return std::nullopt;
1558 MFMAOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1559 : ConvertOpToLLVMPattern<MFMAOp>(converter), chipset(chipset) {}
1564 matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor,
1565 ConversionPatternRewriter &rewriter)
const override {
1566 Location loc = op.getLoc();
1567 Type outType = typeConverter->convertType(op.getDestD().getType());
1568 Type intrinsicOutType = outType;
1569 if (
auto outVecType = dyn_cast<VectorType>(outType))
1570 if (outVecType.getElementType().isBF16())
1571 intrinsicOutType = outVecType.clone(rewriter.getI16Type());
1573 if (chipset.majorVersion != 9 || chipset <
kGfx908)
1574 return op->emitOpError(
"MFMA only supported on gfx908+");
1575 uint32_t getBlgpField =
static_cast<uint32_t
>(op.getBlgp());
1576 if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
1578 return op.emitOpError(
"negation unsupported on older than gfx942");
1580 op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
1583 std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1585 if (!maybeIntrinsic.has_value() && !maybeScaledIntrinsic.has_value())
1586 return op.emitOpError(
"no intrinsic matching MFMA size on given chipset");
1589 !maybeIntrinsic.has_value() && maybeScaledIntrinsic.has_value();
1591 (adaptor.getAbid() > 0 || getBlgpField > 0 || op.getCbsz() > 0)) {
1592 return op.emitOpError(
1593 "non-default abid, blgp, and cbsz aren't supported on MFMAs that can "
1594 "be scaled as those fields are used for type information");
1597 StringRef intrinsicName =
1598 isScaled ? std::get<0>(*maybeScaledIntrinsic) : *maybeIntrinsic;
1601 bool allowBf16 = [&]() {
1606 return intrinsicName.contains(
"16x16x32.bf16") ||
1607 intrinsicName.contains(
"32x32x16.bf16");
1609 OperationState loweredOp(loc, intrinsicName);
1610 loweredOp.addTypes(intrinsicOutType);
1612 rewriter, loc, adaptor.getSourceA(), allowBf16),
1614 rewriter, loc, adaptor.getSourceB(), allowBf16),
1615 adaptor.getDestC()});
1618 auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1619 loweredOp.addOperands({zero, zero});
1620 loweredOp.addAttributes({{
"cbsz", rewriter.getI32IntegerAttr(aTypeCode)},
1621 {
"blgp", rewriter.getI32IntegerAttr(bTypeCode)},
1622 {
"opselA", rewriter.getI32IntegerAttr(0)},
1623 {
"opselB", rewriter.getI32IntegerAttr(0)}});
1625 loweredOp.addAttributes(
1626 {{
"cbsz", rewriter.getI32IntegerAttr(op.getCbsz())},
1627 {
"abid", rewriter.getI32IntegerAttr(op.getAbid())},
1628 {
"blgp", rewriter.getI32IntegerAttr(getBlgpField)}});
1630 Value lowered = rewriter.create(loweredOp)->getResult(0);
1631 if (outType != intrinsicOutType)
1632 lowered = LLVM::BitcastOp::create(rewriter, loc, outType, lowered);
1633 rewriter.replaceOp(op, lowered);
1639 ScaledMFMAOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1640 : ConvertOpToLLVMPattern(converter), chipset(chipset) {}
1645 matchAndRewrite(ScaledMFMAOp op, ScaledMFMAOpAdaptor adaptor,
1646 ConversionPatternRewriter &rewriter)
const override {
1647 Location loc = op.getLoc();
1648 Type intrinsicOutType = typeConverter->convertType(op.getDestD().getType());
1650 if (chipset.majorVersion != 9 || chipset <
kGfx950)
1651 return op->emitOpError(
"scaled MFMA only supported on gfx908+");
1652 std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1654 if (!maybeScaledIntrinsic.has_value())
1655 return op.emitOpError(
1656 "no intrinsic matching scaled MFMA size on given chipset");
1658 auto [intrinsicName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1659 OperationState loweredOp(loc, intrinsicName);
1660 loweredOp.addTypes(intrinsicOutType);
1661 loweredOp.addOperands(
1664 adaptor.getDestC()});
1665 loweredOp.addOperands(
1670 loweredOp.addAttributes(
1671 {{
"cbsz", rewriter.getI32IntegerAttr(aTypeCode)},
1672 {
"blgp", rewriter.getI32IntegerAttr(bTypeCode)},
1673 {
"opselA", rewriter.getI32IntegerAttr(adaptor.getScalesIdxA())},
1674 {
"opselB", rewriter.getI32IntegerAttr(adaptor.getScalesIdxB())}});
1676 Value lowered = rewriter.create(loweredOp)->getResult(0);
1677 rewriter.replaceOp(op, lowered);
1683 SparseMFMAOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1684 : ConvertOpToLLVMPattern<SparseMFMAOp>(converter), chipset(chipset) {}
1689 matchAndRewrite(SparseMFMAOp op, SparseMFMAOpAdaptor adaptor,
1690 ConversionPatternRewriter &rewriter)
const override {
1691 Location loc = op.getLoc();
1693 typeConverter->convertType<VectorType>(op.getDestC().
getType());
1695 return rewriter.notifyMatchFailure(op,
"type conversion failed");
1698 if (chipset.majorVersion != 9 || chipset <
kGfx942)
1699 return op->emitOpError(
"sparse MFMA (smfmac) only supported on gfx942+");
1702 if (!maybeIntrinsic.has_value())
1703 return op.emitOpError(
1704 "no intrinsic matching sparse MFMA on the given chipset");
1707 ROCDL::smfmac_f32_16x16x32_bf16::getOperationName() ||
1709 ROCDL::smfmac_f32_32x32x16_bf16::getOperationName());
1710 bool isGfx950 = (chipset >=
kGfx950) && !isGfx942BF16;
1716 Value c = adaptor.getDestC();
1720 Value sparseIdx = adaptor.getSparseIdx();
1721 Type i32Type = rewriter.getI32Type();
1722 if (sparseIdx.
getType() != i32Type)
1723 sparseIdx = LLVM::BitcastOp::create(rewriter, loc, i32Type, sparseIdx);
1725 OperationState loweredOp(loc, maybeIntrinsic.value());
1726 loweredOp.addTypes(outType);
1727 loweredOp.addOperands({a,
b, c, sparseIdx});
1728 loweredOp.addAttributes(
1729 {{
"cbsz", rewriter.getI32IntegerAttr(op.getCbsz())},
1730 {
"abid", rewriter.getI32IntegerAttr(op.getAbid())}});
1731 Value lowered = rewriter.create(loweredOp)->getResult(0);
1732 rewriter.replaceOp(op, lowered);
1738 WMMAOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1739 : ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
1744 matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
1745 ConversionPatternRewriter &rewriter)
const override {
1746 Location loc = op.getLoc();
1748 typeConverter->convertType<VectorType>(op.getDestD().
getType());
1750 return rewriter.notifyMatchFailure(op,
"type conversion failed");
1752 if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
1753 return op->emitOpError(
"WMMA only supported on gfx11 and gfx12");
1755 bool isGFX1250 = chipset >=
kGfx1250;
1760 auto aType = cast<VectorType>(adaptor.getSourceA().getType());
1761 auto bType = cast<VectorType>(adaptor.getSourceB().getType());
1762 auto destCType = cast<VectorType>(adaptor.getDestC().getType());
1763 bool castAToI16 = aType.getElementType().isBF16() && !isGFX1250;
1764 bool castBToI16 = bType.getElementType().isBF16() && !isGFX1250;
1765 bool castDestCToI16 = destCType.getElementType().isBF16() && !isGFX1250;
1766 bool castOutToI16 = outType.getElementType().
isBF16() && !isGFX1250;
1767 VectorType rawOutType = outType;
1769 rawOutType = outType.clone(rewriter.getI16Type());
1770 Value a = adaptor.getSourceA();
1772 a = LLVM::BitcastOp::create(rewriter, loc,
1773 aType.clone(rewriter.getI16Type()), a);
1774 Value
b = adaptor.getSourceB();
1776 b = LLVM::BitcastOp::create(rewriter, loc,
1777 bType.clone(rewriter.getI16Type()),
b);
1778 Value destC = adaptor.getDestC();
1780 destC = LLVM::BitcastOp::create(
1781 rewriter, loc, destCType.clone(rewriter.getI16Type()), destC);
1785 if (!maybeIntrinsic.has_value())
1786 return op.emitOpError(
"no intrinsic matching WMMA on the given chipset");
1788 if (chipset.majorVersion >= 12 && op.getSubwordOffset() != 0)
1789 return op.emitOpError(
"subwordOffset not supported on gfx12+");
1791 SmallVector<Value, 4> operands;
1792 SmallVector<NamedAttribute, 4> attrs;
1794 op.getSourceA(), operands, attrs,
"signA");
1796 op.getSourceB(), operands, attrs,
"signB");
1798 op.getSubwordOffset(), op.getClamp(), operands,
1801 OperationState loweredOp(loc, *maybeIntrinsic);
1802 loweredOp.addTypes(rawOutType);
1803 loweredOp.addOperands(operands);
1804 loweredOp.addAttributes(attrs);
1805 Operation *lowered = rewriter.create(loweredOp);
1807 Operation *maybeCastBack = lowered;
1808 if (rawOutType != outType)
1809 maybeCastBack = LLVM::BitcastOp::create(rewriter, loc, outType,
1811 rewriter.replaceOp(op, maybeCastBack->
getResults());
1817enum class DotFamily {
1826static std::optional<std::pair<StringRef, DotFamily>>
1827dotOpToIntrinsic(DotOp op,
Chipset chipset) {
1828 Type aElem = cast<VectorType>(op.getSourceA().getType()).getElementType();
1829 Type bElem = cast<VectorType>(op.getSourceB().getType()).getElementType();
1830 Type dest = op.getDestC().getType();
1831 bool uA = op.getUnsignedA();
1832 bool uB = op.getUnsignedB();
1837 return {{ROCDL::fdot2::getOperationName(), DotFamily::Clamp}};
1839 return {{ROCDL::fdot2_f16_f16::getOperationName(), DotFamily::NoClamp}};
1840 return std::nullopt;
1846 return {{ROCDL::fdot2_f32_bf16::getOperationName(), DotFamily::Clamp}};
1848 return {{ROCDL::fdot2_bf16_bf16::getOperationName(), DotFamily::NoClamp}};
1849 return std::nullopt;
1853 if (isa<IntegerType>(aElem) && isa<IntegerType>(bElem) &&
1855 bool mixedSign = (uA != uB);
1860 return std::nullopt;
1862 switch (elemWidth) {
1864 name = ROCDL::sudot4::getOperationName();
1867 name = ROCDL::sudot8::getOperationName();
1870 return std::nullopt;
1872 return {{name, DotFamily::Sudot}};
1876 bool supported =
false;
1877 switch (elemWidth) {
1880 name = uA ? ROCDL::udot2::getOperationName()
1881 :
ROCDL::sdot2::getOperationName();
1886 name = uA ? ROCDL::udot4::getOperationName()
1887 :
ROCDL::sdot4::getOperationName();
1892 name = uA ? ROCDL::udot8::getOperationName()
1893 :
ROCDL::sdot8::getOperationName();
1896 return std::nullopt;
1899 return std::nullopt;
1900 return {{name, DotFamily::Clamp}};
1904 bool aIsFp8 = isa<Float8E4M3FNType>(aElem);
1905 bool aIsBf8 = isa<Float8E5M2Type>(aElem);
1906 bool bIsFp8 = isa<Float8E4M3FNType>(bElem);
1907 bool bIsBf8 = isa<Float8E5M2Type>(bElem);
1908 if ((aIsFp8 || aIsBf8) && (bIsFp8 || bIsBf8) && dest.
isF32()) {
1910 return std::nullopt;
1912 if (aIsFp8 && bIsFp8)
1913 name = ROCDL::dot4_f32_fp8_fp8::getOperationName();
1914 else if (aIsFp8 && bIsBf8)
1915 name = ROCDL::dot4_f32_fp8_bf8::getOperationName();
1916 else if (aIsBf8 && bIsFp8)
1917 name = ROCDL::dot4_f32_bf8_fp8::getOperationName();
1919 name = ROCDL::dot4_f32_bf8_bf8::getOperationName();
1920 return {{name, DotFamily::NoClamp}};
1923 return std::nullopt;
1927 DotOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1928 : ConvertOpToLLVMPattern<DotOp>(converter), chipset(chipset) {}
1933 matchAndRewrite(DotOp op, DotOpAdaptor adaptor,
1934 ConversionPatternRewriter &rewriter)
const override {
1935 Location loc = op.getLoc();
1937 std::optional<std::pair<StringRef, DotFamily>> maybeIntrinsic =
1938 dotOpToIntrinsic(op, chipset);
1939 if (!maybeIntrinsic)
1940 return op.emitOpError(
"no intrinsic matching dot on the given chipset: ")
1941 << op.getSourceA().getType() <<
" * " << op.getSourceB().getType()
1942 <<
" + " << op.getDestC().getType();
1944 auto [intrinsicName, family] = maybeIntrinsic.value();
1948 Value c = adaptor.getDestC();
1950 SmallVector<NamedAttribute, 3> attrs;
1951 if (family == DotFamily::Sudot) {
1952 attrs.push_back(rewriter.getNamedAttr(
1953 "signA", rewriter.getBoolAttr(!op.getUnsignedA())));
1954 attrs.push_back(rewriter.getNamedAttr(
1955 "signB", rewriter.getBoolAttr(!op.getUnsignedB())));
1958 if (family != DotFamily::NoClamp && op.getClamp())
1960 rewriter.getNamedAttr(
"clamp", rewriter.getBoolAttr(
true)));
1962 Type resultType = typeConverter->convertType(op.getDestD().getType());
1964 OperationState loweredOp(loc, intrinsicName);
1965 loweredOp.addTypes(resultType);
1966 loweredOp.addOperands({a,
b, c});
1967 loweredOp.addAttributes(attrs);
1968 Operation *lowered = rewriter.create(loweredOp);
1969 rewriter.replaceOp(op, lowered->
getResults());
1975 SparseWMMAOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1976 : ConvertOpToLLVMPattern<SparseWMMAOp>(converter), chipset(chipset) {}
1981 matchAndRewrite(SparseWMMAOp op, SparseWMMAOpAdaptor adaptor,
1982 ConversionPatternRewriter &rewriter)
const override {
1983 Location loc = op.getLoc();
1985 typeConverter->convertType<VectorType>(op.getDestD().
getType());
1987 return rewriter.notifyMatchFailure(op,
"type conversion failed");
1989 std::optional<SparseWMMAOpInfo> maybeIntrinsic =
1992 if (!maybeIntrinsic.has_value())
1993 return op.emitOpError(
1994 "no intrinsic matching Sparse WMMA on the given chipset");
1995 SparseWMMAOpInfo intrinsic = maybeIntrinsic.value();
1997 SmallVector<NamedAttribute> attrs;
1999 if ((op.getUnsignedA() || op.getUnsignedB()) && !intrinsic.
useSign)
2000 return op->emitOpError(
"intrinsic doesn't support unsign");
2002 if (
auto attr = op.getUnsignedAAttr())
2003 attrs.push_back({
"signA", attr});
2004 if (
auto attr = op.getUnsignedBAttr())
2005 attrs.push_back({
"signB", attr});
2008 if ((op.getReuseA() || op.getReuseB()) && !intrinsic.
useReuse)
2009 return op->emitOpError(
"intrinsic doesn't support reuse");
2011 if (
auto attr = op.getReuseAAttr())
2012 attrs.push_back({
"reuseA", attr});
2013 if (
auto attr = op.getReuseBAttr())
2014 attrs.push_back({
"reuseB", attr});
2017 if (op.getClamp() && !intrinsic.
useClamp)
2018 return op->emitOpError(
"intrinsic doesn't support clamp");
2019 if (intrinsic.
useClamp && op.getClampAttr())
2020 attrs.push_back({
"clamp", op.getClampAttr()});
2022 const bool isGFX1250orHigher =
2023 chipset.majorVersion == 12 && chipset.minorVersion >= 5;
2028 Value c = adaptor.getDestC();
2029 VectorType rawOutType = outType;
2030 if (!isGFX1250orHigher) {
2032 rawOutType = cast<VectorType>(c.
getType());
2036 Value sparseIdx = LLVM::BitcastOp::create(
2037 rewriter, loc, rewriter.getI32Type(), adaptor.getSparseIdx());
2039 OperationState loweredOp(loc, intrinsic.
name);
2040 loweredOp.addTypes(rawOutType);
2041 loweredOp.addOperands({a,
b, c, sparseIdx});
2042 loweredOp.addAttributes(attrs);
2043 Operation *lowered = rewriter.create(loweredOp);
2045 Operation *maybeCastBack = lowered;
2046 if (rawOutType != outType)
2047 maybeCastBack = LLVM::BitcastOp::create(rewriter, loc, outType,
2049 rewriter.replaceOp(op, maybeCastBack->
getResults());
2056 ScaledWMMAOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
2057 : ConvertOpToLLVMPattern<ScaledWMMAOp>(converter), chipset(chipset) {}
2062 matchAndRewrite(ScaledWMMAOp op, ScaledWMMAOpAdaptor adaptor,
2063 ConversionPatternRewriter &rewriter)
const override {
2064 Location loc = op.getLoc();
2066 typeConverter->convertType<VectorType>(op.getDestD().
getType());
2068 return rewriter.notifyMatchFailure(op,
"type conversion failed");
2071 return op->emitOpError(
"WMMA scale only supported on gfx1250+");
2073 int64_t m = op.getM();
2074 int64_t n = op.getN();
2075 int64_t k = op.getK();
2083 if (!aFmtCode || !bFmtCode)
2084 return op.emitOpError(
"unsupported element types for scaled_wmma");
2087 auto scaleAVecType = cast<VectorType>(op.getScaleA().getType());
2088 auto scaleBVecType = cast<VectorType>(op.getScaleB().getType());
2090 if (scaleAVecType.getNumElements() != scaleBVecType.getNumElements())
2091 return op.emitOpError(
"scaleA and scaleB must have equal vector length");
2094 Type scaleAElemType = scaleAVecType.getElementType();
2095 Type scaleBElemType = scaleBVecType.getElementType();
2100 if (!scaleAFmt || !scaleBFmt)
2101 return op.emitOpError(
"unsupported scale element types");
2104 bool isScale16 = (scaleAVecType.getNumElements() == 8);
2105 std::optional<StringRef> intrinsicName =
2108 return op.emitOpError(
"unsupported scaled_wmma dimensions: ")
2109 << m <<
"x" << n <<
"x" << k;
2111 SmallVector<NamedAttribute, 8> attrs;
2114 bool is32x16 = (m == 32 && n == 16 && k == 128);
2116 attrs.emplace_back(
"fmtA", rewriter.getI32IntegerAttr(*aFmtCode));
2117 attrs.emplace_back(
"fmtB", rewriter.getI32IntegerAttr(*bFmtCode));
2121 attrs.emplace_back(
"modC", rewriter.getI16IntegerAttr(0));
2126 "scaleAType", rewriter.getI32IntegerAttr(op.getAFirstScaleLane() / 16));
2127 attrs.emplace_back(
"fmtScaleA", rewriter.getI32IntegerAttr(*scaleAFmt));
2129 "scaleBType", rewriter.getI32IntegerAttr(op.getBFirstScaleLane() / 16));
2130 attrs.emplace_back(
"fmtScaleB", rewriter.getI32IntegerAttr(*scaleBFmt));
2133 attrs.emplace_back(
"reuseA", rewriter.getBoolAttr(
false));
2134 attrs.emplace_back(
"reuseB", rewriter.getBoolAttr(
false));
2147 OperationState loweredOp(loc, *intrinsicName);
2148 loweredOp.addTypes(outType);
2149 loweredOp.addOperands(
2150 {sourceA, sourceB, adaptor.getDestC(), packedScaleA, packedScaleB});
2151 loweredOp.addAttributes(attrs);
2153 Operation *lowered = rewriter.create(loweredOp);
2154 rewriter.replaceOp(op, lowered->
getResults());
2160struct TransposeLoadOpLowering
2162 TransposeLoadOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
2163 : ConvertOpToLLVMPattern<TransposeLoadOp>(converter), chipset(chipset) {}
2168 matchAndRewrite(TransposeLoadOp op, TransposeLoadOpAdaptor adaptor,
2169 ConversionPatternRewriter &rewriter)
const override {
2171 return op.emitOpError(
"Non-gfx950 chipset not supported");
2173 Location loc = op.getLoc();
2174 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
2178 size_t srcElementSize =
2179 srcMemRefType.getElementType().getIntOrFloatBitWidth();
2180 if (srcElementSize < 8)
2181 return op.emitOpError(
"Expect source memref to have at least 8 bits "
2182 "element size, got ")
2185 auto resultType = cast<VectorType>(op.getResult().getType());
2188 (adaptor.getSrcIndices()));
2190 size_t numElements = resultType.getNumElements();
2191 size_t elementTypeSize =
2196 Type rocdlResultType = VectorType::get((numElements * elementTypeSize) / 32,
2197 rewriter.getIntegerType(32));
2198 Type llvmResultType = typeConverter->convertType(resultType);
2200 switch (elementTypeSize) {
2202 assert(numElements == 16);
2203 auto rocdlOp = ROCDL::ds_read_tr4_b64::create(rewriter, loc,
2204 rocdlResultType, srcPtr);
2205 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
2209 assert(numElements == 16);
2210 auto rocdlOp = ROCDL::ds_read_tr6_b96::create(rewriter, loc,
2211 rocdlResultType, srcPtr);
2212 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
2216 assert(numElements == 8);
2217 auto rocdlOp = ROCDL::ds_read_tr8_b64::create(rewriter, loc,
2218 rocdlResultType, srcPtr);
2219 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
2223 assert(numElements == 4);
2224 rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr16_b64>(op, llvmResultType,
2229 return op.emitOpError(
"Unsupported element size for transpose load");
2235struct GlobalTransposeLoadOpLowering
2237 GlobalTransposeLoadOpLowering(
const LLVMTypeConverter &converter,
2239 : ConvertOpToLLVMPattern<GlobalTransposeLoadOp>(converter),
2245 matchAndRewrite(GlobalTransposeLoadOp op,
2246 GlobalTransposeLoadOpAdaptor adaptor,
2247 ConversionPatternRewriter &rewriter)
const override {
2249 return op.emitOpError(
2250 "global_transpose_load is only supported on gfx1200+");
2252 Location loc = op.getLoc();
2253 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
2254 auto resultType = cast<VectorType>(op.getResult().getType());
2257 rewriter, loc, srcMemRefType, adaptor.getSrc(), adaptor.getSrcIndices(),
2258 LLVM::GEPNoWrapFlags::inbounds | LLVM::GEPNoWrapFlags::nuw);
2260 size_t numElements = resultType.getNumElements();
2261 size_t elementTypeSize =
2266 Type rocdlResultType =
2267 elementTypeSize < 16
2268 ? VectorType::get((numElements * elementTypeSize) / 32,
2269 rewriter.getIntegerType(32))
2270 : typeConverter->convertType(resultType);
2271 Type llvmResultType = typeConverter->convertType(resultType);
2273 switch (elementTypeSize) {
2275 assert(numElements == 16);
2277 return op.emitOpError(
"4-bit global_transpose_load requires gfx1250+");
2278 auto rocdlOp = ROCDL::GlobalLoadTr4_B64::create(rewriter, loc,
2279 rocdlResultType, srcPtr);
2280 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
2284 assert(numElements == 16);
2286 return op.emitOpError(
"6-bit global_transpose_load requires gfx1250+");
2287 auto rocdlOp = ROCDL::GlobalLoadTr6_B96::create(rewriter, loc,
2288 rocdlResultType, srcPtr);
2289 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
2293 assert(numElements == 8);
2294 auto rocdlOp = ROCDL::GlobalLoadTr8_B64::create(rewriter, loc,
2295 rocdlResultType, srcPtr);
2296 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
2300 assert(numElements == 8);
2301 rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadTr8_B128>(op, llvmResultType,
2306 return op.emitOpError(
2307 "unsupported element size for global transpose load");
2314 GatherToLDSOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
2315 : ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}
2320 matchAndRewrite(GatherToLDSOp op, GatherToLDSOpAdaptor adaptor,
2321 ConversionPatternRewriter &rewriter)
const override {
2322 if (chipset.majorVersion < 9 || chipset.majorVersion > 10)
2323 return op.emitOpError(
"pre-gfx9 and post-gfx10 not supported");
2325 Location loc = op.getLoc();
2327 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
2328 auto dstMemRefType = cast<MemRefType>(op.getDst().getType());
2333 Type transferType = op.getTransferType();
2334 int loadWidth = [&]() ->
int {
2335 if (
auto transferVectorType = dyn_cast<VectorType>(transferType)) {
2336 return (transferVectorType.getNumElements() *
2337 transferVectorType.getElementTypeBitWidth()) /
2344 if (!llvm::is_contained({1, 2, 4, 12, 16}, loadWidth))
2345 return op.emitOpError(
"chipset unsupported element size");
2347 if (chipset !=
kGfx950 && llvm::is_contained({12, 16}, loadWidth))
2348 return op.emitOpError(
"Gather to LDS instructions with 12-byte and "
2349 "16-byte load widths are only supported on gfx950");
2353 (adaptor.getSrcIndices()));
2356 (adaptor.getDstIndices()));
2358 if (op.getAsync()) {
2359 rewriter.replaceOpWithNewOp<ROCDL::LoadAsyncToLDSOp>(
2360 op, srcPtr, dstPtr, rewriter.getI32IntegerAttr(loadWidth),
2361 rewriter.getI32IntegerAttr(0),
2365 rewriter.replaceOpWithNewOp<ROCDL::LoadToLDSOp>(
2366 op, srcPtr, dstPtr, rewriter.getI32IntegerAttr(loadWidth),
2367 rewriter.getI32IntegerAttr(0),
2376struct GlobalLoadAsyncToLDSOpLowering
2378 GlobalLoadAsyncToLDSOpLowering(
const LLVMTypeConverter &converter,
2380 : ConvertOpToLLVMPattern<GlobalLoadAsyncToLDSOp>(converter),
2386 matchAndRewrite(GlobalLoadAsyncToLDSOp op,
2387 GlobalLoadAsyncToLDSOpAdaptor adaptor,
2388 ConversionPatternRewriter &rewriter)
const override {
2390 return op.emitOpError(
2391 "global_load_async_to_lds is only supported on gfx1250+");
2393 Location loc = op.getLoc();
2394 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
2395 auto dstMemRefType = cast<MemRefType>(op.getDst().getType());
2397 Type transferType = op.getTransferType();
2399 isa<VectorType>(transferType)
2400 ? cast<VectorType>(transferType).getNumElements() *
2401 cast<VectorType>(transferType).getElementTypeBitWidth()
2406 adaptor.getSrcIndices());
2409 adaptor.getDstIndices());
2412 Value mask = adaptor.getMask();
2413 int64_t nullptrVal =
2414 llvm::AMDGPU::getNullPointerValue(llvm::AMDGPUAS::LOCAL_ADDRESS);
2418 LLVM::IntToPtrOp::create(rewriter, loc, dstPtr.
getType(), nullInt);
2419 dstPtr = LLVM::SelectOp::create(rewriter, loc, mask, dstPtr, nullPtr);
2422 auto offset = rewriter.getI32IntegerAttr(0);
2423 auto aux = rewriter.getI32IntegerAttr(0);
2425 switch (transferBits) {
2427 rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadAsyncToLDSB8Op>(
2432 rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadAsyncToLDSB32Op>(
2437 rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadAsyncToLDSB64Op>(
2442 rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadAsyncToLDSB128Op>(
2447 return op.emitOpError(
"unsupported transfer width");
2454struct ExtPackedFp8OpLowering final
2456 ExtPackedFp8OpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
2457 : ConvertOpToLLVMPattern<amdgpu::ExtPackedFp8Op>(converter),
2462 matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
2463 ConversionPatternRewriter &rewriter)
const override;
2466struct ScaledExtPackedMatrixOpLowering final
2468 ScaledExtPackedMatrixOpLowering(
const LLVMTypeConverter &converter,
2470 : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedMatrixOp>(converter),
2475 matchAndRewrite(ScaledExtPackedMatrixOp op,
2476 ScaledExtPackedMatrixOpAdaptor adaptor,
2477 ConversionPatternRewriter &rewriter)
const override;
2480struct PackedTrunc2xFp8OpLowering final
2482 PackedTrunc2xFp8OpLowering(
const LLVMTypeConverter &converter,
2484 : ConvertOpToLLVMPattern<amdgpu::PackedTrunc2xFp8Op>(converter),
2489 matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
2490 ConversionPatternRewriter &rewriter)
const override;
2493struct PackedStochRoundFp8OpLowering final
2495 PackedStochRoundFp8OpLowering(
const LLVMTypeConverter &converter,
2497 : ConvertOpToLLVMPattern<amdgpu::PackedStochRoundFp8Op>(converter),
2502 matchAndRewrite(PackedStochRoundFp8Op op,
2503 PackedStochRoundFp8OpAdaptor adaptor,
2504 ConversionPatternRewriter &rewriter)
const override;
2507struct ScaledExtPackedOpLowering final
2509 ScaledExtPackedOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
2510 : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedOp>(converter),
2515 matchAndRewrite(ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
2516 ConversionPatternRewriter &rewriter)
const override;
2519struct PackedScaledTruncOpLowering final
2521 PackedScaledTruncOpLowering(
const LLVMTypeConverter &converter,
2523 : ConvertOpToLLVMPattern<amdgpu::PackedScaledTruncOp>(converter),
2528 matchAndRewrite(PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
2529 ConversionPatternRewriter &rewriter)
const override;
2534LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
2535 ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
2536 ConversionPatternRewriter &rewriter)
const {
2537 Location loc = op.getLoc();
2539 return rewriter.notifyMatchFailure(
2540 loc,
"Fp8 conversion instructions are not available on target "
2541 "architecture and their emulation is not implemented");
2543 getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type()));
2544 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2545 Type f32 = getTypeConverter()->convertType(op.getResult().getType());
2547 Value source = adaptor.getSource();
2548 auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
2549 auto resultVecType = dyn_cast<VectorType>(op.getResult().getType());
2552 if (!sourceVecType || sourceVecType.getNumElements() < 4) {
2553 Value longVec = LLVM::UndefOp::create(rewriter, loc, v4i8);
2554 if (!sourceVecType) {
2555 longVec = LLVM::InsertElementOp::create(
2558 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
2560 Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
2562 LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
2567 Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
2568 if (resultVecType) {
2570 rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Bf8Op>(op, f32, i32Source,
2573 rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Fp8Op>(op, f32, i32Source,
2578 rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
2581 rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
2588int32_t getScaleSel(int32_t blockSize,
unsigned bitWidth, int32_t scaleWaveHalf,
2589 int32_t firstScaleByte) {
2595 assert(llvm::is_contained({16, 32}, blockSize));
2596 assert(llvm::is_contained({4u, 6u, 8u}, bitWidth));
2598 const bool isFp8 = bitWidth == 8;
2599 const bool isBlock16 = blockSize == 16;
2602 int32_t bit0 = isBlock16;
2603 assert(llvm::is_contained({0, 1, 2}, firstScaleByte));
2604 int32_t bit1 = (firstScaleByte == 2) << 1;
2605 assert(llvm::is_contained({0, 1}, scaleWaveHalf));
2606 int32_t bit2 = scaleWaveHalf << 2;
2607 return bit2 | bit1 | bit0;
2610 int32_t bit0 = isBlock16;
2612 assert(llvm::is_contained({0, 1, 2, 3}, firstScaleByte));
2613 int32_t bits2and1 = firstScaleByte << 1;
2614 assert(llvm::is_contained({0, 1}, scaleWaveHalf));
2615 int32_t bit3 = scaleWaveHalf << 3;
2616 int32_t bits = bit3 | bits2and1 | bit0;
2618 assert(!llvm::is_contained(
2619 {0b0011, 0b0101, 0b0111, 0b1000, 0b1001, 0b1011, 0b1111}, bits));
2623static std::optional<StringRef>
2624scaledExtPacked816ToIntrinsic(Type srcElemType, Type destElemType) {
2625 using fp4 = Float4E2M1FNType;
2626 using fp8 = Float8E4M3FNType;
2627 using bf8 = Float8E5M2Type;
2628 using fp6 = Float6E2M3FNType;
2629 using bf6 = Float6E3M2FNType;
2630 if (isa<fp4>(srcElemType)) {
2631 if (destElemType.
isF16())
2632 return ROCDL::CvtPkScalePk8F16Fp4Op::getOperationName();
2633 if (destElemType.
isBF16())
2634 return ROCDL::CvtPkScalePk8Bf16Fp4Op::getOperationName();
2635 if (destElemType.
isF32())
2636 return ROCDL::CvtPkScalePk8F32Fp4Op::getOperationName();
2637 return std::nullopt;
2639 if (isa<fp8>(srcElemType)) {
2640 if (destElemType.
isF16())
2641 return ROCDL::CvtPkScalePk8F16Fp8Op::getOperationName();
2642 if (destElemType.
isBF16())
2643 return ROCDL::CvtPkScalePk8Bf16Fp8Op::getOperationName();
2644 if (destElemType.
isF32())
2645 return ROCDL::CvtPkScalePk8F32Fp8Op::getOperationName();
2646 return std::nullopt;
2648 if (isa<bf8>(srcElemType)) {
2649 if (destElemType.
isF16())
2650 return ROCDL::CvtPkScalePk8F16Bf8Op::getOperationName();
2651 if (destElemType.
isBF16())
2652 return ROCDL::CvtPkScalePk8Bf16Bf8Op::getOperationName();
2653 if (destElemType.
isF32())
2654 return ROCDL::CvtPkScalePk8F32Bf8Op::getOperationName();
2655 return std::nullopt;
2657 if (isa<fp6>(srcElemType)) {
2658 if (destElemType.
isF16())
2659 return ROCDL::CvtPkScalePk16F16Fp6Op::getOperationName();
2660 if (destElemType.
isBF16())
2661 return ROCDL::CvtPkScalePk16Bf16Fp6Op::getOperationName();
2662 if (destElemType.
isF32())
2663 return ROCDL::CvtPkScalePk16F32Fp6Op::getOperationName();
2664 return std::nullopt;
2666 if (isa<bf6>(srcElemType)) {
2667 if (destElemType.
isF16())
2668 return ROCDL::CvtPkScalePk16F16Bf6Op::getOperationName();
2669 if (destElemType.
isBF16())
2670 return ROCDL::CvtPkScalePk16Bf16Bf6Op::getOperationName();
2671 if (destElemType.
isF32())
2672 return ROCDL::CvtPkScalePk16F32Bf6Op::getOperationName();
2673 return std::nullopt;
2675 llvm_unreachable(
"invalid combination of element types for packed conversion "
2679LogicalResult ScaledExtPackedMatrixOpLowering::matchAndRewrite(
2680 ScaledExtPackedMatrixOp op, ScaledExtPackedMatrixOpAdaptor adaptor,
2681 ConversionPatternRewriter &rewriter)
const {
2682 using fp4 = Float4E2M1FNType;
2683 using fp8 = Float8E4M3FNType;
2684 using bf8 = Float8E5M2Type;
2685 using fp6 = Float6E2M3FNType;
2686 using bf6 = Float6E3M2FNType;
2687 Location loc = op.getLoc();
2689 return rewriter.notifyMatchFailure(
2691 "Scaled fp packed conversion instructions are not available on target "
2692 "architecture and their emulation is not implemented");
2696 int32_t scaleWaveHalf = op.getFirstScaleLane() / 16;
2697 int32_t firstScaleByte = op.getFirstScaleByte();
2698 int32_t blockSize = op.getBlockSize();
2699 auto sourceType = cast<VectorType>(op.getSource().getType());
2700 auto srcElemType = cast<FloatType>(sourceType.getElementType());
2701 unsigned bitWidth = srcElemType.getWidth();
2703 auto targetType = cast<VectorType>(op.getResult().getType());
2704 auto destElemType = cast<FloatType>(targetType.getElementType());
2706 IntegerType i32 = rewriter.getI32Type();
2707 Value source = adaptor.getSource();
2708 Type llvmResultType = typeConverter->convertType(op.getResult().getType());
2709 Type packedType =
nullptr;
2710 if (isa<fp4>(srcElemType)) {
2712 packedType = getTypeConverter()->convertType(packedType);
2713 }
else if (isa<fp8, bf8>(srcElemType)) {
2714 packedType = VectorType::get(2, i32);
2715 packedType = getTypeConverter()->convertType(packedType);
2716 }
else if (isa<fp6, bf6>(srcElemType)) {
2717 packedType = VectorType::get(3, i32);
2718 packedType = getTypeConverter()->convertType(packedType);
2720 llvm_unreachable(
"invalid element type for packed scaled ext");
2723 if (!packedType || !llvmResultType) {
2724 return rewriter.notifyMatchFailure(op,
"type conversion failed");
2727 std::optional<StringRef> maybeIntrinsic =
2728 scaledExtPacked816ToIntrinsic(srcElemType, destElemType);
2729 if (!maybeIntrinsic.has_value())
2730 return op.emitOpError(
2731 "no intrinsic matching packed scaled conversion on the given chipset");
2734 getScaleSel(blockSize, bitWidth, scaleWaveHalf, firstScaleByte);
2736 LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getScale());
2737 Value castedSource =
2738 LLVM::BitcastOp::create(rewriter, loc, packedType, source);
2740 OperationState loweredOp(loc, *maybeIntrinsic);
2741 loweredOp.addTypes({llvmResultType});
2742 loweredOp.addOperands({castedSource, castedScale});
2744 SmallVector<NamedAttribute, 1> attrs;
2746 NamedAttribute(
"scaleSel", rewriter.getI32IntegerAttr(scaleSel)));
2748 loweredOp.addAttributes(attrs);
2749 Operation *lowered = rewriter.create(loweredOp);
2750 rewriter.replaceOp(op, lowered);
2755LogicalResult ScaledExtPackedOpLowering::matchAndRewrite(
2756 ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
2757 ConversionPatternRewriter &rewriter)
const {
2758 Location loc = op.getLoc();
2760 return rewriter.notifyMatchFailure(
2761 loc,
"Scaled fp conversion instructions are not available on target "
2762 "architecture and their emulation is not implemented");
2763 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2765 Value source = adaptor.getSource();
2766 Value scale = adaptor.getScale();
2768 VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
2769 Type sourceElemType = sourceVecType.getElementType();
2770 VectorType destVecType = cast<VectorType>(op.getResult().getType());
2771 Type destElemType = destVecType.getElementType();
2773 VectorType packedVecType;
2774 if (isa<Float8E5M2Type, Float8E4M3FNType>(sourceElemType)) {
2775 VectorType v4i8 = VectorType::get(4, rewriter.getI8Type());
2776 packedVecType = cast<VectorType>(getTypeConverter()->convertType(v4i8));
2777 }
else if (isa<Float4E2M1FNType>(sourceElemType)) {
2778 VectorType v8i4 = VectorType::get(8, rewriter.getI4Type());
2779 packedVecType = cast<VectorType>(getTypeConverter()->convertType(v8i4));
2781 llvm_unreachable(
"invalid element type for scaled ext");
2785 if (sourceVecType.getNumElements() < packedVecType.getNumElements()) {
2786 Value longVec = LLVM::ZeroOp::create(rewriter, loc, packedVecType);
2787 if (!sourceVecType) {
2788 longVec = LLVM::InsertElementOp::create(
2791 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
2793 Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
2795 LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
2800 Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
2802 if (isa<Float8E5M2Type>(sourceElemType) && destElemType.
isF32())
2803 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Bf8Op>(
2804 op, destVecType, i32Source, scale, op.getIndex());
2805 else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.
isF16())
2806 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Bf8Op>(
2807 op, destVecType, i32Source, scale, op.getIndex());
2808 else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.
isBF16())
2809 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Bf8Op>(
2810 op, destVecType, i32Source, scale, op.getIndex());
2811 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.
isF32())
2812 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp8Op>(
2813 op, destVecType, i32Source, scale, op.getIndex());
2814 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.
isF16())
2815 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp8Op>(
2816 op, destVecType, i32Source, scale, op.getIndex());
2817 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.
isBF16())
2818 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp8Op>(
2819 op, destVecType, i32Source, scale, op.getIndex());
2820 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.
isF32())
2821 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp4Op>(
2822 op, destVecType, i32Source, scale, op.getIndex());
2823 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.
isF16())
2824 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp4Op>(
2825 op, destVecType, i32Source, scale, op.getIndex());
2826 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.
isBF16())
2827 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp4Op>(
2828 op, destVecType, i32Source, scale, op.getIndex());
2835LogicalResult PackedScaledTruncOpLowering::matchAndRewrite(
2836 PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
2837 ConversionPatternRewriter &rewriter)
const {
2838 Location loc = op.getLoc();
2840 return rewriter.notifyMatchFailure(
2841 loc,
"Scaled fp conversion instructions are not available on target "
2842 "architecture and their emulation is not implemented");
2843 Type v2i16 = getTypeConverter()->convertType(
2844 VectorType::get(2, rewriter.getI16Type()));
2845 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2847 Type resultType = op.getResult().getType();
2849 VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
2850 Type sourceElemType = sourceVecType.getElementType();
2852 Type intResultType = isa<Float4E2M1FNType>(resultElemType) ? i32 : v2i16;
2854 Value source = adaptor.getSource();
2855 Value scale = adaptor.getScale();
2856 Value existing = adaptor.getExisting();
2858 existing = LLVM::BitcastOp::create(rewriter, loc, intResultType, existing);
2860 existing = LLVM::ZeroOp::create(rewriter, loc, intResultType);
2862 if (sourceVecType.getNumElements() < 2) {
2864 Value elem0 = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
2865 VectorType v2 = VectorType::get(2, sourceElemType);
2866 source = LLVM::ZeroOp::create(rewriter, loc, v2);
2867 source = LLVM::InsertElementOp::create(rewriter, loc, source, elem0, c0);
2870 Value sourceA, sourceB;
2871 if (sourceElemType.
isF32()) {
2874 sourceA = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
2875 sourceB = LLVM::ExtractElementOp::create(rewriter, loc, source, c1);
2879 if (sourceElemType.
isF32() && isa<Float8E5M2Type>(resultElemType))
2880 result = ROCDL::CvtScaleF32PkBf8F32Op::create(rewriter, loc, intResultType,
2881 existing, sourceA, sourceB,
2882 scale, op.getIndex());
2883 else if (sourceElemType.
isF16() && isa<Float8E5M2Type>(resultElemType))
2884 result = ROCDL::CvtScaleF32PkBf8F16Op::create(
2885 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2886 else if (sourceElemType.
isBF16() && isa<Float8E5M2Type>(resultElemType))
2887 result = ROCDL::CvtScaleF32PkBf8Bf16Op::create(
2888 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2889 else if (sourceElemType.
isF32() && isa<Float8E4M3FNType>(resultElemType))
2890 result = ROCDL::CvtScaleF32PkFp8F32Op::create(rewriter, loc, intResultType,
2891 existing, sourceA, sourceB,
2892 scale, op.getIndex());
2893 else if (sourceElemType.
isF16() && isa<Float8E4M3FNType>(resultElemType))
2894 result = ROCDL::CvtScaleF32PkFp8F16Op::create(
2895 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2896 else if (sourceElemType.
isBF16() && isa<Float8E4M3FNType>(resultElemType))
2897 result = ROCDL::CvtScaleF32PkFp8Bf16Op::create(
2898 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2899 else if (sourceElemType.
isF32() && isa<Float4E2M1FNType>(resultElemType))
2900 result = ROCDL::CvtScaleF32PkFp4F32Op::create(rewriter, loc, intResultType,
2901 existing, sourceA, sourceB,
2902 scale, op.getIndex());
2903 else if (sourceElemType.
isF16() && isa<Float4E2M1FNType>(resultElemType))
2904 result = ROCDL::CvtScaleF32PkFp4F16Op::create(
2905 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2906 else if (sourceElemType.
isBF16() && isa<Float4E2M1FNType>(resultElemType))
2907 result = ROCDL::CvtScaleF32PkFp4Bf16Op::create(
2908 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2912 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
2913 op, getTypeConverter()->convertType(resultType),
result);
2917LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
2918 PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
2919 ConversionPatternRewriter &rewriter)
const {
2920 Location loc = op.getLoc();
2922 return rewriter.notifyMatchFailure(
2923 loc,
"Fp8 conversion instructions are not available on target "
2924 "architecture and their emulation is not implemented");
2925 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2927 Type resultType = op.getResult().getType();
2930 Value sourceA = adaptor.getSourceA();
2931 Value sourceB = adaptor.getSourceB();
2933 sourceB = LLVM::UndefOp::create(rewriter, loc, sourceA.
getType());
2934 Value existing = adaptor.getExisting();
2936 existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
2938 existing = LLVM::UndefOp::create(rewriter, loc, i32);
2942 result = ROCDL::CvtPkBf8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
2943 existing, op.getWordIndex());
2945 result = ROCDL::CvtPkFp8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
2946 existing, op.getWordIndex());
2948 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
2949 op, getTypeConverter()->convertType(resultType),
result);
2953LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
2954 PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
2955 ConversionPatternRewriter &rewriter)
const {
2956 Location loc = op.getLoc();
2958 return rewriter.notifyMatchFailure(
2959 loc,
"Fp8 conversion instructions are not available on target "
2960 "architecture and their emulation is not implemented");
2961 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2963 Type resultType = op.getResult().getType();
2966 Value source = adaptor.getSource();
2967 Value stoch = adaptor.getStochiasticParam();
2968 Value existing = adaptor.getExisting();
2970 existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
2972 existing = LLVM::UndefOp::create(rewriter, loc, i32);
2976 result = ROCDL::CvtSrBf8F32Op::create(rewriter, loc, i32, source, stoch,
2977 existing, op.getStoreIndex());
2979 result = ROCDL::CvtSrFp8F32Op::create(rewriter, loc, i32, source, stoch,
2980 existing, op.getStoreIndex());
2982 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
2983 op, getTypeConverter()->convertType(resultType),
result);
2989struct AMDGPUDPPLowering :
public ConvertOpToLLVMPattern<DPPOp> {
2990 AMDGPUDPPLowering(
const LLVMTypeConverter &converter, Chipset chipset)
2991 : ConvertOpToLLVMPattern<DPPOp>(converter), chipset(chipset) {}
2995 matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor,
2996 ConversionPatternRewriter &rewriter)
const override {
2999 Location loc = DppOp.getLoc();
3000 Value src = adaptor.getSrc();
3001 Value old = adaptor.getOld();
3004 Type llvmType =
nullptr;
3006 llvmType = rewriter.getI32Type();
3007 }
else if (isa<FloatType>(srcType)) {
3009 ? rewriter.getF32Type()
3010 : rewriter.getF64Type();
3011 }
else if (isa<IntegerType>(srcType)) {
3013 ? rewriter.getI32Type()
3014 : rewriter.getI64Type();
3016 auto llvmSrcIntType = typeConverter->convertType(
3020 auto convertOperand = [&](Value operand, Type operandType) {
3021 if (operandType.getIntOrFloatBitWidth() <= 16) {
3022 if (llvm::isa<FloatType>(operandType)) {
3024 LLVM::BitcastOp::create(rewriter, loc, llvmSrcIntType, operand);
3026 auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
3027 32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
3028 Value undefVec = LLVM::UndefOp::create(rewriter, loc, llvmVecType);
3030 LLVM::InsertElementOp::create(rewriter, loc, undefVec, operand,
3032 operand = LLVM::BitcastOp::create(rewriter, loc, llvmType, operand);
3037 src = convertOperand(src, srcType);
3038 old = convertOperand(old, oldType);
3041 enum DppCtrl :
unsigned {
3050 ROW_HALF_MIRROR = 0x141,
3055 auto kind = DppOp.getKind();
3056 auto permArgument = DppOp.getPermArgument();
3057 uint32_t DppCtrl = 0;
3061 case DPPPerm::quad_perm: {
3062 auto quadPermAttr = cast<ArrayAttr>(*permArgument);
3064 for (
auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
3065 uint32_t num = elem.getInt();
3066 DppCtrl |= num << (i * 2);
3071 case DPPPerm::row_shl: {
3072 auto intAttr = cast<IntegerAttr>(*permArgument);
3073 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0;
3076 case DPPPerm::row_shr: {
3077 auto intAttr = cast<IntegerAttr>(*permArgument);
3078 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0;
3081 case DPPPerm::row_ror: {
3082 auto intAttr = cast<IntegerAttr>(*permArgument);
3083 DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0;
3086 case DPPPerm::wave_shl:
3087 DppCtrl = DppCtrl::WAVE_SHL1;
3089 case DPPPerm::wave_shr:
3090 DppCtrl = DppCtrl::WAVE_SHR1;
3092 case DPPPerm::wave_rol:
3093 DppCtrl = DppCtrl::WAVE_ROL1;
3095 case DPPPerm::wave_ror:
3096 DppCtrl = DppCtrl::WAVE_ROR1;
3098 case DPPPerm::row_mirror:
3099 DppCtrl = DppCtrl::ROW_MIRROR;
3101 case DPPPerm::row_half_mirror:
3102 DppCtrl = DppCtrl::ROW_HALF_MIRROR;
3104 case DPPPerm::row_bcast_15:
3105 DppCtrl = DppCtrl::BCAST15;
3107 case DPPPerm::row_bcast_31:
3108 DppCtrl = DppCtrl::BCAST31;
3114 auto rowMask = DppOp->getAttrOfType<IntegerAttr>(
"row_mask").getInt();
3115 auto bankMask = DppOp->getAttrOfType<IntegerAttr>(
"bank_mask").getInt();
3116 bool boundCtrl = DppOp->getAttrOfType<BoolAttr>(
"bound_ctrl").getValue();
3120 ROCDL::DPPUpdateOp::create(rewriter, loc, llvmType, old, src, DppCtrl,
3121 rowMask, bankMask, boundCtrl);
3123 Value
result = dppMovOp.getRes();
3125 result = LLVM::TruncOp::create(rewriter, loc, llvmSrcIntType,
result);
3126 if (!llvm::isa<IntegerType>(srcType)) {
3127 result = LLVM::BitcastOp::create(rewriter, loc, srcType,
result);
3138struct AMDGPUSwizzleBitModeLowering
3139 :
public ConvertOpToLLVMPattern<SwizzleBitModeOp> {
3143 matchAndRewrite(SwizzleBitModeOp op, OpAdaptor adaptor,
3144 ConversionPatternRewriter &rewriter)
const override {
3145 Location loc = op.getLoc();
3146 Type i32 = rewriter.getI32Type();
3147 Value src = adaptor.getSrc();
3148 SmallVector<Value> decomposed;
3149 if (
failed(LLVM::decomposeValue(rewriter, loc, src, i32, decomposed)))
3150 return rewriter.notifyMatchFailure(op,
3151 "failed to decompose value to i32");
3152 unsigned andMask = op.getAndMask();
3153 unsigned orMask = op.getOrMask();
3154 unsigned xorMask = op.getXorMask();
3158 unsigned mask = andMask | (orMask << 5) | (xorMask << 10);
3160 SmallVector<Value> swizzled;
3161 for (Value v : decomposed) {
3163 ROCDL::DsSwizzleOp::create(rewriter, loc, v.getType(), v, maskValue);
3164 swizzled.emplace_back(res);
3167 Value
result = LLVM::composeValue(rewriter, loc, swizzled, src.
getType());
3168 rewriter.replaceOp(op,
result);
3173struct AMDGPUPermlaneLowering :
public ConvertOpToLLVMPattern<PermlaneSwapOp> {
3176 AMDGPUPermlaneLowering(
const LLVMTypeConverter &converter, Chipset chipset)
3177 : ConvertOpToLLVMPattern<PermlaneSwapOp>(converter), chipset(chipset) {}
3181 matchAndRewrite(PermlaneSwapOp op, OpAdaptor adaptor,
3182 ConversionPatternRewriter &rewriter)
const override {
3184 return op->emitOpError(
"permlane_swap is only supported on gfx950+");
3186 Location loc = op.getLoc();
3187 Type i32 = rewriter.getI32Type();
3188 Value src = adaptor.getSrc();
3189 unsigned rowLength = op.getRowLength();
3190 bool fi = op.getFetchInactive();
3191 bool boundctrl = op.getBoundCtrl();
3193 SmallVector<Value> decomposed;
3194 if (
failed(LLVM::decomposeValue(rewriter, loc, src, i32, decomposed)))
3195 return rewriter.notifyMatchFailure(op,
3196 "failed to decompose value to i32");
3198 SmallVector<Value> permuted;
3199 for (Value v : decomposed) {
3201 Type i32pair = LLVM::LLVMStructType::getLiteral(
3202 rewriter.getContext(), {v.getType(), v.getType()});
3204 if (rowLength == 16)
3205 res = ROCDL::Permlane16SwapOp::create(rewriter, loc, i32pair, v, v, fi,
3207 else if (rowLength == 32)
3208 res = ROCDL::Permlane32SwapOp::create(rewriter, loc, i32pair, v, v, fi,
3211 llvm_unreachable(
"unsupported row length");
3213 Value vdst0 = LLVM::ExtractValueOp::create(rewriter, loc, res, {0});
3214 Value vdst1 = LLVM::ExtractValueOp::create(rewriter, loc, res, {1});
3216 Value isEqual = LLVM::ICmpOp::create(rewriter, loc,
3217 LLVM::ICmpPredicate::eq, vdst0, v);
3222 LLVM::SelectOp::create(rewriter, loc, isEqual, vdst1, vdst0);
3223 permuted.emplace_back(vdstNew);
3226 Value
result = LLVM::composeValue(rewriter, loc, permuted, src.
getType());
3227 rewriter.replaceOp(op,
result);
3240constexpr int32_t kDsBarrierPendingCountBitWidth = 29;
3241constexpr int32_t kDsBarrierPhasePos = kDsBarrierPendingCountBitWidth;
3242constexpr int32_t kDsBarrierInitCountPos = 32;
3243constexpr int32_t kDsBarrierPendingCountMask =
3244 (1 << kDsBarrierPendingCountBitWidth) - 1;
3246struct DsBarrierInitOpLowering
3247 :
public ConvertOpToLLVMPattern<DsBarrierInitOp> {
3250 DsBarrierInitOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
3251 : ConvertOpToLLVMPattern<DsBarrierInitOp>(converter), chipset(chipset) {}
3254 matchAndRewrite(DsBarrierInitOp op, OpAdaptor adaptor,
3255 ConversionPatternRewriter &rewriter)
const override {
3257 return op->emitOpError(
"only supported on gfx1250+");
3259 Location loc = op.getLoc();
3260 Type i64 = rewriter.getI64Type();
3262 MemRefType memrefType = cast<MemRefType>(op.getBase().getType());
3264 adaptor.getBase(), adaptor.getIndices());
3271 LLVM::SubOp::create(rewriter, loc, adaptor.getParticipants(),
3278 Value maskedCount32 =
3279 LLVM::AndOp::create(rewriter, loc, initCount, countMask);
3280 Value maskedCount = LLVM::ZExtOp::create(rewriter, loc, i64, maskedCount32);
3282 Value initCountShifted = LLVM::ShlOp::create(
3283 rewriter, loc, maskedCount,
3285 Value barrierState =
3286 LLVM::OrOp::create(rewriter, loc, initCountShifted, maskedCount);
3288 LLVM::StoreOp::create(
3289 rewriter, loc, barrierState, ptr, 8,
false,
3291 false, LLVM::AtomicOrdering::release,
3294 rewriter.eraseOp(op);
3299struct DsBarrierPollStateOpLowering
3300 :
public ConvertOpToLLVMPattern<DsBarrierPollStateOp> {
3303 DsBarrierPollStateOpLowering(
const LLVMTypeConverter &converter,
3305 : ConvertOpToLLVMPattern<DsBarrierPollStateOp>(converter),
3309 matchAndRewrite(DsBarrierPollStateOp op, OpAdaptor adaptor,
3310 ConversionPatternRewriter &rewriter)
const override {
3312 return op->emitOpError(
"only supported on gfx1250+");
3314 Location loc = op.getLoc();
3315 Type i64 = rewriter.getI64Type();
3317 MemRefType memrefType = cast<MemRefType>(op.getBase().getType());
3319 adaptor.getBase(), adaptor.getIndices());
3323 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
3324 op, i64, ptr, 8,
false,
3326 false, LLVM::AtomicOrdering::acquire,
3332struct DsAsyncBarrierArriveOpLowering
3333 :
public ConvertOpToLLVMPattern<DsAsyncBarrierArriveOp> {
3336 DsAsyncBarrierArriveOpLowering(
const LLVMTypeConverter &converter,
3338 : ConvertOpToLLVMPattern<DsAsyncBarrierArriveOp>(converter),
3342 matchAndRewrite(DsAsyncBarrierArriveOp op, OpAdaptor adaptor,
3343 ConversionPatternRewriter &rewriter)
const override {
3345 return op->emitOpError(
"only supported on gfx1250+");
3347 Location loc = op.getLoc();
3349 MemRefType memrefType = cast<MemRefType>(op.getBase().getType());
3351 adaptor.getBase(), adaptor.getIndices());
3353 rewriter.replaceOpWithNewOp<ROCDL::DsAtomicAsyncBarrierArriveOp>(
3354 op, ptr,
nullptr,
nullptr,
3360struct DsBarrierArriveOpLowering
3361 :
public ConvertOpToLLVMPattern<DsBarrierArriveOp> {
3364 DsBarrierArriveOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
3365 : ConvertOpToLLVMPattern<DsBarrierArriveOp>(converter), chipset(chipset) {
3369 matchAndRewrite(DsBarrierArriveOp op, OpAdaptor adaptor,
3370 ConversionPatternRewriter &rewriter)
const override {
3372 return op->emitOpError(
"only supported on gfx1250+");
3374 Location loc = op.getLoc();
3375 Type i64 = rewriter.getI64Type();
3377 MemRefType memrefType = cast<MemRefType>(op.getBase().getType());
3379 adaptor.getBase(), adaptor.getIndices());
3381 rewriter.replaceOpWithNewOp<ROCDL::DsAtomicBarrierArriveRtnOp>(
3382 op, i64, ptr, adaptor.getCount(),
nullptr,
3388struct DsBarrierStatePhaseOpLowering
3389 :
public ConvertOpToLLVMPattern<DsBarrierStatePhaseOp> {
3393 matchAndRewrite(DsBarrierStatePhaseOp op, OpAdaptor adaptor,
3394 ConversionPatternRewriter &rewriter)
const override {
3395 Location loc = op.getLoc();
3396 Type i32 = rewriter.getI32Type();
3398 Value state = adaptor.getState();
3400 Value noInitCount = LLVM::TruncOp::create(rewriter, loc, i32, state);
3401 Value phase = LLVM::LShrOp::create(
3402 rewriter, loc, noInitCount,
3405 rewriter.replaceOp(op, phase);
3410struct DsBarrierStatePendingCountOpLowering
3411 :
public ConvertOpToLLVMPattern<DsBarrierStatePendingCountOp> {
3415 matchAndRewrite(DsBarrierStatePendingCountOp op, OpAdaptor adaptor,
3416 ConversionPatternRewriter &rewriter)
const override {
3417 Location loc = op.getLoc();
3418 Type i32 = rewriter.getI32Type();
3420 Value state = adaptor.getState();
3422 Value noInitCount = LLVM::TruncOp::create(rewriter, loc, i32, state);
3423 Value pendingCount = LLVM::AndOp::create(
3424 rewriter, loc, noInitCount,
3426 static_cast<uint32_t
>(kDsBarrierPendingCountMask)));
3428 rewriter.replaceOp(op, pendingCount);
3433struct DsBarrierStateInitCountOpLowering
3434 :
public ConvertOpToLLVMPattern<DsBarrierStateInitCountOp> {
3438 matchAndRewrite(DsBarrierStateInitCountOp op, OpAdaptor adaptor,
3439 ConversionPatternRewriter &rewriter)
const override {
3440 Location loc = op.getLoc();
3441 Type i32 = rewriter.getI32Type();
3443 Value state = adaptor.getState();
3445 Value initCountI64 = LLVM::LShrOp::create(
3446 rewriter, loc, state,
3448 Value initCount = LLVM::TruncOp::create(rewriter, loc, i32, initCountI64);
3450 rewriter.replaceOp(op, initCount);
3455struct DsBarrierStatePhaseParityLowering
3456 :
public ConvertOpToLLVMPattern<DsBarrierStatePhaseParity> {
3460 matchAndRewrite(DsBarrierStatePhaseParity op, OpAdaptor adaptor,
3461 ConversionPatternRewriter &rewriter)
const override {
3462 Location loc = op.getLoc();
3463 Type i1 = rewriter.getI1Type();
3465 Value state = adaptor.getState();
3468 LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(), state);
3469 Value phase = LLVM::LShrOp::create(
3470 rewriter, loc, noInitCount,
3472 Value parity = LLVM::TruncOp::create(rewriter, loc, i1, phase);
3474 rewriter.replaceOp(op, parity);
3483static Value setValueAtOffset(ConversionPatternRewriter &rewriter, Location loc,
3484 Value accumulator, Value value, int64_t shift) {
3489 value = LLVM::ShlOp::create(rewriter, loc, value, shiftAmount);
3495 constexpr bool isDisjoint =
true;
3496 return LLVM::OrOp::create(rewriter, loc, accumulator, value, isDisjoint);
3499template <
typename BaseOp>
3500struct AMDGPUMakeDmaBaseLowering :
public ConvertOpToLLVMPattern<BaseOp> {
3501 using ConvertOpToLLVMPattern<BaseOp>::ConvertOpToLLVMPattern;
3504 AMDGPUMakeDmaBaseLowering(
const LLVMTypeConverter &converter, Chipset chipset)
3505 : ConvertOpToLLVMPattern<BaseOp>(converter), chipset(chipset) {}
3509 matchAndRewrite(BaseOp op, Adaptor adaptor,
3510 ConversionPatternRewriter &rewriter)
const override {
3512 return op->emitOpError(
"make_dma_base is only supported on gfx1250");
3514 Location loc = op.getLoc();
3516 constexpr int32_t constlen = 4;
3517 Value consts[constlen];
3518 for (int64_t i = 0; i < constlen; ++i)
3521 constexpr int32_t sgprslen = constlen;
3522 Value sgprs[sgprslen];
3523 for (int64_t i = 0; i < sgprslen; ++i) {
3524 sgprs[i] = consts[0];
3527 sgprs[0] = consts[1];
3529 if constexpr (BaseOp::isGather()) {
3530 sgprs[0] = setValueAtOffset(rewriter, loc, sgprs[0], consts[1], 30);
3532 auto type = cast<TDMGatherBaseType>(op.getResult().getType());
3533 Type indexType = type.getIndexType();
3535 assert(llvm::is_contained({16u, 32u}, indexSize) &&
3536 "expected index_size to be 16 or 32");
3537 unsigned idx = (indexSize / 16) - 1;
3540 sgprs[0] = setValueAtOffset(rewriter, loc, sgprs[0], consts[1], 31);
3543 ValueRange ldsIndices = adaptor.getLdsIndices();
3544 Value lds = adaptor.getLds();
3545 auto ldsMemRefType = cast<MemRefType>(op.getLds().getType());
3548 rewriter, loc, ldsMemRefType, lds, ldsIndices);
3550 ValueRange globalIndices = adaptor.getGlobalIndices();
3551 Value global = adaptor.getGlobal();
3552 auto globalMemRefType = cast<MemRefType>(op.getGlobal().getType());
3555 rewriter, loc, globalMemRefType, global, globalIndices);
3557 Type i32 = rewriter.getI32Type();
3558 Type i64 = rewriter.getI64Type();
3560 sgprs[1] = LLVM::PtrToIntOp::create(rewriter, loc, i32, ldsPtr);
3561 Value castForGlobalAddr =
3562 LLVM::PtrToIntOp::create(rewriter, loc, i64, globalPtr);
3564 sgprs[2] = LLVM::TruncOp::create(rewriter, loc, i32, castForGlobalAddr);
3566 Value shift = LLVM::LShrOp::create(rewriter, loc, castForGlobalAddr,
3569 Value highHalf = LLVM::TruncOp::create(rewriter, loc, i32, shift);
3572 highHalf = LLVM::AndOp::create(rewriter, loc, highHalf, mask);
3574 sgprs[3] = setValueAtOffset(rewriter, loc, highHalf, consts[2], 30);
3576 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
3577 assert(v4i32 &&
"expected type conversion to succeed");
3578 Value
result = LLVM::PoisonOp::create(rewriter, loc, v4i32);
3580 for (
auto [sgpr, constant] : llvm::zip_equal(sgprs, consts))
3582 LLVM::InsertElementOp::create(rewriter, loc,
result, sgpr, constant);
3584 rewriter.replaceOp(op,
result);
3589template <
typename DescriptorOp>
3590struct AMDGPULowerDescriptor :
public ConvertOpToLLVMPattern<DescriptorOp> {
3591 using ConvertOpToLLVMPattern<DescriptorOp>::ConvertOpToLLVMPattern;
3594 AMDGPULowerDescriptor(
const LLVMTypeConverter &converter, Chipset chipset)
3595 : ConvertOpToLLVMPattern<DescriptorOp>(converter), chipset(chipset) {}
3598 Value getDGroup0(OpAdaptor adaptor)
const {
return adaptor.getBase(); }
3600 Value setWorkgroupMask(DescriptorOp op, OpAdaptor adaptor,
3601 ConversionPatternRewriter &rewriter, Location loc,
3602 Value sgpr0)
const {
3603 Value mask = op.getWorkgroupMask();
3607 Type i16 = rewriter.getI16Type();
3608 mask = LLVM::BitcastOp::create(rewriter, loc, i16, mask);
3609 Type i32 = rewriter.getI32Type();
3610 Value extendedMask = LLVM::ZExtOp::create(rewriter, loc, i32, mask);
3611 return setValueAtOffset(rewriter, loc, sgpr0, extendedMask, 0);
3614 Value setDataSize(DescriptorOp op, OpAdaptor adaptor,
3615 ConversionPatternRewriter &rewriter, Location loc,
3616 Value sgpr0, ArrayRef<Value> consts)
const {
3617 unsigned elementTypeWidthInBits = op.getElementTypeWidth();
3618 assert(llvm::is_contained({8u, 16u, 32u, 64u}, elementTypeWidthInBits) &&
3619 "expected type width to be 8, 16, 32, or 64.");
3620 int64_t idx = llvm::Log2_32(elementTypeWidthInBits / 8);
3621 Value size = consts[idx];
3622 return setValueAtOffset(rewriter, loc, sgpr0, size, 16);
3625 Value setAtomicBarrier(DescriptorOp op, OpAdaptor adaptor,
3626 ConversionPatternRewriter &rewriter, Location loc,
3627 Value sgpr0, ArrayRef<Value> consts)
const {
3628 if (!adaptor.getAtomicBarrierAddress())
3631 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 18);
3634 Value setIterateEnable(DescriptorOp op, OpAdaptor adaptor,
3635 ConversionPatternRewriter &rewriter, Location loc,
3636 Value sgpr0, ArrayRef<Value> consts)
const {
3637 if (!adaptor.getGlobalIncrement())
3642 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 19);
3645 Value setPadEnable(DescriptorOp op, OpAdaptor adaptor,
3646 ConversionPatternRewriter &rewriter, Location loc,
3647 Value sgpr0, ArrayRef<Value> consts)
const {
3648 if (!op.getPadAmount())
3651 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 20);
3654 Value setEarlyTimeout(DescriptorOp op, OpAdaptor adaptor,
3655 ConversionPatternRewriter &rewriter, Location loc,
3656 Value sgpr0, ArrayRef<Value> consts)
const {
3657 if (!op.getWorkgroupMask())
3660 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 21);
3663 Value setPadInterval(DescriptorOp op, OpAdaptor adaptor,
3664 ConversionPatternRewriter &rewriter, Location loc,
3665 Value sgpr0, ArrayRef<Value> consts)
const {
3666 if (!op.getPadAmount())
3675 IntegerType i32 = rewriter.getI32Type();
3676 Value padInterval = adaptor.getPadInterval();
3677 padInterval = LLVM::CountTrailingZerosOp::create(rewriter, loc, i32,
3678 padInterval,
false);
3679 padInterval = LLVM::SubOp::create(rewriter, loc, padInterval, consts[1]);
3681 return setValueAtOffset(rewriter, loc, sgpr0, padInterval, 22);
3684 Value setPadAmount(DescriptorOp op, OpAdaptor adaptor,
3685 ConversionPatternRewriter &rewriter, Location loc,
3686 Value sgpr0, ArrayRef<Value> consts)
const {
3687 if (!op.getPadAmount())
3696 Value padAmount = adaptor.getPadAmount();
3697 padAmount = LLVM::SubOp::create(rewriter, loc, padAmount, consts[1]);
3699 return setValueAtOffset(rewriter, loc, sgpr0, padAmount, 25);
3702 Value setAtomicBarrierAddress(DescriptorOp op, OpAdaptor adaptor,
3703 ConversionPatternRewriter &rewriter,
3704 Location loc, Value sgpr1,
3705 ArrayRef<Value> consts)
const {
3706 if (!adaptor.getAtomicBarrierAddress())
3709 Value atomicBarrierAddress = adaptor.getAtomicBarrierAddress();
3710 auto barrierAddressTy =
3711 cast<MemRefType>(op.getAtomicBarrierAddress().getType());
3712 ValueRange atomicBarrierIndices = adaptor.getAtomicBarrierIndices();
3714 rewriter, loc, barrierAddressTy, atomicBarrierAddress,
3715 atomicBarrierIndices);
3716 IntegerType i32 = rewriter.getI32Type();
3722 atomicBarrierAddress =
3723 LLVM::PtrToIntOp::create(rewriter, loc, i32, atomicBarrierAddress);
3724 atomicBarrierAddress =
3725 LLVM::LShrOp::create(rewriter, loc, atomicBarrierAddress, consts[3]);
3727 atomicBarrierAddress =
3728 LLVM::AndOp::create(rewriter, loc, atomicBarrierAddress, mask);
3729 return setValueAtOffset(rewriter, loc, sgpr1, atomicBarrierAddress, 32);
3732 std::pair<Value, Value> setTensorDimX(DescriptorOp op, OpAdaptor adaptor,
3733 ConversionPatternRewriter &rewriter,
3734 Location loc, Value sgpr1, Value sgpr2,
3735 ArrayRef<Value> consts, uint64_t dimX,
3736 uint32_t offset)
const {
3737 ArrayRef<int64_t> globalStaticSizes = adaptor.getGlobalStaticSizes();
3738 ValueRange globalDynamicSizes = adaptor.getGlobalDynamicSizes();
3739 SmallVector<OpFoldResult> mixedGlobalSizes =
3741 if (mixedGlobalSizes.size() <= dimX)
3742 return {sgpr1, sgpr2};
3744 OpFoldResult tensorDimXOpFoldResult = *(mixedGlobalSizes.rbegin() + dimX);
3751 if (
auto attr = dyn_cast<Attribute>(tensorDimXOpFoldResult)) {
3755 IntegerType i32 = rewriter.getI32Type();
3756 tensorDimX = cast<Value>(tensorDimXOpFoldResult);
3757 tensorDimX = LLVM::TruncOp::create(rewriter, loc, i32, tensorDimX);
3760 sgpr1 = setValueAtOffset(rewriter, loc, sgpr1, tensorDimX, offset);
3763 Value tensorDimXHigh = LLVM::LShrOp::create(rewriter, loc, tensorDimX, c16);
3764 sgpr2 = setValueAtOffset(rewriter, loc, sgpr2, tensorDimXHigh, offset + 16);
3765 return {sgpr1, sgpr2};
3768 std::pair<Value, Value> setTensorDim0(DescriptorOp op, OpAdaptor adaptor,
3769 ConversionPatternRewriter &rewriter,
3770 Location loc, Value sgpr1, Value sgpr2,
3771 ArrayRef<Value> consts)
const {
3772 return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, sgpr2, consts, 0,
3776 std::pair<Value, Value> setTensorDim1(DescriptorOp op, OpAdaptor adaptor,
3777 ConversionPatternRewriter &rewriter,
3778 Location loc, Value sgpr2, Value sgpr3,
3779 ArrayRef<Value> consts)
const {
3780 return setTensorDimX(op, adaptor, rewriter, loc, sgpr2, sgpr3, consts, 1,
3784 Value setTileDimX(DescriptorOp op, OpAdaptor adaptor,
3785 ConversionPatternRewriter &rewriter, Location loc,
3786 Value sgpr, ArrayRef<Value> consts,
size_t dimX,
3787 int64_t offset)
const {
3788 ArrayRef<int64_t> sharedStaticSizes = adaptor.getSharedStaticSizes();
3789 ValueRange sharedDynamicSizes = adaptor.getSharedDynamicSizes();
3790 SmallVector<OpFoldResult> mixedSharedSizes =
3792 if (mixedSharedSizes.size() <= dimX)
3795 OpFoldResult tileDimXOpFoldResult = *(mixedSharedSizes.rbegin() + dimX);
3804 if (
auto attr = dyn_cast<Attribute>(tileDimXOpFoldResult)) {
3808 IntegerType i32 = rewriter.getI32Type();
3809 tileDimX = cast<Value>(tileDimXOpFoldResult);
3810 tileDimX = LLVM::TruncOp::create(rewriter, loc, i32, tileDimX);
3813 return setValueAtOffset(rewriter, loc, sgpr, tileDimX, offset);
3816 Value setTileDim0(DescriptorOp op, OpAdaptor adaptor,
3817 ConversionPatternRewriter &rewriter, Location loc,
3818 Value sgpr3, ArrayRef<Value> consts)
const {
3819 return setTileDimX(op, adaptor, rewriter, loc, sgpr3, consts, 0, 112);
3822 Value setTileDim1(DescriptorOp op, OpAdaptor adaptor,
3823 ConversionPatternRewriter &rewriter, Location loc,
3824 Value sgpr4, ArrayRef<Value> consts)
const {
3825 return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 1, 128);
3828 Value setValidIndices(DescriptorOp op, OpAdaptor adaptor,
3829 ConversionPatternRewriter &rewriter, Location loc,
3830 Value sgpr4, ArrayRef<Value> consts)
const {
3831 auto type = cast<VectorType>(op.getIndices().getType());
3832 ArrayRef<int64_t> shape = type.getShape();
3833 assert(shape.size() == 1 &&
"expected shape to be of rank 1.");
3834 unsigned length = shape.back();
3835 assert(0 < length && length <= 16 &&
"expected length to be at most 16.");
3837 return setValueAtOffset(rewriter, loc, sgpr4, value, 128);
3840 Value setTileDim1OrValidIndices(DescriptorOp op, OpAdaptor adaptor,
3841 ConversionPatternRewriter &rewriter,
3842 Location loc, Value sgpr4,
3843 ArrayRef<Value> consts)
const {
3844 if constexpr (DescriptorOp::isGather())
3845 return setValidIndices(op, adaptor, rewriter, loc, sgpr4, consts);
3846 return setTileDim1(op, adaptor, rewriter, loc, sgpr4, consts);
3849 Value setTileDim2(DescriptorOp op, OpAdaptor adaptor,
3850 ConversionPatternRewriter &rewriter, Location loc,
3851 Value sgpr4, ArrayRef<Value> consts)
const {
3853 if constexpr (DescriptorOp::isGather())
3855 return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 2, 144);
3858 std::pair<Value, Value>
3859 setTensorDimXStride(DescriptorOp op, OpAdaptor adaptor,
3860 ConversionPatternRewriter &rewriter, Location loc,
3861 Value sgprY, Value sgprZ, ArrayRef<Value> consts,
3862 size_t dimX, int64_t offset)
const {
3863 ArrayRef<int64_t> globalStaticStrides = adaptor.getGlobalStaticStrides();
3864 ValueRange globalDynamicStrides = adaptor.getGlobalDynamicStrides();
3865 SmallVector<OpFoldResult> mixedGlobalStrides =
3866 getMixedValues(globalStaticStrides, globalDynamicStrides, rewriter);
3868 if (mixedGlobalStrides.size() <= (dimX + 1))
3869 return {sgprY, sgprZ};
3871 OpFoldResult tensorDimXStrideOpFoldResult =
3872 *(mixedGlobalStrides.rbegin() + dimX + 1);
3877 Value tensorDimXStride;
3878 if (
auto attr = dyn_cast<Attribute>(tensorDimXStrideOpFoldResult))
3882 tensorDimXStride = cast<Value>(tensorDimXStrideOpFoldResult);
3884 constexpr int64_t first48bits = (1ll << 48) - 1;
3887 LLVM::AndOp::create(rewriter, loc, mask, tensorDimXStride);
3888 IntegerType i32 = rewriter.getI32Type();
3889 Value tensorDimXStrideLow =
3890 LLVM::TruncOp::create(rewriter, loc, i32, tensorDimXStride);
3891 sgprY = setValueAtOffset(rewriter, loc, sgprY, tensorDimXStrideLow, offset);
3893 int64_t shift = (offset % 32) == 0 ? 32 : offset % 32;
3895 Value tensorDimXStrideHigh =
3896 LLVM::LShrOp::create(rewriter, loc, tensorDimXStride, shiftVal);
3897 tensorDimXStrideHigh =
3898 LLVM::TruncOp::create(rewriter, loc, i32, tensorDimXStrideHigh);
3899 sgprZ = setValueAtOffset(rewriter, loc, sgprZ, tensorDimXStrideHigh,
3901 return {sgprY, sgprZ};
3904 std::pair<Value, Value>
3905 setTensorDim0Stride(DescriptorOp op, OpAdaptor adaptor,
3906 ConversionPatternRewriter &rewriter, Location loc,
3907 Value sgpr5, Value sgpr6, ArrayRef<Value> consts)
const {
3908 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts,
3912 std::pair<Value, Value>
3913 setTensorDim1Stride(DescriptorOp op, OpAdaptor adaptor,
3914 ConversionPatternRewriter &rewriter, Location loc,
3915 Value sgpr5, Value sgpr6, ArrayRef<Value> consts)
const {
3917 if constexpr (DescriptorOp::isGather())
3918 return {sgpr5, sgpr6};
3919 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts,
3923 Value getDGroup1(DescriptorOp op, OpAdaptor adaptor,
3924 ConversionPatternRewriter &rewriter, Location loc,
3925 ArrayRef<Value> consts)
const {
3927 for (int64_t i = 0; i < 8; ++i) {
3928 sgprs[i] = consts[0];
3931 sgprs[0] = setWorkgroupMask(op, adaptor, rewriter, loc, sgprs[0]);
3932 sgprs[0] = setDataSize(op, adaptor, rewriter, loc, sgprs[0], consts);
3933 sgprs[0] = setAtomicBarrier(op, adaptor, rewriter, loc, sgprs[0], consts);
3934 sgprs[0] = setIterateEnable(op, adaptor, rewriter, loc, sgprs[0], consts);
3935 sgprs[0] = setPadEnable(op, adaptor, rewriter, loc, sgprs[0], consts);
3936 sgprs[0] = setEarlyTimeout(op, adaptor, rewriter, loc, sgprs[0], consts);
3937 sgprs[0] = setPadInterval(op, adaptor, rewriter, loc, sgprs[0], consts);
3938 sgprs[0] = setPadAmount(op, adaptor, rewriter, loc, sgprs[0], consts);
3941 setAtomicBarrierAddress(op, adaptor, rewriter, loc, sgprs[1], consts);
3942 std::tie(sgprs[1], sgprs[2]) =
3943 setTensorDim0(op, adaptor, rewriter, loc, sgprs[1], sgprs[2], consts);
3944 std::tie(sgprs[2], sgprs[3]) =
3945 setTensorDim1(op, adaptor, rewriter, loc, sgprs[2], sgprs[3], consts);
3947 sgprs[3] = setTileDim0(op, adaptor, rewriter, loc, sgprs[3], consts);
3949 setTileDim1OrValidIndices(op, adaptor, rewriter, loc, sgprs[4], consts);
3950 sgprs[4] = setTileDim2(op, adaptor, rewriter, loc, sgprs[4], consts);
3951 std::tie(sgprs[5], sgprs[6]) = setTensorDim0Stride(
3952 op, adaptor, rewriter, loc, sgprs[5], sgprs[6], consts);
3953 std::tie(sgprs[6], sgprs[7]) = setTensorDim1Stride(
3954 op, adaptor, rewriter, loc, sgprs[6], sgprs[7], consts);
3956 IntegerType i32 = rewriter.getI32Type();
3957 Type v8i32 = this->typeConverter->convertType(VectorType::get(8, i32));
3958 assert(v8i32 &&
"expected type conversion to succeed");
3959 Value dgroup1 = LLVM::PoisonOp::create(rewriter, loc, v8i32);
3961 for (
auto [sgpr, constant] : llvm::zip_equal(sgprs, consts)) {
3963 LLVM::InsertElementOp::create(rewriter, loc, dgroup1, sgpr, constant);
3969 Value setTensorDimX(DescriptorOp op, OpAdaptor adaptor,
3970 ConversionPatternRewriter &rewriter, Location loc,
3971 Value sgpr0, ArrayRef<Value> consts, int64_t dimX,
3972 int64_t offset)
const {
3973 ArrayRef<int64_t> globalStaticSizes = adaptor.getGlobalStaticSizes();
3974 ValueRange globalDynamicSizes = adaptor.getGlobalDynamicSizes();
3975 SmallVector<OpFoldResult> mixedGlobalSizes =
3977 if (mixedGlobalSizes.size() <=
static_cast<unsigned long>(dimX))
3980 OpFoldResult tensorDimXOpFoldResult = *(mixedGlobalSizes.rbegin() + dimX);
3982 if (
auto attr = dyn_cast<Attribute>(tensorDimXOpFoldResult)) {
3986 IntegerType i32 = rewriter.getI32Type();
3987 tensorDimX = cast<Value>(tensorDimXOpFoldResult);
3988 tensorDimX = LLVM::TruncOp::create(rewriter, loc, i32, tensorDimX);
3991 return setValueAtOffset(rewriter, loc, sgpr0, tensorDimX, offset);
3994 Value setTensorDim2(DescriptorOp op, OpAdaptor adaptor,
3995 ConversionPatternRewriter &rewriter, Location loc,
3996 Value sgpr0, ArrayRef<Value> consts)
const {
3997 return setTensorDimX(op, adaptor, rewriter, loc, sgpr0, consts, 2, 0);
4000 Value truncateAndSetValueAtOffset(ConversionPatternRewriter &rewriter,
4001 Location loc, Value accumulator,
4002 Value value, int64_t shift)
const {
4004 IntegerType i32 = rewriter.getI32Type();
4005 value = LLVM::TruncOp::create(rewriter, loc, i32, value);
4006 return setValueAtOffset(rewriter, loc, accumulator, value, shift);
4009 Value setLDSAddrIncrement(DescriptorOp op, OpAdaptor adaptor,
4010 ConversionPatternRewriter &rewriter, Location loc,
4011 Value sgpr1, ArrayRef<Value> consts,
4012 int64_t offset)
const {
4013 Value ldsAddrIncrement = adaptor.getLdsIncrement();
4014 return setValueAtOffset(rewriter, loc, sgpr1, ldsAddrIncrement, offset);
4017 std::pair<Value, Value>
4018 setGlobalAddrIncrement(DescriptorOp op, OpAdaptor adaptor,
4019 ConversionPatternRewriter &rewriter, Location loc,
4020 Value sgpr2, Value sgpr3, ArrayRef<Value> consts,
4021 int64_t offset)
const {
4022 Value globalAddrIncrement = adaptor.getGlobalIncrement();
4023 sgpr2 = truncateAndSetValueAtOffset(rewriter, loc, sgpr2,
4024 globalAddrIncrement, offset);
4026 globalAddrIncrement =
4027 LLVM::LShrOp::create(rewriter, loc, globalAddrIncrement, shift);
4028 constexpr int64_t first16BitsHigh = (1ll << 16) - 1;
4029 sgpr3 = truncateAndSetValueAtOffset(rewriter, loc, sgpr3,
4030 globalAddrIncrement, offset + 32);
4032 sgpr3 = LLVM::AndOp::create(rewriter, loc, sgpr3, mask);
4033 return {sgpr2, sgpr3};
4036 Value setTensorDim3OrLDSAddrIncrement(DescriptorOp op, OpAdaptor adaptor,
4037 ConversionPatternRewriter &rewriter,
4038 Location loc, Value sgpr1,
4039 ArrayRef<Value> consts)
const {
4040 Value ldsIncrement = op.getLdsIncrement();
4041 constexpr int64_t dim = 3;
4042 constexpr int64_t offset = 32;
4044 return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, consts, dim,
4046 return setLDSAddrIncrement(op, adaptor, rewriter, loc, sgpr1, consts,
4050 std::pair<Value, Value> setTensorDim2StrideOrGlobalAddrIncrement(
4051 DescriptorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter,
4052 Location loc, Value sgpr2, Value sgpr3, ArrayRef<Value> consts)
const {
4053 Value globalIncrement = op.getGlobalIncrement();
4054 constexpr int32_t dim = 2;
4055 constexpr int32_t offset = 64;
4056 if (!globalIncrement)
4057 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr2, sgpr3,
4058 consts, dim, offset);
4059 return setGlobalAddrIncrement(op, adaptor, rewriter, loc, sgpr2, sgpr3,
4063 Value setIterateCount(DescriptorOp op, OpAdaptor adaptor,
4064 ConversionPatternRewriter &rewriter, Location loc,
4065 Value sgpr3, ArrayRef<Value> consts,
4066 int32_t offset)
const {
4067 Value iterationCount = adaptor.getIterationCount();
4068 IntegerType i32 = rewriter.getI32Type();
4075 iterationCount = LLVM::TruncOp::create(rewriter, loc, i32, iterationCount);
4077 LLVM::SubOp::create(rewriter, loc, iterationCount, consts[1]);
4078 return setValueAtOffset(rewriter, loc, sgpr3, iterationCount, offset);
4081 Value setTileDim3OrIterateCount(DescriptorOp op, OpAdaptor adaptor,
4082 ConversionPatternRewriter &rewriter,
4083 Location loc, Value sgpr3,
4084 ArrayRef<Value> consts)
const {
4085 Value iterateCount = op.getIterationCount();
4086 constexpr int32_t dim = 2;
4087 constexpr int32_t offset = 112;
4089 return setTileDimX(op, adaptor, rewriter, loc, sgpr3, consts, dim,
4092 return setIterateCount(op, adaptor, rewriter, loc, sgpr3, consts, offset);
4095 Value getDGroup2(DescriptorOp op, OpAdaptor adaptor,
4096 ConversionPatternRewriter &rewriter, Location loc,
4097 ArrayRef<Value> consts)
const {
4098 if constexpr (DescriptorOp::isGather())
4099 return getDGroup2Gather(op, adaptor, rewriter, loc, consts);
4100 return getDGroup2NonGather(op, adaptor, rewriter, loc, consts);
4103 Value getDGroup2NonGather(DescriptorOp op, OpAdaptor adaptor,
4104 ConversionPatternRewriter &rewriter, Location loc,
4105 ArrayRef<Value> consts)
const {
4106 IntegerType i32 = rewriter.getI32Type();
4107 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
4108 assert(v4i32 &&
"expected type conversion to succeed.");
4110 bool onlyNeedsTwoDescriptors = !op.getLdsIncrement() && op.getRank() <= 2;
4111 if (onlyNeedsTwoDescriptors)
4112 return LLVM::ZeroOp::create(rewriter, loc, v4i32);
4114 constexpr int64_t sgprlen = 4;
4115 Value sgprs[sgprlen];
4116 for (
int i = 0; i < sgprlen; ++i)
4117 sgprs[i] = consts[0];
4119 sgprs[0] = setTensorDim2(op, adaptor, rewriter, loc, sgprs[0], consts);
4120 sgprs[1] = setTensorDim3OrLDSAddrIncrement(op, adaptor, rewriter, loc,
4122 std::tie(sgprs[2], sgprs[3]) = setTensorDim2StrideOrGlobalAddrIncrement(
4123 op, adaptor, rewriter, loc, sgprs[2], sgprs[3], consts);
4125 setTileDim3OrIterateCount(op, adaptor, rewriter, loc, sgprs[3], consts);
4127 Value dgroup2 = LLVM::PoisonOp::create(rewriter, loc, v4i32);
4128 for (
auto [sgpr, constant] : llvm::zip(sgprs, consts))
4130 LLVM::InsertElementOp::create(rewriter, loc, dgroup2, sgpr, constant);
4135 Value getGatherIndices(DescriptorOp op, OpAdaptor adaptor,
4136 ConversionPatternRewriter &rewriter, Location loc,
4137 ArrayRef<Value> consts,
bool firstHalf)
const {
4138 IntegerType i32 = rewriter.getI32Type();
4139 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
4140 assert(v4i32 &&
"expected type conversion to succeed.");
4142 Value
indices = adaptor.getIndices();
4143 auto vectorType = cast<VectorType>(
indices.getType());
4144 unsigned length = vectorType.getShape().back();
4145 Type elementType = vectorType.getElementType();
4146 unsigned maxLength = elementType == i32 ? 4 : 8;
4147 int32_t offset = firstHalf ? 0 : maxLength;
4148 unsigned discountedLength =
4149 std::max(
static_cast<int32_t
>(length - offset), 0);
4151 unsigned targetSize = std::min(maxLength, discountedLength);
4153 SmallVector<Value> indicesVector;
4154 for (
unsigned i = offset; i < targetSize + offset; ++i) {
4156 if (i < consts.size())
4160 Value elem = LLVM::ExtractElementOp::create(rewriter, loc,
indices, idx);
4161 indicesVector.push_back(elem);
4164 SmallVector<Value> indicesI32Vector;
4165 if (elementType == i32) {
4166 indicesI32Vector = indicesVector;
4168 for (
unsigned i = 0; i < targetSize; ++i) {
4169 Value index = indicesVector[i];
4170 indicesI32Vector.push_back(
4171 LLVM::ZExtOp::create(rewriter, loc, i32, index));
4173 if ((targetSize % 2) != 0)
4175 indicesI32Vector.push_back(consts[0]);
4178 SmallVector<Value> indicesToInsert;
4179 if (elementType == i32) {
4180 indicesToInsert = indicesI32Vector;
4182 unsigned size = indicesI32Vector.size() / 2;
4183 for (
unsigned i = 0; i < size; ++i) {
4184 Value first = indicesI32Vector[2 * i];
4185 Value second = indicesI32Vector[2 * i + 1];
4186 Value joined = setValueAtOffset(rewriter, loc, first, second, 16);
4187 indicesToInsert.push_back(joined);
4191 Value dgroup = LLVM::PoisonOp::create(rewriter, loc, v4i32);
4192 for (
auto [sgpr, constant] : llvm::zip_first(indicesToInsert, consts))
4194 LLVM::InsertElementOp::create(rewriter, loc, dgroup, sgpr, constant);
4199 Value getDGroup2Gather(DescriptorOp op, OpAdaptor adaptor,
4200 ConversionPatternRewriter &rewriter, Location loc,
4201 ArrayRef<Value> consts)
const {
4202 return getGatherIndices(op, adaptor, rewriter, loc, consts,
true);
4205 std::pair<Value, Value>
4206 setTensorDim3Stride(DescriptorOp op, OpAdaptor adaptor,
4207 ConversionPatternRewriter &rewriter, Location loc,
4208 Value sgpr0, Value sgpr1, ArrayRef<Value> consts)
const {
4209 constexpr int32_t dim = 3;
4210 constexpr int32_t offset = 0;
4211 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr0, sgpr1, consts,
4215 std::pair<Value, Value> setTensorDim4(DescriptorOp op, OpAdaptor adaptor,
4216 ConversionPatternRewriter &rewriter,
4217 Location loc, Value sgpr1, Value sgpr2,
4218 ArrayRef<Value> consts)
const {
4219 constexpr int32_t dim = 4;
4220 constexpr int32_t offset = 48;
4221 return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, sgpr2, consts, dim,
4225 Value setTileDim4(DescriptorOp op, OpAdaptor adaptor,
4226 ConversionPatternRewriter &rewriter, Location loc,
4227 Value sgpr2, ArrayRef<Value> consts)
const {
4228 constexpr int32_t dim = 4;
4229 constexpr int32_t offset = 80;
4230 return setTileDimX(op, adaptor, rewriter, loc, sgpr2, consts, dim, offset);
4233 Value getDGroup3(DescriptorOp op, OpAdaptor adaptor,
4234 ConversionPatternRewriter &rewriter, Location loc,
4235 ArrayRef<Value> consts)
const {
4236 if constexpr (DescriptorOp::isGather())
4237 return getDGroup3Gather(op, adaptor, rewriter, loc, consts);
4238 return getDGroup3NonGather(op, adaptor, rewriter, loc, consts);
4241 Value getDGroup3NonGather(DescriptorOp op, OpAdaptor adaptor,
4242 ConversionPatternRewriter &rewriter, Location loc,
4243 ArrayRef<Value> consts)
const {
4244 IntegerType i32 = rewriter.getI32Type();
4245 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
4246 assert(v4i32 &&
"expected type conversion to succeed.");
4247 bool onlyNeedsTwoDescriptors = !op.getLdsIncrement() && op.getRank() <= 2;
4248 if (onlyNeedsTwoDescriptors)
4249 return LLVM::ZeroOp::create(rewriter, loc, v4i32);
4251 constexpr int32_t sgprlen = 4;
4252 Value sgprs[sgprlen];
4253 for (
int i = 0; i < sgprlen; ++i)
4254 sgprs[i] = consts[0];
4256 std::tie(sgprs[0], sgprs[1]) = setTensorDim3Stride(
4257 op, adaptor, rewriter, loc, sgprs[0], sgprs[1], consts);
4258 std::tie(sgprs[1], sgprs[2]) =
4259 setTensorDim4(op, adaptor, rewriter, loc, sgprs[1], sgprs[2], consts);
4260 sgprs[2] = setTileDim4(op, adaptor, rewriter, loc, sgprs[2], consts);
4262 Value dgroup3 = LLVM::PoisonOp::create(rewriter, loc, v4i32);
4263 for (
auto [sgpr, constant] : llvm::zip(sgprs, consts))
4265 LLVM::InsertElementOp::create(rewriter, loc, dgroup3, sgpr, constant);
4270 Value getDGroup3Gather(DescriptorOp op, OpAdaptor adaptor,
4271 ConversionPatternRewriter &rewriter, Location loc,
4272 ArrayRef<Value> consts)
const {
4273 return getGatherIndices(op, adaptor, rewriter, loc, consts,
false);
4277 matchAndRewrite(DescriptorOp op, OpAdaptor adaptor,
4278 ConversionPatternRewriter &rewriter)
const override {
4280 return op->emitOpError(
4281 "make_dma_descriptor is only supported on gfx1250");
4283 Location loc = op.getLoc();
4285 SmallVector<Value> consts;
4286 for (int64_t i = 0; i < 8; ++i)
4289 Value dgroup0 = this->getDGroup0(adaptor);
4290 Value dgroup1 = this->getDGroup1(op, adaptor, rewriter, loc, consts);
4291 Value dgroup2 = this->getDGroup2(op, adaptor, rewriter, loc, consts);
4292 Value dgroup3 = this->getDGroup3(op, adaptor, rewriter, loc, consts);
4293 SmallVector<Value> results = {dgroup0, dgroup1, dgroup2, dgroup3};
4294 rewriter.replaceOpWithMultiple(op, {results});
4299template <
typename SourceOp,
typename TargetOp>
4300struct AMDGPUTensorLoadStoreOpLowering
4301 :
public ConvertOpToLLVMPattern<SourceOp> {
4302 using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
4304 AMDGPUTensorLoadStoreOpLowering(
const LLVMTypeConverter &converter,
4306 : ConvertOpToLLVMPattern<SourceOp>(converter), chipset(chipset) {}
4310 matchAndRewrite(SourceOp op, Adaptor adaptor,
4311 ConversionPatternRewriter &rewriter)
const override {
4313 return op->emitOpError(
"is only supported on gfx1250");
4318 auto v8i32 = VectorType::get(8, rewriter.getI32Type());
4319 Value dgroup4 = LLVM::ZeroOp::create(rewriter, op.getLoc(), v8i32);
4320 rewriter.replaceOpWithNewOp<TargetOp>(op, desc[0], desc[1], desc[2],
4321 desc[3], dgroup4, 0,
4329struct GlobalPrefetchOpLowering
4330 :
public ConvertOpToLLVMPattern<GlobalPrefetchOp> {
4331 GlobalPrefetchOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
4332 : ConvertOpToLLVMPattern<GlobalPrefetchOp>(converter), chipset(chipset) {}
4335 matchAndRewrite(GlobalPrefetchOp op, GlobalPrefetchOpAdaptor adaptor,
4336 ConversionPatternRewriter &rewriter)
const override {
4338 return op->emitOpError(
"is only supported on gfx1250+");
4340 const bool isSpeculative = op.getSpeculative();
4342 op.getTemporalHint(), op.getCacheScope(), isSpeculative);
4343 IntegerAttr immArgAttr = rewriter.getI32IntegerAttr(immArgValue);
4346 Value memRef = adaptor.getSrc();
4347 MemRefDescriptor descriptor(memRef);
4348 MemRefType memRefType = op.getSrc().getType();
4349 Location loc = op->getLoc();
4350 auto inboundsFlags = isSpeculative ? LLVM::GEPNoWrapFlags::none
4351 : LLVM::GEPNoWrapFlags::inbounds |
4352 LLVM::GEPNoWrapFlags::nuw;
4354 rewriter, loc, memRefType, descriptor,
indices, inboundsFlags);
4356 rewriter.replaceOpWithNewOp<ROCDL::GlobalPrefetchOp>(
4357 op, prefetchPtr, immArgAttr, mlir::ArrayAttr{}, mlir::ArrayAttr{},
4366struct ConvertAMDGPUToROCDLPass
4367 :
public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
4370 void runOnOperation()
override {
4373 if (
failed(maybeChipset)) {
4374 emitError(UnknownLoc::get(ctx),
"Invalid chipset name: " + chipset);
4375 return signalPassFailure();
4378 RewritePatternSet patterns(ctx);
4379 LLVMTypeConverter converter(ctx);
4382 amdgpu::populateCommonGPUTypeAndAttributeConversions(converter);
4384 target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
4385 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
4386 target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
4387 if (
failed(applyPartialConversion(getOperation(),
target,
4388 std::move(patterns))))
4389 signalPassFailure();
4397 typeConverter, [](gpu::AddressSpace space) {
4399 case gpu::AddressSpace::Global:
4400 return ROCDL::ROCDLDialect::kGlobalMemoryAddressSpace;
4401 case gpu::AddressSpace::Workgroup:
4402 return ROCDL::ROCDLDialect::kSharedMemoryAddressSpace;
4403 case gpu::AddressSpace::Private:
4404 return ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace;
4405 case gpu::AddressSpace::Constant:
4406 return ROCDL::ROCDLDialect::kConstantMemoryAddressSpace;
4408 llvm_unreachable(
"unknown address space enum value");
4411 return LLVM::LLVMPointerType::get(
4412 type.getContext(), ROCDL::ROCDLDialect::kSharedMemoryAddressSpace);
4418 typeConverter.addTypeAttributeConversion(
4420 -> TypeConverter::AttributeConversionResult {
4422 Type i64 = IntegerType::get(ctx, 64);
4423 switch (as.getValue()) {
4424 case amdgpu::AddressSpace::FatRawBuffer:
4425 return IntegerAttr::get(i64, 7);
4426 case amdgpu::AddressSpace::BufferRsrc:
4427 return IntegerAttr::get(i64, 8);
4428 case amdgpu::AddressSpace::FatStructuredBuffer:
4429 return IntegerAttr::get(i64, 9);
4431 return TypeConverter::AttributeConversionResult::abort();
4433 typeConverter.addConversion([&](DsBarrierStateType type) ->
Type {
4434 return IntegerType::get(type.
getContext(), 64);
4436 typeConverter.addConversion([&](TDMBaseType type) ->
Type {
4438 return typeConverter.convertType(VectorType::get(4, i32));
4440 typeConverter.addConversion([&](TDMGatherBaseType type) ->
Type {
4442 return typeConverter.convertType(VectorType::get(4, i32));
4444 typeConverter.addConversion(
4445 [&](TDMDescriptorType type,
4448 Type v4i32 = typeConverter.convertType(VectorType::get(4, i32));
4449 Type v8i32 = typeConverter.convertType(VectorType::get(8, i32));
4450 llvm::append_values(
result, v4i32, v8i32, v4i32, v4i32);
4460 if (inputs.size() != 1)
4463 if (!isa<TDMDescriptorType>(inputs[0].
getType()))
4466 auto cast = UnrealizedConversionCastOp::create(builder, loc, types, inputs);
4467 return cast.getResults();
4470 typeConverter.addTargetMaterialization(addUnrealizedCast);
4478 .
add<FatRawBufferCastLowering,
4479 RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
4480 RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
4481 RawBufferOpLowering<RawBufferAtomicFaddOp,
4482 ROCDL::RawPtrBufferAtomicFaddOp>,
4483 RawBufferOpLowering<RawBufferAtomicFmaxOp,
4484 ROCDL::RawPtrBufferAtomicFmaxOp>,
4485 RawBufferOpLowering<RawBufferAtomicSmaxOp,
4486 ROCDL::RawPtrBufferAtomicSmaxOp>,
4487 RawBufferOpLowering<RawBufferAtomicUminOp,
4488 ROCDL::RawPtrBufferAtomicUminOp>,
4489 RawBufferOpLowering<RawBufferAtomicCmpswapOp,
4490 ROCDL::RawPtrBufferAtomicCmpSwap>,
4491 AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
4492 SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
4493 SparseMFMAOpLowering, WMMAOpLowering, ScaledWMMAOpLowering,
4494 SparseWMMAOpLowering, DotOpLowering, ExtPackedFp8OpLowering,
4495 ScaledExtPackedMatrixOpLowering, ScaledExtPackedOpLowering,
4496 PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
4497 PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
4498 GlobalLoadAsyncToLDSOpLowering, TransposeLoadOpLowering,
4499 GlobalTransposeLoadOpLowering, AMDGPUPermlaneLowering,
4500 AMDGPUMakeDmaBaseLowering<MakeDmaBaseOp>,
4501 AMDGPUMakeDmaBaseLowering<MakeGatherDmaBaseOp>,
4502 AMDGPULowerDescriptor<MakeDmaDescriptorOp>,
4503 AMDGPULowerDescriptor<MakeGatherDmaDescriptorOp>,
4504 AMDGPUTensorLoadStoreOpLowering<TensorLoadToLDSOp,
4505 ROCDL::TensorLoadToLDSOp>,
4506 AMDGPUTensorLoadStoreOpLowering<TensorStoreFromLDSOp,
4507 ROCDL::TensorStoreFromLDSOp>,
4508 DsBarrierInitOpLowering, DsBarrierPollStateOpLowering,
4509 DsAsyncBarrierArriveOpLowering, DsBarrierArriveOpLowering,
4510 GlobalPrefetchOpLowering>(converter, chipset);
4511 patterns.
add<AMDGPUSwizzleBitModeLowering, DsBarrierStatePhaseOpLowering,
4512 DsBarrierStatePendingCountOpLowering,
4513 DsBarrierStateInitCountOpLowering,
4514 DsBarrierStatePhaseParityLowering>(converter);
static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type)
Return true if type is the E4M3FN variant of an 8-bit float that is supported by the _fp8 instruction...
constexpr Chipset kGfx942
static std::optional< StringRef > wmmaOpToIntrinsicRDNA(Type elemSourceType, Type elemBSourceType, Type elemDestType, uint32_t k, bool isRDNA3)
Returns the rocdl intrinsic corresponding to a WMMA operation wmma for RDNA3/4 architectures.
static bool hasDot10Insts(const Chipset &chipset)
static bool hasDot7Insts(const Chipset &chipset)
static std::optional< SparseWMMAOpInfo > sparseWMMAOpToIntrinsic(SparseWMMAOp swmmac, Chipset chipset)
static std::optional< std::tuple< StringRef, uint32_t, uint32_t > > mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m, uint32_t n, uint32_t k, uint32_t b, Chipset chipset)
If there is a scaled MFMA instruction for the input element types aType and bType,...
static std::optional< StringRef > mfmaOpToIntrinsic(MFMAOp mfma, Chipset chipset)
Return the rocdl intrinsic corresponding to a MFMA operation mfma if one exists.
constexpr Chipset kGfx908
static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput, Value mlirInput, SmallVectorImpl< Value > &operands, SmallVectorImpl< NamedAttribute > &attrs, StringRef attrName)
Push an input operand.
constexpr Chipset kGfx1250
static Value castScaleOperand(ConversionPatternRewriter &rewriter, Location loc, Value input)
Converts the scaled MFMA/WMMA operands, scalesA and scalesB, from MLIR AMDGPU dialect convention to R...
constexpr Chipset kGfx90a
static std::optional< StringRef > getScaledWmmaIntrinsicName(int64_t m, int64_t n, int64_t k, bool isScale16)
Determines the ROCDL intrinsic name for scaled WMMA based on dimensions and scale block size (16 or 3...
static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, Value output, int32_t subwordOffset, bool clamp, SmallVectorImpl< Value > &operands, SmallVectorImpl< NamedAttribute > &attrs)
Push the output operand.
static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type)
Return true if type is the E5M2 variant of an 8-bit float that is supported by the _bf8 instructions ...
static std::optional< StringRef > wmmaOpToIntrinsic(WMMAOp wmma, Chipset chipset)
Returns the rocdl intrinsic corresponding to a WMMA operation wmma if one exists.
static bool hasDot11Insts(const Chipset &chipset)
static std::optional< StringRef > smfmacOpToIntrinsic(SparseMFMAOp op, Chipset chipset)
Returns the rocdl intrinsic corresponding to a SparseMFMA (smfmac) operation if one exists.
static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc, Value basePointer, Value numRecords, bool boundsCheck, amdgpu::Chipset chipset, Value cacheSwizzleStride=nullptr, unsigned addressSpace=8)
static Value createI64Constant(ConversionPatternRewriter &rewriter, Location loc, int64_t value)
static bool hasDot9Insts(const Chipset &chipset)
static std::optional< StringRef > wmmaOpToIntrinsicGfx1250(Type elemSourceType, Type elemBSourceType, Type elemDestType, uint32_t k)
Return the rocdl intrinsic corresponding to a WMMA operation wmma for the gfx1250 architecture.
constexpr Chipset kGfx1200
static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc, MemRefType memrefType, MemRefDescriptor &memrefDescriptor, ArrayRef< int64_t > strides, int64_t elementByteWidth, amdgpu::Chipset chipset, bool boundsCheck)
Compute the contents of the num_records field for a given memref descriptor - that is,...
static Value packSmallFloatVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input, bool allowBf16=true)
Pack small float vector operands (fp4/fp6/fp8/bf16) into the format expected by scaled matrix multipl...
static std::optional< uint32_t > getWmmaScaleFormat(Type elemType)
Maps f8 scale element types to WMMA scale format codes.
static Value convertPackedVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input, bool allowBf16=true)
Converts packed vector operands to the expected ROCDL types.
static Value getLinearIndexI32(ConversionPatternRewriter &rewriter, Location loc, MemRefDescriptor &memRefDescriptor, ValueRange indices, ArrayRef< int64_t > strides)
Returns the linear index used to access an element in the memref.
static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i32.
static bool hasDot8Insts(const Chipset &chipset)
static bool hasDot2Insts(const Chipset &chipset)
static Value createI32Constant(ConversionPatternRewriter &rewriter, Location loc, int32_t value)
static bool hasDot12Insts(const Chipset &chipset)
static std::optional< uint32_t > smallFloatTypeToFormatCode(Type mlirElemType)
static Value convertUnsignedToI64(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i64.
constexpr Chipset kGfx950
static bool hasDot1Insts(const Chipset &chipset)
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static constexpr unsigned kSizePosInMemRefDescriptor
static constexpr unsigned kStridePosInMemRefDescriptor
static constexpr unsigned kOffsetPosInMemRefDescriptor
static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor
static constexpr unsigned kAlignedPtrPosInMemRefDescriptor
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
This class provides a shared interface for ranked and unranked memref types.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
typename SourceOp::template GenericAdaptor< ArrayRef< ValueRange > > OneToNOpAdaptor
typename SourceOp::Adaptor OpAdaptor
Value getStridedElementPtr(ConversionPatternRewriter &rewriter, Location loc, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none) const
Convenience wrapper for the corresponding helper utility.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
NamedAttribute represents a combination of a name and an Attribute value.
This class helps build Operations.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
int32_t getGlobalPrefetchLLVMEncoding(amdgpu::LoadTemporalHint hint, amdgpu::Scope scope, bool isSpeculative)
bool hasOcpFp8(const Chipset &chipset)
void populateCommonGPUTypeAndAttributeConversions(TypeConverter &typeConverter)
Remap common GPU memory spaces (Workgroup, Private, etc) to LLVM address spaces.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, amdgpu::Chipset chipset)
Note: This function will also add conversions for the AMDGPU-specific address spaces and types,...
llvm::TypeSwitch< T, ResultT > TypeSwitch
void populateAMDGPUTypeAndAttributeConversions(TypeConverter &typeConverter)
Remap AMDGPU memory spaces to LLVM address spaces by mapping amdgpu::AddressSpace::fat_raw_buffer to ...
Returns the rocdl intrinsic corresponding to a SparseWMMA operation swmmac if one exists.
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.