29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/TypeSwitch.h"
31#include "llvm/Support/Casting.h"
32#include "llvm/Support/ErrorHandling.h"
36#define GEN_PASS_DEF_CONVERTAMDGPUTOROCDLPASS
37#include "mlir/Conversion/Passes.h.inc"
53 IntegerType i32 = rewriter.getI32Type();
55 auto valTy = cast<IntegerType>(val.
getType());
58 return valTy.getWidth() > 32
59 ?
Value(LLVM::TruncOp::create(rewriter, loc, i32, val))
60 :
Value(LLVM::ZExtOp::create(rewriter, loc, i32, val));
65 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), value);
71 IntegerType i64 = rewriter.getI64Type();
73 auto valTy = cast<IntegerType>(val.
getType());
76 return valTy.getWidth() > 64
77 ?
Value(LLVM::TruncOp::create(rewriter, loc, i64, val))
78 :
Value(LLVM::ZExtOp::create(rewriter, loc, i64, val));
83 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), value);
90 IntegerType i32 = rewriter.getI32Type();
92 for (
auto [i, increment, stride] : llvm::enumerate(
indices, strides)) {
95 ShapedType::isDynamic(stride)
97 memRefDescriptor.
stride(rewriter, loc, i))
98 : LLVM::ConstantOp::create(rewriter, loc, i32, stride);
99 increment = LLVM::MulOp::create(rewriter, loc, increment, strideValue);
111 MemRefType memrefType,
115 if (chipset >=
kGfx1250 && !boundsCheck) {
116 constexpr int64_t first45bits = (1ll << 45) - 1;
119 if (memrefType.hasStaticShape() &&
120 !llvm::any_of(strides, ShapedType::isDynamic)) {
121 int64_t size = memrefType.getRank() == 0 ? 1 : 0;
123 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
124 size = std::max(
shape[i] * strides[i], size);
125 size = size * elementByteWidth;
129 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
130 Value size = memrefDescriptor.
size(rewriter, loc, i);
131 Value stride = memrefDescriptor.
stride(rewriter, loc, i);
132 Value maxThisDim = LLVM::MulOp::create(rewriter, loc, size, stride);
134 ? LLVM::UMaxOp::create(rewriter, loc, maxIndex, maxThisDim)
139 return LLVM::MulOp::create(rewriter, loc, maxIndexI64, byteWidthConst);
145 Value cacheSwizzleStride =
nullptr,
146 unsigned addressSpace = 8) {
150 Type i16 = rewriter.getI16Type();
153 Value cacheStrideZext =
154 LLVM::ZExtOp::create(rewriter, loc, i16, cacheSwizzleStride);
155 Value swizzleBit = LLVM::ConstantOp::create(
156 rewriter, loc, i16, rewriter.getI16IntegerAttr(1 << 14));
157 stride = LLVM::OrOp::create(rewriter, loc, cacheStrideZext, swizzleBit,
160 stride = LLVM::ConstantOp::create(rewriter, loc, i16,
161 rewriter.getI16IntegerAttr(0));
190 flags |= (7 << 12) | (4 << 15);
193 uint32_t oob = boundsCheck ? 3 : 2;
194 flags |= (oob << 28);
199 LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
200 Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>(
201 loc, rsrcType, basePointer, stride, numRecords, flagsConst);
206struct FatRawBufferCastLowering
208 FatRawBufferCastLowering(
const LLVMTypeConverter &converter, Chipset chipset)
209 : ConvertOpToLLVMPattern<FatRawBufferCastOp>(converter),
215 matchAndRewrite(FatRawBufferCastOp op, FatRawBufferCastOpAdaptor adaptor,
216 ConversionPatternRewriter &rewriter)
const override {
217 Location loc = op.getLoc();
218 Value memRef = adaptor.getSource();
219 Value unconvertedMemref = op.getSource();
220 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.
getType());
221 MemRefDescriptor descriptor(memRef);
223 DataLayout dataLayout = DataLayout::closest(op);
224 int64_t elementByteWidth =
227 int64_t unusedOffset = 0;
228 SmallVector<int64_t, 5> strideVals;
229 if (
failed(memrefType.getStridesAndOffset(strideVals, unusedOffset)))
230 return op.emitOpError(
"Can't lower non-stride-offset memrefs");
232 Value numRecords = adaptor.getValidBytes();
235 getNumRecords(rewriter, loc, memrefType, descriptor, strideVals,
236 elementByteWidth, chipset, adaptor.getBoundsCheck());
239 adaptor.getResetOffset()
240 ? descriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
242 : descriptor.alignedPtr(rewriter, loc);
244 Value offset = adaptor.getResetOffset()
245 ? LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
246 rewriter.getIndexAttr(0))
247 : descriptor.offset(rewriter, loc);
249 bool hasSizes = memrefType.getRank() > 0;
252 Value sizes = hasSizes
253 ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
257 hasSizes ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
262 rewriter, loc, basePointer, numRecords, adaptor.getBoundsCheck(),
263 chipset, adaptor.getCacheSwizzleStride(), 7);
265 Value
result = MemRefDescriptor::poison(
267 getTypeConverter()->convertType(op.getResult().getType()));
269 result = LLVM::InsertValueOp::create(rewriter, loc,
result, fatPtr, pos);
270 result = LLVM::InsertValueOp::create(rewriter, loc,
result, fatPtr,
272 result = LLVM::InsertValueOp::create(rewriter, loc,
result, offset,
275 result = LLVM::InsertValueOp::create(rewriter, loc,
result, sizes,
277 result = LLVM::InsertValueOp::create(rewriter, loc,
result, strides,
280 rewriter.replaceOp(op,
result);
286template <
typename GpuOp,
typename Intrinsic>
288 RawBufferOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
289 : ConvertOpToLLVMPattern<GpuOp>(converter), chipset(chipset) {}
292 static constexpr uint32_t maxVectorOpWidth = 128;
295 matchAndRewrite(GpuOp gpuOp,
typename GpuOp::Adaptor adaptor,
296 ConversionPatternRewriter &rewriter)
const override {
297 Location loc = gpuOp.getLoc();
298 Value memref = adaptor.getMemref();
299 Value unconvertedMemref = gpuOp.getMemref();
300 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.
getType());
302 if (chipset.majorVersion < 9)
303 return gpuOp.emitOpError(
"raw buffer ops require GCN or higher");
305 Value storeData = adaptor.getODSOperands(0)[0];
306 if (storeData == memref)
310 wantedDataType = storeData.
getType();
312 wantedDataType = gpuOp.getODSResults(0)[0].getType();
314 Value atomicCmpData = Value();
317 Value maybeCmpData = adaptor.getODSOperands(1)[0];
318 if (maybeCmpData != memref)
319 atomicCmpData = maybeCmpData;
322 Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
324 Type i32 = rewriter.getI32Type();
327 DataLayout dataLayout = DataLayout::closest(gpuOp);
328 int64_t elementByteWidth =
337 Type llvmBufferValType = llvmWantedDataType;
339 if (
auto floatType = dyn_cast<FloatType>(wantedDataType))
340 llvmBufferValType = this->getTypeConverter()->convertType(
341 rewriter.getIntegerType(floatType.getWidth()));
343 if (
auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
344 uint32_t vecLen = dataVector.getNumElements();
347 uint32_t totalBits = elemBits * vecLen;
349 isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2;
350 if (totalBits > maxVectorOpWidth)
351 return gpuOp.emitOpError(
352 "Total width of loads or stores must be no more than " +
353 Twine(maxVectorOpWidth) +
" bits, but we call for " +
355 " bits. This should've been caught in validation");
356 if (!usePackedFp16 && elemBits < 32) {
357 if (totalBits > 32) {
358 if (totalBits % 32 != 0)
359 return gpuOp.emitOpError(
"Load or store of more than 32-bits that "
360 "doesn't fit into words. Can't happen\n");
361 llvmBufferValType = this->typeConverter->convertType(
362 VectorType::get(totalBits / 32, i32));
364 llvmBufferValType = this->typeConverter->convertType(
365 rewriter.getIntegerType(totalBits));
369 if (
auto vecType = dyn_cast<VectorType>(llvmBufferValType)) {
372 if (vecType.getNumElements() == 1)
373 llvmBufferValType = vecType.getElementType();
376 SmallVector<Value, 6> args;
378 if (llvmBufferValType != llvmWantedDataType) {
379 Value castForStore = LLVM::BitcastOp::create(
380 rewriter, loc, llvmBufferValType, storeData);
381 args.push_back(castForStore);
383 args.push_back(storeData);
388 if (llvmBufferValType != llvmWantedDataType) {
389 Value castForCmp = LLVM::BitcastOp::create(
390 rewriter, loc, llvmBufferValType, atomicCmpData);
391 args.push_back(castForCmp);
393 args.push_back(atomicCmpData);
399 SmallVector<int64_t, 5> strides;
400 if (
failed(memrefType.getStridesAndOffset(strides, offset)))
401 return gpuOp.emitOpError(
"Can't lower non-stride-offset memrefs");
403 MemRefDescriptor memrefDescriptor(memref);
405 Value ptr = memrefDescriptor.bufferPtr(
406 rewriter, loc, *this->getTypeConverter(), memrefType);
408 getNumRecords(rewriter, loc, memrefType, memrefDescriptor, strides,
409 elementByteWidth, chipset, adaptor.getBoundsCheck());
411 adaptor.getBoundsCheck(), chipset);
412 args.push_back(resource);
416 adaptor.getIndices(), strides);
417 if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset();
418 indexOffset && *indexOffset > 0) {
420 voffset = voffset ? LLVM::AddOp::create(rewriter, loc, voffset,
424 voffset = LLVM::MulOp::create(rewriter, loc, voffset, byteWidthConst);
425 args.push_back(voffset);
428 Value sgprOffset = adaptor.getSgprOffset();
431 sgprOffset = LLVM::MulOp::create(rewriter, loc, sgprOffset, byteWidthConst);
432 args.push_back(sgprOffset);
439 llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(),
441 Operation *lowered = Intrinsic::create(rewriter, loc, resultTypes, args,
442 ArrayRef<NamedAttribute>());
445 if (llvmBufferValType != llvmWantedDataType) {
446 replacement = LLVM::BitcastOp::create(rewriter, loc, llvmWantedDataType,
451 rewriter.eraseOp(gpuOp);
468static FailureOr<unsigned> encodeWaitcnt(
Chipset chipset,
unsigned vmcnt,
469 unsigned expcnt,
unsigned lgkmcnt) {
471 vmcnt = std::min(15u, vmcnt);
472 expcnt = std::min(7u, expcnt);
473 lgkmcnt = std::min(15u, lgkmcnt);
474 return vmcnt | (expcnt << 4) | (lgkmcnt << 8);
477 vmcnt = std::min(63u, vmcnt);
478 expcnt = std::min(7u, expcnt);
479 lgkmcnt = std::min(15u, lgkmcnt);
480 unsigned lowBits = vmcnt & 0xF;
481 unsigned highBits = (vmcnt >> 4) << 14;
482 unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
483 return lowBits | highBits | otherCnts;
486 vmcnt = std::min(63u, vmcnt);
487 expcnt = std::min(7u, expcnt);
488 lgkmcnt = std::min(63u, lgkmcnt);
489 unsigned lowBits = vmcnt & 0xF;
490 unsigned highBits = (vmcnt >> 4) << 14;
491 unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
492 return lowBits | highBits | otherCnts;
495 vmcnt = std::min(63u, vmcnt);
496 expcnt = std::min(7u, expcnt);
497 lgkmcnt = std::min(63u, lgkmcnt);
498 return (vmcnt << 10) | expcnt | (lgkmcnt << 4);
503struct MemoryCounterWaitOpLowering
505 MemoryCounterWaitOpLowering(
const LLVMTypeConverter &converter,
507 : ConvertOpToLLVMPattern<MemoryCounterWaitOp>(converter),
513 matchAndRewrite(MemoryCounterWaitOp op, OpAdaptor adaptor,
514 ConversionPatternRewriter &rewriter)
const override {
515 if (chipset.majorVersion >= 12) {
516 Location loc = op.getLoc();
517 if (std::optional<int> ds = adaptor.getDs())
518 ROCDL::WaitDscntOp::create(rewriter, loc, *ds);
520 if (std::optional<int>
load = adaptor.getLoad())
521 ROCDL::WaitLoadcntOp::create(rewriter, loc, *
load);
523 if (std::optional<int> store = adaptor.getStore())
524 ROCDL::WaitStorecntOp::create(rewriter, loc, *store);
526 if (std::optional<int> exp = adaptor.getExp())
527 ROCDL::WaitExpcntOp::create(rewriter, loc, *exp);
529 if (std::optional<int> tensor = adaptor.getTensor())
530 ROCDL::WaitTensorcntOp::create(rewriter, loc, *tensor);
532 rewriter.eraseOp(op);
536 if (adaptor.getTensor())
537 return op.emitOpError(
"unsupported chipset");
539 auto getVal = [](Attribute attr) ->
unsigned {
541 return cast<IntegerAttr>(attr).getInt();
546 unsigned ds = getVal(adaptor.getDsAttr());
547 unsigned exp = getVal(adaptor.getExpAttr());
549 unsigned vmcnt = 1024;
550 Attribute
load = adaptor.getLoadAttr();
551 Attribute store = adaptor.getStoreAttr();
553 vmcnt = getVal(
load) + getVal(store);
555 vmcnt = getVal(
load);
557 vmcnt = getVal(store);
560 FailureOr<unsigned> waitcnt = encodeWaitcnt(chipset, vmcnt, exp, ds);
562 return op.emitOpError(
"unsupported chipset");
564 rewriter.replaceOpWithNewOp<ROCDL::SWaitcntOp>(op, *waitcnt);
570 LDSBarrierOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
571 : ConvertOpToLLVMPattern<LDSBarrierOp>(converter), chipset(chipset) {}
576 matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
577 ConversionPatternRewriter &rewriter)
const override {
578 Location loc = op.getLoc();
581 bool requiresInlineAsm = chipset <
kGfx90a;
584 rewriter.getAttr<LLVM::MMRATagAttr>(
"amdgpu-synchronize-as",
"local");
593 StringRef scope =
"workgroup";
595 auto relFence = LLVM::FenceOp::create(rewriter, loc,
596 LLVM::AtomicOrdering::release, scope);
597 relFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(), mmra);
598 if (requiresInlineAsm) {
599 auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
600 LLVM::AsmDialect::AD_ATT);
601 const char *asmStr =
";;;WARNING: BREAKS DEBUG WATCHES\ns_barrier";
602 const char *constraints =
"";
603 LLVM::InlineAsmOp::create(
606 asmStr, constraints,
true,
607 false, LLVM::TailCallKind::None,
610 }
else if (chipset.majorVersion < 12) {
611 ROCDL::SBarrierOp::create(rewriter, loc);
613 ROCDL::BarrierSignalOp::create(rewriter, loc, -1);
614 ROCDL::BarrierWaitOp::create(rewriter, loc, -1);
617 auto acqFence = LLVM::FenceOp::create(rewriter, loc,
618 LLVM::AtomicOrdering::acquire, scope);
619 acqFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(), mmra);
620 rewriter.replaceOp(op, acqFence);
626 SchedBarrierOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
627 : ConvertOpToLLVMPattern<SchedBarrierOp>(converter), chipset(chipset) {}
632 matchAndRewrite(SchedBarrierOp op, SchedBarrierOp::Adaptor adaptor,
633 ConversionPatternRewriter &rewriter)
const override {
634 rewriter.replaceOpWithNewOp<ROCDL::SchedBarrier>(op,
635 (uint32_t)op.getOpts());
659 bool allowBf16 =
true) {
661 if (
auto vectorType = dyn_cast<VectorType>(inputType)) {
662 if (vectorType.getElementType().isBF16() && !allowBf16)
663 return LLVM::BitcastOp::create(
664 rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
665 if (vectorType.getElementType().isInteger(8) &&
666 vectorType.getNumElements() <= 8)
667 return LLVM::BitcastOp::create(
669 rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
670 if (isa<IntegerType>(vectorType.getElementType()) &&
671 vectorType.getElementTypeBitWidth() <= 8) {
672 int64_t numWords = llvm::divideCeil(
673 vectorType.getNumElements() * vectorType.getElementTypeBitWidth(),
675 return LLVM::BitcastOp::create(
676 rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()),
686 bool allowBf16 =
true) {
688 auto vectorType = cast<VectorType>(inputType);
690 if (vectorType.getElementType().isBF16() && !allowBf16)
691 return LLVM::BitcastOp::create(
692 rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
694 if (isa<IntegerType>(vectorType.getElementType()) &&
695 vectorType.getElementTypeBitWidth() <= 8) {
696 int64_t numWords = llvm::divideCeil(
697 vectorType.getNumElements() * vectorType.getElementTypeBitWidth(), 32);
698 return LLVM::BitcastOp::create(
699 rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()), input);
717 .Case([&](IntegerType) {
719 return LLVM::ZExtOp::create(rewriter, loc, rewriter.getI32Type(),
722 .Case([&](VectorType vectorType) {
724 int64_t numElements = vectorType.getNumElements();
725 assert((numElements == 4 || numElements == 8) &&
726 "scale operand must be a vector of length 4 or 8");
727 IntegerType outputType =
728 (numElements == 4) ? rewriter.getI32Type() : rewriter.getI64Type();
729 return LLVM::BitcastOp::create(rewriter, loc, outputType, input);
731 .DefaultUnreachable(
"unexpected input type for scale operand");
737 .Case([](Float8E8M0FNUType) {
return 0; })
738 .Case([](Float8E4M3FNType) {
return 2; })
739 .Default(std::nullopt);
744static std::optional<StringRef>
746 if (m == 16 && n == 16 && k == 128)
748 ? ROCDL::wmma_scale16_f32_16x16x128_f8f6f4::getOperationName()
749 : ROCDL::wmma_scale_f32_16x16x128_f8f6f4::getOperationName();
751 if (m == 32 && n == 16 && k == 128)
752 return isScale16 ? ROCDL::wmma_scale16_f32_32x16x128_f4::getOperationName()
753 : ROCDL::wmma_scale_f32_32x16x128_f4::getOperationName();
767 ConversionPatternRewriter &rewriter,
Location loc,
772 auto vectorType = dyn_cast<VectorType>(inputType);
774 operands.push_back(llvmInput);
777 Type elemType = vectorType.getElementType();
779 operands.push_back(llvmInput);
786 auto mlirInputType = cast<VectorType>(mlirInput.
getType());
787 bool isInputInteger = mlirInputType.getElementType().isInteger();
788 if (isInputInteger) {
790 bool localIsUnsigned = isUnsigned;
792 localIsUnsigned =
true;
794 localIsUnsigned =
false;
797 NamedAttribute(attrName, rewriter.getBoolAttr(!localIsUnsigned)));
802 Type i32 = rewriter.getI32Type();
803 Type intrinsicInType = numBits <= 32
804 ? (
Type)rewriter.getIntegerType(numBits)
805 : (
Type)VectorType::get(numBits / 32, i32);
806 auto llvmIntrinsicInType = typeConverter->convertType(intrinsicInType);
807 Value castInput = rewriter.createOrFold<LLVM::BitcastOp>(
808 loc, llvmIntrinsicInType, llvmInput);
813 castInput = LLVM::ZExtOp::create(rewriter, loc, i32, castInput);
814 operands.push_back(castInput);
827 Value output, int32_t subwordOffset,
831 auto vectorType = dyn_cast<VectorType>(inputType);
832 Type elemType = vectorType.getElementType();
833 operands.push_back(output);
845 return (chipset ==
kGfx942 && isa<Float8E5M2FNUZType>(type)) ||
846 (
hasOcpFp8(chipset) && isa<Float8E5M2Type>(type));
852 return (chipset ==
kGfx942 && isa<Float8E4M3FNUZType>(type)) ||
853 (
hasOcpFp8(chipset) && isa<Float8E4M3FNType>(type));
861 uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
862 b = mfma.getBlocks();
867 if (mfma.getReducePrecision() && chipset >=
kGfx942) {
868 if (m == 32 && n == 32 && k == 4 &&
b == 1)
869 return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
870 if (m == 16 && n == 16 && k == 8 &&
b == 1)
871 return ROCDL::mfma_f32_16x16x8_xf32::getOperationName();
873 if (m == 32 && n == 32 && k == 1 &&
b == 2)
874 return ROCDL::mfma_f32_32x32x1f32::getOperationName();
875 if (m == 16 && n == 16 && k == 1 &&
b == 4)
876 return ROCDL::mfma_f32_16x16x1f32::getOperationName();
877 if (m == 4 && n == 4 && k == 1 &&
b == 16)
878 return ROCDL::mfma_f32_4x4x1f32::getOperationName();
879 if (m == 32 && n == 32 && k == 2 &&
b == 1)
880 return ROCDL::mfma_f32_32x32x2f32::getOperationName();
881 if (m == 16 && n == 16 && k == 4 &&
b == 1)
882 return ROCDL::mfma_f32_16x16x4f32::getOperationName();
887 if (m == 32 && n == 32 && k == 16 &&
b == 1)
888 return ROCDL::mfma_f32_32x32x16_f16::getOperationName();
889 if (m == 16 && n == 16 && k == 32 &&
b == 1)
890 return ROCDL::mfma_f32_16x16x32_f16::getOperationName();
892 if (m == 32 && n == 32 && k == 4 &&
b == 2)
893 return ROCDL::mfma_f32_32x32x4f16::getOperationName();
894 if (m == 16 && n == 16 && k == 4 &&
b == 4)
895 return ROCDL::mfma_f32_16x16x4f16::getOperationName();
896 if (m == 4 && n == 4 && k == 4 &&
b == 16)
897 return ROCDL::mfma_f32_4x4x4f16::getOperationName();
898 if (m == 32 && n == 32 && k == 8 &&
b == 1)
899 return ROCDL::mfma_f32_32x32x8f16::getOperationName();
900 if (m == 16 && n == 16 && k == 16 &&
b == 1)
901 return ROCDL::mfma_f32_16x16x16f16::getOperationName();
906 if (m == 32 && n == 32 && k == 16 &&
b == 1)
907 return ROCDL::mfma_f32_32x32x16_bf16::getOperationName();
908 if (m == 16 && n == 16 && k == 32 &&
b == 1)
909 return ROCDL::mfma_f32_16x16x32_bf16::getOperationName();
912 if (m == 32 && n == 32 && k == 4 &&
b == 2)
913 return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
914 if (m == 16 && n == 16 && k == 4 &&
b == 4)
915 return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
916 if (m == 4 && n == 4 && k == 4 &&
b == 16)
917 return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
918 if (m == 32 && n == 32 && k == 8 &&
b == 1)
919 return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
920 if (m == 16 && n == 16 && k == 16 &&
b == 1)
921 return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
923 if (m == 32 && n == 32 && k == 2 &&
b == 2)
924 return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
925 if (m == 16 && n == 16 && k == 2 &&
b == 4)
926 return ROCDL::mfma_f32_16x16x2bf16::getOperationName();
927 if (m == 4 && n == 4 && k == 2 &&
b == 16)
928 return ROCDL::mfma_f32_4x4x2bf16::getOperationName();
929 if (m == 32 && n == 32 && k == 4 &&
b == 1)
930 return ROCDL::mfma_f32_32x32x4bf16::getOperationName();
931 if (m == 16 && n == 16 && k == 8 &&
b == 1)
932 return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
937 if (m == 32 && n == 32 && k == 32 &&
b == 1)
938 return ROCDL::mfma_i32_32x32x32_i8::getOperationName();
939 if (m == 16 && n == 16 && k == 64 &&
b == 1)
940 return ROCDL::mfma_i32_16x16x64_i8::getOperationName();
942 if (m == 32 && n == 32 && k == 4 &&
b == 2)
943 return ROCDL::mfma_i32_32x32x4i8::getOperationName();
944 if (m == 16 && n == 16 && k == 4 &&
b == 4)
945 return ROCDL::mfma_i32_16x16x4i8::getOperationName();
946 if (m == 4 && n == 4 && k == 4 &&
b == 16)
947 return ROCDL::mfma_i32_4x4x4i8::getOperationName();
948 if (m == 32 && n == 32 && k == 8 &&
b == 1)
949 return ROCDL::mfma_i32_32x32x8i8::getOperationName();
950 if (m == 16 && n == 16 && k == 16 &&
b == 1)
951 return ROCDL::mfma_i32_16x16x16i8::getOperationName();
952 if (m == 32 && n == 32 && k == 16 &&
b == 1 && chipset >=
kGfx942)
953 return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
954 if (m == 16 && n == 16 && k == 32 &&
b == 1 && chipset >=
kGfx942)
955 return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
959 if (m == 16 && n == 16 && k == 4 &&
b == 1)
960 return ROCDL::mfma_f64_16x16x4f64::getOperationName();
961 if (m == 4 && n == 4 && k == 4 &&
b == 4)
962 return ROCDL::mfma_f64_4x4x4f64::getOperationName();
969 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
970 if (m == 16 && n == 16 && k == 32 &&
b == 1) {
972 return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
974 return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
976 if (m == 32 && n == 32 && k == 16 &&
b == 1) {
978 return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
980 return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
986 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
987 if (m == 16 && n == 16 && k == 32 &&
b == 1) {
989 return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
991 return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
993 if (m == 32 && n == 32 && k == 16 &&
b == 1) {
995 return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
997 return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
1001 return std::nullopt;
1006 .Case([](Float8E4M3FNType) {
return 0u; })
1007 .Case([](Float8E5M2Type) {
return 1u; })
1008 .Case([](Float6E2M3FNType) {
return 2u; })
1009 .Case([](Float6E3M2FNType) {
return 3u; })
1010 .Case([](Float4E2M1FNType) {
return 4u; })
1011 .Default(std::nullopt);
1021static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1023 uint32_t n, uint32_t k, uint32_t
b,
Chipset chipset) {
1029 return std::nullopt;
1030 if (!isa<Float32Type>(destType))
1031 return std::nullopt;
1035 if (!aTypeCode || !bTypeCode)
1036 return std::nullopt;
1038 if (m == 32 && n == 32 && k == 64 &&
b == 1)
1039 return std::tuple{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(),
1040 *aTypeCode, *bTypeCode};
1041 if (m == 16 && n == 16 && k == 128 &&
b == 1)
1043 ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), *aTypeCode,
1046 return std::nullopt;
1049static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1052 mfma.getSourceA().getType(), mfma.getSourceB().getType(),
1053 mfma.getDestC().getType(), mfma.getM(), mfma.getN(), mfma.getK(),
1054 mfma.getBlocks(), chipset);
1057static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1060 smfma.getSourceB().getType(),
1061 smfma.getDestC().getType(), smfma.getM(),
1062 smfma.getN(), smfma.getK(), 1u, chipset);
1067static std::optional<StringRef>
1069 Type elemDestType, uint32_t k,
bool isRDNA3) {
1070 using fp8 = Float8E4M3FNType;
1071 using bf8 = Float8E5M2Type;
1076 if (elemSourceType.
isF16() && elemDestType.
isF32())
1077 return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
1078 if (elemSourceType.
isBF16() && elemDestType.
isF32())
1079 return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
1080 if (elemSourceType.
isF16() && elemDestType.
isF16())
1081 return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
1083 return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
1085 return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
1090 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
1091 return std::nullopt;
1095 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
1096 elemDestType.
isF32())
1097 return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
1098 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
1099 elemDestType.
isF32())
1100 return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName();
1101 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
1102 elemDestType.
isF32())
1103 return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName();
1104 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
1105 elemDestType.
isF32())
1106 return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
1108 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
1110 return std::nullopt;
1114 if (k == 32 && !isRDNA3) {
1116 return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
1119 return std::nullopt;
1125 Type elemBSourceType,
1128 using fp8 = Float8E4M3FNType;
1129 using bf8 = Float8E5M2Type;
1132 if (elemSourceType.
isF32() && elemDestType.
isF32())
1133 return ROCDL::wmma_f32_16x16x4_f32::getOperationName();
1135 return std::nullopt;
1139 if (elemSourceType.
isF16() && elemDestType.
isF32())
1140 return ROCDL::wmma_f32_16x16x32_f16::getOperationName();
1141 if (elemSourceType.
isBF16() && elemDestType.
isF32())
1142 return ROCDL::wmma_f32_16x16x32_bf16::getOperationName();
1143 if (elemSourceType.
isF16() && elemDestType.
isF16())
1144 return ROCDL::wmma_f16_16x16x32_f16::getOperationName();
1146 return ROCDL::wmma_bf16_16x16x32_bf16::getOperationName();
1148 return std::nullopt;
1152 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1153 if (elemDestType.
isF32())
1154 return ROCDL::wmma_f32_16x16x64_fp8_fp8::getOperationName();
1155 if (elemDestType.
isF16())
1156 return ROCDL::wmma_f16_16x16x64_fp8_fp8::getOperationName();
1158 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1159 if (elemDestType.
isF32())
1160 return ROCDL::wmma_f32_16x16x64_fp8_bf8::getOperationName();
1161 if (elemDestType.
isF16())
1162 return ROCDL::wmma_f16_16x16x64_fp8_bf8::getOperationName();
1164 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1165 if (elemDestType.
isF32())
1166 return ROCDL::wmma_f32_16x16x64_bf8_bf8::getOperationName();
1167 if (elemDestType.
isF16())
1168 return ROCDL::wmma_f16_16x16x64_bf8_bf8::getOperationName();
1170 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1171 if (elemDestType.
isF32())
1172 return ROCDL::wmma_f32_16x16x64_bf8_fp8::getOperationName();
1173 if (elemDestType.
isF16())
1174 return ROCDL::wmma_f16_16x16x64_bf8_fp8::getOperationName();
1177 return ROCDL::wmma_i32_16x16x64_iu8::getOperationName();
1179 return std::nullopt;
1183 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1184 if (elemDestType.
isF32())
1185 return ROCDL::wmma_f32_16x16x128_fp8_fp8::getOperationName();
1186 if (elemDestType.
isF16())
1187 return ROCDL::wmma_f16_16x16x128_fp8_fp8::getOperationName();
1189 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1190 if (elemDestType.
isF32())
1191 return ROCDL::wmma_f32_16x16x128_fp8_bf8::getOperationName();
1192 if (elemDestType.
isF16())
1193 return ROCDL::wmma_f16_16x16x128_fp8_bf8::getOperationName();
1195 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1196 if (elemDestType.
isF32())
1197 return ROCDL::wmma_f32_16x16x128_bf8_bf8::getOperationName();
1198 if (elemDestType.
isF16())
1199 return ROCDL::wmma_f16_16x16x128_bf8_bf8::getOperationName();
1201 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1202 if (elemDestType.
isF32())
1203 return ROCDL::wmma_f32_16x16x128_bf8_fp8::getOperationName();
1204 if (elemDestType.
isF16())
1205 return ROCDL::wmma_f16_16x16x128_bf8_fp8::getOperationName();
1208 return std::nullopt;
1211 return std::nullopt;
1219 bool isGfx950 = chipset >=
kGfx950;
1223 uint32_t m = op.getM(), n = op.getN(), k = op.getK();
1228 if (m == 16 && n == 16 && k == 32) {
1230 return ROCDL::smfmac_f32_16x16x32_f16::getOperationName();
1232 return ROCDL::smfmac_f32_16x16x32_bf16::getOperationName();
1235 if (m == 16 && n == 16 && k == 64) {
1238 return ROCDL::smfmac_f32_16x16x64_f16::getOperationName();
1240 return ROCDL::smfmac_f32_16x16x64_bf16::getOperationName();
1244 return ROCDL::smfmac_i32_16x16x64_i8::getOperationName();
1245 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1246 return ROCDL::smfmac_f32_16x16x64_fp8_fp8::getOperationName();
1247 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1248 return ROCDL::smfmac_f32_16x16x64_fp8_bf8::getOperationName();
1249 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1250 return ROCDL::smfmac_f32_16x16x64_bf8_fp8::getOperationName();
1251 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1252 return ROCDL::smfmac_f32_16x16x64_bf8_bf8::getOperationName();
1255 if (m == 16 && n == 16 && k == 128 && isGfx950) {
1258 return ROCDL::smfmac_i32_16x16x128_i8::getOperationName();
1259 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1260 return ROCDL::smfmac_f32_16x16x128_fp8_fp8::getOperationName();
1261 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1262 return ROCDL::smfmac_f32_16x16x128_fp8_bf8::getOperationName();
1263 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1264 return ROCDL::smfmac_f32_16x16x128_bf8_fp8::getOperationName();
1265 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1266 return ROCDL::smfmac_f32_16x16x128_bf8_bf8::getOperationName();
1269 if (m == 32 && n == 32 && k == 16) {
1271 return ROCDL::smfmac_f32_32x32x16_f16::getOperationName();
1273 return ROCDL::smfmac_f32_32x32x16_bf16::getOperationName();
1276 if (m == 32 && n == 32 && k == 32) {
1279 return ROCDL::smfmac_f32_32x32x32_f16::getOperationName();
1281 return ROCDL::smfmac_f32_32x32x32_bf16::getOperationName();
1285 return ROCDL::smfmac_i32_32x32x32_i8::getOperationName();
1286 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1287 return ROCDL::smfmac_f32_32x32x32_fp8_fp8::getOperationName();
1288 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1289 return ROCDL::smfmac_f32_32x32x32_fp8_bf8::getOperationName();
1290 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1291 return ROCDL::smfmac_f32_32x32x32_bf8_fp8::getOperationName();
1292 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1293 return ROCDL::smfmac_f32_32x32x32_bf8_bf8::getOperationName();
1296 if (m == 32 && n == 32 && k == 64 && isGfx950) {
1299 return ROCDL::smfmac_i32_32x32x64_i8::getOperationName();
1300 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1301 return ROCDL::smfmac_f32_32x32x64_fp8_fp8::getOperationName();
1302 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1303 return ROCDL::smfmac_f32_32x32x64_fp8_bf8::getOperationName();
1304 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1305 return ROCDL::smfmac_f32_32x32x64_bf8_fp8::getOperationName();
1306 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1307 return ROCDL::smfmac_f32_32x32x64_bf8_bf8::getOperationName();
1310 return std::nullopt;
1318 auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType());
1319 auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType());
1320 auto destVectorType = cast<VectorType>(wmma.getDestC().getType());
1321 Type elemSourceType = sourceVectorType.getElementType();
1322 Type elemBSourceType = sourceBVectorType.getElementType();
1323 Type elemDestType = destVectorType.getElementType();
1325 const uint32_t k = wmma.getK();
1330 if (isRDNA3 || isRDNA4)
1339 return std::nullopt;
1344 MFMAOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1345 : ConvertOpToLLVMPattern<MFMAOp>(converter), chipset(chipset) {}
1350 matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor,
1351 ConversionPatternRewriter &rewriter)
const override {
1352 Location loc = op.getLoc();
1353 Type outType = typeConverter->convertType(op.getDestD().getType());
1354 Type intrinsicOutType = outType;
1355 if (
auto outVecType = dyn_cast<VectorType>(outType))
1356 if (outVecType.getElementType().isBF16())
1357 intrinsicOutType = outVecType.clone(rewriter.getI16Type());
1359 if (chipset.majorVersion != 9 || chipset <
kGfx908)
1360 return op->emitOpError(
"MFMA only supported on gfx908+");
1361 uint32_t getBlgpField =
static_cast<uint32_t
>(op.getBlgp());
1362 if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
1364 return op.emitOpError(
"negation unsupported on older than gfx942");
1366 op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
1369 std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1371 if (!maybeIntrinsic.has_value() && !maybeScaledIntrinsic.has_value())
1372 return op.emitOpError(
"no intrinsic matching MFMA size on given chipset");
1375 !maybeIntrinsic.has_value() && maybeScaledIntrinsic.has_value();
1377 (adaptor.getAbid() > 0 || getBlgpField > 0 || op.getCbsz() > 0)) {
1378 return op.emitOpError(
1379 "non-default abid, blgp, and cbsz aren't supported on MFMAs that can "
1380 "be scaled as those fields are used for type information");
1383 StringRef intrinsicName =
1384 isScaled ? std::get<0>(*maybeScaledIntrinsic) : *maybeIntrinsic;
1387 bool allowBf16 = [&]() {
1392 return intrinsicName.contains(
"16x16x32.bf16") ||
1393 intrinsicName.contains(
"32x32x16.bf16");
1395 OperationState loweredOp(loc, intrinsicName);
1396 loweredOp.addTypes(intrinsicOutType);
1398 rewriter, loc, adaptor.getSourceA(), allowBf16),
1400 rewriter, loc, adaptor.getSourceB(), allowBf16),
1401 adaptor.getDestC()});
1404 auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1405 loweredOp.addOperands({zero, zero});
1406 loweredOp.addAttributes({{
"cbsz", rewriter.getI32IntegerAttr(aTypeCode)},
1407 {
"blgp", rewriter.getI32IntegerAttr(bTypeCode)},
1408 {
"opselA", rewriter.getI32IntegerAttr(0)},
1409 {
"opselB", rewriter.getI32IntegerAttr(0)}});
1411 loweredOp.addAttributes(
1412 {{
"cbsz", rewriter.getI32IntegerAttr(op.getCbsz())},
1413 {
"abid", rewriter.getI32IntegerAttr(op.getAbid())},
1414 {
"blgp", rewriter.getI32IntegerAttr(getBlgpField)}});
1416 Value lowered = rewriter.create(loweredOp)->getResult(0);
1417 if (outType != intrinsicOutType)
1418 lowered = LLVM::BitcastOp::create(rewriter, loc, outType, lowered);
1419 rewriter.replaceOp(op, lowered);
1425 ScaledMFMAOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1426 : ConvertOpToLLVMPattern(converter), chipset(chipset) {}
1431 matchAndRewrite(ScaledMFMAOp op, ScaledMFMAOpAdaptor adaptor,
1432 ConversionPatternRewriter &rewriter)
const override {
1433 Location loc = op.getLoc();
1434 Type intrinsicOutType = typeConverter->convertType(op.getDestD().getType());
1436 if (chipset.majorVersion != 9 || chipset <
kGfx950)
1437 return op->emitOpError(
"scaled MFMA only supported on gfx908+");
1438 std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1440 if (!maybeScaledIntrinsic.has_value())
1441 return op.emitOpError(
1442 "no intrinsic matching scaled MFMA size on given chipset");
1444 auto [intrinsicName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1445 OperationState loweredOp(loc, intrinsicName);
1446 loweredOp.addTypes(intrinsicOutType);
1447 loweredOp.addOperands(
1450 adaptor.getDestC()});
1451 loweredOp.addOperands(
1456 loweredOp.addAttributes(
1457 {{
"cbsz", rewriter.getI32IntegerAttr(aTypeCode)},
1458 {
"blgp", rewriter.getI32IntegerAttr(bTypeCode)},
1459 {
"opselA", rewriter.getI32IntegerAttr(adaptor.getScalesIdxA())},
1460 {
"opselB", rewriter.getI32IntegerAttr(adaptor.getScalesIdxB())}});
1462 Value lowered = rewriter.create(loweredOp)->getResult(0);
1463 rewriter.replaceOp(op, lowered);
1469 SparseMFMAOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1470 : ConvertOpToLLVMPattern<SparseMFMAOp>(converter), chipset(chipset) {}
1475 matchAndRewrite(SparseMFMAOp op, SparseMFMAOpAdaptor adaptor,
1476 ConversionPatternRewriter &rewriter)
const override {
1477 Location loc = op.getLoc();
1479 typeConverter->convertType<VectorType>(op.getDestC().
getType());
1481 return rewriter.notifyMatchFailure(op,
"type conversion failed");
1484 if (chipset.majorVersion != 9 || chipset <
kGfx942)
1485 return op->emitOpError(
"sparse MFMA (smfmac) only supported on gfx942+");
1486 bool isGfx950 = chipset >=
kGfx950;
1489 adaptor.getSourceA(), isGfx950);
1491 adaptor.getSourceB(), isGfx950);
1492 Value c = adaptor.getDestC();
1495 if (!maybeIntrinsic.has_value())
1496 return op.emitOpError(
1497 "no intrinsic matching sparse MFMA on the given chipset");
1500 Value sparseIdx = LLVM::BitcastOp::create(
1501 rewriter, loc, rewriter.getI32Type(), adaptor.getSparseIdx());
1503 OperationState loweredOp(loc, maybeIntrinsic.value());
1504 loweredOp.addTypes(outType);
1505 loweredOp.addOperands({a,
b, c, sparseIdx});
1506 loweredOp.addAttributes(
1507 {{
"cbsz", rewriter.getI32IntegerAttr(op.getCbsz())},
1508 {
"abid", rewriter.getI32IntegerAttr(op.getAbid())}});
1509 Value lowered = rewriter.create(loweredOp)->getResult(0);
1510 rewriter.replaceOp(op, lowered);
1516 WMMAOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1517 : ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
1522 matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
1523 ConversionPatternRewriter &rewriter)
const override {
1524 Location loc = op.getLoc();
1526 typeConverter->convertType<VectorType>(op.getDestD().
getType());
1528 return rewriter.notifyMatchFailure(op,
"type conversion failed");
1530 if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
1531 return op->emitOpError(
"WMMA only supported on gfx11 and gfx12");
1533 bool isGFX1250 = chipset >=
kGfx1250;
1538 auto aType = cast<VectorType>(adaptor.getSourceA().getType());
1539 auto bType = cast<VectorType>(adaptor.getSourceB().getType());
1540 auto destCType = cast<VectorType>(adaptor.getDestC().getType());
1541 bool castAToI16 = aType.getElementType().isBF16() && !isGFX1250;
1542 bool castBToI16 = bType.getElementType().isBF16() && !isGFX1250;
1543 bool castDestCToI16 = destCType.getElementType().isBF16() && !isGFX1250;
1544 bool castOutToI16 = outType.getElementType().
isBF16() && !isGFX1250;
1545 VectorType rawOutType = outType;
1547 rawOutType = outType.clone(rewriter.getI16Type());
1548 Value a = adaptor.getSourceA();
1550 a = LLVM::BitcastOp::create(rewriter, loc,
1551 aType.clone(rewriter.getI16Type()), a);
1552 Value
b = adaptor.getSourceB();
1554 b = LLVM::BitcastOp::create(rewriter, loc,
1555 bType.clone(rewriter.getI16Type()),
b);
1556 Value destC = adaptor.getDestC();
1558 destC = LLVM::BitcastOp::create(
1559 rewriter, loc, destCType.clone(rewriter.getI16Type()), destC);
1563 if (!maybeIntrinsic.has_value())
1564 return op.emitOpError(
"no intrinsic matching WMMA on the given chipset");
1566 if (chipset.majorVersion >= 12 && op.getSubwordOffset() != 0)
1567 return op.emitOpError(
"subwordOffset not supported on gfx12+");
1569 SmallVector<Value, 4> operands;
1570 SmallVector<NamedAttribute, 4> attrs;
1572 op.getSourceA(), operands, attrs,
"signA");
1574 op.getSourceB(), operands, attrs,
"signB");
1576 op.getSubwordOffset(), op.getClamp(), operands,
1579 OperationState loweredOp(loc, *maybeIntrinsic);
1580 loweredOp.addTypes(rawOutType);
1581 loweredOp.addOperands(operands);
1582 loweredOp.addAttributes(attrs);
1583 Operation *lowered = rewriter.create(loweredOp);
1585 Operation *maybeCastBack = lowered;
1586 if (rawOutType != outType)
1587 maybeCastBack = LLVM::BitcastOp::create(rewriter, loc, outType,
1589 rewriter.replaceOp(op, maybeCastBack->
getResults());
1596 ScaledWMMAOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1597 : ConvertOpToLLVMPattern<ScaledWMMAOp>(converter), chipset(chipset) {}
1602 matchAndRewrite(ScaledWMMAOp op, ScaledWMMAOpAdaptor adaptor,
1603 ConversionPatternRewriter &rewriter)
const override {
1604 Location loc = op.getLoc();
1606 typeConverter->convertType<VectorType>(op.getDestD().
getType());
1608 return rewriter.notifyMatchFailure(op,
"type conversion failed");
1611 return op->emitOpError(
"WMMA scale only supported on gfx1250+");
1613 int64_t m = op.getM();
1614 int64_t n = op.getN();
1615 int64_t k = op.getK();
1623 if (!aFmtCode || !bFmtCode)
1624 return op.emitOpError(
"unsupported element types for scaled_wmma");
1627 auto scaleAVecType = cast<VectorType>(op.getScaleA().getType());
1628 auto scaleBVecType = cast<VectorType>(op.getScaleB().getType());
1630 if (scaleAVecType.getNumElements() != scaleBVecType.getNumElements())
1631 return op.emitOpError(
"scaleA and scaleB must have equal vector length");
1634 Type scaleAElemType = scaleAVecType.getElementType();
1635 Type scaleBElemType = scaleBVecType.getElementType();
1640 if (!scaleAFmt || !scaleBFmt)
1641 return op.emitOpError(
"unsupported scale element types");
1644 bool isScale16 = (scaleAVecType.getNumElements() == 8);
1645 std::optional<StringRef> intrinsicName =
1648 return op.emitOpError(
"unsupported scaled_wmma dimensions: ")
1649 << m <<
"x" << n <<
"x" << k;
1651 SmallVector<NamedAttribute, 8> attrs;
1654 bool is32x16 = (m == 32 && n == 16 && k == 128);
1656 attrs.emplace_back(
"fmtA", rewriter.getI32IntegerAttr(*aFmtCode));
1657 attrs.emplace_back(
"fmtB", rewriter.getI32IntegerAttr(*bFmtCode));
1661 attrs.emplace_back(
"modC", rewriter.getI16IntegerAttr(0));
1666 "scaleAType", rewriter.getI32IntegerAttr(op.getAFirstScaleLane() / 16));
1667 attrs.emplace_back(
"fmtScaleA", rewriter.getI32IntegerAttr(*scaleAFmt));
1669 "scaleBType", rewriter.getI32IntegerAttr(op.getBFirstScaleLane() / 16));
1670 attrs.emplace_back(
"fmtScaleB", rewriter.getI32IntegerAttr(*scaleBFmt));
1673 attrs.emplace_back(
"reuseA", rewriter.getBoolAttr(
false));
1674 attrs.emplace_back(
"reuseB", rewriter.getBoolAttr(
false));
1687 OperationState loweredOp(loc, *intrinsicName);
1688 loweredOp.addTypes(outType);
1689 loweredOp.addOperands(
1690 {sourceA, sourceB, adaptor.getDestC(), packedScaleA, packedScaleB});
1691 loweredOp.addAttributes(attrs);
1693 Operation *lowered = rewriter.create(loweredOp);
1694 rewriter.replaceOp(op, lowered->
getResults());
1700struct TransposeLoadOpLowering
1702 TransposeLoadOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1703 : ConvertOpToLLVMPattern<TransposeLoadOp>(converter), chipset(chipset) {}
1708 matchAndRewrite(TransposeLoadOp op, TransposeLoadOpAdaptor adaptor,
1709 ConversionPatternRewriter &rewriter)
const override {
1711 return op.emitOpError(
"Non-gfx950 chipset not supported");
1713 Location loc = op.getLoc();
1714 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
1718 size_t srcElementSize =
1719 srcMemRefType.getElementType().getIntOrFloatBitWidth();
1720 if (srcElementSize < 8)
1721 return op.emitOpError(
"Expect source memref to have at least 8 bits "
1722 "element size, got ")
1725 auto resultType = cast<VectorType>(op.getResult().getType());
1728 (adaptor.getSrcIndices()));
1730 size_t numElements = resultType.getNumElements();
1731 size_t elementTypeSize =
1732 resultType.getElementType().getIntOrFloatBitWidth();
1736 Type rocdlResultType = VectorType::get((numElements * elementTypeSize) / 32,
1737 rewriter.getIntegerType(32));
1738 Type llvmResultType = typeConverter->convertType(resultType);
1740 switch (elementTypeSize) {
1742 assert(numElements == 16);
1743 auto rocdlOp = ROCDL::ds_read_tr4_b64::create(rewriter, loc,
1744 rocdlResultType, srcPtr);
1745 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1749 assert(numElements == 16);
1750 auto rocdlOp = ROCDL::ds_read_tr6_b96::create(rewriter, loc,
1751 rocdlResultType, srcPtr);
1752 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1756 assert(numElements == 8);
1757 auto rocdlOp = ROCDL::ds_read_tr8_b64::create(rewriter, loc,
1758 rocdlResultType, srcPtr);
1759 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1763 assert(numElements == 4);
1764 rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr16_b64>(op, llvmResultType,
1769 return op.emitOpError(
"Unsupported element size for transpose load");
1776 GatherToLDSOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1777 : ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}
1782 matchAndRewrite(GatherToLDSOp op, GatherToLDSOpAdaptor adaptor,
1783 ConversionPatternRewriter &rewriter)
const override {
1784 if (chipset.majorVersion < 9 || chipset.majorVersion > 10)
1785 return op.emitOpError(
"pre-gfx9 and post-gfx10 not supported");
1787 Location loc = op.getLoc();
1789 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
1790 auto dstMemRefType = cast<MemRefType>(op.getDst().getType());
1795 Type transferType = op.getTransferType();
1796 int loadWidth = [&]() ->
int {
1797 if (
auto transferVectorType = dyn_cast<VectorType>(transferType)) {
1798 return (transferVectorType.getNumElements() *
1799 transferVectorType.getElementTypeBitWidth()) /
1806 if (!llvm::is_contained({1, 2, 4, 12, 16}, loadWidth))
1807 return op.emitOpError(
"chipset unsupported element size");
1809 if (chipset !=
kGfx950 && llvm::is_contained({12, 16}, loadWidth))
1810 return op.emitOpError(
"Gather to LDS instructions with 12-byte and "
1811 "16-byte load widths are only supported on gfx950");
1815 (adaptor.getSrcIndices()));
1818 (adaptor.getDstIndices()));
1820 rewriter.replaceOpWithNewOp<ROCDL::LoadToLDSOp>(
1821 op, srcPtr, dstPtr, rewriter.getI32IntegerAttr(loadWidth),
1822 rewriter.getI32IntegerAttr(0),
1831struct ExtPackedFp8OpLowering final
1833 ExtPackedFp8OpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1834 : ConvertOpToLLVMPattern<amdgpu::ExtPackedFp8Op>(converter),
1839 matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
1840 ConversionPatternRewriter &rewriter)
const override;
1843struct ScaledExtPackedMatrixOpLowering final
1845 ScaledExtPackedMatrixOpLowering(
const LLVMTypeConverter &converter,
1847 : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedMatrixOp>(converter),
1852 matchAndRewrite(ScaledExtPackedMatrixOp op,
1853 ScaledExtPackedMatrixOpAdaptor adaptor,
1854 ConversionPatternRewriter &rewriter)
const override;
1857struct PackedTrunc2xFp8OpLowering final
1859 PackedTrunc2xFp8OpLowering(
const LLVMTypeConverter &converter,
1861 : ConvertOpToLLVMPattern<amdgpu::PackedTrunc2xFp8Op>(converter),
1866 matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
1867 ConversionPatternRewriter &rewriter)
const override;
1870struct PackedStochRoundFp8OpLowering final
1872 PackedStochRoundFp8OpLowering(
const LLVMTypeConverter &converter,
1874 : ConvertOpToLLVMPattern<amdgpu::PackedStochRoundFp8Op>(converter),
1879 matchAndRewrite(PackedStochRoundFp8Op op,
1880 PackedStochRoundFp8OpAdaptor adaptor,
1881 ConversionPatternRewriter &rewriter)
const override;
1884struct ScaledExtPackedOpLowering final
1886 ScaledExtPackedOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1887 : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedOp>(converter),
1892 matchAndRewrite(ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
1893 ConversionPatternRewriter &rewriter)
const override;
1896struct PackedScaledTruncOpLowering final
1898 PackedScaledTruncOpLowering(
const LLVMTypeConverter &converter,
1900 : ConvertOpToLLVMPattern<amdgpu::PackedScaledTruncOp>(converter),
1905 matchAndRewrite(PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
1906 ConversionPatternRewriter &rewriter)
const override;
1911LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
1912 ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
1913 ConversionPatternRewriter &rewriter)
const {
1914 Location loc = op.getLoc();
1916 return rewriter.notifyMatchFailure(
1917 loc,
"Fp8 conversion instructions are not available on target "
1918 "architecture and their emulation is not implemented");
1920 getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type()));
1921 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1922 Type f32 = getTypeConverter()->convertType(op.getResult().getType());
1924 Value source = adaptor.getSource();
1925 auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
1926 auto resultVecType = dyn_cast<VectorType>(op.getResult().getType());
1929 if (!sourceVecType || sourceVecType.getNumElements() < 4) {
1930 Value longVec = LLVM::UndefOp::create(rewriter, loc, v4i8);
1931 if (!sourceVecType) {
1932 longVec = LLVM::InsertElementOp::create(
1935 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
1937 Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
1939 LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
1944 Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
1945 if (resultVecType) {
1947 rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Bf8Op>(op, f32, i32Source,
1950 rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Fp8Op>(op, f32, i32Source,
1955 rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
1958 rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
1965int32_t getScaleSel(int32_t blockSize,
unsigned bitWidth, int32_t scaleWaveHalf,
1966 int32_t firstScaleByte) {
1972 assert(llvm::is_contained({16, 32}, blockSize));
1973 assert(llvm::is_contained({4u, 6u, 8u}, bitWidth));
1975 const bool isFp8 = bitWidth == 8;
1976 const bool isBlock16 = blockSize == 16;
1979 int32_t bit0 = isBlock16;
1980 assert(llvm::is_contained({0, 1, 2}, firstScaleByte));
1981 int32_t bit1 = (firstScaleByte == 2) << 1;
1982 assert(llvm::is_contained({0, 1}, scaleWaveHalf));
1983 int32_t bit2 = scaleWaveHalf << 2;
1984 return bit2 | bit1 | bit0;
1987 int32_t bit0 = isBlock16;
1989 assert(llvm::is_contained({0, 1, 2, 3}, firstScaleByte));
1990 int32_t bits2and1 = firstScaleByte << 1;
1991 assert(llvm::is_contained({0, 1}, scaleWaveHalf));
1992 int32_t bit3 = scaleWaveHalf << 3;
1993 int32_t bits = bit3 | bits2and1 | bit0;
1995 assert(!llvm::is_contained(
1996 {0b0011, 0b0101, 0b0111, 0b1000, 0b1001, 0b1011, 0b1111}, bits));
2000static std::optional<StringRef>
2001scaledExtPacked816ToIntrinsic(Type srcElemType, Type destElemType) {
2002 using fp4 = Float4E2M1FNType;
2003 using fp8 = Float8E4M3FNType;
2004 using bf8 = Float8E5M2Type;
2005 using fp6 = Float6E2M3FNType;
2006 using bf6 = Float6E3M2FNType;
2007 if (isa<fp4>(srcElemType)) {
2008 if (destElemType.
isF16())
2009 return ROCDL::CvtPkScalePk8F16Fp4Op::getOperationName();
2010 if (destElemType.
isBF16())
2011 return ROCDL::CvtPkScalePk8Bf16Fp4Op::getOperationName();
2012 if (destElemType.
isF32())
2013 return ROCDL::CvtPkScalePk8F32Fp4Op::getOperationName();
2014 return std::nullopt;
2016 if (isa<fp8>(srcElemType)) {
2017 if (destElemType.
isF16())
2018 return ROCDL::CvtPkScalePk8F16Fp8Op::getOperationName();
2019 if (destElemType.
isBF16())
2020 return ROCDL::CvtPkScalePk8Bf16Fp8Op::getOperationName();
2021 if (destElemType.
isF32())
2022 return ROCDL::CvtPkScalePk8F32Fp8Op::getOperationName();
2023 return std::nullopt;
2025 if (isa<bf8>(srcElemType)) {
2026 if (destElemType.
isF16())
2027 return ROCDL::CvtPkScalePk8F16Bf8Op::getOperationName();
2028 if (destElemType.
isBF16())
2029 return ROCDL::CvtPkScalePk8Bf16Bf8Op::getOperationName();
2030 if (destElemType.
isF32())
2031 return ROCDL::CvtPkScalePk8F32Bf8Op::getOperationName();
2032 return std::nullopt;
2034 if (isa<fp6>(srcElemType)) {
2035 if (destElemType.
isF16())
2036 return ROCDL::CvtPkScalePk16F16Fp6Op::getOperationName();
2037 if (destElemType.
isBF16())
2038 return ROCDL::CvtPkScalePk16Bf16Fp6Op::getOperationName();
2039 if (destElemType.
isF32())
2040 return ROCDL::CvtPkScalePk16F32Fp6Op::getOperationName();
2041 return std::nullopt;
2043 if (isa<bf6>(srcElemType)) {
2044 if (destElemType.
isF16())
2045 return ROCDL::CvtPkScalePk16F16Bf6Op::getOperationName();
2046 if (destElemType.
isBF16())
2047 return ROCDL::CvtPkScalePk16Bf16Bf6Op::getOperationName();
2048 if (destElemType.
isF32())
2049 return ROCDL::CvtPkScalePk16F32Bf6Op::getOperationName();
2050 return std::nullopt;
2052 llvm_unreachable(
"invalid combination of element types for packed conversion "
2056LogicalResult ScaledExtPackedMatrixOpLowering::matchAndRewrite(
2057 ScaledExtPackedMatrixOp op, ScaledExtPackedMatrixOpAdaptor adaptor,
2058 ConversionPatternRewriter &rewriter)
const {
2059 using fp4 = Float4E2M1FNType;
2060 using fp8 = Float8E4M3FNType;
2061 using bf8 = Float8E5M2Type;
2062 using fp6 = Float6E2M3FNType;
2063 using bf6 = Float6E3M2FNType;
2064 Location loc = op.getLoc();
2066 return rewriter.notifyMatchFailure(
2068 "Scaled fp packed conversion instructions are not available on target "
2069 "architecture and their emulation is not implemented");
2073 int32_t scaleWaveHalf = op.getFirstScaleLane() / 16;
2074 int32_t firstScaleByte = op.getFirstScaleByte();
2075 int32_t blockSize = op.getBlockSize();
2076 auto sourceType = cast<VectorType>(op.getSource().getType());
2077 auto srcElemType = cast<FloatType>(sourceType.getElementType());
2078 unsigned bitWidth = srcElemType.getWidth();
2080 auto targetType = cast<VectorType>(op.getResult().getType());
2081 auto destElemType = cast<FloatType>(targetType.getElementType());
2083 IntegerType i32 = rewriter.getI32Type();
2084 Value source = adaptor.getSource();
2085 Type llvmResultType = typeConverter->convertType(op.getResult().getType());
2086 Type packedType =
nullptr;
2087 if (isa<fp4>(srcElemType)) {
2089 packedType = getTypeConverter()->convertType(packedType);
2090 }
else if (isa<fp8, bf8>(srcElemType)) {
2091 packedType = VectorType::get(2, i32);
2092 packedType = getTypeConverter()->convertType(packedType);
2093 }
else if (isa<fp6, bf6>(srcElemType)) {
2094 packedType = VectorType::get(3, i32);
2095 packedType = getTypeConverter()->convertType(packedType);
2097 llvm_unreachable(
"invalid element type for packed scaled ext");
2100 if (!packedType || !llvmResultType) {
2101 return rewriter.notifyMatchFailure(op,
"type conversion failed");
2104 std::optional<StringRef> maybeIntrinsic =
2105 scaledExtPacked816ToIntrinsic(srcElemType, destElemType);
2106 if (!maybeIntrinsic.has_value())
2107 return op.emitOpError(
2108 "no intrinsic matching packed scaled conversion on the given chipset");
2111 getScaleSel(blockSize, bitWidth, scaleWaveHalf, firstScaleByte);
2113 LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getScale());
2114 Value castedSource =
2115 LLVM::BitcastOp::create(rewriter, loc, packedType, source);
2117 OperationState loweredOp(loc, *maybeIntrinsic);
2118 loweredOp.addTypes({llvmResultType});
2119 loweredOp.addOperands({castedSource, castedScale});
2121 SmallVector<NamedAttribute, 1> attrs;
2123 NamedAttribute(
"scaleSel", rewriter.getI32IntegerAttr(scaleSel)));
2125 loweredOp.addAttributes(attrs);
2126 Operation *lowered = rewriter.create(loweredOp);
2127 rewriter.replaceOp(op, lowered);
2132LogicalResult ScaledExtPackedOpLowering::matchAndRewrite(
2133 ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
2134 ConversionPatternRewriter &rewriter)
const {
2135 Location loc = op.getLoc();
2137 return rewriter.notifyMatchFailure(
2138 loc,
"Scaled fp conversion instructions are not available on target "
2139 "architecture and their emulation is not implemented");
2140 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2142 Value source = adaptor.getSource();
2143 Value scale = adaptor.getScale();
2145 VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
2146 Type sourceElemType = sourceVecType.getElementType();
2147 VectorType destVecType = cast<VectorType>(op.getResult().getType());
2148 Type destElemType = destVecType.getElementType();
2150 VectorType packedVecType;
2151 if (isa<Float8E5M2Type, Float8E4M3FNType>(sourceElemType)) {
2152 VectorType v4i8 = VectorType::get(4, rewriter.getI8Type());
2153 packedVecType = cast<VectorType>(getTypeConverter()->convertType(v4i8));
2154 }
else if (isa<Float4E2M1FNType>(sourceElemType)) {
2155 VectorType v8i4 = VectorType::get(8, rewriter.getI4Type());
2156 packedVecType = cast<VectorType>(getTypeConverter()->convertType(v8i4));
2158 llvm_unreachable(
"invalid element type for scaled ext");
2162 if (sourceVecType.getNumElements() < packedVecType.getNumElements()) {
2163 Value longVec = LLVM::ZeroOp::create(rewriter, loc, packedVecType);
2164 if (!sourceVecType) {
2165 longVec = LLVM::InsertElementOp::create(
2168 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
2170 Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
2172 LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
2177 Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
2179 if (isa<Float8E5M2Type>(sourceElemType) && destElemType.
isF32())
2180 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Bf8Op>(
2181 op, destVecType, i32Source, scale, op.getIndex());
2182 else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.
isF16())
2183 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Bf8Op>(
2184 op, destVecType, i32Source, scale, op.getIndex());
2185 else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.
isBF16())
2186 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Bf8Op>(
2187 op, destVecType, i32Source, scale, op.getIndex());
2188 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.
isF32())
2189 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp8Op>(
2190 op, destVecType, i32Source, scale, op.getIndex());
2191 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.
isF16())
2192 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp8Op>(
2193 op, destVecType, i32Source, scale, op.getIndex());
2194 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.
isBF16())
2195 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp8Op>(
2196 op, destVecType, i32Source, scale, op.getIndex());
2197 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.
isF32())
2198 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp4Op>(
2199 op, destVecType, i32Source, scale, op.getIndex());
2200 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.
isF16())
2201 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp4Op>(
2202 op, destVecType, i32Source, scale, op.getIndex());
2203 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.
isBF16())
2204 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp4Op>(
2205 op, destVecType, i32Source, scale, op.getIndex());
2212LogicalResult PackedScaledTruncOpLowering::matchAndRewrite(
2213 PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
2214 ConversionPatternRewriter &rewriter)
const {
2215 Location loc = op.getLoc();
2217 return rewriter.notifyMatchFailure(
2218 loc,
"Scaled fp conversion instructions are not available on target "
2219 "architecture and their emulation is not implemented");
2220 Type v2i16 = getTypeConverter()->convertType(
2221 VectorType::get(2, rewriter.getI16Type()));
2222 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2224 Type resultType = op.getResult().getType();
2226 VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
2227 Type sourceElemType = sourceVecType.getElementType();
2229 Type intResultType = isa<Float4E2M1FNType>(resultElemType) ? i32 : v2i16;
2231 Value source = adaptor.getSource();
2232 Value scale = adaptor.getScale();
2233 Value existing = adaptor.getExisting();
2235 existing = LLVM::BitcastOp::create(rewriter, loc, intResultType, existing);
2237 existing = LLVM::ZeroOp::create(rewriter, loc, intResultType);
2239 if (sourceVecType.getNumElements() < 2) {
2241 Value elem0 = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
2242 VectorType v2 = VectorType::get(2, sourceElemType);
2243 source = LLVM::ZeroOp::create(rewriter, loc, v2);
2244 source = LLVM::InsertElementOp::create(rewriter, loc, source, elem0, c0);
2247 Value sourceA, sourceB;
2248 if (sourceElemType.
isF32()) {
2251 sourceA = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
2252 sourceB = LLVM::ExtractElementOp::create(rewriter, loc, source, c1);
2256 if (sourceElemType.
isF32() && isa<Float8E5M2Type>(resultElemType))
2257 result = ROCDL::CvtScaleF32PkBf8F32Op::create(rewriter, loc, intResultType,
2258 existing, sourceA, sourceB,
2259 scale, op.getIndex());
2260 else if (sourceElemType.
isF16() && isa<Float8E5M2Type>(resultElemType))
2261 result = ROCDL::CvtScaleF32PkBf8F16Op::create(
2262 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2263 else if (sourceElemType.
isBF16() && isa<Float8E5M2Type>(resultElemType))
2264 result = ROCDL::CvtScaleF32PkBf8Bf16Op::create(
2265 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2266 else if (sourceElemType.
isF32() && isa<Float8E4M3FNType>(resultElemType))
2267 result = ROCDL::CvtScaleF32PkFp8F32Op::create(rewriter, loc, intResultType,
2268 existing, sourceA, sourceB,
2269 scale, op.getIndex());
2270 else if (sourceElemType.
isF16() && isa<Float8E4M3FNType>(resultElemType))
2271 result = ROCDL::CvtScaleF32PkFp8F16Op::create(
2272 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2273 else if (sourceElemType.
isBF16() && isa<Float8E4M3FNType>(resultElemType))
2274 result = ROCDL::CvtScaleF32PkFp8Bf16Op::create(
2275 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2276 else if (sourceElemType.
isF32() && isa<Float4E2M1FNType>(resultElemType))
2277 result = ROCDL::CvtScaleF32PkFp4F32Op::create(rewriter, loc, intResultType,
2278 existing, sourceA, sourceB,
2279 scale, op.getIndex());
2280 else if (sourceElemType.
isF16() && isa<Float4E2M1FNType>(resultElemType))
2281 result = ROCDL::CvtScaleF32PkFp4F16Op::create(
2282 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2283 else if (sourceElemType.
isBF16() && isa<Float4E2M1FNType>(resultElemType))
2284 result = ROCDL::CvtScaleF32PkFp4Bf16Op::create(
2285 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2289 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
2290 op, getTypeConverter()->convertType(resultType),
result);
2294LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
2295 PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
2296 ConversionPatternRewriter &rewriter)
const {
2297 Location loc = op.getLoc();
2299 return rewriter.notifyMatchFailure(
2300 loc,
"Fp8 conversion instructions are not available on target "
2301 "architecture and their emulation is not implemented");
2302 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2304 Type resultType = op.getResult().getType();
2307 Value sourceA = adaptor.getSourceA();
2308 Value sourceB = adaptor.getSourceB();
2310 sourceB = LLVM::UndefOp::create(rewriter, loc, sourceA.
getType());
2311 Value existing = adaptor.getExisting();
2313 existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
2315 existing = LLVM::UndefOp::create(rewriter, loc, i32);
2319 result = ROCDL::CvtPkBf8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
2320 existing, op.getWordIndex());
2322 result = ROCDL::CvtPkFp8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
2323 existing, op.getWordIndex());
2325 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
2326 op, getTypeConverter()->convertType(resultType),
result);
2330LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
2331 PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
2332 ConversionPatternRewriter &rewriter)
const {
2333 Location loc = op.getLoc();
2335 return rewriter.notifyMatchFailure(
2336 loc,
"Fp8 conversion instructions are not available on target "
2337 "architecture and their emulation is not implemented");
2338 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2340 Type resultType = op.getResult().getType();
2343 Value source = adaptor.getSource();
2344 Value stoch = adaptor.getStochiasticParam();
2345 Value existing = adaptor.getExisting();
2347 existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
2349 existing = LLVM::UndefOp::create(rewriter, loc, i32);
2353 result = ROCDL::CvtSrBf8F32Op::create(rewriter, loc, i32, source, stoch,
2354 existing, op.getStoreIndex());
2356 result = ROCDL::CvtSrFp8F32Op::create(rewriter, loc, i32, source, stoch,
2357 existing, op.getStoreIndex());
2359 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
2360 op, getTypeConverter()->convertType(resultType),
result);
2366struct AMDGPUDPPLowering :
public ConvertOpToLLVMPattern<DPPOp> {
2367 AMDGPUDPPLowering(
const LLVMTypeConverter &converter, Chipset chipset)
2368 : ConvertOpToLLVMPattern<DPPOp>(converter), chipset(chipset) {}
2372 matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor,
2373 ConversionPatternRewriter &rewriter)
const override {
2376 Location loc = DppOp.getLoc();
2377 Value src = adaptor.getSrc();
2378 Value old = adaptor.getOld();
2381 Type llvmType =
nullptr;
2383 llvmType = rewriter.getI32Type();
2384 }
else if (isa<FloatType>(srcType)) {
2386 ? rewriter.getF32Type()
2387 : rewriter.getF64Type();
2388 }
else if (isa<IntegerType>(srcType)) {
2390 ? rewriter.getI32Type()
2391 : rewriter.getI64Type();
2393 auto llvmSrcIntType = typeConverter->convertType(
2397 auto convertOperand = [&](Value operand, Type operandType) {
2398 if (operandType.getIntOrFloatBitWidth() <= 16) {
2399 if (llvm::isa<FloatType>(operandType)) {
2401 LLVM::BitcastOp::create(rewriter, loc, llvmSrcIntType, operand);
2403 auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
2404 32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
2405 Value undefVec = LLVM::UndefOp::create(rewriter, loc, llvmVecType);
2407 LLVM::InsertElementOp::create(rewriter, loc, undefVec, operand,
2409 operand = LLVM::BitcastOp::create(rewriter, loc, llvmType, operand);
2414 src = convertOperand(src, srcType);
2415 old = convertOperand(old, oldType);
2418 enum DppCtrl :
unsigned {
2427 ROW_HALF_MIRROR = 0x141,
2432 auto kind = DppOp.getKind();
2433 auto permArgument = DppOp.getPermArgument();
2434 uint32_t DppCtrl = 0;
2438 case DPPPerm::quad_perm: {
2439 auto quadPermAttr = cast<ArrayAttr>(*permArgument);
2441 for (
auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
2442 uint32_t num = elem.getInt();
2443 DppCtrl |= num << (i * 2);
2448 case DPPPerm::row_shl: {
2449 auto intAttr = cast<IntegerAttr>(*permArgument);
2450 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0;
2453 case DPPPerm::row_shr: {
2454 auto intAttr = cast<IntegerAttr>(*permArgument);
2455 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0;
2458 case DPPPerm::row_ror: {
2459 auto intAttr = cast<IntegerAttr>(*permArgument);
2460 DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0;
2463 case DPPPerm::wave_shl:
2464 DppCtrl = DppCtrl::WAVE_SHL1;
2466 case DPPPerm::wave_shr:
2467 DppCtrl = DppCtrl::WAVE_SHR1;
2469 case DPPPerm::wave_rol:
2470 DppCtrl = DppCtrl::WAVE_ROL1;
2472 case DPPPerm::wave_ror:
2473 DppCtrl = DppCtrl::WAVE_ROR1;
2475 case DPPPerm::row_mirror:
2476 DppCtrl = DppCtrl::ROW_MIRROR;
2478 case DPPPerm::row_half_mirror:
2479 DppCtrl = DppCtrl::ROW_HALF_MIRROR;
2481 case DPPPerm::row_bcast_15:
2482 DppCtrl = DppCtrl::BCAST15;
2484 case DPPPerm::row_bcast_31:
2485 DppCtrl = DppCtrl::BCAST31;
2491 auto rowMask = DppOp->getAttrOfType<IntegerAttr>(
"row_mask").getInt();
2492 auto bankMask = DppOp->getAttrOfType<IntegerAttr>(
"bank_mask").getInt();
2493 bool boundCtrl = DppOp->getAttrOfType<BoolAttr>(
"bound_ctrl").getValue();
2497 ROCDL::DPPUpdateOp::create(rewriter, loc, llvmType, old, src, DppCtrl,
2498 rowMask, bankMask, boundCtrl);
2500 Value
result = dppMovOp.getRes();
2502 result = LLVM::TruncOp::create(rewriter, loc, llvmSrcIntType,
result);
2503 if (!llvm::isa<IntegerType>(srcType)) {
2504 result = LLVM::BitcastOp::create(rewriter, loc, srcType,
result);
2515struct AMDGPUSwizzleBitModeLowering
2516 :
public ConvertOpToLLVMPattern<SwizzleBitModeOp> {
2520 matchAndRewrite(SwizzleBitModeOp op, OpAdaptor adaptor,
2521 ConversionPatternRewriter &rewriter)
const override {
2522 Location loc = op.getLoc();
2523 Type i32 = rewriter.getI32Type();
2524 Value src = adaptor.getSrc();
2525 SmallVector<Value> decomposed =
2526 LLVM::decomposeValue(rewriter, loc, src, i32);
2527 unsigned andMask = op.getAndMask();
2528 unsigned orMask = op.getOrMask();
2529 unsigned xorMask = op.getXorMask();
2533 unsigned mask = andMask | (orMask << 5) | (xorMask << 10);
2535 SmallVector<Value> swizzled;
2536 for (Value v : decomposed) {
2538 ROCDL::DsSwizzleOp::create(rewriter, loc, v.getType(), v, maskValue);
2539 swizzled.emplace_back(res);
2542 Value
result = LLVM::composeValue(rewriter, loc, swizzled, src.
getType());
2543 rewriter.replaceOp(op,
result);
2548struct AMDGPUPermlaneLowering :
public ConvertOpToLLVMPattern<PermlaneSwapOp> {
2551 AMDGPUPermlaneLowering(
const LLVMTypeConverter &converter, Chipset chipset)
2552 : ConvertOpToLLVMPattern<PermlaneSwapOp>(converter), chipset(chipset) {}
2556 matchAndRewrite(PermlaneSwapOp op, OpAdaptor adaptor,
2557 ConversionPatternRewriter &rewriter)
const override {
2559 return op->emitOpError(
"permlane_swap is only supported on gfx950+");
2561 Location loc = op.getLoc();
2562 Type i32 = rewriter.getI32Type();
2563 Value src = adaptor.getSrc();
2564 unsigned rowLength = op.getRowLength();
2565 bool fi = op.getFetchInactive();
2566 bool boundctrl = op.getBoundCtrl();
2568 SmallVector<Value> decomposed =
2569 LLVM::decomposeValue(rewriter, loc, src, i32);
2571 SmallVector<Value> permuted;
2572 for (Value v : decomposed) {
2574 Type i32pair = LLVM::LLVMStructType::getLiteral(
2575 rewriter.getContext(), {v.getType(), v.getType()});
2577 if (rowLength == 16)
2578 res = ROCDL::Permlane16SwapOp::create(rewriter, loc, i32pair, v, v, fi,
2580 else if (rowLength == 32)
2581 res = ROCDL::Permlane32SwapOp::create(rewriter, loc, i32pair, v, v, fi,
2584 llvm_unreachable(
"unsupported row length");
2586 Value vdst0 = LLVM::ExtractValueOp::create(rewriter, loc, res, {0});
2587 Value vdst1 = LLVM::ExtractValueOp::create(rewriter, loc, res, {1});
2589 Value isEqual = LLVM::ICmpOp::create(rewriter, loc,
2590 LLVM::ICmpPredicate::eq, vdst0, v);
2595 LLVM::SelectOp::create(rewriter, loc, isEqual, vdst1, vdst0);
2596 permuted.emplace_back(vdstNew);
2599 Value
result = LLVM::composeValue(rewriter, loc, permuted, src.
getType());
2600 rewriter.replaceOp(op,
result);
2605static Value setValueAtOffset(ConversionPatternRewriter &rewriter, Location loc,
2606 Value accumulator, Value value, int64_t shift) {
2611 value = LLVM::ShlOp::create(rewriter, loc, value, shiftAmount);
2617 constexpr bool isDisjoint =
true;
2618 return LLVM::OrOp::create(rewriter, loc, accumulator, value, isDisjoint);
2621template <
typename BaseOp>
2622struct AMDGPUMakeDmaBaseLowering :
public ConvertOpToLLVMPattern<BaseOp> {
2623 using ConvertOpToLLVMPattern<BaseOp>::ConvertOpToLLVMPattern;
2626 AMDGPUMakeDmaBaseLowering(
const LLVMTypeConverter &converter, Chipset chipset)
2627 : ConvertOpToLLVMPattern<BaseOp>(converter), chipset(chipset) {}
2631 matchAndRewrite(BaseOp op, Adaptor adaptor,
2632 ConversionPatternRewriter &rewriter)
const override {
2634 return op->emitOpError(
"make_dma_base is only supported on gfx1250");
2636 Location loc = op.getLoc();
2638 constexpr int32_t constlen = 4;
2639 Value consts[constlen];
2640 for (int64_t i = 0; i < constlen; ++i)
2643 constexpr int32_t sgprslen = constlen;
2644 Value sgprs[sgprslen];
2645 for (int64_t i = 0; i < sgprslen; ++i) {
2646 sgprs[i] = consts[0];
2649 sgprs[0] = consts[1];
2651 if constexpr (BaseOp::isGather()) {
2652 sgprs[0] = setValueAtOffset(rewriter, loc, sgprs[0], consts[1], 30);
2654 auto type = cast<TDMGatherBaseType>(op.getResult().getType());
2655 Type indexType = type.getIndexType();
2657 assert(llvm::is_contained({16u, 32u}, indexSize) &&
2658 "expected index_size to be 16 or 32");
2659 unsigned idx = (indexSize / 16) - 1;
2662 sgprs[0] = setValueAtOffset(rewriter, loc, sgprs[0], consts[1], 31);
2665 ValueRange ldsIndices = adaptor.getLdsIndices();
2666 Value lds = adaptor.getLds();
2667 auto ldsMemRefType = cast<MemRefType>(op.getLds().getType());
2670 rewriter, loc, ldsMemRefType, lds, ldsIndices);
2672 ValueRange globalIndices = adaptor.getGlobalIndices();
2673 Value global = adaptor.getGlobal();
2674 auto globalMemRefType = cast<MemRefType>(op.getGlobal().getType());
2677 rewriter, loc, globalMemRefType, global, globalIndices);
2679 Type i32 = rewriter.getI32Type();
2680 Type i64 = rewriter.getI64Type();
2682 sgprs[1] = LLVM::PtrToIntOp::create(rewriter, loc, i32, ldsPtr);
2683 Value castForGlobalAddr =
2684 LLVM::PtrToIntOp::create(rewriter, loc, i64, globalPtr);
2686 sgprs[2] = LLVM::TruncOp::create(rewriter, loc, i32, castForGlobalAddr);
2688 Value shift = LLVM::LShrOp::create(rewriter, loc, castForGlobalAddr,
2691 Value highHalf = LLVM::TruncOp::create(rewriter, loc, i32, shift);
2694 highHalf = LLVM::AndOp::create(rewriter, loc, highHalf, mask);
2696 sgprs[3] = setValueAtOffset(rewriter, loc, highHalf, consts[2], 30);
2698 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
2699 assert(v4i32 &&
"expected type conversion to succeed");
2700 Value
result = LLVM::PoisonOp::create(rewriter, loc, v4i32);
2702 for (
auto [sgpr, constant] : llvm::zip_equal(sgprs, consts))
2704 LLVM::InsertElementOp::create(rewriter, loc,
result, sgpr, constant);
2706 rewriter.replaceOp(op,
result);
2711template <
typename DescriptorOp>
2712struct AMDGPULowerDescriptor :
public ConvertOpToLLVMPattern<DescriptorOp> {
2713 using ConvertOpToLLVMPattern<DescriptorOp>::ConvertOpToLLVMPattern;
2716 AMDGPULowerDescriptor(
const LLVMTypeConverter &converter, Chipset chipset)
2717 : ConvertOpToLLVMPattern<DescriptorOp>(converter), chipset(chipset) {}
2720 Value getDGroup0(OpAdaptor adaptor)
const {
return adaptor.getBase(); }
2722 Value setWorkgroupMask(DescriptorOp op, OpAdaptor adaptor,
2723 ConversionPatternRewriter &rewriter, Location loc,
2724 Value sgpr0)
const {
2725 Value mask = op.getWorkgroupMask();
2729 Type i16 = rewriter.getI16Type();
2730 mask = LLVM::BitcastOp::create(rewriter, loc, i16, mask);
2731 Type i32 = rewriter.getI32Type();
2732 Value extendedMask = LLVM::ZExtOp::create(rewriter, loc, i32, mask);
2733 return setValueAtOffset(rewriter, loc, sgpr0, extendedMask, 0);
2736 Value setDataSize(DescriptorOp op, OpAdaptor adaptor,
2737 ConversionPatternRewriter &rewriter, Location loc,
2738 Value sgpr0, ArrayRef<Value> consts)
const {
2739 unsigned elementTypeWidthInBits = op.getElementTypeWidth();
2740 assert(llvm::is_contained({8u, 16u, 32u, 64u}, elementTypeWidthInBits) &&
2741 "expected type width to be 8, 16, 32, or 64.");
2742 int64_t idx = llvm::Log2_32(elementTypeWidthInBits / 8);
2743 Value size = consts[idx];
2744 return setValueAtOffset(rewriter, loc, sgpr0, size, 16);
2747 Value setAtomicBarrier(DescriptorOp op, OpAdaptor adaptor,
2748 ConversionPatternRewriter &rewriter, Location loc,
2749 Value sgpr0, ArrayRef<Value> consts)
const {
2750 if (!adaptor.getAtomicBarrierAddress())
2753 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 18);
2756 Value setIterateEnable(DescriptorOp op, OpAdaptor adaptor,
2757 ConversionPatternRewriter &rewriter, Location loc,
2758 Value sgpr0, ArrayRef<Value> consts)
const {
2759 if (!adaptor.getGlobalIncrement())
2764 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 19);
2767 Value setPadEnable(DescriptorOp op, OpAdaptor adaptor,
2768 ConversionPatternRewriter &rewriter, Location loc,
2769 Value sgpr0, ArrayRef<Value> consts)
const {
2770 if (!op.getPadAmount())
2773 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 20);
2776 Value setEarlyTimeout(DescriptorOp op, OpAdaptor adaptor,
2777 ConversionPatternRewriter &rewriter, Location loc,
2778 Value sgpr0, ArrayRef<Value> consts)
const {
2779 if (!op.getWorkgroupMask())
2782 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 21);
2785 Value setPadInterval(DescriptorOp op, OpAdaptor adaptor,
2786 ConversionPatternRewriter &rewriter, Location loc,
2787 Value sgpr0, ArrayRef<Value> consts)
const {
2788 if (!op.getPadAmount())
2797 IntegerType i32 = rewriter.getI32Type();
2798 Value padInterval = adaptor.getPadInterval();
2799 padInterval = LLVM::CountTrailingZerosOp::create(rewriter, loc, i32,
2800 padInterval,
false);
2801 padInterval = LLVM::SubOp::create(rewriter, loc, padInterval, consts[1]);
2803 return setValueAtOffset(rewriter, loc, sgpr0, padInterval, 22);
2806 Value setPadAmount(DescriptorOp op, OpAdaptor adaptor,
2807 ConversionPatternRewriter &rewriter, Location loc,
2808 Value sgpr0, ArrayRef<Value> consts)
const {
2809 if (!op.getPadAmount())
2818 Value padAmount = adaptor.getPadAmount();
2819 padAmount = LLVM::SubOp::create(rewriter, loc, padAmount, consts[1]);
2821 return setValueAtOffset(rewriter, loc, sgpr0, padAmount, 25);
2824 Value setAtomicBarrierAddress(DescriptorOp op, OpAdaptor adaptor,
2825 ConversionPatternRewriter &rewriter,
2826 Location loc, Value sgpr1,
2827 ArrayRef<Value> consts)
const {
2828 if (!adaptor.getAtomicBarrierAddress())
2831 Value atomicBarrierAddress = adaptor.getAtomicBarrierAddress();
2832 auto barrierAddressTy =
2833 cast<MemRefType>(op.getAtomicBarrierAddress().getType());
2834 ValueRange atomicBarrierIndices = adaptor.getAtomicBarrierIndices();
2836 rewriter, loc, barrierAddressTy, atomicBarrierAddress,
2837 atomicBarrierIndices);
2838 IntegerType i32 = rewriter.getI32Type();
2844 atomicBarrierAddress =
2845 LLVM::PtrToIntOp::create(rewriter, loc, i32, atomicBarrierAddress);
2846 atomicBarrierAddress =
2847 LLVM::LShrOp::create(rewriter, loc, atomicBarrierAddress, consts[3]);
2849 atomicBarrierAddress =
2850 LLVM::AndOp::create(rewriter, loc, atomicBarrierAddress, mask);
2851 return setValueAtOffset(rewriter, loc, sgpr1, atomicBarrierAddress, 32);
2854 std::pair<Value, Value> setTensorDimX(DescriptorOp op, OpAdaptor adaptor,
2855 ConversionPatternRewriter &rewriter,
2856 Location loc, Value sgpr1, Value sgpr2,
2857 ArrayRef<Value> consts, uint64_t dimX,
2858 uint32_t offset)
const {
2859 ArrayRef<int64_t> globalStaticSizes = adaptor.getGlobalStaticSizes();
2860 ValueRange globalDynamicSizes = adaptor.getGlobalDynamicSizes();
2861 SmallVector<OpFoldResult> mixedGlobalSizes =
2863 if (mixedGlobalSizes.size() <= dimX)
2864 return {sgpr1, sgpr2};
2866 OpFoldResult tensorDimXOpFoldResult = *(mixedGlobalSizes.rbegin() + dimX);
2873 if (
auto attr = dyn_cast<Attribute>(tensorDimXOpFoldResult)) {
2877 IntegerType i32 = rewriter.getI32Type();
2878 tensorDimX = cast<Value>(tensorDimXOpFoldResult);
2879 tensorDimX = LLVM::TruncOp::create(rewriter, loc, i32, tensorDimX);
2882 sgpr1 = setValueAtOffset(rewriter, loc, sgpr1, tensorDimX, offset);
2885 Value tensorDimXHigh = LLVM::LShrOp::create(rewriter, loc, tensorDimX, c16);
2886 sgpr2 = setValueAtOffset(rewriter, loc, sgpr2, tensorDimXHigh, offset + 16);
2887 return {sgpr1, sgpr2};
2890 std::pair<Value, Value> setTensorDim0(DescriptorOp op, OpAdaptor adaptor,
2891 ConversionPatternRewriter &rewriter,
2892 Location loc, Value sgpr1, Value sgpr2,
2893 ArrayRef<Value> consts)
const {
2894 return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, sgpr2, consts, 0,
2898 std::pair<Value, Value> setTensorDim1(DescriptorOp op, OpAdaptor adaptor,
2899 ConversionPatternRewriter &rewriter,
2900 Location loc, Value sgpr2, Value sgpr3,
2901 ArrayRef<Value> consts)
const {
2902 return setTensorDimX(op, adaptor, rewriter, loc, sgpr2, sgpr3, consts, 1,
2906 Value setTileDimX(DescriptorOp op, OpAdaptor adaptor,
2907 ConversionPatternRewriter &rewriter, Location loc,
2908 Value sgpr, ArrayRef<Value> consts,
size_t dimX,
2909 int64_t offset)
const {
2910 ArrayRef<int64_t> sharedStaticSizes = adaptor.getSharedStaticSizes();
2911 ValueRange sharedDynamicSizes = adaptor.getSharedDynamicSizes();
2912 SmallVector<OpFoldResult> mixedSharedSizes =
2914 if (mixedSharedSizes.size() <= dimX)
2917 OpFoldResult tileDimXOpFoldResult = *(mixedSharedSizes.rbegin() + dimX);
2926 if (
auto attr = dyn_cast<Attribute>(tileDimXOpFoldResult)) {
2930 IntegerType i32 = rewriter.getI32Type();
2931 tileDimX = cast<Value>(tileDimXOpFoldResult);
2932 tileDimX = LLVM::TruncOp::create(rewriter, loc, i32, tileDimX);
2935 return setValueAtOffset(rewriter, loc, sgpr, tileDimX, offset);
2938 Value setTileDim0(DescriptorOp op, OpAdaptor adaptor,
2939 ConversionPatternRewriter &rewriter, Location loc,
2940 Value sgpr3, ArrayRef<Value> consts)
const {
2941 return setTileDimX(op, adaptor, rewriter, loc, sgpr3, consts, 0, 112);
2944 Value setTileDim1(DescriptorOp op, OpAdaptor adaptor,
2945 ConversionPatternRewriter &rewriter, Location loc,
2946 Value sgpr4, ArrayRef<Value> consts)
const {
2947 return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 1, 128);
2950 Value setValidIndices(DescriptorOp op, OpAdaptor adaptor,
2951 ConversionPatternRewriter &rewriter, Location loc,
2952 Value sgpr4, ArrayRef<Value> consts)
const {
2953 auto type = cast<VectorType>(op.getIndices().getType());
2954 ArrayRef<int64_t> shape = type.getShape();
2955 assert(shape.size() == 1 &&
"expected shape to be of rank 1.");
2956 unsigned length = shape.back();
2957 assert(0 < length && length <= 16 &&
"expected length to be at most 16.");
2959 return setValueAtOffset(rewriter, loc, sgpr4, value, 128);
2962 Value setTileDim1OrValidIndices(DescriptorOp op, OpAdaptor adaptor,
2963 ConversionPatternRewriter &rewriter,
2964 Location loc, Value sgpr4,
2965 ArrayRef<Value> consts)
const {
2966 if constexpr (DescriptorOp::isGather())
2967 return setValidIndices(op, adaptor, rewriter, loc, sgpr4, consts);
2968 return setTileDim1(op, adaptor, rewriter, loc, sgpr4, consts);
2971 Value setTileDim2(DescriptorOp op, OpAdaptor adaptor,
2972 ConversionPatternRewriter &rewriter, Location loc,
2973 Value sgpr4, ArrayRef<Value> consts)
const {
2975 if constexpr (DescriptorOp::isGather())
2977 return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 2, 144);
2980 std::pair<Value, Value>
2981 setTensorDimXStride(DescriptorOp op, OpAdaptor adaptor,
2982 ConversionPatternRewriter &rewriter, Location loc,
2983 Value sgprY, Value sgprZ, ArrayRef<Value> consts,
2984 size_t dimX, int64_t offset)
const {
2985 ArrayRef<int64_t> globalStaticStrides = adaptor.getGlobalStaticStrides();
2986 ValueRange globalDynamicStrides = adaptor.getGlobalDynamicStrides();
2987 SmallVector<OpFoldResult> mixedGlobalStrides =
2988 getMixedValues(globalStaticStrides, globalDynamicStrides, rewriter);
2990 if (mixedGlobalStrides.size() <= (dimX + 1))
2991 return {sgprY, sgprZ};
2993 OpFoldResult tensorDimXStrideOpFoldResult =
2994 *(mixedGlobalStrides.rbegin() + dimX + 1);
2999 Value tensorDimXStride;
3000 if (
auto attr = dyn_cast<Attribute>(tensorDimXStrideOpFoldResult))
3004 tensorDimXStride = cast<Value>(tensorDimXStrideOpFoldResult);
3006 constexpr int64_t first48bits = (1ll << 48) - 1;
3009 LLVM::AndOp::create(rewriter, loc, mask, tensorDimXStride);
3010 IntegerType i32 = rewriter.getI32Type();
3011 Value tensorDimXStrideLow =
3012 LLVM::TruncOp::create(rewriter, loc, i32, tensorDimXStride);
3013 sgprY = setValueAtOffset(rewriter, loc, sgprY, tensorDimXStrideLow, offset);
3015 int64_t shift = (offset % 32) == 0 ? 32 : offset % 32;
3017 Value tensorDimXStrideHigh =
3018 LLVM::LShrOp::create(rewriter, loc, tensorDimXStride, shiftVal);
3019 tensorDimXStrideHigh =
3020 LLVM::TruncOp::create(rewriter, loc, i32, tensorDimXStrideHigh);
3021 sgprZ = setValueAtOffset(rewriter, loc, sgprZ, tensorDimXStrideHigh,
3023 return {sgprY, sgprZ};
3026 std::pair<Value, Value>
3027 setTensorDim0Stride(DescriptorOp op, OpAdaptor adaptor,
3028 ConversionPatternRewriter &rewriter, Location loc,
3029 Value sgpr5, Value sgpr6, ArrayRef<Value> consts)
const {
3030 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts,
3034 std::pair<Value, Value>
3035 setTensorDim1Stride(DescriptorOp op, OpAdaptor adaptor,
3036 ConversionPatternRewriter &rewriter, Location loc,
3037 Value sgpr5, Value sgpr6, ArrayRef<Value> consts)
const {
3039 if constexpr (DescriptorOp::isGather())
3040 return {sgpr5, sgpr6};
3041 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts,
3045 Value getDGroup1(DescriptorOp op, OpAdaptor adaptor,
3046 ConversionPatternRewriter &rewriter, Location loc,
3047 ArrayRef<Value> consts)
const {
3049 for (int64_t i = 0; i < 8; ++i) {
3050 sgprs[i] = consts[0];
3053 sgprs[0] = setWorkgroupMask(op, adaptor, rewriter, loc, sgprs[0]);
3054 sgprs[0] = setDataSize(op, adaptor, rewriter, loc, sgprs[0], consts);
3055 sgprs[0] = setAtomicBarrier(op, adaptor, rewriter, loc, sgprs[0], consts);
3056 sgprs[0] = setIterateEnable(op, adaptor, rewriter, loc, sgprs[0], consts);
3057 sgprs[0] = setPadEnable(op, adaptor, rewriter, loc, sgprs[0], consts);
3058 sgprs[0] = setEarlyTimeout(op, adaptor, rewriter, loc, sgprs[0], consts);
3059 sgprs[0] = setPadInterval(op, adaptor, rewriter, loc, sgprs[0], consts);
3060 sgprs[0] = setPadAmount(op, adaptor, rewriter, loc, sgprs[0], consts);
3063 setAtomicBarrierAddress(op, adaptor, rewriter, loc, sgprs[1], consts);
3064 std::tie(sgprs[1], sgprs[2]) =
3065 setTensorDim0(op, adaptor, rewriter, loc, sgprs[1], sgprs[2], consts);
3066 std::tie(sgprs[2], sgprs[3]) =
3067 setTensorDim1(op, adaptor, rewriter, loc, sgprs[2], sgprs[3], consts);
3069 sgprs[3] = setTileDim0(op, adaptor, rewriter, loc, sgprs[3], consts);
3071 setTileDim1OrValidIndices(op, adaptor, rewriter, loc, sgprs[4], consts);
3072 sgprs[4] = setTileDim2(op, adaptor, rewriter, loc, sgprs[4], consts);
3073 std::tie(sgprs[5], sgprs[6]) = setTensorDim0Stride(
3074 op, adaptor, rewriter, loc, sgprs[5], sgprs[6], consts);
3075 std::tie(sgprs[6], sgprs[7]) = setTensorDim1Stride(
3076 op, adaptor, rewriter, loc, sgprs[6], sgprs[7], consts);
3078 IntegerType i32 = rewriter.getI32Type();
3079 Type v8i32 = this->typeConverter->convertType(VectorType::get(8, i32));
3080 assert(v8i32 &&
"expected type conversion to succeed");
3081 Value dgroup1 = LLVM::PoisonOp::create(rewriter, loc, v8i32);
3083 for (
auto [sgpr, constant] : llvm::zip_equal(sgprs, consts)) {
3085 LLVM::InsertElementOp::create(rewriter, loc, dgroup1, sgpr, constant);
3091 Value setTensorDimX(DescriptorOp op, OpAdaptor adaptor,
3092 ConversionPatternRewriter &rewriter, Location loc,
3093 Value sgpr0, ArrayRef<Value> consts, int64_t dimX,
3094 int64_t offset)
const {
3095 ArrayRef<int64_t> globalStaticSizes = adaptor.getGlobalStaticSizes();
3096 ValueRange globalDynamicSizes = adaptor.getGlobalDynamicSizes();
3097 SmallVector<OpFoldResult> mixedGlobalSizes =
3099 if (mixedGlobalSizes.size() <=
static_cast<unsigned long>(dimX))
3102 OpFoldResult tensorDimXOpFoldResult = *(mixedGlobalSizes.rbegin() + dimX);
3104 if (
auto attr = dyn_cast<Attribute>(tensorDimXOpFoldResult)) {
3108 IntegerType i32 = rewriter.getI32Type();
3109 tensorDimX = cast<Value>(tensorDimXOpFoldResult);
3110 tensorDimX = LLVM::TruncOp::create(rewriter, loc, i32, tensorDimX);
3113 return setValueAtOffset(rewriter, loc, sgpr0, tensorDimX, offset);
3116 Value setTensorDim2(DescriptorOp op, OpAdaptor adaptor,
3117 ConversionPatternRewriter &rewriter, Location loc,
3118 Value sgpr0, ArrayRef<Value> consts)
const {
3119 return setTensorDimX(op, adaptor, rewriter, loc, sgpr0, consts, 2, 0);
3122 Value truncateAndSetValueAtOffset(ConversionPatternRewriter &rewriter,
3123 Location loc, Value accumulator,
3124 Value value, int64_t shift)
const {
3126 IntegerType i32 = rewriter.getI32Type();
3127 value = LLVM::TruncOp::create(rewriter, loc, i32, value);
3128 return setValueAtOffset(rewriter, loc, accumulator, value, shift);
3131 Value setLDSAddrIncrement(DescriptorOp op, OpAdaptor adaptor,
3132 ConversionPatternRewriter &rewriter, Location loc,
3133 Value sgpr1, ArrayRef<Value> consts,
3134 int64_t offset)
const {
3135 Value ldsAddrIncrement = adaptor.getLdsIncrement();
3136 return setValueAtOffset(rewriter, loc, sgpr1, ldsAddrIncrement, offset);
3139 std::pair<Value, Value>
3140 setGlobalAddrIncrement(DescriptorOp op, OpAdaptor adaptor,
3141 ConversionPatternRewriter &rewriter, Location loc,
3142 Value sgpr2, Value sgpr3, ArrayRef<Value> consts,
3143 int64_t offset)
const {
3144 Value globalAddrIncrement = adaptor.getGlobalIncrement();
3145 sgpr2 = truncateAndSetValueAtOffset(rewriter, loc, sgpr2,
3146 globalAddrIncrement, offset);
3148 globalAddrIncrement =
3149 LLVM::LShrOp::create(rewriter, loc, globalAddrIncrement, shift);
3150 constexpr int64_t first16BitsHigh = (1ll << 16) - 1;
3151 sgpr3 = truncateAndSetValueAtOffset(rewriter, loc, sgpr3,
3152 globalAddrIncrement, offset + 32);
3154 sgpr3 = LLVM::AndOp::create(rewriter, loc, sgpr3, mask);
3155 return {sgpr2, sgpr3};
3158 Value setTensorDim3OrLDSAddrIncrement(DescriptorOp op, OpAdaptor adaptor,
3159 ConversionPatternRewriter &rewriter,
3160 Location loc, Value sgpr1,
3161 ArrayRef<Value> consts)
const {
3162 Value ldsIncrement = op.getLdsIncrement();
3163 constexpr int64_t dim = 3;
3164 constexpr int64_t offset = 32;
3166 return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, consts, dim,
3168 return setLDSAddrIncrement(op, adaptor, rewriter, loc, sgpr1, consts,
3172 std::pair<Value, Value> setTensorDim2StrideOrGlobalAddrIncrement(
3173 DescriptorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter,
3174 Location loc, Value sgpr2, Value sgpr3, ArrayRef<Value> consts)
const {
3175 Value globalIncrement = op.getGlobalIncrement();
3176 constexpr int32_t dim = 2;
3177 constexpr int32_t offset = 64;
3178 if (!globalIncrement)
3179 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr2, sgpr3,
3180 consts, dim, offset);
3181 return setGlobalAddrIncrement(op, adaptor, rewriter, loc, sgpr2, sgpr3,
3185 Value setIterateCount(DescriptorOp op, OpAdaptor adaptor,
3186 ConversionPatternRewriter &rewriter, Location loc,
3187 Value sgpr3, ArrayRef<Value> consts,
3188 int32_t offset)
const {
3189 Value iterationCount = adaptor.getIterationCount();
3190 IntegerType i32 = rewriter.getI32Type();
3197 iterationCount = LLVM::TruncOp::create(rewriter, loc, i32, iterationCount);
3199 LLVM::SubOp::create(rewriter, loc, iterationCount, consts[1]);
3200 return setValueAtOffset(rewriter, loc, sgpr3, iterationCount, offset);
3203 Value setTileDim3OrIterateCount(DescriptorOp op, OpAdaptor adaptor,
3204 ConversionPatternRewriter &rewriter,
3205 Location loc, Value sgpr3,
3206 ArrayRef<Value> consts)
const {
3207 Value iterateCount = op.getIterationCount();
3208 constexpr int32_t dim = 2;
3209 constexpr int32_t offset = 112;
3211 return setTileDimX(op, adaptor, rewriter, loc, sgpr3, consts, dim,
3214 return setIterateCount(op, adaptor, rewriter, loc, sgpr3, consts, offset);
3217 Value getDGroup2(DescriptorOp op, OpAdaptor adaptor,
3218 ConversionPatternRewriter &rewriter, Location loc,
3219 ArrayRef<Value> consts)
const {
3220 if constexpr (DescriptorOp::isGather())
3221 return getDGroup2Gather(op, adaptor, rewriter, loc, consts);
3222 return getDGroup2NonGather(op, adaptor, rewriter, loc, consts);
3225 Value getDGroup2NonGather(DescriptorOp op, OpAdaptor adaptor,
3226 ConversionPatternRewriter &rewriter, Location loc,
3227 ArrayRef<Value> consts)
const {
3228 IntegerType i32 = rewriter.getI32Type();
3229 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
3230 assert(v4i32 &&
"expected type conversion to succeed.");
3232 bool onlyNeedsTwoDescriptors = !op.getLdsIncrement() && op.getRank() <= 2;
3233 if (onlyNeedsTwoDescriptors)
3234 return LLVM::ZeroOp::create(rewriter, loc, v4i32);
3236 constexpr int64_t sgprlen = 4;
3237 Value sgprs[sgprlen];
3238 for (
int i = 0; i < sgprlen; ++i)
3239 sgprs[i] = consts[0];
3241 sgprs[0] = setTensorDim2(op, adaptor, rewriter, loc, sgprs[0], consts);
3242 sgprs[1] = setTensorDim3OrLDSAddrIncrement(op, adaptor, rewriter, loc,
3244 std::tie(sgprs[2], sgprs[3]) = setTensorDim2StrideOrGlobalAddrIncrement(
3245 op, adaptor, rewriter, loc, sgprs[2], sgprs[3], consts);
3247 setTileDim3OrIterateCount(op, adaptor, rewriter, loc, sgprs[3], consts);
3249 Value dgroup2 = LLVM::PoisonOp::create(rewriter, loc, v4i32);
3250 for (
auto [sgpr, constant] : llvm::zip(sgprs, consts))
3252 LLVM::InsertElementOp::create(rewriter, loc, dgroup2, sgpr, constant);
3257 Value getGatherIndices(DescriptorOp op, OpAdaptor adaptor,
3258 ConversionPatternRewriter &rewriter, Location loc,
3259 ArrayRef<Value> consts,
bool firstHalf)
const {
3260 IntegerType i32 = rewriter.getI32Type();
3261 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
3262 assert(v4i32 &&
"expected type conversion to succeed.");
3264 Value
indices = adaptor.getIndices();
3265 auto vectorType = cast<VectorType>(
indices.getType());
3266 unsigned length = vectorType.getShape().back();
3267 Type elementType = vectorType.getElementType();
3268 unsigned maxLength = elementType == i32 ? 4 : 8;
3269 int32_t offset = firstHalf ? 0 : maxLength;
3270 unsigned discountedLength =
3271 std::max(
static_cast<int32_t
>(length - offset), 0);
3273 unsigned targetSize = std::min(maxLength, discountedLength);
3275 SmallVector<Value> indicesVector;
3276 for (
unsigned i = offset; i < targetSize + offset; ++i) {
3278 if (i < consts.size())
3282 Value elem = LLVM::ExtractElementOp::create(rewriter, loc,
indices, idx);
3283 indicesVector.push_back(elem);
3286 SmallVector<Value> indicesI32Vector;
3287 if (elementType == i32) {
3288 indicesI32Vector = indicesVector;
3290 for (
unsigned i = 0; i < targetSize; ++i) {
3291 Value index = indicesVector[i];
3292 indicesI32Vector.push_back(
3293 LLVM::ZExtOp::create(rewriter, loc, i32, index));
3295 if ((targetSize % 2) != 0)
3297 indicesI32Vector.push_back(consts[0]);
3300 SmallVector<Value> indicesToInsert;
3301 if (elementType == i32) {
3302 indicesToInsert = indicesI32Vector;
3304 unsigned size = indicesI32Vector.size() / 2;
3305 for (
unsigned i = 0; i < size; ++i) {
3306 Value first = indicesI32Vector[2 * i];
3307 Value second = indicesI32Vector[2 * i + 1];
3308 Value joined = setValueAtOffset(rewriter, loc, first, second, 16);
3309 indicesToInsert.push_back(joined);
3313 Value dgroup = LLVM::PoisonOp::create(rewriter, loc, v4i32);
3314 for (
auto [sgpr, constant] : llvm::zip_first(indicesToInsert, consts))
3316 LLVM::InsertElementOp::create(rewriter, loc, dgroup, sgpr, constant);
3321 Value getDGroup2Gather(DescriptorOp op, OpAdaptor adaptor,
3322 ConversionPatternRewriter &rewriter, Location loc,
3323 ArrayRef<Value> consts)
const {
3324 return getGatherIndices(op, adaptor, rewriter, loc, consts,
true);
3327 std::pair<Value, Value>
3328 setTensorDim3Stride(DescriptorOp op, OpAdaptor adaptor,
3329 ConversionPatternRewriter &rewriter, Location loc,
3330 Value sgpr0, Value sgpr1, ArrayRef<Value> consts)
const {
3331 constexpr int32_t dim = 3;
3332 constexpr int32_t offset = 0;
3333 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr0, sgpr1, consts,
3337 std::pair<Value, Value> setTensorDim4(DescriptorOp op, OpAdaptor adaptor,
3338 ConversionPatternRewriter &rewriter,
3339 Location loc, Value sgpr1, Value sgpr2,
3340 ArrayRef<Value> consts)
const {
3341 constexpr int32_t dim = 4;
3342 constexpr int32_t offset = 48;
3343 return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, sgpr2, consts, dim,
3347 Value setTileDim4(DescriptorOp op, OpAdaptor adaptor,
3348 ConversionPatternRewriter &rewriter, Location loc,
3349 Value sgpr2, ArrayRef<Value> consts)
const {
3350 constexpr int32_t dim = 4;
3351 constexpr int32_t offset = 80;
3352 return setTileDimX(op, adaptor, rewriter, loc, sgpr2, consts, dim, offset);
3355 Value getDGroup3(DescriptorOp op, OpAdaptor adaptor,
3356 ConversionPatternRewriter &rewriter, Location loc,
3357 ArrayRef<Value> consts)
const {
3358 if constexpr (DescriptorOp::isGather())
3359 return getDGroup3Gather(op, adaptor, rewriter, loc, consts);
3360 return getDGroup3NonGather(op, adaptor, rewriter, loc, consts);
3363 Value getDGroup3NonGather(DescriptorOp op, OpAdaptor adaptor,
3364 ConversionPatternRewriter &rewriter, Location loc,
3365 ArrayRef<Value> consts)
const {
3366 IntegerType i32 = rewriter.getI32Type();
3367 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
3368 assert(v4i32 &&
"expected type conversion to succeed.");
3369 bool onlyNeedsTwoDescriptors = !op.getLdsIncrement() && op.getRank() <= 2;
3370 if (onlyNeedsTwoDescriptors)
3371 return LLVM::ZeroOp::create(rewriter, loc, v4i32);
3373 constexpr int32_t sgprlen = 4;
3374 Value sgprs[sgprlen];
3375 for (
int i = 0; i < sgprlen; ++i)
3376 sgprs[i] = consts[0];
3378 std::tie(sgprs[0], sgprs[1]) = setTensorDim3Stride(
3379 op, adaptor, rewriter, loc, sgprs[0], sgprs[1], consts);
3380 std::tie(sgprs[1], sgprs[2]) =
3381 setTensorDim4(op, adaptor, rewriter, loc, sgprs[1], sgprs[2], consts);
3382 sgprs[2] = setTileDim4(op, adaptor, rewriter, loc, sgprs[2], consts);
3384 Value dgroup3 = LLVM::PoisonOp::create(rewriter, loc, v4i32);
3385 for (
auto [sgpr, constant] : llvm::zip(sgprs, consts))
3387 LLVM::InsertElementOp::create(rewriter, loc, dgroup3, sgpr, constant);
3392 Value getDGroup3Gather(DescriptorOp op, OpAdaptor adaptor,
3393 ConversionPatternRewriter &rewriter, Location loc,
3394 ArrayRef<Value> consts)
const {
3395 return getGatherIndices(op, adaptor, rewriter, loc, consts,
false);
3399 matchAndRewrite(DescriptorOp op, OpAdaptor adaptor,
3400 ConversionPatternRewriter &rewriter)
const override {
3402 return op->emitOpError(
3403 "make_dma_descriptor is only supported on gfx1250");
3405 Location loc = op.getLoc();
3407 SmallVector<Value> consts;
3408 for (int64_t i = 0; i < 8; ++i)
3411 Value dgroup0 = this->getDGroup0(adaptor);
3412 Value dgroup1 = this->getDGroup1(op, adaptor, rewriter, loc, consts);
3413 Value dgroup2 = this->getDGroup2(op, adaptor, rewriter, loc, consts);
3414 Value dgroup3 = this->getDGroup3(op, adaptor, rewriter, loc, consts);
3415 SmallVector<Value> results = {dgroup0, dgroup1, dgroup2, dgroup3};
3416 rewriter.replaceOpWithMultiple(op, {results});
3421template <
typename SourceOp,
typename TargetOp>
3422struct AMDGPUTensorLoadStoreOpLowering
3423 :
public ConvertOpToLLVMPattern<SourceOp> {
3424 using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
3426 AMDGPUTensorLoadStoreOpLowering(
const LLVMTypeConverter &converter,
3428 : ConvertOpToLLVMPattern<SourceOp>(converter), chipset(chipset) {}
3432 matchAndRewrite(SourceOp op, Adaptor adaptor,
3433 ConversionPatternRewriter &rewriter)
const override {
3435 return op->emitOpError(
"is only supported on gfx1250");
3438 rewriter.replaceOpWithNewOp<TargetOp>(op, desc[0], desc[1], desc[2],
3447struct ConvertAMDGPUToROCDLPass
3448 :
public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
3451 void runOnOperation()
override {
3454 if (
failed(maybeChipset)) {
3455 emitError(UnknownLoc::get(ctx),
"Invalid chipset name: " + chipset);
3456 return signalPassFailure();
3460 LLVMTypeConverter converter(ctx);
3463 amdgpu::populateCommonGPUTypeAndAttributeConversions(converter);
3465 target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
3466 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
3467 target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
3468 if (
failed(applyPartialConversion(getOperation(),
target,
3470 signalPassFailure();
3478 typeConverter, [](gpu::AddressSpace space) {
3480 case gpu::AddressSpace::Global:
3481 return ROCDL::ROCDLDialect::kGlobalMemoryAddressSpace;
3482 case gpu::AddressSpace::Workgroup:
3483 return ROCDL::ROCDLDialect::kSharedMemoryAddressSpace;
3484 case gpu::AddressSpace::Private:
3485 return ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace;
3487 llvm_unreachable(
"unknown address space enum value");
3493 typeConverter.addTypeAttributeConversion(
3495 -> TypeConverter::AttributeConversionResult {
3497 Type i64 = IntegerType::get(ctx, 64);
3498 switch (as.getValue()) {
3499 case amdgpu::AddressSpace::FatRawBuffer:
3500 return IntegerAttr::get(i64, 7);
3501 case amdgpu::AddressSpace::BufferRsrc:
3502 return IntegerAttr::get(i64, 8);
3503 case amdgpu::AddressSpace::FatStructuredBuffer:
3504 return IntegerAttr::get(i64, 9);
3506 return TypeConverter::AttributeConversionResult::abort();
3508 typeConverter.addConversion([&](TDMBaseType type) ->
Type {
3510 return typeConverter.convertType(VectorType::get(4, i32));
3512 typeConverter.addConversion([&](TDMGatherBaseType type) ->
Type {
3514 return typeConverter.convertType(VectorType::get(4, i32));
3516 typeConverter.addConversion(
3517 [&](TDMDescriptorType type,
3520 Type v4i32 = typeConverter.convertType(VectorType::get(4, i32));
3521 Type v8i32 = typeConverter.convertType(VectorType::get(8, i32));
3522 llvm::append_values(
result, v4i32, v8i32, v4i32, v4i32);
3532 if (inputs.size() != 1)
3535 if (!isa<TDMDescriptorType>(inputs[0].
getType()))
3538 auto cast = UnrealizedConversionCastOp::create(builder, loc, types, inputs);
3539 return cast.getResults();
3542 typeConverter.addTargetMaterialization(addUnrealizedCast);
3550 .add<FatRawBufferCastLowering,
3551 RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
3552 RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
3553 RawBufferOpLowering<RawBufferAtomicFaddOp,
3554 ROCDL::RawPtrBufferAtomicFaddOp>,
3555 RawBufferOpLowering<RawBufferAtomicFmaxOp,
3556 ROCDL::RawPtrBufferAtomicFmaxOp>,
3557 RawBufferOpLowering<RawBufferAtomicSmaxOp,
3558 ROCDL::RawPtrBufferAtomicSmaxOp>,
3559 RawBufferOpLowering<RawBufferAtomicUminOp,
3560 ROCDL::RawPtrBufferAtomicUminOp>,
3561 RawBufferOpLowering<RawBufferAtomicCmpswapOp,
3562 ROCDL::RawPtrBufferAtomicCmpSwap>,
3563 AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
3564 SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
3565 SparseMFMAOpLowering, WMMAOpLowering, ScaledWMMAOpLowering,
3566 ExtPackedFp8OpLowering, ScaledExtPackedMatrixOpLowering,
3567 ScaledExtPackedOpLowering, PackedScaledTruncOpLowering,
3568 PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
3569 GatherToLDSOpLowering, TransposeLoadOpLowering,
3570 AMDGPUPermlaneLowering, AMDGPUMakeDmaBaseLowering<MakeDmaBaseOp>,
3571 AMDGPUMakeDmaBaseLowering<MakeGatherDmaBaseOp>,
3572 AMDGPULowerDescriptor<MakeDmaDescriptorOp>,
3573 AMDGPULowerDescriptor<MakeGatherDmaDescriptorOp>,
3574 AMDGPUTensorLoadStoreOpLowering<TensorLoadToLDSOp,
3575 ROCDL::TensorLoadToLDSOp>,
3576 AMDGPUTensorLoadStoreOpLowering<TensorStoreFromLDSOp,
3577 ROCDL::TensorStoreFromLDSOp>>(
3578 converter, chipset);
3579 patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type)
Return true if type is the E4M3FN variant of an 8-bit float that is supported by the _fp8 instruction...
constexpr Chipset kGfx942
static std::optional< StringRef > wmmaOpToIntrinsicRDNA(Type elemSourceType, Type elemBSourceType, Type elemDestType, uint32_t k, bool isRDNA3)
Returns the rocdl intrinsic corresponding to a WMMA operation wmma for RDNA3/4 architectures.
static std::optional< std::tuple< StringRef, uint32_t, uint32_t > > mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m, uint32_t n, uint32_t k, uint32_t b, Chipset chipset)
If there is a scaled MFMA instruction for the input element types aType and bType,...
static std::optional< StringRef > mfmaOpToIntrinsic(MFMAOp mfma, Chipset chipset)
Return the rocdl intrinsic corresponding to a MFMA operation mfma if one exists.
constexpr Chipset kGfx908
static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput, Value mlirInput, SmallVectorImpl< Value > &operands, SmallVectorImpl< NamedAttribute > &attrs, StringRef attrName)
Push an input operand.
constexpr Chipset kGfx1250
static Value castScaleOperand(ConversionPatternRewriter &rewriter, Location loc, Value input)
Converts the scaled MFMA/WMMA operands, scalesA and scalesB, from MLIR AMDGPU dialect convention to R...
constexpr Chipset kGfx90a
static std::optional< StringRef > getScaledWmmaIntrinsicName(int64_t m, int64_t n, int64_t k, bool isScale16)
Determines the ROCDL intrinsic name for scaled WMMA based on dimensions and scale block size (16 or 3...
static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, Value output, int32_t subwordOffset, bool clamp, SmallVectorImpl< Value > &operands, SmallVectorImpl< NamedAttribute > &attrs)
Push the output operand.
static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type)
Return true if type is the E5M2 variant of an 8-bit float that is supported by the _bf8 instructions ...
static std::optional< StringRef > wmmaOpToIntrinsic(WMMAOp wmma, Chipset chipset)
Returns the rocdl intrinsic corresponding to a WMMA operation wmma if one exists.
static std::optional< StringRef > smfmacOpToIntrinsic(SparseMFMAOp op, Chipset chipset)
Returns the rocdl intrinsic corresponding to a SparseMFMA (smfmac) operation if one exists.
static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc, Value basePointer, Value numRecords, bool boundsCheck, amdgpu::Chipset chipset, Value cacheSwizzleStride=nullptr, unsigned addressSpace=8)
static Value createI64Constant(ConversionPatternRewriter &rewriter, Location loc, int64_t value)
static Value convertSparseMFMAVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input, bool allowBf16=true)
Converts sparse MFMA (smfmac) operands to the expected ROCDL types.
static std::optional< StringRef > wmmaOpToIntrinsicGfx1250(Type elemSourceType, Type elemBSourceType, Type elemDestType, uint32_t k)
Return the rocdl intrinsic corresponding to a WMMA operation wmma for the gfx1250 architecture.
static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc, MemRefType memrefType, MemRefDescriptor &memrefDescriptor, ArrayRef< int64_t > strides, int64_t elementByteWidth, amdgpu::Chipset chipset, bool boundsCheck)
Compute the contents of the num_records field for a given memref descriptor - that is,...
static Value packSmallFloatVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input, bool allowBf16=true)
Pack small float vector operands (fp4/fp6/fp8/bf16) into the format expected by scaled matrix multipl...
static std::optional< uint32_t > getWmmaScaleFormat(Type elemType)
Maps f8 scale element types to WMMA scale format codes.
static Value getLinearIndexI32(ConversionPatternRewriter &rewriter, Location loc, MemRefDescriptor &memRefDescriptor, ValueRange indices, ArrayRef< int64_t > strides)
Returns the linear index used to access an element in the memref.
static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i32.
static Value createI32Constant(ConversionPatternRewriter &rewriter, Location loc, int32_t value)
static std::optional< uint32_t > smallFloatTypeToFormatCode(Type mlirElemType)
static Value convertUnsignedToI64(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i64.
constexpr Chipset kGfx950
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static constexpr unsigned kSizePosInMemRefDescriptor
static constexpr unsigned kStridePosInMemRefDescriptor
static constexpr unsigned kOffsetPosInMemRefDescriptor
static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor
static constexpr unsigned kAlignedPtrPosInMemRefDescriptor
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
This class provides a shared interface for ranked and unranked memref types.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
typename SourceOp::template GenericAdaptor< ArrayRef< ValueRange > > OneToNOpAdaptor
typename SourceOp::Adaptor OpAdaptor
Value getStridedElementPtr(ConversionPatternRewriter &rewriter, Location loc, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none) const
Convenience wrapper for the corresponding helper utility.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
NamedAttribute represents a combination of a name and an Attribute value.
This class helps build Operations.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
bool hasOcpFp8(const Chipset &chipset)
void populateCommonGPUTypeAndAttributeConversions(TypeConverter &typeConverter)
Remap common GPU memory spaces (Workgroup, Private, etc) to LLVM address spaces.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, amdgpu::Chipset chipset)
Note: This function will also add conversions for the AMDGPU-specific address spaces and types,...
llvm::TypeSwitch< T, ResultT > TypeSwitch
void populateAMDGPUTypeAndAttributeConversions(TypeConverter &typeConverter)
Remap AMDGPU memory spaces to LLVM address spaces by mapping amdgpu::AddressSpace::fat_raw_buffer to ...
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.